diff --git a/.flake8 b/.flake8 index 93ae223..db35656 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,6 @@ [flake8] exclude = __pycache__,built,build,venv -ignore = E203, E266, W503, E701, E704 +ignore = E203, E266, W503, E701, E704, C901 max-line-length = 88 max-complexity = 18 select = B,C,E,F,W,T4,B9 diff --git a/CHANGELOG.md b/CHANGELOG.md index 2790c9b..f2d4e97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [2.1.0] - 2026-03-?? +## [2.1.0] - 2026-03-08 :woman: - Improve `resolve()` typing, by @sobolevn. - Use `Self` type for Container, by @sobolevn. @@ -35,6 +35,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Resolves [issue #43](https://github.com/Neoteroi/rodi/issues/43), reported by [@lucas-labs](https://github.com/lucas-labs). +- Add support for the [decorator pattern](https://en.wikipedia.org/wiki/Decorator_pattern) + via `Container.decorate(base_type, decorator_type)`. The decorator class must have an + `__init__` parameter whose type annotation matches the registered type; that parameter + receives the inner service instance, while all other parameters are resolved from the + container as usual. Decorators can be chained by calling `decorate()` multiple times — + each call wraps the previous registration: + + ```python + container.add_singleton(IGreeter, SimpleGreeter) + container.decorate(IGreeter, LoggingGreeter) # wraps SimpleGreeter + container.decorate(IGreeter, CachingGreeter) # wraps LoggingGreeter + # resolves as: CachingGreeter(LoggingGreeter(SimpleGreeter())) + ``` + + Resolves [issue #15](https://github.com/Neoteroi/rodi/issues/15), requested by @Eldar1205. ## [2.0.8] - 2025-04-12 diff --git a/examples/README.md b/examples/README.md index e81eff5..f0f0345 100644 --- a/examples/README.md +++ b/examples/README.md @@ -23,3 +23,8 @@ from exact implementations of data access logic). ## example-03.py This example illustrates how to configure a singleton object. + + +## example-04.py + +This example illustrates how to use the decorator pattern (available since `2.1.0`). diff --git a/examples/example-04.py b/examples/example-04.py new file mode 100644 index 0000000..15756db --- /dev/null +++ b/examples/example-04.py @@ -0,0 +1,90 @@ +""" +This example illustrates the decorator pattern using Container.decorate(). + +The decorator pattern lets you wrap a registered service with another implementation +of the same interface, transparently adding behaviour (logging, caching, retries, etc.) +without modifying the original class. + +Rules: +- The decorator class must implement (or be compatible with) the same interface. +- Its __init__ must have exactly one parameter whose type annotation matches the + registered base type; that parameter receives the inner service instance. +- All other __init__ parameters (and class-level annotations) are resolved from the + container as usual. +- Calling decorate() multiple times chains decorators — each call wraps the previous + registration, so the last registered decorator is the outermost one. +""" +from abc import ABC, abstractmethod + +from rodi import Container + + +# --- Domain interface --- + + +class MessageSender(ABC): + @abstractmethod + def send(self, message: str) -> None: + """Sends a message.""" + + +# --- Concrete implementation --- + + +class ConsoleSender(MessageSender): + """Sends messages by printing them to the console.""" + + def send(self, message: str) -> None: + print(f"[console] {message}") + + +# --- Decorator 1: logging --- + + +class LoggingMessageSender(MessageSender): + """Decorator that records every sent message before delegating.""" + + def __init__(self, inner: MessageSender) -> None: + self.inner = inner + self.log: list[str] = [] + + def send(self, message: str) -> None: + self.log.append(message) + self.inner.send(message) + + +# --- Decorator 2: prefixing (chained on top of the logging decorator) --- + + +class PrefixedMessageSender(MessageSender): + """Decorator that prepends a fixed prefix to every message.""" + + def __init__(self, inner: MessageSender) -> None: + self.inner = inner + + def send(self, message: str) -> None: + self.inner.send(f"[app] {message}") + + +# --- Wiring --- + +container = Container() + +container.add_singleton(MessageSender, ConsoleSender) +container.decorate(MessageSender, LoggingMessageSender) # wraps ConsoleSender +container.decorate(MessageSender, PrefixedMessageSender) # wraps LoggingMessageSender + +sender = container.resolve(MessageSender) + +# Resolution order: PrefixedMessageSender → LoggingMessageSender → ConsoleSender +assert isinstance(sender, PrefixedMessageSender) +assert isinstance(sender.inner, LoggingMessageSender) +assert isinstance(sender.inner.inner, ConsoleSender) + +sender.send("Hello, world") +# prints: [console] [app] Hello, world + +assert sender.inner.log == ["[app] Hello, world"] + +# Singleton: same instance every time +assert sender is container.resolve(MessageSender) diff --git a/rodi/__init__.py b/rodi/__init__.py index 84a3997..eda42ac 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -209,6 +209,27 @@ def __init__(self, _type): ) +class DecoratorRegistrationException(DIException): + """ + Exception raised when registering a decorator fails, either because the base type + is not registered or because the decorator class has no parameter matching the + base type. + """ + + def __init__(self, base_type, decorator_type): + if decorator_type is None: + super().__init__( + f"Cannot register a decorator for type '{class_name(base_type)}': " + f"the type is not registered in the container." + ) + else: + super().__init__( + f"Cannot register '{class_name(decorator_type)}' as a decorator for " + f"'{class_name(base_type)}': no __init__ parameter with a type " + f"annotation matching '{class_name(base_type)}' was found." + ) + + class ServiceLifeStyle(Enum): TRANSIENT = 1 SCOPED = 2 @@ -787,6 +808,143 @@ def __call__(self, context: ResolutionContext): return FactoryTypeProvider(self.concrete_type, self.factory) +def _get_resolver_lifestyle(resolver) -> "ServiceLifeStyle": + """Returns the ServiceLifeStyle of a resolver, defaulting to SINGLETON.""" + if isinstance(resolver, (DynamicResolver, FactoryResolver)): + return resolver.life_style + return ServiceLifeStyle.SINGLETON + + +class DecoratorResolver: + """ + Resolver that wraps an existing resolver with a decorator class. The decorator + must have an __init__ parameter whose type annotation matches (or is a supertype + of) the registered base type; that parameter receives the inner service instance. + All other __init__ parameters are resolved normally from the container. + """ + + __slots__ = ( + "_base_type", + "_decorator_type", + "_inner_resolver", + "services", + "life_style", + ) + + def __init__(self, base_type, decorator_type, inner_resolver, services, life_style): + self._base_type = base_type + self._decorator_type = decorator_type + self._inner_resolver = inner_resolver + self.services = services + self.life_style = life_style + + def _get_resolver(self, desired_type, context: ResolutionContext): + if desired_type in context.resolved: + return context.resolved[desired_type] + reg = self.services._map.get(desired_type) + assert ( + reg is not None + ), f"A resolver for type {class_name(desired_type)} is not configured" + resolver = reg(context) + context.resolved[desired_type] = resolver + return resolver + + def __call__(self, context: ResolutionContext): + inner_provider = self._inner_resolver(context) + + sig = Signature.from_callable(self._decorator_type.__init__) + params = { + key: Dependency(key, value.annotation) + for key, value in sig.parameters.items() + } + + globalns = dict(vars(sys.modules[self._decorator_type.__module__])) + globalns.update(_get_obj_globals(self._decorator_type)) + try: + annotations = get_type_hints( + self._decorator_type.__init__, + globalns, + _get_obj_locals(self._decorator_type), + ) + for key, value in params.items(): + if key in annotations: + value.annotation = annotations[key] + except Exception: + pass + + fns = [] + decoratee_found = False + + for param_name, dep in params.items(): + if param_name in ("self", "args", "kwargs"): + continue + + annotation = dep.annotation + if ( + annotation is not _empty + and isclass(annotation) + and annotation is not object + and issubclass(self._base_type, annotation) + ): + fns.append(inner_provider) + decoratee_found = True + else: + if annotation is _empty or annotation not in self.services._map: + raise CannotResolveParameterException( + param_name, self._decorator_type + ) + fns.append(self._get_resolver(annotation, context)) + + if not decoratee_found: + raise DecoratorRegistrationException(self._base_type, self._decorator_type) + + # Also resolve class-level annotations (property injection), excluding any + # names already covered by __init__ params or ClassVar / pre-initialised attrs. + init_param_names = set(params.keys()) + annotation_resolvers: dict[str, Callable] = {} + + if self._decorator_type.__annotations__: + class_hints = get_type_hints( + self._decorator_type, + { + **dict(vars(sys.modules[self._decorator_type.__module__])), + **_get_obj_globals(self._decorator_type), + }, + _get_obj_locals(self._decorator_type), + ) + for attr_name, attr_type in class_hints.items(): + if attr_name in init_param_names: + continue + is_classvar = getattr(attr_type, "__origin__", None) is ClassVar + is_initialized = ( + getattr(self._decorator_type, attr_name, None) is not None + ) + if is_classvar or is_initialized: + continue + if attr_type not in self.services._map: + raise CannotResolveParameterException( + attr_name, self._decorator_type + ) + annotation_resolvers[attr_name] = self._get_resolver(attr_type, context) + + decorator_type = self._decorator_type + + if annotation_resolvers: + + def factory(context, parent_type): + instance = decorator_type(*[fn(context, parent_type) for fn in fns]) + for name, resolver in annotation_resolvers.items(): + setattr(instance, name, resolver(context, parent_type)) + return instance + + else: + + def factory(context, parent_type): + return decorator_type(*[fn(context, parent_type) for fn in fns]) + + return FactoryResolver(decorator_type, factory, self.life_style)(context) + + first_cap_re = re.compile("(.)([A-Z][a-z]+)") all_cap_re = re.compile("([a-z0-9])([A-Z])") @@ -1227,6 +1385,38 @@ def add_transient( return self.bind_types(base_type, concrete_type, ServiceLifeStyle.TRANSIENT) + def decorate( + self: _ContainerSelf, + base_type: Type, + decorator_type: Type, + ) -> _ContainerSelf: + """ + Registers a decorator for an already-registered type. The decorator wraps the + existing service: when base_type is resolved, the decorator instance is returned + with the inner service injected as the decorated dependency. + + The decorator class must have an __init__ parameter whose type annotation is + base_type (or a supertype of it); that parameter receives the inner service. + All other __init__ parameters are resolved from the container as usual. + + Calling decorate() multiple times for the same type chains the decorators — + each wrapping the previous one (last registered = outermost decorator). + + :param base_type: the type being decorated (must already be registered) + :param decorator_type: the decorator class + :return: the service collection itself + """ + existing = self._map.get(base_type) + if existing is None: + raise DecoratorRegistrationException(base_type, None) + life_style = _get_resolver_lifestyle(existing) + self._map[base_type] = DecoratorResolver( + base_type, decorator_type, existing, self, life_style + ) + if self._provider is not None: + self._provider = None + return self + def _add_exact_singleton( self: _ContainerSelf, concrete_type: Type ) -> _ContainerSelf: diff --git a/tests/examples.py b/tests/examples.py index 77dfc76..b3df35b 100644 --- a/tests/examples.py +++ b/tests/examples.py @@ -327,3 +327,86 @@ class MixedAnnotationOverlapsInit: def __init__(self, dep1: MixedDep1) -> None: self.dep1 = dep1 + + +# Classes for testing the decorator pattern + + +class IGreeter: + def greet(self, name: str) -> str: + raise NotImplementedError + + +class SimpleGreeter(IGreeter): + def greet(self, name: str) -> str: + return f"Hello, {name}" + + +class LoggingGreeter(IGreeter): + """Decorator that logs calls before delegating to the inner greeter.""" + + def __init__(self, inner: IGreeter) -> None: + self.inner = inner + self.calls: list = [] + + def greet(self, name: str) -> str: + self.calls.append(name) + return self.inner.greet(name) + + +class ExclamatoryGreeter(IGreeter): + """Second decorator that adds an exclamation mark.""" + + def __init__(self, inner: IGreeter) -> None: + self.inner = inner + + def greet(self, name: str) -> str: + return self.inner.greet(name) + "!" + + +class Logger: + """A simple logger dependency for decorator tests.""" + + def __init__(self) -> None: + self.messages: list = [] + + def log(self, message: str) -> None: + self.messages.append(message) + + +class GreeterWithExtraDep(IGreeter): + """Decorator that has both the decorated service and an additional dependency.""" + + def __init__(self, inner: IGreeter, logger: Logger) -> None: + self.inner = inner + self.logger = logger + + def greet(self, name: str) -> str: + self.logger.log(f"greet({name})") + return self.inner.greet(name) + + +class DecoratorNoMatchingParam(IGreeter): + """Decorator whose __init__ has no parameter matching IGreeter — invalid.""" + + def __init__(self, logger: Logger) -> None: + self.logger = logger + + def greet(self, name: str) -> str: + return "" + + +class LoggingGreeterWithClassProp(IGreeter): + """ + Decorator with the decoratee in __init__ and an extra dependency as a + class-level annotation (property injection). + """ + + logger: Logger + + def __init__(self, inner: IGreeter) -> None: + self.inner = inner + + def greet(self, name: str) -> str: + self.logger.log(f"greet({name})") + return self.inner.greet(name) diff --git a/tests/test_services.py b/tests/test_services.py index 2dbdd8f..59db1d5 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -28,6 +28,7 @@ CircularDependencyException, Container, ContainerProtocol, + DecoratorRegistrationException, DynamicResolver, FactoryMissingContextException, InstanceResolver, @@ -42,7 +43,7 @@ inject, to_standard_param_name, ) -from tests.examples import ( +from tests.examples import ( # decorator pattern examples A, B, C, @@ -50,20 +51,26 @@ CatsController, Circle, Circle2, + DecoratorNoMatchingParam, + ExclamatoryGreeter, Foo, FooByParamName, FooDBCatsRepository, FooDBContext, GetCatRequestHandler, + GreeterWithExtraDep, IByParamName, ICatsRepository, ICircle, IdGetter, + IGreeter, InMemoryCatsRepository, IRequestContext, Jang, Jing, Ko, + Logger, + LoggingGreeter, MixedAnnotationOverlapsInit, MixedDep1, MixedDep2, @@ -80,6 +87,7 @@ ResolveThisByParameterName, ServiceSettings, Shape, + SimpleGreeter, TrickyCircle, TypeWithOptional, UfoFour, @@ -2912,3 +2920,145 @@ def test_mixed_annotation_overlaps_init_param(): instance = provider.get(MixedAnnotationOverlapsInit) assert isinstance(instance.dep1, MixedDep1) + + +# Tests for the decorator pattern (issue #15) + + +def test_decorator_basic_transient(): + """Basic decorator wrapping: resolving IGreeter returns the decorator instance.""" + container = Container() + container.add_transient(IGreeter, SimpleGreeter) + container.decorate(IGreeter, LoggingGreeter) + provider = container.build_provider() + + instance = provider.get(IGreeter) + assert isinstance(instance, LoggingGreeter) + assert isinstance(instance.inner, SimpleGreeter) + assert instance.greet("World") == "Hello, World" + assert instance.calls == ["World"] + + +def test_decorator_transient_new_instance_each_time(): + """Each resolve of a transient-decorated type gives a fresh decorator instance.""" + container = Container() + container.add_transient(IGreeter, SimpleGreeter) + container.decorate(IGreeter, LoggingGreeter) + provider = container.build_provider() + + a = provider.get(IGreeter) + b = provider.get(IGreeter) + assert a is not b + assert a.inner is not b.inner + + +def test_decorator_singleton(): + """Decorator respects singleton lifetime — same instance returned every time.""" + container = Container() + container.add_singleton(IGreeter, SimpleGreeter) + container.decorate(IGreeter, LoggingGreeter) + provider = container.build_provider() + + a = provider.get(IGreeter) + b = provider.get(IGreeter) + assert a is b + assert isinstance(a, LoggingGreeter) + assert isinstance(a.inner, SimpleGreeter) + + +def test_decorator_scoped(): + """Decorator respects scoped lifetime — same instance within a scope.""" + container = Container() + container.add_scoped(IGreeter, SimpleGreeter) + container.decorate(IGreeter, LoggingGreeter) + provider = container.build_provider() + + with provider.create_scope() as scope: + a = provider.get(IGreeter, scope) + b = provider.get(IGreeter, scope) + assert a is b + assert isinstance(a, LoggingGreeter) + + with provider.create_scope() as scope2: + c = provider.get(IGreeter, scope2) + assert c is not a + + +def test_decorator_with_additional_dependency(): + """Decorator that has both the inner service and another injected dependency.""" + container = Container() + container.add_transient(IGreeter, SimpleGreeter) + container.add_transient(Logger) + container.decorate(IGreeter, GreeterWithExtraDep) + provider = container.build_provider() + + instance = provider.get(IGreeter) + assert isinstance(instance, GreeterWithExtraDep) + assert isinstance(instance.inner, SimpleGreeter) + assert isinstance(instance.logger, Logger) + instance.greet("Alice") + assert instance.logger.messages == ["greet(Alice)"] + + +def test_decorator_chaining(): + """Multiple decorate() calls chain decorators outermost-last.""" + container = Container() + container.add_transient(IGreeter, SimpleGreeter) + container.decorate(IGreeter, LoggingGreeter) + container.decorate(IGreeter, ExclamatoryGreeter) + provider = container.build_provider() + + instance = provider.get(IGreeter) + # Resolution: ExclamatoryGreeter(LoggingGreeter(SimpleGreeter())) + assert isinstance(instance, ExclamatoryGreeter) + assert isinstance(instance.inner, LoggingGreeter) + assert isinstance(instance.inner.inner, SimpleGreeter) + assert instance.greet("Bob") == "Hello, Bob!" + + +def test_decorator_error_base_type_not_registered(): + """decorate() raises when the base type is not registered.""" + container = Container() + with raises(DecoratorRegistrationException): + container.decorate(IGreeter, LoggingGreeter) + + +def test_decorator_error_no_matching_param(): + """ + build_provider() raises when the decorator has no parameter matching base type. + """ + container = Container() + container.add_transient(IGreeter, SimpleGreeter) + container.add_transient(Logger) + container.decorate(IGreeter, DecoratorNoMatchingParam) + with raises(DecoratorRegistrationException): + container.build_provider() + + +def test_decorator_fluent_chaining(): + """decorate() returns self for fluent chaining.""" + container = Container() + container.add_transient(IGreeter, SimpleGreeter) + result = container.decorate(IGreeter, LoggingGreeter) + assert result is container + + +def test_decorator_class_property_injection(): + """ + Decorator with the decoratee in __init__ and an extra dep as a class annotation: + both should be injected (constructor injection + setattr). + """ + from tests.examples import LoggingGreeterWithClassProp + + container = Container() + container.add_transient(IGreeter, SimpleGreeter) + container.add_transient(Logger) + container.decorate(IGreeter, LoggingGreeterWithClassProp) + provider = container.build_provider() + + instance = provider.get(IGreeter) + assert isinstance(instance, LoggingGreeterWithClassProp) + assert isinstance(instance.inner, SimpleGreeter) + assert isinstance(instance.logger, Logger) + instance.greet("Alice") + assert instance.logger.messages == ["greet(Alice)"]