Source code for genno.testing

import contextlib
import logging
import sys
from contextlib import nullcontext
from functools import partial
from itertools import chain, islice
from typing import TYPE_CHECKING, ContextManager, Dict

import numpy as np
import pandas as pd
import pint
import pytest
import xarray as xr
import xarray.testing
from dask.core import quote
from pandas.testing import assert_series_equal

import genno
from genno import ComputationError, Computer, Key, set_class
from genno.compat.pint import PintError
from genno.core.attrseries import AttrSeries
from genno.core.sparsedataarray import HAS_SPARSE, SparseDataArray

if sys.version_info.minor >= 10:
    import importlib.resources as importlib_resources
else:
    # Use the backport to get identical behaviour
    import importlib_resources  # type: ignore [no-redef]

if TYPE_CHECKING:
    from genno.core.quantity import AnyQuantity


log = logging.getLogger(__name__)

# Pytest hooks


def pytest_sessionstart(session):
    logging.getLogger("numba").setLevel(logging.INFO)


[docs] def pytest_runtest_makereport(item, call): """Pytest hook to unwrap :class:`genno.ComputationError`. This allows to "xfail" tests more precisely on the underlying exception, rather than the ComputationError which wraps it. """ if call.when == "call" and getattr(call.excinfo, "type", None) is ComputationError: # Retrieve the Exception wrapped by ComputationError e = call.excinfo.value.args[0] # Look for an "xfail" marker whose raises= class(es) match `e` for mark in filter( lambda m: m.name == "xfail" and isinstance(e, m.kwargs.get("raises", ())), item.iter_markers(), ): # Change the ExceptionInfo describe `e`, which will match this mark # and produce an "xfail" report call.excinfo = pytest.ExceptionInfo( excinfo=(type(e), e, e.__traceback__), _ispytest=True ) # Generate and return the report return pytest.TestReport.from_item_and_call(item, call)
[docs] def add_large_data(c: Computer, num_params, N_dims=6, N_data=0): """Add nodes to `c` that return large-ish data. The result is a matrix wherein the Cartesian product of all the keys is very large— about 2e17 elements for N_dim = 6—but the contents are very sparse. This can be handled by :class:`.SparseDataArray`, but not by :class:`xarray.DataArray` backed by :class:`numpy.ndarray`. """ def _fib(): """Yield dimensions and their lengths: Fibonacci numbers.""" a, b = 233, 377 dim_names = iter("abcdefghijklmnopqrstuvwxyz") yield next(dim_names), a while True: yield next(dim_names), b a, b = b, a + b # Dimensions and their lengths dims, sizes = zip(*islice(_fib(), N_dims)) # Number of data points to generate N_data = max(int(N_data), sizes[-1]) # commented; for debugging # # Output something like "True: 2584 values / 2.182437e+17 = 1.184e-12% full" # from math import prod # # total = prod(sizes) # log.info( # # See https://github.com/pydata/sparse/issues/429; total elements must be # # less than the maximum value of np.intp # repr(total < np.iinfo(np.intp).max) # + f": {max(sizes)} values / {total:3e} = {100 * max(sizes) / total:.3e}% full" # ) # Names like f_00000 ... f_01596 along each dimension dtypes = {"value": float} for d, N in zip(dims, sizes): categories = [f"{d}_{i:05d}" for i in range(N)] # Add to Computer c.add(d, quote(categories)) # Create a categorical dtype dtypes[d] = pd.CategoricalDtype(categories) # Random generator rng = np.random.default_rng() def get_large_quantity(name): """Make a DataFrame containing each label in *coords* ≥ 1 time.""" log.info(f"{N_data} values") # Allocate memory for the data frame using the given data types df = pd.DataFrame( index=pd.RangeIndex(N_data), columns=list(dims) + ["value"] ).astype(dtypes) # Fill values df.loc[:, "value"] = rng.random(N_data) # Fill labels for d in dims: df[d] = pd.Categorical.from_codes( rng.integers(0, len(dtypes[d].categories), N_data), dtype=dtypes[d] ) return genno.Quantity( df.set_index(list(dims)), units=pint.get_application_registry().kilogram, name=name, ) # Fill the Scenario with quantities named q_01 ... q_09 keys = [] for i in range(num_params): key = Key(f"q_{i:02d}", dims) c.add(key, (partial(get_large_quantity, key),)) keys.append(key) return keys
[docs] def add_test_data(c: Computer): """:func:`add_test_data` operating on a Computer, not an ixmp.Scenario.""" # TODO combine with add_dantzig(), below # New sets t_foo = ["foo{}".format(i) for i in (1, 2, 3)] t_bar = ["bar{}".format(i) for i in (4, 5, 6)] t = t_foo + t_bar y = list(range(2000, 2051, 10)) # Add to Computer c.add("t", quote(t)) c.add("y", quote(y)) # Data ureg = pint.get_application_registry() x = genno.Quantity( np.random.rand(len(t), len(y)), coords={"t": t, "y": y}, units=ureg.kg, name="Quantity X", ) # Add, including sums and to index c.add(Key("x", ("t", "y")), genno.Quantity(x), sums=True) return t, t_foo, t_bar, x
_i = ["seattle", "san-diego"] _j = ["new-york", "chicago", "topeka"] _TEST_DATA = { Key(k): data for k, data in { "a:i": (xr.DataArray([350, 600], coords=[("i", _i)]), "cases"), "b:j": (xr.DataArray([325, 300, 275], coords=[("i", _j)]), "cases"), "d:i-j": ( xr.DataArray( [[2.5, 1.7, 1.8], [2.5, 1.8, 1.4]], coords=[("i", _i), ("j", _j)] ), "km", ), "f:": (90.0, "USD/km"), # TODO complete the following # Decision variables and equations "x:i-j": ( xr.DataArray([[0, 0, 0], [0, 0, 0]], coords=[("i", _i), ("j", _j)]), "cases", ), "z:": (0, "cases"), "cost:": (0, "USD"), "cost-margin:": (0, "USD"), "demand:j": (xr.DataArray([0, 0, 0], coords=[("j", _j)]), "cases"), "demand-margin:j": (xr.DataArray([0, 0, 0], coords=[("j", _j)]), "cases"), "supply:i": (xr.DataArray([0, 0], coords=[("i", _i)]), "cases"), "supply-margin:i": (xr.DataArray([0, 0], coords=[("i", _i)]), "cases"), }.items() }
[docs] def get_test_quantity(key: Key) -> "AnyQuantity": """Computation that returns test data.""" value, unit = _TEST_DATA[key] return genno.Quantity(value, name=key.name, units=unit)
[docs] def add_dantzig(c: Computer): """Add contents analogous to the ixmp Dantzig scenario.""" c.add("i", quote(_i)) c.add("j", quote(_j)) _all = list() for key in _TEST_DATA.keys(): c.add(key, (partial(get_test_quantity, key),), sums=True) _all.append(key) c.add("all", sorted(_all))
[docs] @contextlib.contextmanager def assert_logs(caplog, message_or_messages=None, at_level=None): """Assert that *message_or_messages* appear in logs. Use assert_logs as a context manager for a statement that is expected to trigger certain log messages. assert_logs checks that these messages are generated. Derived from :func:`ixmp.testing.assert_logs`. Example ------- >>> def test_foo(caplog): ... with assert_logs(caplog, 'a message'): ... logging.getLogger(__name__).info('this is a message!') Parameters ---------- caplog : object The pytest caplog fixture. message_or_messages : str or list of str String(s) that must appear in log messages. at_level : int, optional Messages must appear on 'genno' or a sub-logger with at least this level. """ __tracebackhide__ = True # Wrap a string in a list expected = ( [message_or_messages] if isinstance(message_or_messages, str) else message_or_messages ) # Record the number of records prior to the managed block first = len(caplog.records) if at_level is not None: # Use the pytest caplog fixture's built-in context manager to temporarily set # the level of the logger for the whole package (parent of the current module) ctx = caplog.at_level(at_level, logger=__name__.split(".")[0]) else: # ctx does nothing ctx = contextlib.nullcontext() try: with ctx: yield # Nothing provided to the managed block finally: # List of bool indicating whether each of `expected` was found found = [any(e in msg for msg in caplog.messages[first:]) for e in expected] if not all(found): # Format a description of the missing messages lines = chain( ["Did not log:"], [f" {repr(msg)}" for i, msg in enumerate(expected) if not found[i]], ["among:"], [" []"] if len(caplog.records) == first else [f" {repr(msg)}" for msg in caplog.messages[first:]], ) pytest.fail("\n".join(lines))
[docs] def assert_qty_equal( a, b, check_type: bool = True, check_attrs: bool = True, ignore_extra_coords: bool = False, **kwargs, ): """Assert that objects `a` and `b` are equal. Parameters ---------- check_type : bool, optional Assert that `a` and `b` are both :class:`.Quantity` instances. If :obj:`False`, the arguments are converted to Quantity. check_attrs : bool, optional Also assert that check that attributes are identical. ignore_extra_coords : bool, optional Ignore extra coords that are not dimensions. Only meaningful when Quantity is :class:`.SparseDataArray`. """ __tracebackhide__ = True try: assert type(a) is type(b) is genno.Quantity except AssertionError: if check_type: raise else: # Convert both arguments to Quantity a = genno.Quantity(a) b = genno.Quantity(b) if genno.Quantity is AttrSeries: try: a = a.sort_index().dropna() b = b.sort_index().dropna() except TypeError: # pragma: no cover pass assert_series_equal(a, b, check_dtype=False, **kwargs) else: import xarray.testing if ignore_extra_coords: a = a.reset_coords(set(a.coords.keys()) - set(a.dims), drop=True) b = b.reset_coords(set(b.coords.keys()) - set(b.dims), drop=True) assert 0 == len(kwargs) xarray.testing.assert_equal(a._sda.dense, b._sda.dense) # Check attributes are equal if check_attrs: assert a.attrs == b.attrs
[docs] def assert_qty_allclose( a, b, check_type: bool = True, check_attrs: bool = True, ignore_extra_coords: bool = False, **kwargs, ): """Assert that objects `a` and `b` have numerically close values. Parameters ---------- check_type : bool, optional Assert that `a` and `b` are both :class:`.Quantity` instances. If :obj:`False`, the arguments are converted to Quantity. check_attrs : bool, optional Also assert that check that attributes are identical. ignore_extra_coords : bool, optional Ignore extra coords that are not dimensions. Only meaningful when Quantity is :class:`.SparseDataArray`. """ __tracebackhide__ = True try: assert type(a) is type(b) is genno.Quantity except AssertionError: if check_type: raise else: # Convert both arguments to Quantity a = genno.Quantity(a) b = genno.Quantity(b) if genno.Quantity is AttrSeries: assert_series_equal(a.sort_index(), b.sort_index(), **kwargs) else: if ignore_extra_coords: a = a.reset_coords(set(a.coords.keys()) - set(a.dims), drop=True) b = b.reset_coords(set(b.coords.keys()) - set(b.dims), drop=True) # Remove a kwarg not recognized by the xarray function kwargs.pop("check_dtype", None) xarray.testing.assert_allclose(a._sda.dense, b._sda.dense, **kwargs) # Check attributes are equal if check_attrs: assert a.attrs == b.attrs
[docs] def assert_units(qty: "AnyQuantity", exp: str) -> None: """Assert that `qty` has units `exp`.""" assert ( qty.units / qty.units._REGISTRY(exp) ).dimensionless, f"Units '{qty.units:~}'; expected {repr(exp)}"
[docs] def random_qty(shape: Dict[str, int], **kwargs) -> "AnyQuantity": """Return a Quantity with `shape` and random contents. Parameters ---------- shape : dict Mapping from dimension names (:class:`str`) to lengths along each dimension (:class:`int`). **kwargs Other keyword arguments to :class:`.Quantity`. Returns ------- .Quantity Random data with one dimension for each key in `shape`, and coords along those dimensions like "foo1", "foo2", with total length matching the value from `shape`. If `shape` is empty, a scalar (0-dimensional) Quantity. """ return genno.Quantity( np.random.rand(*shape.values()) if len(shape) else np.random.rand(1)[0], coords={ dim: [f"{dim}{i}" for i in range(length)] for dim, length in shape.items() }, **kwargs, )
[docs] def raises_or_warns(value, *args, **kwargs) -> ContextManager: """Context manager for tests that :func:`.pytest.raises` or :func:`.pytest.warns`. If `value` is a context manager—such as returned by :func:`.pytest.raises`, it is used directly. Examples -------- .. code-block:: python @pytest.mark.parametrize( "input, output", (("FOO", 1), ("BAR", pytest.raises(ValueError))) ) def test_myfunc0(input, expected): with raises_or_warns(expected, DeprecationWarning, match="FOO"): assert expected == myfunc(input) In this example: - :py:`myfunc("FOO")` is expected to emit :class:`DeprecationWarning` and return 1. - :py:`myfunc("BAR")` is expected to raise :class:`ValueError` and issue no warning. .. code-block:: python @pytest.mark.parametrize( "input, output", (("FOO", 1), ("BAR", pytest.raises(ValueError))) ) def test_myfunc1(input, expected): with raises_or_warns(expected, None): assert expected == myfunc(input) In this example, no warnings are expected from :py:`myfunc("FOO")`. """ if isinstance(value, ContextManager): return value elif args == (None,) and kwargs == {}: return nullcontext() else: return pytest.warns(*args, **kwargs)
# Fixtures
[docs] @pytest.fixture(scope="session") def test_data_path(): """Path to the directory containing test data.""" return importlib_resources.files("genno.tests.data")
[docs] @pytest.fixture(scope="session") def ureg(): """Application-wide units registry.""" registry = pint.get_application_registry() # Used by .compat.ixmp, .compat.pyam for name in ("USD", "case"): try: registry.define(f"{name} = [{name}]") except PintError: # pragma: no cover # pint.RedefinitionError with pint 0.22 on Python ≤ 3.11 # pint.DefinitionSyntaxError with pint 0.17 on Python 3.12 pass yield registry
[docs] @pytest.fixture( params=[ (True, "AttrSeries", AttrSeries), (HAS_SPARSE, "SparseDataArray", SparseDataArray), ], ids=["attrseries", "sparsedataarray"], ) def parametrize_quantity_class(request): """Fixture to run tests twice, for both Quantity implementations.""" from genno.core import quantity if not request.param[0]: # pragma: no cover pytest.skip(reason="`sparse` not available → can't test SparseDataArray") pre = quantity.CLASS try: set_class(request.param[1]) yield finally: set_class(pre)
[docs] @pytest.fixture(params=[True, False], ids=["cow-true", "cow-false"]) def parametrize_copy_on_write(monkeypatch, request): """Fixture to run tests with pandas copy-on-write either enabled or disabled.""" monkeypatch.setattr(pd.options.mode, "copy_on_write", request.param) yield
@pytest.fixture(scope="function") def quantity_is_sparsedataarray(request): from genno.core import quantity pre = quantity.CLASS try: set_class("SparseDataArray") yield finally: set_class(pre)