Source code for lories.data.access

# -*- coding: utf-8 -*-
"""
lories.data.access
~~~~~~~~~~~~~~~~~~


"""

from __future__ import annotations

from typing import Any, Callable, Collection, Iterable, Optional, Type, overload

import pandas as pd
from lories._core import _Context, _Registrator  # noqa
from lories._core._data import DataContext, _DataContext, _DataManager  # noqa
from lories.core import Configurator, Constant, ResourceError
from lories.core.typing import ChannelsArgument, Registrator, Timestamp
from lories.data.channels import Channel, Channels
from lories.data.context import DataContext as _DataAccess
from lories.util import get_context, update_recursive

# FIXME: Remove this once Python >= 3.9 is a requirement
try:
    from typing import Literal

except ImportError:
    from typing_extensions import Literal


# noinspection PyProtectedMember, PyShadowingBuiltins
[docs] class DataAccess(_DataAccess, Configurator): __registrar: _Registrator __context: _DataContext def __init__(self, registrar: Registrator, **kwargs: Any) -> None: registrar = self._assert_registrar(registrar) super().__init__(logger=registrar._logger, **kwargs) self.__registrar = registrar self.__context = self._assert_context(get_context(registrar, _DataManager)) @classmethod def _assert_registrar(cls, registrar: Registrator) -> Registrator: if registrar is None or not isinstance(registrar, _Registrator): raise TypeError(f"Invalid '{cls.__name__}' registrator: {type(registrar)}") return registrar @classmethod def _assert_context(cls, context: DataContext) -> DataContext: if context is None or not isinstance(context, _DataManager): raise TypeError(f"Invalid '{cls.__name__}' context: {type(context)}") return context def __repr__(self) -> str: return f"{type(self).__name__}({', '.join(str(c.key) for c in self.values())})" def __str__(self) -> str: return f"{type(self).__name__}:\n\t" + "\n\t".join(f"{v.key} = {repr(v)}" for v in self.values()) def __getitem__(self, id: Iterable[str] | str) -> Channel | Channels: if isinstance(id, str): return self._get(id) if isinstance(id, Iterable): return Channels([self._get(i) for i in id]) raise KeyError(id) # noinspection PyArgumentList def __getattr__(self, attr): channels = _Context.__getattribute__(self, f"{_Context.__name__}__map") channels_by_key = {c.key: c for c in channels.values()} if attr in channels_by_key: return channels_by_key[attr] raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") def __validate_id(self, id: str) -> str: if not len(id.split(".")) > 1: id = f"{self.__registrar.id}.{id}" return id # noinspection PyArgumentList def _contains(self, __channel: str | Channel) -> bool: channels = _Context.__getattribute__(self, f"{_Context.__name__}__map") if isinstance(__channel, str): __channel = self.__validate_id(__channel) return __channel in channels.keys() if isinstance(__channel, _Registrator): return __channel in channels.values() return False def _get(self, id: str) -> Channel: return super()._get(self.__validate_id(id)) def _set(self, id: str, channel: Channel) -> None: id = self.__validate_id(id) self.context._set(id, channel) super()._set(id, channel) def _create(self, id: str, key: str, type: Type, **configs: Any) -> Channel: return self.context._create(id=id, key=key, type=type, **configs) def _remove(self, *__objects: str | Channel) -> None: for __object in __objects: if isinstance(__object, str): __object = self.__validate_id(__object) self.context._remove(__object) super()._remove(__object) @property def empty(self) -> bool: return len(self.values()) == 0 or self.to_frame(states=False).dropna(axis="index", how="all").empty @property def context(self) -> DataContext: return self.__context def load(self, sort: bool = True) -> Collection[Channel]: channels = [] defaults = {} if self.configs.has_member(Channels.TYPE): configs = self.configs.get_member(Channels.TYPE) defaults = Channel._build_defaults(configs) channels.extend(self._load_from_members(self.__registrar, configs)) channels.extend( self._load_from_file(self.__registrar, self.configs.dirs, f"{Channels.TYPE}.conf", defaults=defaults) ) if sort: self.sort() return channels def add(self, key: str | Constant, **configs: Any) -> None: if isinstance(key, Constant): configs = { **key.to_dict(), **configs, } key = configs.pop("key") configs = Channel._build_configs(configs) channels = self.configs.get_member(Channels.TYPE, ensure_exists=True) if not channels.has_member(key): channels._add_member(key, configs) else: channel_configs = Channel._build_configs(channels[key]) channel_configs = update_recursive(channel_configs, configs, replace=False) channels[key] = channel_configs if self.__registrar.is_configured(): channel_defaults = Channel._build_defaults(channels) channel_configs = Channel._build_configs(channel_defaults) # Be wary of the order. First, update the channel core with the default core # of the configuration file, then update the function arguments. Last, override # everything with the channel specific configurations of the file. channel_configs = update_recursive(channel_configs, configs) channel_configs = update_recursive(channel_configs, channels[key]) channel_id = f"{self.__registrar.id}.{key}" if self._contains(channel_id): self._update(id=channel_id, key=key, **channel_configs) else: channel = self._create(id=channel_id, key=key, **channel_configs) self._add(channel) def register( self, function: Callable[[pd.DataFrame], None], channels: Optional[ChannelsArgument] = None, how: Literal["any", "all"] = "any", unique: bool = False, ) -> None: channels = self._filter_by_args(channels) self.__context.register(function, channels=channels, how=how, unique=unique) def has_logged( self, channels: Optional[ChannelsArgument] = None, start: Optional[Timestamp] = None, end: Optional[Timestamp] = None, timeout: Optional[float] = None, ) -> bool: channels = self._filter_by_args(channels) return self.__context.has_logged(channels=channels, start=start, end=end, timeout=timeout) @overload def from_logger( self, start: Optional[Timestamp] = None, end: Optional[Timestamp] = None, timeout: Optional[float] = None, unique: bool = False, ) -> pd.DataFrame: ... @overload def from_logger( self, channels: ChannelsArgument, start: Optional[Timestamp] = None, end: Optional[Timestamp] = None, timeout: Optional[float] = None, unique: bool = False, ) -> pd.DataFrame: ... def from_logger( self, channels: Optional[ChannelsArgument] = None, start: Optional[Timestamp] = None, end: Optional[Timestamp] = None, timeout: Optional[float] = None, unique: bool = False, ) -> pd.DataFrame: return self.read_logged(channels=channels, start=start, end=end, timeout=timeout, unique=unique) def read_logged( self, channels: Optional[ChannelsArgument] = None, start: Optional[Timestamp] = None, end: Optional[Timestamp] = None, timeout: Optional[float] = None, unique: bool = False, ) -> pd.DataFrame: channels = self._filter_by_args(channels) data = self.__context.read_logged(channels=channels, start=start, end=end, timeout=timeout) if not unique: data.rename(columns={c.id: c.key for c in channels}, inplace=True) return data def read( self, channels: Optional[ChannelsArgument] = None, timeout: Optional[float] = None, unique: bool = False, **kwargs, ) -> pd.DataFrame: channels = self._filter_by_args(channels) data = self.__context.read(channels=channels, timeout=timeout, **kwargs) if not unique: data.rename(columns={c.id: c.key for c in channels}, inplace=True) return data def write( self, data: pd.DataFrame, channels: Optional[ChannelsArgument] = None, timeout: Optional[float] = None, ) -> None: if data is None: raise ResourceError(f"Invalid data to write '{self.id}': {data}") data.rename(columns={c.key: c.id for c in channels}, inplace=True) channels = self._filter_by_args(channels) self.__context.write(data, channels=channels, timeout=timeout)