import operator
from abc import abstractmethod
from numbers import Number
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
Hashable,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
import numpy as np
import pandas as pd
import pint
if TYPE_CHECKING:
from genno.types import Unit
from .quantity import AnyQuantity
[docs]
class UnitsMixIn:
"""Object with :attr:`.units` and :meth:`._binary_op_units`."""
attrs: Dict[Hashable, Any]
@property
# def units(self) -> "Unit":
def units(self):
"""Retrieve or set the units of the Quantity.
Examples
--------
Create a quantity without units:
>>> qty = Quantity(...)
Set using a string; automatically converted to pint.Unit:
>>> qty.units = "kg"
>>> qty.units
<Unit('kilogram')>
"""
return self.attrs.setdefault(
"_unit", pint.get_application_registry().dimensionless
)
@units.setter
def units(self, value: Union["Unit", str]) -> None:
self.attrs["_unit"] = pint.get_application_registry().Unit(value)
def _binary_op_units(
self, other: "UnitsMixIn", op, swap: bool
) -> Tuple["Unit", float]:
"""Determine result units for a binary operation between `self` and `other`.
Returns:
1. Result units.
2. For rank-1 operations ('add', 'radd', 'rsub', 'sub') operations, a scaling
factor to make magnitudes of `other` compatible with `self`.
"""
# Retrieve units of `other`
ou = other.units
# Ensure there is not a mix of pint.Unit and pint.registry.Unit; this throws off
# pint's internal logic
if ou.__class__ is not self.units.__class__:
ou = self.units.__class__(ou)
if rank(op) == 1:
# Determine multiplicative factor to align `other` to `self`
return self.units, pint.Quantity(1.0, ou).to(self.units).magnitude
elif rank(op) == 2:
# Allow pint to determine the output units
return op(*[ou, self.units] if swap else [self.units, ou]), np.nan
else:
# Exponent, its units, and base units
exp, eu, bu = (self, self.units, ou) if swap else (other, ou, self.units)
if not eu.dimensionless:
raise ValueError(f"Cannot raise to a power with units {eu:~}")
# Extract the (dense) data of the exponent
data = cast("AnyQuantity", exp).to_series().values
# Each exponent modulo 1. Set of {0} if exponents are all integers.
check = set(np.mod(data, 1))
# Unique values in data
unique_values = np.unique(data)
if check == {0.0} and len(unique_values) == 1:
# The same, integer exponent for all values; raise the base units to
# this value
return op(bu, unique_values[0]), np.nan
else:
return pint.get_application_registry().dimensionless, np.nan
[docs]
def make_binary_op(op, *, swap: bool):
"""Create a method for binary operator `name`."""
def method(obj: "BaseQuantity", other: "BaseQuantity"):
scalar_other = False
if isinstance(other, Number):
other = type(obj)(other)
scalar_other = True
elif not (
isinstance(other, type(obj))
or getattr(other, "__thisclass__", None) is type(obj) # super()
):
raise TypeError(type(other))
left, right, result_units, factor = prepare_binary_op(obj, other, op, swap)
# If `other` was scalar and the operation is rank-1 (add, sub, etc.), the units
# of `obj` carry to the result. Otherwise, use `result_units`.
return obj._keep(
obj._perform_binary_op(op, left, right, factor),
name=scalar_other,
attrs=scalar_other,
units=obj.units if (scalar_other and rank(op) == 1) else result_units,
)
return method
T = TypeVar("T")
[docs]
class BinaryOpsMixIn(Generic[T]):
"""Binary operations for :class:`Quantity`.
Subclasses **must** implement :meth:`_perform_binary_op`.
Several binary operations are provided with methods that:
- Convert scalar operands to :class:`.Quantity`.
- Determine result units.
- Preserve name and non-unit attrs.
"""
__add__ = make_binary_op(operator.add, swap=False)
__mul__ = make_binary_op(operator.mul, swap=False)
__pow__ = make_binary_op(operator.pow, swap=False)
__radd__ = make_binary_op(operator.add, swap=True)
__rmul__ = make_binary_op(operator.mul, swap=True)
__rpow__ = make_binary_op(operator.pow, swap=True)
__rsub__ = make_binary_op(operator.sub, swap=True)
__rtruediv__ = make_binary_op(operator.truediv, swap=True)
__sub__ = make_binary_op(operator.sub, swap=False)
__truediv__ = make_binary_op(operator.truediv, swap=False)
@staticmethod
@abstractmethod
def _perform_binary_op(name: str, left: T, right: T, factor: float) -> T: ...
[docs]
class BaseQuantity(
BinaryOpsMixIn,
UnitsMixIn,
):
"""Common base for a class that behaves like :class:`xarray.DataArray`.
The class has units and unit-aware binary operations.
"""
name: Optional[Hashable]
@abstractmethod
def __init__(
self,
data: Any = None,
coords: Union[Sequence[Tuple], Mapping[Hashable, Any], None] = None,
dims: Union[str, Sequence[Hashable], None] = None,
name: Hashable = None,
attrs: Optional[Mapping] = None,
# internal parameters
indexes: Optional[Dict[Hashable, pd.Index]] = None,
fastpath: bool = False,
**kwargs,
): ...
def _keep(
self,
target: "AnyQuantity",
attrs: Optional[Any] = False,
name: Optional[Any] = False,
units: Optional[Any] = False,
) -> "AnyQuantity":
"""Preserve `name`, `units`, and/or other `attrs` from `self` to `target`.
The action for each argument is:
- :any:`False`: don't keep.
- :any:`True`: keep the existing value.
- Anything else: assign this value.
"""
if name is not False:
target.name = self.name if name is True else name
if attrs is True:
target.attrs.update(self.attrs)
elif attrs is not False:
assert isinstance(attrs, Mapping)
target.attrs.update(attrs)
if units is not False:
# Only units; not other attrs
target.units = self.units if units is True else units
return target
[docs]
def prepare_binary_op(
obj: BaseQuantity, other, op, swap: bool
) -> Tuple[BaseQuantity, BaseQuantity, "Unit", float]:
"""Prepare inputs for a binary operation.
Returns:
1. The left operand (`obj` if `swap` is False else `other`).
2. The right operand. If units of `other` are different than `obj`, `other` is
scaled.
3. Units for the result. In additive operations, the units of `obj` take precedence.
4. Any scaling factor needed to make units of `other` compatible with `obj`.
"""
# Determine resulting units
result_units, factor = obj._binary_op_units(other, op, swap)
# Apply a multiplicative factor to align units
if rank(op) == 1 and factor != 1.0:
other = super(type(obj), other).__mul__(factor)
# For __r*__ methods
left, right = (other, obj) if swap else (obj, other)
return left, right, result_units, factor
[docs]
def collect_attrs(
data, attrs_arg: Optional[Mapping], kwargs: MutableMapping
) -> MutableMapping:
"""Handle `attrs` and 'units' `kwargs` to Quantity constructors."""
# Use attrs, if any, from an existing object, if any
new_attrs = getattr(data, "attrs", dict()).copy()
# Overwrite with values from an explicit attrs argument
new_attrs.update(attrs_arg or dict())
# Store the "units" keyword argument as an attr
units = kwargs.pop("units", None)
if units is not None:
new_attrs["_unit"] = pint.Unit(units)
return new_attrs
[docs]
def rank(op) -> int:
"""Rank of the binary operation `op`.
See `‘Hyperoperation’ on Wikipedia <https://en.wikipedia.org/wiki/Hyperoperation>`_
for the sense of this meaning of ‘rank’.
"""
return {
operator.add: 1,
operator.sub: 1,
operator.mul: 2,
operator.truediv: 2,
operator.pow: 3,
}[op]
[docs]
def single_column_df(data, name: Hashable) -> Tuple[Any, Hashable]:
"""Handle `data` and `name` arguments to Quantity constructors."""
if isinstance(data, pd.DataFrame):
if len(data.columns) != 1:
raise TypeError(
f"Cannot instantiate Quantity from {len(data.columns)}-D data frame"
)
# Unpack a single column; use its name if not overridden by `name`
return data.iloc[:, 0], (name or data.columns[0])
else:
return data, name