Source code for genno.testing

import contextlib
import logging
import sys
from copy import copy
from functools import partial
from itertools import chain, islice
from typing import Dict

import numpy as np
import pandas as pd
import pint
import pytest
import xarray as xr
from dask.core import quote
from pandas.testing import assert_series_equal

import genno.core.quantity
from genno import ComputationError, Computer, Key, Quantity
from genno.core.sparsedataarray import HAS_SPARSE

log = logging.getLogger(__name__)

if sys.version_info.minor >= 10:
    import importlib.resources as importlib_resources
    # Use the backport to get identical behaviour
    import importlib_resources  # type: ignore [no-redef]

# Pytest hooks

[docs]def pytest_runtest_makereport(item, call): """Pytest hook to unwrap :class:`genno.ComputationError`. This allows to "xfail" tests more precisely on the underlying exception, rather than the ComputationError which wraps it. """ if call.when == "call" and getattr(call.excinfo, "type", None) is ComputationError: # Retrieve the Exception wrapped by ComputationError e = call.excinfo.value.args[0] # Look for an "xfail" marker whose raises= class(es) match `e` for mark in filter( lambda m: == "xfail" and isinstance(e, m.kwargs.get("raises", ())), item.iter_markers(), ): # Change the ExceptionInfo describe `e`, which will match this mark # and produce an "xfail" report call.excinfo = pytest.ExceptionInfo( excinfo=(type(e), e, e.__traceback__), _ispytest=True ) # Generate and return the report return pytest.TestReport.from_item_and_call(item, call)
[docs]def add_large_data(c: Computer, num_params, N_dims=6, N_data=0): """Add nodes to `c` that return large-ish data. The result is a matrix wherein the Cartesian product of all the keys is very large— about 2e17 elements for N_dim = 6—but the contents are very sparse. This can be handled by :class:`.SparseDataArray`, but not by :class:`xarray.DataArray` backed by :class:`np.array`. """ def _fib(): """Yield dimensions and their lengths: Fibonacci numbers.""" a, b = 233, 377 dim_names = iter("abcdefghijklmnopqrstuvwxyz") yield next(dim_names), a while True: yield next(dim_names), b a, b = b, a + b # Dimensions and their lengths dims, sizes = zip(*islice(_fib(), N_dims)) # Number of data points to generate N_data = max(int(N_data), sizes[-1]) # commented; for debugging # # Output something like "True: 2584 values / 2.182437e+17 = 1.184e-12% full" # from math import prod # # total = prod(sizes) # # # See; total elements must be # # less than the maximum value of np.intp # repr(total < np.iinfo(np.intp).max) # + f": {max(sizes)} values / {total:3e} = {100 * max(sizes) / total:.3e}% full" # ) # Names like f_00000 ... f_01596 along each dimension dtypes = {"value": float} for d, N in zip(dims, sizes): categories = [f"{d}_{i:05d}" for i in range(N)] # Add to Computer c.add(d, quote(categories)) # Create a categorical dtype dtypes[d] = pd.CategoricalDtype(categories) # Random generator rng = np.random.default_rng() def get_large_quantity(name): """Make a DataFrame containing each label in *coords* ≥ 1 time.""""{N_data} values") # Allocate memory for the data frame using the given data types df = pd.DataFrame( index=pd.RangeIndex(N_data), columns=list(dims) + ["value"] ).astype(dtypes) # Fill values df.loc[:, "value"] = rng.random(N_data) # Fill labels for d in dims: df[d] = pd.Categorical.from_codes( rng.integers(0, len(dtypes[d].categories), N_data), dtype=dtypes[d] ) return Quantity( df.set_index(list(dims)), units=pint.get_application_registry().kilogram, name=name, ) # Fill the Scenario with quantities named q_01 ... q_09 keys = [] for i in range(num_params): key = Key(f"q_{i:02d}", dims) c.add(key, (partial(get_large_quantity, key),)) keys.append(key) return keys
[docs]def add_test_data(c: Computer): """:func:`add_test_data` operating on a Computer, not an ixmp.Scenario.""" # TODO combine with add_dantzig(), below # New sets t_foo = ["foo{}".format(i) for i in (1, 2, 3)] t_bar = ["bar{}".format(i) for i in (4, 5, 6)] t = t_foo + t_bar y = list(range(2000, 2051, 10)) # Add to Computer c.add("t", quote(t)) c.add("y", quote(y)) # Data ureg = pint.get_application_registry() x = Quantity( xr.DataArray(np.random.rand(len(t), len(y)), coords=[("t", t), ("y", y)]),, name="Quantity X", ) # Add, including sums and to index c.add(Key("x", ("t", "y")), Quantity(x), sums=True) return t, t_foo, t_bar, x
_i = ["seattle", "san-diego"] _j = ["new-york", "chicago", "topeka"] _TEST_DATA = { Key(k): data for k, data in { "a:i": (xr.DataArray([350, 600], coords=[("i", _i)]), "cases"), "b:j": (xr.DataArray([325, 300, 275], coords=[("i", _j)]), "cases"), "d:i-j": ( xr.DataArray( [[2.5, 1.7, 1.8], [2.5, 1.8, 1.4]], coords=[("i", _i), ("j", _j)] ), "km", ), "f:": (90.0, "USD/km"), # TODO complete the following # Decision variables and equations "x:i-j": ( xr.DataArray([[0, 0, 0], [0, 0, 0]], coords=[("i", _i), ("j", _j)]), "cases", ), "z:": (0, "cases"), "cost:": (0, "USD"), "cost-margin:": (0, "USD"), "demand:j": (xr.DataArray([0, 0, 0], coords=[("j", _j)]), "cases"), "demand-margin:j": (xr.DataArray([0, 0, 0], coords=[("j", _j)]), "cases"), "supply:i": (xr.DataArray([0, 0], coords=[("i", _i)]), "cases"), "supply-margin:i": (xr.DataArray([0, 0], coords=[("i", _i)]), "cases"), }.items() }
[docs]def get_test_quantity(key: Key) -> Quantity: """Computation that returns test data.""" value, unit = _TEST_DATA[key] return Quantity(value,, units=unit)
[docs]def add_dantzig(c: Computer): """Add contents analogous to the ixmp Dantzig scenario.""" c.add("i", quote(_i)) c.add("j", quote(_j)) _all = list() for key in _TEST_DATA.keys(): c.add(key, (partial(get_test_quantity, key),), sums=True) _all.append(key) c.add("all", sorted(_all))
[docs]@contextlib.contextmanager def assert_logs(caplog, message_or_messages=None, at_level=None): """Assert that *message_or_messages* appear in logs. Use assert_logs as a context manager for a statement that is expected to trigger certain log messages. assert_logs checks that these messages are generated. Derived from :func:`ixmp.testing.assert_logs`. Example ------- >>> def test_foo(caplog): ... with assert_logs(caplog, 'a message'): ... logging.getLogger(__name__).info('this is a message!') Parameters ---------- caplog : object The pytest caplog fixture. message_or_messages : str or list of str String(s) that must appear in log messages. at_level : int, optional Messages must appear on 'genno' or a sub-logger with at least this level. """ __tracebackhide__ = True # Wrap a string in a list expected = ( [message_or_messages] if isinstance(message_or_messages, str) else message_or_messages ) # Record the number of records prior to the managed block first = len(caplog.records) if at_level is not None: # Use the pytest caplog fixture's built-in context manager to temporarily set # the level of the logger for the whole package (parent of the current module) ctx = caplog.at_level(at_level, logger=__name__.split(".")[0]) else: # ctx does nothing ctx = contextlib.nullcontext() try: with ctx: yield # Nothing provided to the managed block finally: # List of bool indicating whether each of `expected` was found found = [any(e in msg for msg in caplog.messages[first:]) for e in expected] if not all(found): # Format a description of the missing messages lines = chain( ["Did not log:"], [f" {repr(msg)}" for i, msg in enumerate(expected) if not found[i]], ["among:"], [" []"] if len(caplog.records) == first else [f" {repr(msg)}" for msg in caplog.messages[first:]], )"\n".join(lines))
[docs]def assert_qty_equal( a, b, check_type: bool = True, check_attrs: bool = True, ignore_extra_coords: bool = False, **kwargs, ): """Assert that objects `a` and `b` are equal. Parameters ---------- check_type : bool, optional Assert that `a` and `b` are both :class:`.Quantity` instances. If :obj:`False`, the arguments are converted to Quantity. check_attrs : bool, optional Also assert that check that attributes are identical. ignore_extra_coords : bool, optional Ignore extra coords that are not dimensions. Only meaningful when Quantity is :class:`.SparseDataArray`. """ __tracebackhide__ = True try: assert type(a) is type(b) and type(a).__name__ == genno.core.quantity.CLASS except AssertionError: if check_type: raise else: # Convert both arguments to Quantity a = Quantity(a) b = Quantity(b) if genno.core.quantity.CLASS == "AttrSeries": try: a = a.sort_index().dropna() b = b.sort_index().dropna() except TypeError: # pragma: no cover pass assert_series_equal(a, b, check_dtype=False, **kwargs) else: import xarray.testing if ignore_extra_coords: a = a.reset_coords(set(a.coords.keys()) - set(a.dims), drop=True) b = b.reset_coords(set(b.coords.keys()) - set(b.dims), drop=True) assert 0 == len(kwargs) xarray.testing.assert_equal(a._sda.dense, b._sda.dense) # Check attributes are equal if check_attrs: assert a.attrs == b.attrs
[docs]def assert_qty_allclose( a, b, check_type: bool = True, check_attrs: bool = True, ignore_extra_coords: bool = False, **kwargs, ): """Assert that objects `a` and `b` have numerically close values. Parameters ---------- check_type : bool, optional Assert that `a` and `b` are both :class:`.Quantity` instances. If :obj:`False`, the arguments are converted to Quantity. check_attrs : bool, optional Also assert that check that attributes are identical. ignore_extra_coords : bool, optional Ignore extra coords that are not dimensions. Only meaningful when Quantity is :class:`.SparseDataArray`. """ __tracebackhide__ = True try: assert type(a) is type(b) and type(a).__name__ == genno.core.quantity.CLASS except AssertionError: if check_type: raise else: # Convert both arguments to Quantity a = Quantity(a) b = Quantity(b) if genno.core.quantity.CLASS == "AttrSeries": assert_series_equal(a.sort_index(), b.sort_index(), **kwargs) else: import xarray.testing if ignore_extra_coords: a = a.reset_coords(set(a.coords.keys()) - set(a.dims), drop=True) b = b.reset_coords(set(b.coords.keys()) - set(b.dims), drop=True) # Remove a kwarg not recognized by the xarray function kwargs.pop("check_dtype", None) xarray.testing.assert_allclose(a._sda.dense, b._sda.dense, **kwargs) # Check attributes are equal if check_attrs: assert a.attrs == b.attrs
[docs]def assert_units(qty: Quantity, exp: str) -> None: """Assert that `qty` has units `exp`.""" assert ( qty.units / qty.units._REGISTRY(exp) ).dimensionless, f"Units '{qty.units:~}'; expected {repr(exp)}"
[docs]def random_qty(shape: Dict[str, int], **kwargs): """Return a Quantity with `shape` and random contents. Parameters ---------- shape : dict (str -> int) Mapping from dimension names to lengths along each dimension. **kwargs Other keyword arguments to :class:`Quantity`. Returns ------- Quantity Random data with one dimension for each key in `shape`, and coords along those dimensions like "foo1", "foo2", with total length matching the value from `shape`. If `shape` is empty, a scalar (0-dimensional) Quantity. """ return Quantity( xr.DataArray( np.random.rand(*shape.values()) if len(shape) else np.random.rand(1)[0], coords=[ (dim, [f"{dim}{i}" for i in range(length)]) for dim, length in shape.items() ], ), **kwargs, )
# Fixtures
[docs]@pytest.fixture(scope="session") def test_data_path(): """Path to the directory containing test data.""" return importlib_resources.files("")
[docs]@pytest.fixture(scope="session") def ureg(): """Application-wide units registry.""" registry = pint.get_application_registry() # Used by .compat.ixmp, .compat.pyam for name in ("USD", "case"): try: registry.define(f"{name} = [{name}]") except pint.RedefinitionError: # pragma: no cover pass yield registry
@pytest.fixture( params=[(True, "AttrSeries"), (HAS_SPARSE, "SparseDataArray")], ids=["attrseries", "sparsedataarray"], ) def parametrize_quantity_class(request): """Fixture to run tests twice, for both Quantity implementations.""" if not request.param[0]: pytest.skip(reason="`sparse` not available → can't test SparseDataArray") pre = genno.core.quantity.CLASS genno.core.quantity.CLASS = request.param[1] yield genno.core.quantity.CLASS = pre @pytest.fixture(scope="function") def quantity_is_sparsedataarray(request): pre = copy(genno.core.quantity.CLASS) genno.core.quantity.CLASS = "SparseDataArray" yield genno.core.quantity.CLASS = pre