Source code for genno.core.graph

from import Generator, Iterable, Sequence
from itertools import chain, tee
from operator import itemgetter
from typing import Any, Optional, Union

from .key import Key, KeyLike

def _key_arg(key: KeyLike) -> Union[str, Key]:
    return Key.bare_name(key) or Key(key)

[docs] class Graph(dict): """A dictionary for a graph indexed by :class:`.Key`. Graph maintains indexes on set/delete/pop/update operations that allow for fast lookups/member checks in certain special cases: .. autosummary:: unsorted_key full_key These basic features are used to provide higher-level helpers for :class:`.Computer`: .. autosummary:: infer """ _unsorted: dict[KeyLike, KeyLike] = dict() _full: dict[Key, Key] = dict() def __init__(self, *args, **kwargs): # Initialize members super().__init__(*args, **kwargs) # Initialize indices self._unsorted = dict() self._full = dict() # Index new keys for k in kwargs.keys(): self._index(k) def _index(self, key: KeyLike): """Add `key` to the indices.""" k = _key_arg(key) if isinstance(k, Key): self._unsorted[k.sorted] = k nodim = k.drop(True) if len(k.dims) >= len(self._full.get(nodim, nodim).dims): self._full[nodim] = k else: self._unsorted[k] = key def _deindex(self, key: KeyLike): """Remove `key` from the indices.""" k = _key_arg(key) if isinstance(k, Key): self._unsorted.pop(k.sorted, None) self._full.pop(k.drop(True), None) else: self._unsorted.pop(k, None) def __setitem__(self, key: KeyLike, value: Any): super().__setitem__(key, value) self._index(key) def __delitem__(self, key: KeyLike): super().__delitem__(key) self._deindex(key) def __contains__(self, item) -> bool: """:obj:`True` if `item` *or* a key with the same dims in a different order.""" try: return super().__contains__(item) or bool(self.unsorted_key(item)) except Exception: # for instance, TypeError return False
[docs] def pop(self, *args): """Overload :meth:`dict.pop` to also call :meth:`_deindex`.""" try: return super().pop(*args) finally: self._deindex(args[0])
[docs] def update(self, arg=None, **kwargs): """Overload :meth:`dict.pop` to also call :meth:`_index`.""" if isinstance(arg, (Sequence, Generator)): arg0, arg1 = tee(arg) arg_keys = map(itemgetter(0), arg0) else: arg1 = arg or dict() arg_keys = arg1.keys() for key in chain(kwargs.keys(), arg_keys): self._index(key) super().update(arg1, **kwargs)
[docs] def unsorted_key(self, key: KeyLike) -> Optional[KeyLike]: """Return `key` with its original or unsorted dimensions.""" k = _key_arg(key) return self._unsorted.get(k.sorted if isinstance(k, Key) else k)
[docs] def full_key(self, name_or_key: KeyLike) -> Optional[KeyLike]: """Return `name_or_key` with its full dimensions.""" return self._full.get(Key(name_or_key).drop_all())
[docs] def infer( self, key: Union[str, Key], dims: Iterable[str] = [] ) -> Optional[KeyLike]: """Infer a `key`. Parameters ---------- dims : list of str, optional Drop all but these dimensions from the returned key(s). Returns ------- str If `key` is not found in the Graph. .Key `key` with either its full dimensions (cf. :meth:`full_key`) or, if `dims` are given, with only these dims. """ result = self.unsorted_key(key) or key if isinstance(key, str) or not key.dims: # Find the full-dimensional key result = self.full_key(result) or "" if not isinstance(result, Key): return result or key # Drop all but `dims` if dims: result = result.drop(*(set(result.dims) - set(dims))) return result