"""Elementary computations for genno."""
# Notes:
# - To avoid ambiguity, computations should not have default arguments. Define default
# values for the corresponding methods on the Computer class.
import logging
import operator
import re
from itertools import chain
from os import PathLike
from pathlib import Path
from typing import (
Any,
Collection,
Hashable,
Iterable,
List,
Mapping,
Optional,
Union,
cast,
)
import pandas as pd
import pint
from xarray.core.types import InterpOptions
from xarray.core.utils import either_dict_or_kwargs
from genno.core.attrseries import AttrSeries, _multiindex_of
from genno.core.quantity import (
Quantity,
assert_quantity,
maybe_densify,
possible_scalar,
unwrap_scalar,
)
from genno.core.sparsedataarray import SparseDataArray
from genno.util import UnitLike, collect_units, filter_concat_args
__all__ = [
"add",
"aggregate",
"apply_units",
"assign_units",
"broadcast_map",
"combine",
"concat",
"convert_units",
"disaggregate_shares",
"div",
"drop_vars",
"group_sum",
"index_to",
"interpolate",
"load_file",
"mul",
"pow",
"product",
"ratio",
"relabel",
"rename_dims",
"round",
"select",
"sum",
"write_report",
]
import xarray as xr # noqa: E402
log = logging.getLogger(__name__)
# Carry unit attributes automatically
xr.set_options(keep_attrs=True)
[docs]def add(*quantities: Quantity, fill_value: float = 0.0) -> Quantity:
"""Sum across multiple `quantities`.
Raises
------
ValueError
if any of the `quantities` have incompatible units.
Returns
-------
.Quantity
Units are the same as the first of `quantities`.
"""
# Ensure arguments are all quantities
assert_quantity(*quantities)
if isinstance(quantities[0], AttrSeries):
# map() returns an iterable
q_iter = iter(quantities)
else:
# Use xarray's built-in broadcasting, return to Quantity class
q_iter = map(Quantity, xr.broadcast(*cast(xr.DataArray, quantities)))
# Initialize result values with first entry
result = next(q_iter)
ref_unit = collect_units(result)[0]
# Iterate over remaining entries
for q in q_iter:
u = collect_units(q)[0]
if not u.is_compatible_with(ref_unit):
raise ValueError(f"Units '{ref_unit:~}' and '{u:~}' are incompatible")
factor = u.from_(1.0, strict=False).to(ref_unit).magnitude
if isinstance(q, AttrSeries):
result = (
cast(AttrSeries, result).add(factor * q, fill_value=fill_value).dropna()
)
else:
result = result + factor * q
return result
[docs]def aggregate(
quantity: Quantity, groups: Mapping[str, Mapping], keep: bool
) -> Quantity:
"""Aggregate *quantity* by *groups*.
Parameters
----------
groups: dict of dict
Top-level keys are the names of dimensions in `quantity`. Second-level keys are
group names; second-level values are lists of labels along the dimension to sum
into a group.
keep : bool
If True, the members that are aggregated into a group are returned with the
group sums. If False, they are discarded.
Returns
-------
:class:`Quantity <genno.utils.Quantity>`
Same dimensionality as `quantity`.
"""
result = quantity
for dim, dim_groups in groups.items():
# Optionally keep the original values
values = [result] if keep else []
# Aggregate each group
for group, members in dim_groups.items():
if keep and group in values[0].coords[dim]:
log.warning(
f"{dim}={group!r} is already present in quantity {quantity.name!r} "
"with keep=True"
)
agg = result.sel({dim: members}).sum(dim=dim).expand_dims({dim: [group]})
if isinstance(agg, AttrSeries):
# .transpose() is necessary for AttrSeries
agg = agg.transpose(*quantity.dims)
else:
# Restore fill_value=NaN for compatibility
agg = agg._sda.convert()
values.append(agg)
# Reassemble to a single dataarray
result = concat(
*values, **({} if isinstance(quantity, AttrSeries) else {"dim": dim})
)
# Preserve attrs
result.attrs.update(quantity.attrs)
result.name = quantity.name
return result
def _unit_args(qty, units):
result = [pint.get_application_registry(), qty.attrs.get("_unit", None)]
return *result, getattr(result[1], "dimensionality", {}), result[0].Unit(units)
[docs]def apply_units(qty: Quantity, units: UnitLike) -> Quantity:
"""Apply `units` to `qty`.
If `qty` has existing units…
- …with compatible dimensionality to `units`, the magnitudes are adjusted, i.e.
behaves like :func:`convert_units`.
- …with incompatible dimensionality to `units`, the units attribute is overwritten
and magnitudes are not changed, i.e. like :func:`assign_units`, with a log message
on level ``WARNING``.
To avoid ambiguities between the two cases, use :func:`convert_units` or
:func:`assign_units` instead.
Parameters
----------
units : str or pint.Unit
Units to apply to `qty`.
"""
registry, existing, existing_dims, new_units = _unit_args(qty, units)
if len(existing_dims):
# Some existing dimensions: log a message either way
if existing_dims == new_units.dimensionality:
log.debug(f"Convert '{existing}' to '{new_units}'")
# NB use a factor because pint.Quantity cannot wrap AttrSeries
result = qty * registry.Quantity(1.0, existing).to(new_units).magnitude
else:
log.warning(f"Replace '{existing}' with incompatible '{new_units}'")
result = qty.copy()
else:
# No units, or dimensionless
result = qty.copy()
result.units = new_units
return result
[docs]def assign_units(qty: Quantity, units: UnitLike) -> Quantity:
"""Set the `units` of `qty` without changing magnitudes.
Logs on level ``INFO`` if `qty` has existing units.
Parameters
----------
units : str or pint.Unit
Units to assign to `qty`.
"""
registry, existing, existing_dims, new_units = _unit_args(qty, units)
if len(existing_dims):
msg = f"Replace '{existing}' with '{new_units}'"
# Some existing dimensions: log a message either way
if existing_dims == new_units.dimensionality:
# NB use a factor because pint.Quantity cannot wrap AttrSeries
if registry.Quantity(1.0, existing).to(new_units).magnitude != 1.0:
log.info(f"{msg} without altering magnitudes")
else:
log.info(f"{msg} with different dimensionality")
result = qty.copy()
result.units = new_units
return result
[docs]def broadcast_map(
quantity: Quantity, map: Quantity, rename: Mapping = {}, strict: bool = False
) -> Quantity:
"""Broadcast `quantity` using a `map`.
The `map` must be a 2-dimensional Quantity with dimensions (``d1``, ``d2``), such as
returned by :func:`map_as_qty`. `quantity` must also have a dimension ``d1``.
Typically ``len(d2) > len(d1)``.
`quantity` is 'broadcast' by multiplying it with `map`, and then summing on the
common dimension ``d1``. The result has the dimensions of `quantity`, but with
``d2`` in place of ``d1``.
Parameters
----------
rename : dict (str -> str), optional
Dimensions to rename on the result.
strict : bool, optional
Require that each element of ``d2`` is mapped from exactly 1 element of ``d1``.
"""
if strict and int(map.sum().item()) != len(map.coords[map.dims[1]]):
raise ValueError("invalid map")
return product(quantity, map).sum([map.dims[0]]).rename(rename)
[docs]def combine(
*quantities: Quantity,
select: Optional[List[Mapping]] = None,
weights: Optional[List[float]] = None,
) -> Quantity: # noqa: F811
"""Sum distinct *quantities* by *weights*.
Parameters
----------
*quantities : Quantity
The quantities to be added.
select : list of dict
Elements to be selected from each quantity. Must have the same number of
elements as `quantities`.
weights : list of float
Weight applied to each quantity. Must have the same number of elements as
`quantities`.
Raises
------
ValueError
If the *quantities* have mismatched units.
"""
# Handle arguments
if select is None:
select = [{}] * len(quantities)
weights = weights or len(quantities) * [1.0]
# Check units
units = collect_units(*quantities)
for u in units:
# TODO relax this condition: modify the weights with conversion factors if the
# units are compatible, but not the same
if u != units[0]:
raise ValueError(f"Cannot combine() units {units[0]} and {u}")
units = units[0]
args = []
for quantity, indexers, weight in zip(quantities, select, weights):
# Select data
temp = globals()["select"](quantity, indexers)
# Dimensions along which multiple values are selected
multi = [dim for dim, values in indexers.items() if isinstance(values, list)]
if len(multi):
# Sum along these dimensions
temp = temp.sum(dim=multi)
args.append(weight * temp)
result = add(*args)
result.attrs["_unit"] = units
return result
[docs]def concat(*objs: Quantity, **kwargs) -> Quantity:
"""Concatenate Quantity `objs`.
Any strings included amongst `objs` are discarded, with a logged warning; these
usually indicate that a quantity is referenced which is not in the Computer.
"""
objs = tuple(filter_concat_args(objs))
if isinstance(objs[0], AttrSeries):
try:
# Retrieve a "dim" keyword argument
dim = kwargs.pop("dim")
except KeyError:
pass
else:
if isinstance(dim, pd.Index):
# Convert a pd.Index argument to names and keys
kwargs["names"] = [dim.name]
kwargs["keys"] = dim.values
else:
# Something else; warn and discard
log.warning(f"Ignore concat(…, dim={repr(dim)})")
# Ensure objects have aligned dimensions
_objs = [objs[0]]
_objs.extend(
map(lambda o: cast(AttrSeries, o).align_levels(_objs[0]), objs[1:])
)
return pd.concat(_objs, **kwargs)
else:
# Correct fill-values
# NB mypy here cannot tell that the returned DataArray has an accessor ._sda
return xr.concat(
cast(xr.DataArray, objs),
**kwargs,
)._sda.convert() # type: ignore[attr-defined]
[docs]def convert_units(qty: Quantity, units: UnitLike) -> Quantity:
"""Convert magnitude of `qty` from its current units to `units`.
Parameters
----------
units : str or pint.Unit
Units to assign to `qty`.
Raises
------
ValueError
if `units` does not match the dimensionality of the current units of `qty`.
"""
registry, existing, existing_dims, new_units = _unit_args(qty, units)
try:
# NB use a factor because pint.Quantity cannot wrap AttrSeries
factor = registry.Quantity(1.0, existing).to(new_units).magnitude
except pint.DimensionalityError:
raise ValueError(
f"Existing dimensionality {existing_dims!r} cannot be converted to {units} "
f"with dimensionality {new_units.dimensionality!r}"
) from None
result = qty * factor
result.units = new_units
return result
[docs]def disaggregate_shares(quantity: Quantity, shares: Quantity) -> Quantity:
"""Disaggregate *quantity* by *shares*."""
result = quantity * shares
result.attrs["_unit"] = collect_units(quantity)[0]
return result
[docs]def div(numerator: Union[Quantity, float], denominator: Quantity) -> Quantity:
"""Compute the ratio `numerator` / `denominator`.
Parameters
----------
numerator : .Quantity
denominator : .Quantity
"""
numerator = possible_scalar(numerator)
denominator = possible_scalar(denominator)
# Handle units
u_num, u_denom = collect_units(numerator, denominator)
if isinstance(numerator, AttrSeries):
result = unwrap_scalar(numerator) / cast(AttrSeries, denominator).align_levels(
numerator
)
else:
result = numerator / denominator
# This shouldn't be necessary; would instead prefer:
# result.attrs["_unit"] = u_num / u_denom
# … but is necessary to avoid an issue when the operands are different Unit classes
ureg = pint.get_application_registry()
result.attrs["_unit"] = ureg.Unit(u_num) / ureg.Unit(u_denom)
if isinstance(result, AttrSeries):
result.dropna(inplace=True)
return result
#: Alias of :func:`div`, for backwards compatibility.
#:
#: .. note:: This may be deprecated and possibly removed in a future version.
ratio = div
[docs]def drop_vars(
qty: Quantity,
names: Union[Hashable, Iterable[Hashable]],
*,
errors="raise",
) -> Quantity:
"""Return a Quantity with dropped variables (coordinates).
Like :meth:`xarray.DataArray.drop_vars`.
"""
return qty.drop_vars(names)
[docs]def group_sum(qty: Quantity, group: str, sum: str) -> Quantity:
"""Group by dimension *group*, then sum across dimension *sum*.
The result drops the latter dimension.
"""
return concat(
*[values.sum(dim=[sum]) for _, values in qty.groupby(group)],
dim=group,
)
[docs]def index_to(
qty: Quantity,
dim_or_selector: Union[str, Mapping],
label: Optional[Hashable] = None,
) -> Quantity:
"""Compute an index of `qty` against certain of its values.
If the label is not provided, :func:`index_to` uses the label in the first position
along the identified dimension.
Parameters
----------
qty : :class:`~genno.Quantity`
dim_or_selector : str or mapping
If a string, the ID of the dimension to index along.
If a mapping, it must have only one element, mapping a dimension ID to a label.
label : Hashable
Label to select along the dimension, required if `dim_or_selector` is a string.
Raises
------
TypeError
if `dim_or_selector` is a mapping with length != 1.
"""
if isinstance(dim_or_selector, Mapping):
if len(dim_or_selector) != 1:
raise TypeError(
f"Got {dim_or_selector}; expected a mapping from 1 key to 1 value"
)
dim, label = dict(dim_or_selector).popitem()
else:
# Unwrap dask.core.literals
dim = getattr(dim_or_selector, "data", dim_or_selector)
label = getattr(label, "data", label)
if label is None:
# Choose a label on which to normalize
label = qty.coords[dim][0].item()
log.info(f"Normalize quantity {qty.name} on {dim}={label}")
return div(qty, qty.sel({dim: label}))
[docs]@maybe_densify
def interpolate(
qty: Quantity,
coords: Optional[Mapping[Hashable, Any]] = None,
method: InterpOptions = "linear",
assume_sorted: bool = True,
kwargs: Optional[Mapping[str, Any]] = None,
**coords_kwargs: Any,
) -> Quantity:
"""Interpolate `qty`.
For the meaning of arguments, see :meth:`xarray.DataArray.interp`. When
:data:`.CLASS` is :class:`.AttrSeries`, only 1-dimensional interpolation (one key
in `coords`) is tested/supported.
"""
if assume_sorted is not True:
log.warning(f"interpolate(…, assume_sorted={assume_sorted}) ignored")
return qty.interp(coords, method, assume_sorted, kwargs, **coords_kwargs)
[docs]def load_file(
path: Path,
dims: Union[Collection[Hashable], Mapping[Hashable, Hashable]] = {},
units: UnitLike = None,
name: Optional[str] = None,
) -> Any:
"""Read the file at *path* and return its contents as a :class:`.Quantity`.
Some file formats are automatically converted into objects for direct use in genno
computations:
:file:`.csv`:
Converted to :class:`.Quantity`. CSV files must have a 'value' column; all others
are treated as indices, except as given by `dims`. Lines beginning with '#' are
ignored.
Parameters
----------
path : pathlib.Path
Path to the file to read.
dims : collections.abc.Collection or collections.abc.Mapping, optional
If a collection of names, other columns besides these and 'value' are discarded.
If a mapping, the keys are the column labels in `path`, and the values are the
target dimension names.
units : str or pint.Unit
Units to apply to the loaded Quantity.
name : str
Name for the loaded Quantity.
"""
# TODO optionally cache: if the same Computer is used repeatedly, then the file will
# be read each time; instead cache the contents in memory.
# TODO strip leading/trailing whitespace from column names
if path.suffix == ".csv":
return _load_file_csv(path, dims, units, name)
elif path.suffix in (".xls", ".xlsx"):
# TODO define expected Excel data input format
raise NotImplementedError # pragma: no cover
elif path.suffix == ".yaml":
# TODO define expected YAML data input format
raise NotImplementedError # pragma: no cover
else:
# Default
return open(path).read()
UNITS_RE = re.compile(r"# Units?: (.*)\s+")
def _load_file_csv(
path: Path,
dims: Union[Collection[Hashable], Mapping[Hashable, Hashable]] = {},
units: UnitLike = None,
name: Optional[str] = None,
) -> Quantity:
# Peek at the header, if any, and match a units expression
with open(path, "r", encoding="utf-8") as f:
for line, match in map(lambda li: (li, UNITS_RE.fullmatch(li)), f):
if match:
if units:
log.warning(f"Replace {match.group(1)!r} from file with {units!r}")
else:
units = match.group(1)
break
elif not line.startswith("#"):
break # Give up at first non-commented line
# Read the data
data = pd.read_csv(path, comment="#", skipinitialspace=True)
# Index columns
index_columns = data.columns.tolist()
index_columns.remove("value")
try:
# Retrieve the unit column from the file
units_col = data.pop("unit").unique()
index_columns.remove("unit")
except KeyError:
pass # No such column; use None or argument value
else:
# Use a unique value for units of the quantity
if len(units_col) > 1:
raise ValueError(
f"Cannot load {path} with non-unique units {repr(units_col)}"
)
elif units and units not in units_col:
raise ValueError(
f"Explicit units {units} do not match {units_col[0]} in {path}"
)
units = units_col[0]
if dims:
# Convert a list, set, etc. to a dict
dims = dims if isinstance(dims, Mapping) else {d: d for d in dims}
# - Drop columns not mentioned in *dims*
# - Rename columns according to *dims*
data = data.drop(columns=set(index_columns) - set(dims.keys())).rename(
columns=dims
)
index_columns = list(data.columns)
index_columns.pop(index_columns.index("value"))
# Prepare a Quantity object with the (bare) units and any conversion factor
registry = pint.get_application_registry()
units = units or "1.0 dimensionless"
if isinstance(units, str):
uq = registry(units)
elif isinstance(units, pint.Unit):
uq = registry.Quantity(1.0, units)
else:
uq = units
return Quantity(
uq.magnitude * data.set_index(index_columns)["value"], units=uq.units, name=name
)
[docs]def mul(*quantities: Quantity) -> Quantity:
"""Compute the product of any number of *quantities*."""
# Iterator over (quantity, unit) tuples
items = zip(quantities, collect_units(*quantities))
# Initialize result values with first entry
result, u_result = next(items)
# Iterate over remaining entries
for q, u in items:
if isinstance(q, AttrSeries):
# Work around pandas-dev/pandas#25760; see attrseries.py
result = (result * q.align_levels(result)).dropna()
else:
result = result * q
u_result *= u
result.attrs["_unit"] = u_result
return result
#: Alias of :func:`mul`, for backwards compatibility.
#:
#: .. note:: This may be deprecated and possibly removed in a future version.
product = mul
[docs]def pow(a: Quantity, b: Union[Quantity, int]) -> Quantity:
"""Compute `a` raised to the power of `b`.
.. todo:: Provide units on the result in the special case where `b` is a Quantity
but all its values are the same :class:`int`.
Returns
-------
.Quantity
If `b` is :class:`int`, then the quantity has the units of `a` raised to this
power; e.g. "kg²" → "kg⁴" if `b` is 2. In other cases, there are no meaningful
units, so the returned quantity is dimensionless.
"""
if isinstance(b, int):
unit_exponent = b
b = Quantity(float(b))
else:
unit_exponent = 0
u_a, u_b = collect_units(a, b)
if not u_b.dimensionless:
raise ValueError(f"Cannot raise to a power with units ({u_b:~})")
if isinstance(a, AttrSeries):
result = a ** cast(AttrSeries, b).align_levels(a)
else:
result = a**b
result.attrs["_unit"] = (
a.attrs["_unit"] ** unit_exponent
if unit_exponent
else pint.get_application_registry().dimensionless
)
return result
[docs]def relabel(
qty: Quantity,
labels: Optional[Mapping[Hashable, Mapping]] = None,
**dim_labels: Mapping,
) -> Quantity:
"""Replace specific labels along dimensions of `qty`.
Parameters
----------
labels :
Keys are strings identifying dimensions of `qty`; values are further mappings
from original labels to new labels. Dimensions and labels not appearing in `qty`
have no effect.
dim_labels :
Mappings given as keyword arguments, where argument name is the dimension.
Raises
------
ValueError
if both `labels` and `dim_labels` are given.
"""
# NB pandas uses the term "levels [of a MultiIndex]"; xarray uses "coords [for a
# dimension]".
# TODO accept callables as values in `mapper`, as DataArray.assign_coords() does
maps = either_dict_or_kwargs(labels, dim_labels, "relabel")
# Iterate over (dim, label_map) for only dims included in `qty`
iter = filter(lambda kv: kv[0] in qty.dims, maps.items())
def map_labels(mapper, values):
"""Generate the new labels for a single dimension."""
return list(map(lambda label: mapper.get(label, label), values))
if isinstance(qty, AttrSeries):
# Prepare a new index
idx = _multiindex_of(qty)
for dim, label_map in iter:
# - Look up numerical index of the dimension in `idx`
# - Retrieve the existing levels.
# - Map to new levels.
# - Assign, creating a new index
idx = idx.set_levels(
map_labels(label_map, idx.levels[idx.names.index(dim)]), level=dim
)
# Assign the new index to a copy of qty
result = cast(AttrSeries, qty.copy())
result.index = idx
return result
else:
return cast(SparseDataArray, qty).assign_coords(
{dim: map_labels(m, qty.coords[dim].data) for dim, m in iter}
)
[docs]def rename_dims(
qty: Quantity,
new_name_or_name_dict: Union[Hashable, Mapping[Any, Hashable]] = None,
**names: Hashable,
) -> Quantity:
"""Rename the dimensions of `qty`.
Like :meth:`xarray.DataArray.rename`.
"""
return qty.rename(new_name_or_name_dict, **names)
[docs]def round(qty: Quantity, *args, **kwargs) -> Quantity:
"""Like :meth:`xarray.DataArray.round`."""
return qty.round(*args, **kwargs)
[docs]def select(
qty: Quantity,
indexers: Mapping[Hashable, Iterable[Hashable]],
*,
inverse: bool = False,
drop: bool = False,
) -> Quantity:
"""Select from `qty` based on `indexers`.
Parameters
----------
indexers : dict (str -> xarray.DataArray or list of str)
Elements to be selected from `qty`. Mapping from dimension names to coords along
the respective dimension of `qty`, or to xarray-style indexers. Values not
appearing in the dimension coords are silently ignored.
inverse : bool, optional
If :obj:`True`, *remove* the items in indexers instead of keeping them.
"""
# Identify the type of the first value in `indexers`
_t = type(next(chain(iter(indexers.values()), [None])))
if _t is xr.DataArray:
if inverse:
raise NotImplementedError("select(…, inverse=True) with DataArray indexers")
# Pass through
new_indexers = indexers
else:
# Predicate for containment
op = operator.not_ if inverse else operator.truth
# Use only the values from `indexers` (not) appearing in `qty.coords`
coords = qty.coords
new_indexers = {
dim: list(filter(lambda x: op(x in labels), coords[dim].data))
for dim, labels in indexers.items()
}
return qty.sel(new_indexers, drop=drop)
[docs]def sum(
quantity: Quantity,
weights: Optional[Quantity] = None,
dimensions: Optional[List[str]] = None,
) -> Quantity:
"""Sum *quantity* over *dimensions*, with optional *weights*.
Parameters
----------
weights : .Quantity, optional
If *dimensions* is given, *weights* must have at least these
dimensions. Otherwise, any dimensions are valid.
dimensions : list of str, optional
If not provided, sum over all dimensions. If provided, sum over these
dimensions.
"""
if weights is None:
_w = Quantity(1.0)
w_total = Quantity(1.0)
else:
_w = weights
w_total = weights.sum(dim=dimensions)
if 0 == len(w_total.dims):
w_total = w_total.item()
return div(mul(quantity, _w).sum(dim=dimensions), w_total)
[docs]def write_report(quantity: Quantity, path: Union[str, PathLike]) -> None:
"""Write a quantity to a file.
Parameters
----------
path : str or Path
Path to the file to be written.
"""
path = Path(path)
if path.suffix == ".csv":
quantity.to_dataframe().to_csv(path)
elif path.suffix == ".xlsx":
quantity.to_dataframe().to_excel(path, merge_cells=False)
else:
path.write_text(quantity) # type: ignore