# -*- coding: utf-8 -*-
"""
lories.core.register._load
~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
from __future__ import annotations
import logging
import os
from abc import abstractmethod
from copy import deepcopy
from logging import Logger
from typing import Any, Collection, Dict, Generic, Mapping, Optional, Sequence, Type
from lories._core._configurator import Configurator # noqa
from lories._core._registrator import ( # noqa
Registrator,
RegistratorContext,
_RegistratorContext, # noqa
)
from lories.core.configs import Configurations, Directories, Directory
from lories.core.errors import ResourceError
from lories.util import update_recursive
# FIXME: Remove this once Python >= 3.9 is a requirement
try:
from typing import get_args, get_origin
except ImportError:
from typing_extensions import get_args, get_origin
# FIXME: Remove this once Python >= 3.12 is a requirement
try:
from typing import get_bound
except ImportError:
def get_bound(cls):
return getattr(cls, "__bound__", None)
# noinspection SpellCheckingInspection, PyAbstractClass, PyProtectedMember
[docs]
class _RegistratorLoader(_RegistratorContext[Registrator], Generic[Registrator]):
__registrator_class: Type[Registrator]
_logger: Logger
def __init__(
self,
logger: Optional[Logger] = None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
if logger is None:
logger = self._build_logger()
self._logger = logger
self.__registrator_class = self._build_registrators_class()
@classmethod
def _build_logger(cls) -> Logger:
return logging.getLogger(cls.__module__)
@classmethod
def _build_registrators_class(cls) -> Type[Registrator]:
for _cls in cls.mro():
for _base in getattr(_cls, "__orig_bases__", ()):
if isinstance(_base, type):
continue
if issubclass(get_origin(_base), _RegistratorContext):
(_arg,) = get_args(_base)
if not isinstance(_arg, type):
_arg = get_bound(_arg)
return _arg
raise ResourceError("Unable to extract context registrator")
# noinspection PyUnresolvedReferences
def _build_registrator_id(
self,
context: RegistratorContext,
configs: Configurations,
) -> str:
return self._get_registator_class()._build_id(context=context, configs=configs)
# noinspection PyUnresolvedReferences
def _build_registrator_defaults(
self,
configs: Configurations,
includes: Optional[Collection[str]] = (),
) -> Dict[str, Any]:
return self._get_registator_class()._build_defaults(configs, includes)
# noinspection PyUnresolvedReferences
def _load(
self,
context: RegistratorContext,
configs: Configurations,
configs_file: Optional[str] = None,
configs_dir: Optional[str | Directory] = None,
includes: Optional[Collection[str]] = (),
defaults: Optional[Mapping[str, Any]] = None,
configure: bool = True,
sort: bool = True,
**kwargs: Any,
) -> Sequence[Registrator]:
registrators = []
if defaults is None:
defaults = {}
update_recursive(defaults, self._build_registrator_defaults(configs, includes))
configs_dirs = configs.dirs.copy()
if configs_dir is None:
configs_dir = configs.dirs.conf.joinpath(configs.name.replace(".conf", ".d"))
if isinstance(configs_dir, str) and not os.path.isabs(configs_dir):
configs_dir = configs_dirs.conf.joinpath(configs_dir)
configs_dirs.conf = configs_dir
configs_members = configs.get_members([s for s in configs.members if s not in defaults])
if configs_file and configs_file != configs.name:
registrators.extend(self._load_from_file(context, configs_file, configs.dirs, defaults, **kwargs))
registrators.extend(self._load_from_members(context, configs_members, defaults, **kwargs))
registrators.extend(self._load_from_dir(context, configs_dirs, defaults, **kwargs))
if sort:
self.sort()
if configure:
self.configure(registrators)
return registrators
def _load_from_configs(
self,
context: RegistratorContext,
configs: Configurations,
**kwargs: Any,
) -> Registrator:
registrator_id = self._build_registrator_id(context, configs)
if self._contains(registrator_id):
registrator = self._update(registrator_id, configs)
else:
registrator = self._create(context, configs, **kwargs)
self._add(registrator)
return registrator
# noinspection PyUnresolvedReferences
def _load_from_members(
self,
context: RegistratorContext,
configs: Configurations,
includes: Optional[Collection[str]] = (),
defaults: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> Sequence[Registrator]:
registrators = []
if defaults is None:
defaults = {}
update_recursive(defaults, self._build_registrator_defaults(configs))
for member_name in configs.members:
if member_name in includes:
continue
member_file = f"{member_name}.conf"
member_default = deepcopy(defaults)
update_recursive(member_default, configs.get(member_name))
member = Configurations.load(
member_file,
**configs.dirs.to_dict(),
**member_default,
require=False,
)
registrators.append(self._load_from_configs(context, member, **kwargs))
return registrators
def _load_from_file(
self,
context: RegistratorContext,
configs_file: str,
configs_dirs: Directories,
defaults: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> Sequence[Registrator]:
registrators = []
if configs_dirs.conf.joinpath(configs_file).is_file():
# Do not call .load() function here, as configs_dirs._conf may be None and would otherwise be overridden
# with the data directory
configs = Configurations(configs_file, deepcopy(configs_dirs))
configs._load()
registrators.extend(self._load_from_members(context, configs, defaults, **kwargs))
return registrators
def _load_from_dir(
self,
context: RegistratorContext,
configs_dirs: Directories,
defaults: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> Sequence[Registrator]:
registrators = []
if os.path.isdir(configs_dirs.conf):
for configs_entry in os.scandir(configs_dirs.conf):
if (
configs_entry.is_file()
and configs_entry.path.endswith(".conf")
and not configs_entry.path.endswith("default.conf")
and configs_entry.name
not in [
"settings.conf",
"system.conf",
"evaluations.conf",
"replications.conf",
"logging.conf",
]
):
configs = Configurations.load(
configs_entry.name,
**configs_dirs.to_dict(),
**defaults,
)
try:
registrators.append(self._load_from_configs(context, configs, **kwargs))
except ResourceError:
# Skip files with missing or unknown type
# TODO: Introduce debug logging here
pass
return registrators
# noinspection PyUnresolvedReferences, PyTypeChecker
def configure(self, configurators: Optional[Collection[Configurator]] = None) -> None:
if configurators is None:
configurators = self.values()
for configurator in configurators:
configurations = configurator.configs
if configurations is None or not configurations.enabled:
self._logger.debug(f"Skipping configuring disabled {type(configurator).__name__}")
continue
self._logger.debug(f"Configuring {type(self).__name__}: {configurations.path}")
configurator.configure(configurations)
if self._logger.getEffectiveLevel() <= logging.DEBUG:
self._logger.debug(f"Configured {configurator}")
# noinspection PyUnresolvedReferences, PyArgumentList, PyShadowingBuiltins
def _update(self, id: str, configs: Configurations) -> Registrator:
registrator = self._get(id)
if registrator.is_configured() and configs.enabled:
registrator_configs = registrator.configs.copy()
registrator_configs.update(configs)
registrator.update(registrator_configs)
else:
registrator.configs.update(configs)
return registrator
# noinspection PyUnresolvedReferences
def get_all(self, *types: Type[Registrator]) -> Collection[Registrator]:
if len(types) == 0:
return self.values()
return self.filter(lambda r: any(isinstance(r, t) and r.is_enabled() for t in types))
def get_first(self, *types: Optional[str | Type[Registrator]]) -> Optional[Registrator]:
registrators = self.get_all(*types)
return next(iter(registrators)) if len(registrators) > 0 else None
# noinspection PyTypeChecker
def get_last(self, *types: Optional[str | Type[Registrator]]) -> Optional[Registrator]:
registrators = self.get_all(*types)
return next(reversed(registrators)) if len(registrators) > 0 else None
def has_type(self, *types: str | Type[Registrator]) -> bool:
if len(types) == 0:
raise ValueError("At least one type to look up required")
return len(self.get_all(*types)) > 0
# noinspection PyTypeChecker, PyUnresolvedReferences
def _get_registator_class(self) -> Type[Registrator]:
return self.__registrator_class
@classmethod
def _get_registrators_type(cls) -> str:
return cls.TYPE
@abstractmethod
def _load_registrators_configs(self) -> Configurations: ...
@abstractmethod
def load(self, **kwargs: Any) -> Sequence[Registrator]:
pass