Source code for mini_arcade.utils.implementation_registry

"""
Typed, safe, and subclass-isolated implementation registry.
"""

from __future__ import annotations

import re
import threading
from abc import ABC
from typing import (
    Callable,
    ClassVar,
    Generic,
    Iterator,
    Mapping,
    MutableMapping,
    Optional,
    Type,
    TypeVar,
)

# pylint: disable=invalid-name
ImplementationType = TypeVar(
    "ImplementationType"
)  # implementation type (classes deriving from implementation_base)
# pylint: enable=invalid-name


[docs] class ImplementationRegistry(Generic[ImplementationType]): """ A registry of *classes* implementing a specific interface. Subclasses **must** set `implementation_base` to the ABC (or base class) that all implementations derive from. :ivar implementation_base: ClassVar[type]: The base class for registered implementations. :ivar _registry: ClassVar[MutableMapping[str, Type[ImplementationType]]]: Mapping of names to implementation classes. :ivar _lock: ClassVar[threading.RLock]: Lock for thread-safe operations. """ implementation_base: ClassVar[type] = ABC # override in subclasses _registry: ClassVar[MutableMapping[str, Type[ImplementationType]]] _lock: ClassVar[threading.RLock]
[docs] def __init_subclass__(cls, **kwargs): # type: ignore[override] super().__init_subclass__(**kwargs) cls._registry = {} cls._lock = threading.RLock()
@staticmethod def _infer_name(impl_class: type) -> str: return impl_class.__name__.lower()
[docs] @classmethod def register( cls, name: str, impl_class: Type[ImplementationType], *, replace: bool = False, ): """ Register an implementation implementation class. :param name: Name to register the implementation under. :type name: str :param impl_class: Implementation class to register. :type impl_class: Type[ImplementationType] :param replace: Whether to replace an existing registration. :type replace: bool :raises TypeError: If `impl_class` does not subclass `implementation_base`. :raises KeyError: If `name` is already registered and `replace` is False """ if not issubclass(impl_class, cls.implementation_base): raise TypeError( f"{impl_class.__qualname__} must subclass {cls.implementation_base.__qualname__}" ) with cls._lock: if not replace and name in cls._registry: raise KeyError(f"Implementation '{name}' already registered") cls._registry[name] = impl_class
[docs] @classmethod def unregister(cls, name: str): """ Unregister an implementation implementation by name. :param name: Name of the implementation to unregister. :type name: str """ with cls._lock: cls._registry.pop(name, None)
[docs] @classmethod def implementation( cls, name: Optional[str] = None, *, replace: bool = False ) -> Callable[[Type[ImplementationType]], Type[ImplementationType]]: """ Decorator to register an implementation implementation class. :param name: Name to register the implementation under. :type name: Optional[str] :param replace: Whether to replace an existing registration. :type replace: bool """ def decorator( impl_class: Type[ImplementationType], ) -> Type[ImplementationType]: cls.register( name or cls._infer_name(impl_class), impl_class, replace=replace, ) return impl_class return decorator
[docs] @classmethod def get(cls, name: str) -> Type[ImplementationType]: """ Get the implementation implementation class by name. :param name: Name of the implementation to retrieve. :type name: str :return: Implementation class. :rtype: Type[ImplementationType] :raises KeyError: If the name is not registered. """ try: return cls._registry[name] except KeyError as e: raise KeyError(f"Unknown implementation '{name}'") from e
[docs] @classmethod def try_get(cls, name: str) -> Optional[Type[ImplementationType]]: """ Try to get the implementation implementation class by name. :param name: Name of the implementation to retrieve. :type name: str :return: Implementation class or None if not found. :rtype: Optional[Type[ImplementationType]] """ return cls._registry.get(name, None)
[docs] @classmethod def contains(cls, name: str) -> bool: """ Check if an implementation implementation is registered by name. :param name: Name of the implementation to check. :type name: str :return: True if registered, False otherwise. :rtype: bool """ return name in cls._registry
[docs] @classmethod def all(cls) -> Mapping[str, Type[ImplementationType]]: """ Get a mapping of all registered implementation implementations. :return: Mapping of names to implementation classes. :rtype: Mapping[str, Type[ImplementationType]] """ return dict(cls._registry)
[docs] @classmethod def names(cls) -> list[str]: """ Get a list of all registered implementation implementation names. :return: List of implementation names. :rtype: list[str] """ return list(cls._registry.keys())
[docs] @classmethod def find_contains(cls, needle: str) -> list[Type[ImplementationType]]: """ Find all implementation implementations whose names contain the given substring (case-insensitive). :param needle: Substring to search for. :type needle: str :return: List of matching implementation classes. :rtype: list[Type[ImplementationType]] """ n = needle.lower() return [ impl for key, impl in cls._registry.items() if n in key.lower() ]
[docs] @classmethod def find_regex(cls, pattern: str) -> list[Type[ImplementationType]]: """ Find all implementation implementations whose names match the given regex pattern (case-insensitive). :param pattern: Regex pattern to search for. :type pattern: str :return: List of matching implementation classes. :rtype: list[Type[ImplementationType]] """ rx = re.compile(pattern, re.IGNORECASE) return [impl for key, impl in cls._registry.items() if rx.search(key)]
[docs] @classmethod def instantiate(cls, name: str, *args, **kwargs) -> ImplementationType: """ Instantiate an implementation implementation by name. :param name: Name of the implementation to instantiate. :type name: str :return: Instance of the implementation. :rtype: ImplementationType """ impl = cls.get(name) return impl(*args, **kwargs) # type: ignore[call-arg]
[docs] @classmethod def clear(cls): """ Clear all registered implementation implementations. """ with cls._lock: cls._registry.clear()
[docs] @classmethod def __iter__(cls) -> Iterator[tuple[str, Type[ImplementationType]]]: return iter(cls._registry.items())
[docs] @classmethod def __len__(cls) -> int: return len(cls._registry)