Source code for lories.core.register.registry

# -*- coding: utf-8 -*-
"""
lories.core.register.registry
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


"""

from __future__ import annotations

import builtins
from typing import Callable, Collection, Dict, Generic, Mapping, Optional, Type

from lories._core._registrator import Registrator  # noqa
from lories.core.register.registration import Registration, RegistrationError, RegistrationFactory

# FIXME: Remove this once Python >= 3.9 is a requirement
try:
    from typing import get_args

except ImportError:
    from typing_extensions import get_args


# noinspection PyShadowingBuiltins
[docs] class Registry(Generic[Registrator]): __types: Mapping[str, Registration[Registrator]] def __init__(self) -> None: self.__types: Dict[str, Registration[Registrator]] = {} def register( self, cls: Type[Registrator], key: str, *alias: str, factory: Optional[RegistrationFactory] = None, replace: bool = False, ) -> None: if not isinstance(key, str): raise RegistrationError(f"Invalid '{builtins.type(key)}' registration type: {key}") key = key.lower() type = self._get_generic_type() if not issubclass(cls, type): raise ValueError(f"Can only register {type} types") if self.has_type(key) and not replace: raise RegistrationError( f"Registration '{key}' does already exist: " f"{next(t for t in self.__types.values() if key == t.key).name}" ) registration = Registration[Registrator](cls, key, *alias, factory=factory) self.__types[key] = registration # noinspection PyShadowingBuiltins def filter( self, filter: Callable[[Registration[Registrator]], bool], ) -> Collection[Registration[Registrator]]: return [c for c in self.__types.values() if filter(c)] def get_types(self) -> Collection[str]: return self.__types.keys() def has_type(self, type: str) -> bool: return any(t.is_type(type) for t in self.__types.values()) def from_type(self, type: str) -> Registration[Registrator]: for registration in self.__types.values(): if registration.is_type(type): return registration raise RegistrationError(f"Registration '{type}' does not exist") # noinspection PyTypeChecker, PyUnresolvedReferences def _get_generic_type(self) -> Type[Registrator]: return get_args(self.__orig_class__)[0]