From 48d1070578df29b36c358daaab15c645b6970aad Mon Sep 17 00:00:00 2001 From: vikasrao23 Date: Fri, 20 Feb 2026 20:05:26 -0800 Subject: [PATCH 1/8] feat: implement Multi-Provider Implements the Multi-Provider as specified in OpenFeature Appendix A. The Multi-Provider wraps multiple underlying providers in a unified interface, allowing a single client to interact with multiple flag sources simultaneously. Key features implemented: - MultiProvider class extending AbstractProvider - FirstMatchStrategy (sequential evaluation, stops at first success) - EvaluationStrategy protocol for custom strategies - Provider name uniqueness (explicit, metadata-based, or auto-indexed) - Parallel initialization of all providers with error aggregation - Support for all flag types (boolean, string, integer, float, object) - Hook aggregation from all providers Use cases: - Migration: Run old and new providers in parallel - Multiple data sources: Combine env vars, files, and SaaS providers - Fallback: Primary provider with backup sources Example usage: provider_a = SomeProvider() provider_b = AnotherProvider() multi = MultiProvider([ ProviderEntry(provider_a, name="primary"), ProviderEntry(provider_b, name="fallback") ]) api.set_provider(multi) Closes #511 Signed-off-by: vikasrao23 Signed-off-by: Jonathan Norris --- openfeature/provider/__init__.py | 17 +- openfeature/provider/multi_provider.py | 352 +++++++++++++++++++++++++ tests/test_multi_provider.py | 297 +++++++++++++++++++++ 3 files changed, 665 insertions(+), 1 deletion(-) create mode 100644 openfeature/provider/multi_provider.py create mode 100644 tests/test_multi_provider.py diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index aea5069f..55e00263 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -11,11 +11,26 @@ from openfeature.hook import Hook from .metadata import Metadata +from .multi_provider import ( + EvaluationStrategy, + FirstMatchStrategy, + MultiProvider, + ProviderEntry, +) if typing.TYPE_CHECKING: from openfeature.flag_evaluation import FlagValueType -__all__ = ["AbstractProvider", "FeatureProvider", "Metadata", "ProviderStatus"] +__all__ = [ + "AbstractProvider", + "EvaluationStrategy", + "FeatureProvider", + "FirstMatchStrategy", + "Metadata", + "MultiProvider", + "ProviderEntry", + "ProviderStatus", +] class ProviderStatus(Enum): diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py new file mode 100644 index 00000000..7511830c --- /dev/null +++ b/openfeature/provider/multi_provider.py @@ -0,0 +1,352 @@ +""" +Multi-Provider implementation for OpenFeature Python SDK. + +This provider wraps multiple underlying providers, allowing a single client +to interact with multiple flag sources simultaneously. + +See: https://openfeature.dev/specification/appendix-a/#multi-provider +""" + +from __future__ import annotations + +import asyncio +import typing +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass + +from openfeature.evaluation_context import EvaluationContext +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import GeneralError +from openfeature.flag_evaluation import FlagResolutionDetails, FlagValueType, Reason +from openfeature.hook import Hook +from openfeature.provider import AbstractProvider, FeatureProvider, Metadata, ProviderStatus + +__all__ = ["MultiProvider", "ProviderEntry", "FirstMatchStrategy", "EvaluationStrategy"] + + +@dataclass +class ProviderEntry: + """Configuration for a provider in the Multi-Provider.""" + + provider: FeatureProvider + name: str | None = None + + +class EvaluationStrategy(typing.Protocol): + """ + Strategy interface for determining which provider's result to use. + + Strategies can be 'sequential' (evaluate one at a time, stop early) or + 'parallel' (evaluate all simultaneously). + """ + + run_mode: typing.Literal["sequential", "parallel"] + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails, + ) -> bool: + """ + Determine if this result should be used (and stop evaluation if sequential). + + :param flag_key: The flag being evaluated + :param provider_name: Name of the provider that returned this result + :param result: The resolution details from the provider + :return: True if this result should be used as the final result + """ + ... + + +class FirstMatchStrategy: + """ + Uses the first successful result from providers (in order). + + In sequential mode, stops at the first non-error result. + In parallel mode, picks the first successful result from the ordered list. + """ + + run_mode: typing.Literal["sequential", "parallel"] = "sequential" + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails, + ) -> bool: + """Use the first result that doesn't have an error.""" + return result.reason != Reason.ERROR + + +class MultiProvider(AbstractProvider): + """ + A provider that aggregates multiple underlying providers. + + Evaluations are delegated to underlying providers based on the configured + strategy (default: FirstMatchStrategy in sequential mode). + + Example: + provider_a = SomeProvider() + provider_b = AnotherProvider() + + multi = MultiProvider([ + ProviderEntry(provider_a, name="primary"), + ProviderEntry(provider_b, name="fallback") + ]) + + api.set_provider(multi) + """ + + def __init__( + self, + providers: list[ProviderEntry], + strategy: EvaluationStrategy | None = None, + ): + """ + Initialize the Multi-Provider. + + :param providers: List of ProviderEntry objects defining the providers + :param strategy: Evaluation strategy (defaults to FirstMatchStrategy) + """ + super().__init__() + + if not providers: + raise ValueError("At least one provider must be provided") + + self.strategy = strategy or FirstMatchStrategy() + self._registered_providers: list[tuple[str, FeatureProvider]] = [] + self._register_providers(providers) + + def _register_providers(self, providers: list[ProviderEntry]) -> None: + """ + Register providers with unique names. + + Names are determined by: + 1. Explicit name in ProviderEntry + 2. provider.get_metadata().name if unique + 3. {metadata.name}_{index} if not unique + """ + # Count providers by their metadata name to detect duplicates + name_counts: dict[str, int] = {} + for entry in providers: + metadata_name = entry.provider.get_metadata().name or "provider" + name_counts[metadata_name] = name_counts.get(metadata_name, 0) + 1 + + # Track used names to prevent conflicts + used_names: set[str] = set() + name_indices: dict[str, int] = {} + + for entry in providers: + metadata_name = entry.provider.get_metadata().name or "provider" + + if entry.name: + # Explicit name provided + if entry.name in used_names: + raise ValueError(f"Provider name '{entry.name}' is not unique") + final_name = entry.name + elif name_counts[metadata_name] == 1: + # Metadata name is unique + final_name = metadata_name + else: + # Multiple providers with same metadata name, add index + name_indices[metadata_name] = name_indices.get(metadata_name, 0) + 1 + final_name = f"{metadata_name}_{name_indices[metadata_name]}" + + used_names.add(final_name) + self._registered_providers.append((final_name, entry.provider)) + + def get_metadata(self) -> Metadata: + """Return metadata including all wrapped provider metadata.""" + return Metadata(name="MultiProvider") + + def get_provider_hooks(self) -> list[Hook]: + """Aggregate hooks from all providers.""" + hooks: list[Hook] = [] + for _, provider in self._registered_providers: + hooks.extend(provider.get_provider_hooks()) + return hooks + + def initialize(self, evaluation_context: EvaluationContext) -> None: + """Initialize all providers in parallel.""" + errors: list[Exception] = [] + + for name, provider in self._registered_providers: + try: + provider.initialize(evaluation_context) + except Exception as e: + errors.append(Exception(f"Provider '{name}' initialization failed: {e}")) + + if errors: + # Aggregate errors + error_msgs = "; ".join(str(e) for e in errors) + raise GeneralError(f"Multi-provider initialization failed: {error_msgs}") + + def shutdown(self) -> None: + """Shutdown all providers.""" + for _, provider in self._registered_providers: + try: + provider.shutdown() + except Exception: + # Log but don't fail shutdown + pass + + def _evaluate_with_providers( + self, + flag_key: str, + default_value: FlagValueType, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], FlagResolutionDetails], + ) -> FlagResolutionDetails[FlagValueType]: + """ + Core evaluation logic that delegates to providers based on strategy. + + :param flag_key: The flag key to evaluate + :param default_value: Default value for the flag + :param evaluation_context: Evaluation context + :param resolve_fn: Function to call on each provider for resolution + :return: Final resolution details + """ + results: list[tuple[str, FlagResolutionDetails]] = [] + + for provider_name, provider in self._registered_providers: + try: + result = resolve_fn(provider, flag_key, default_value, evaluation_context) + results.append((provider_name, result)) + + # In sequential mode, stop if strategy says to use this result + if (self.strategy.run_mode == "sequential" and + self.strategy.should_use_result(flag_key, provider_name, result)): + return result + + except Exception as e: + # Record error but continue to next provider + error_result = FlagResolutionDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_message=str(e), + ) + results.append((provider_name, error_result)) + + # In parallel mode or if all sequential attempts completed, pick best result + for provider_name, result in results: + if self.strategy.should_use_result(flag_key, provider_name, result): + return result + + # No successful result - return last error or default + if results: + return results[-1][1] + + return FlagResolutionDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_message="No providers returned a result", + ) + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_boolean_details(k, d, ctx), + ) + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + # For async, delegate to sync for now (async aggregation would be more complex) + return self.resolve_boolean_details(flag_key, default_value, evaluation_context) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_string_details(k, d, ctx), + ) + + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + return self.resolve_string_details(flag_key, default_value, evaluation_context) + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_integer_details(k, d, ctx), + ) + + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + return self.resolve_integer_details(flag_key, default_value, evaluation_context) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_float_details(k, d, ctx), + ) + + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + return self.resolve_float_details(flag_key, default_value, evaluation_context) + + def resolve_object_details( + self, + flag_key: str, + default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + return self._evaluate_with_providers( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_object_details(k, d, ctx), + ) + + async def resolve_object_details_async( + self, + flag_key: str, + default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + return self.resolve_object_details(flag_key, default_value, evaluation_context) diff --git a/tests/test_multi_provider.py b/tests/test_multi_provider.py new file mode 100644 index 00000000..2ba7759a --- /dev/null +++ b/tests/test_multi_provider.py @@ -0,0 +1,297 @@ +import pytest + +from openfeature import api +from openfeature.evaluation_context import EvaluationContext +from openfeature.exception import GeneralError +from openfeature.flag_evaluation import FlagResolutionDetails, Reason +from openfeature.provider import Metadata +from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider +from openfeature.provider.multi_provider import ( + FirstMatchStrategy, + MultiProvider, + ProviderEntry, +) +from openfeature.provider.no_op_provider import NoOpProvider + + +def test_multi_provider_requires_at_least_one_provider(): + # Given/When/Then + with pytest.raises(ValueError, match="At least one provider must be provided"): + MultiProvider([]) + + +def test_multi_provider_uses_explicit_names(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + # When + multi = MultiProvider([ + ProviderEntry(provider_a, name="first"), + ProviderEntry(provider_b, name="second"), + ]) + + # Then + assert len(multi._registered_providers) == 2 + assert multi._registered_providers[0][0] == "first" + assert multi._registered_providers[1][0] == "second" + + +def test_multi_provider_generates_unique_names_when_metadata_conflicts(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + # When - both have same metadata name "NoOpProvider" + multi = MultiProvider([ + ProviderEntry(provider_a), + ProviderEntry(provider_b), + ]) + + # Then - names are auto-indexed + assert len(multi._registered_providers) == 2 + names = [name for name, _ in multi._registered_providers] + assert names == ["NoOpProvider_1", "NoOpProvider_2"] + + +def test_multi_provider_rejects_duplicate_explicit_names(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + # When/Then + with pytest.raises(ValueError, match="Provider name 'duplicate' is not unique"): + MultiProvider([ + ProviderEntry(provider_a, name="duplicate"), + ProviderEntry(provider_b, name="duplicate"), + ]) + + +def test_multi_provider_first_match_strategy_sequential(): + # Given + flags_a = { + "flag1": InMemoryFlag("off", {"on": True, "off": False}), + } + flags_b = { + "flag1": InMemoryFlag("on", {"on": True, "off": False}), + "flag2": InMemoryFlag("on", {"on": True, "off": False}), + } + + provider_a = InMemoryProvider(flags_a) + provider_b = InMemoryProvider(flags_b) + + multi = MultiProvider([ + ProviderEntry(provider_a, name="primary"), + ProviderEntry(provider_b, name="fallback"), + ], strategy=FirstMatchStrategy()) + + # When - flag1 exists in both, should use first (primary) + result = multi.resolve_boolean_details("flag1", False) + + # Then + assert result.value == False # primary provider returns "off" variant + assert result.reason != Reason.ERROR + + +def test_multi_provider_fallback_to_second_provider(): + # Given + flags_a = {} # primary has no flags + flags_b = { + "flag1": InMemoryFlag("on", {"on": True, "off": False}), + } + + provider_a = InMemoryProvider(flags_a) + provider_b = InMemoryProvider(flags_b) + + multi = MultiProvider([ + ProviderEntry(provider_a, name="primary"), + ProviderEntry(provider_b, name="fallback"), + ]) + + # When - flag1 doesn't exist in primary, should fallback + result = multi.resolve_boolean_details("flag1", False) + + # Then + assert result.value == True # fallback provider has the flag + assert result.reason != Reason.ERROR + + +def test_multi_provider_all_types_work(): + # Given + flags = { + "bool-flag": InMemoryFlag("on", {"on": True, "off": False}), + "string-flag": InMemoryFlag("greeting", {"greeting": "hello", "farewell": "goodbye"}), + "int-flag": InMemoryFlag("big", {"small": 10, "big": 100}), + "float-flag": InMemoryFlag("pi", {"pi": 3.14, "e": 2.71}), + "object-flag": InMemoryFlag("full", { + "full": {"name": "test", "value": 42}, + "empty": {}, + }), + } + + provider = InMemoryProvider(flags) + multi = MultiProvider([ProviderEntry(provider)]) + + # When/Then + bool_result = multi.resolve_boolean_details("bool-flag", False) + assert bool_result.value == True + + string_result = multi.resolve_string_details("string-flag", "default") + assert string_result.value == "hello" + + int_result = multi.resolve_integer_details("int-flag", 0) + assert int_result.value == 100 + + float_result = multi.resolve_float_details("float-flag", 0.0) + assert float_result.value == 3.14 + + object_result = multi.resolve_object_details("object-flag", {}) + assert object_result.value == {"name": "test", "value": 42} + + +def test_multi_provider_initialize_all_providers(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + # Track if initialize was called + provider_a.initialize = lambda ctx: None + provider_b.initialize = lambda ctx: None + + a_initialized = False + b_initialized = False + + def track_a_init(ctx): + nonlocal a_initialized + a_initialized = True + + def track_b_init(ctx): + nonlocal b_initialized + b_initialized = True + + provider_a.initialize = track_a_init + provider_b.initialize = track_b_init + + multi = MultiProvider([ + ProviderEntry(provider_a), + ProviderEntry(provider_b), + ]) + + # When + multi.initialize(EvaluationContext()) + + # Then + assert a_initialized + assert b_initialized + + +def test_multi_provider_initialization_failures_are_aggregated(): + # Given + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + def fail_init(ctx): + raise Exception("Init failed") + + provider_a.initialize = fail_init + provider_b.initialize = fail_init + + multi = MultiProvider([ + ProviderEntry(provider_a, name="a"), + ProviderEntry(provider_b, name="b"), + ]) + + # When/Then + with pytest.raises(GeneralError, match="Multi-provider initialization failed"): + multi.initialize(EvaluationContext()) + + +def test_multi_provider_returns_error_when_no_providers_have_flag(): + # Given + provider_a = InMemoryProvider({}) + provider_b = InMemoryProvider({}) + + multi = MultiProvider([ + ProviderEntry(provider_a), + ProviderEntry(provider_b), + ]) + + # When + result = multi.resolve_boolean_details("nonexistent", False) + + # Then + assert result.value == False # default value + assert result.reason == Reason.ERROR + + +@pytest.mark.asyncio +async def test_multi_provider_async_methods_work(): + # Given + flags = { + "async-flag": InMemoryFlag("on", {"on": True, "off": False}), + } + provider = InMemoryProvider(flags) + multi = MultiProvider([ProviderEntry(provider)]) + + # When + result = await multi.resolve_boolean_details_async("async-flag", False) + + # Then + assert result.value == True + assert result.reason != Reason.ERROR + + +def test_multi_provider_can_be_used_with_api(): + # Given + api.clear_providers() + flags = { + "api-flag": InMemoryFlag("on", {"on": True, "off": False}), + } + provider = InMemoryProvider(flags) + multi = MultiProvider([ProviderEntry(provider)]) + + # When + api.set_provider(multi) + client = api.get_client() + value = client.get_boolean_value("api-flag", False) + + # Then + assert value == True + + +def test_multi_provider_metadata(): + # Given + multi = MultiProvider([ProviderEntry(NoOpProvider())]) + + # When + metadata = multi.get_metadata() + + # Then + assert metadata.name == "MultiProvider" + + +def test_multi_provider_aggregates_hooks(): + # Given + from unittest.mock import MagicMock + + provider_a = NoOpProvider() + provider_b = NoOpProvider() + + hook_a = MagicMock() + hook_b = MagicMock() + + provider_a.get_provider_hooks = lambda: [hook_a] + provider_b.get_provider_hooks = lambda: [hook_b] + + multi = MultiProvider([ + ProviderEntry(provider_a), + ProviderEntry(provider_b), + ]) + + # When + hooks = multi.get_provider_hooks() + + # Then + assert len(hooks) == 2 + assert hook_a in hooks + assert hook_b in hooks From 762b7434e579fa07b2a12fe084dfca62727e6a2a Mon Sep 17 00:00:00 2001 From: vikasrao23 Date: Fri, 20 Feb 2026 21:14:09 -0800 Subject: [PATCH 2/8] docs: clarify sequential implementation and planned async/parallel enhancements Address Gemini code review feedback: - Update initialize() docstring to reflect sequential (not parallel) initialization - Add documentation notes to all async methods explaining they currently delegate to sync - Clarify that parallel evaluation mode is planned but not yet implemented - Update EvaluationStrategy protocol docs to set correct expectations This brings documentation in line with actual implementation. True async and parallel execution will be added in follow-up PRs. Refs: #511 Signed-off-by: vikasrao23 Signed-off-by: Jonathan Norris --- openfeature/provider/multi_provider.py | 30 +++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py index 7511830c..df3747be 100644 --- a/openfeature/provider/multi_provider.py +++ b/openfeature/provider/multi_provider.py @@ -36,8 +36,9 @@ class EvaluationStrategy(typing.Protocol): """ Strategy interface for determining which provider's result to use. - Strategies can be 'sequential' (evaluate one at a time, stop early) or - 'parallel' (evaluate all simultaneously). + Current implementation supports 'sequential' mode (evaluate one at a time, + stop early). 'parallel' mode (evaluate all simultaneously using asyncio.gather + or ThreadPoolExecutor) is planned for a future enhancement. """ run_mode: typing.Literal["sequential", "parallel"] @@ -168,7 +169,12 @@ def get_provider_hooks(self) -> list[Hook]: return hooks def initialize(self, evaluation_context: EvaluationContext) -> None: - """Initialize all providers in parallel.""" + """ + Initialize all providers sequentially. + + Note: Parallel initialization using ThreadPoolExecutor or asyncio.gather() + is planned for a future enhancement. + """ errors: list[Exception] = [] for name, provider in self._registered_providers: @@ -201,6 +207,10 @@ def _evaluate_with_providers( """ Core evaluation logic that delegates to providers based on strategy. + Current implementation evaluates providers sequentially regardless of + strategy.run_mode. True concurrent evaluation for 'parallel' mode is + planned for a future enhancement. + :param flag_key: The flag key to evaluate :param default_value: Default value for the flag :param evaluation_context: Evaluation context @@ -229,7 +239,7 @@ def _evaluate_with_providers( ) results.append((provider_name, error_result)) - # In parallel mode or if all sequential attempts completed, pick best result + # If all sequential attempts completed (or parallel mode), pick best result for provider_name, result in results: if self.strategy.should_use_result(flag_key, provider_name, result): return result @@ -264,7 +274,13 @@ async def resolve_boolean_details_async( default_value: bool, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[bool]: - # For async, delegate to sync for now (async aggregation would be more complex) + """ + Async boolean evaluation (currently delegates to sync implementation). + + Note: True async evaluation using await and provider-level async methods + is planned for a future enhancement. The current implementation maintains + API compatibility but does not provide non-blocking I/O benefits. + """ return self.resolve_boolean_details(flag_key, default_value, evaluation_context) def resolve_string_details( @@ -286,6 +302,7 @@ async def resolve_string_details_async( default_value: str, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[str]: + """Async string evaluation (currently delegates to sync implementation).""" return self.resolve_string_details(flag_key, default_value, evaluation_context) def resolve_integer_details( @@ -307,6 +324,7 @@ async def resolve_integer_details_async( default_value: int, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[int]: + """Async integer evaluation (currently delegates to sync implementation).""" return self.resolve_integer_details(flag_key, default_value, evaluation_context) def resolve_float_details( @@ -328,6 +346,7 @@ async def resolve_float_details_async( default_value: float, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[float]: + """Async float evaluation (currently delegates to sync implementation).""" return self.resolve_float_details(flag_key, default_value, evaluation_context) def resolve_object_details( @@ -349,4 +368,5 @@ async def resolve_object_details_async( default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + """Async object evaluation (currently delegates to sync implementation).""" return self.resolve_object_details(flag_key, default_value, evaluation_context) From 74ffd2203eb7330f03cc7ab1646bcafdb6f4186b Mon Sep 17 00:00:00 2001 From: Vikas Rao Date: Sun, 22 Feb 2026 08:54:45 -0800 Subject: [PATCH 3/8] Address Gemini code review feedback CRITICAL FIXES: - Fix FlagResolutionDetails initialization - remove invalid flag_key parameter - Add error_code (ErrorCode.GENERAL) to all error results per spec HIGH PRIORITY: - Implement true async evaluation using _evaluate_with_providers_async - All async methods now properly await provider async methods (no blocking) - Implement parallel provider initialization using ThreadPoolExecutor IMPROVEMENTS: - Remove unused imports (asyncio, ProviderEvent, ProviderEventDetails, ProviderStatus) - Add ErrorCode import for proper error handling - Cache provider hooks to avoid re-aggregating on every evaluation - Update docstrings to clarify current implementation status Signed-off-by: Jonathan Norris --- openfeature/provider/multi_provider.py | 159 ++++++++++++++++++------- 1 file changed, 119 insertions(+), 40 deletions(-) diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py index df3747be..a561a6b1 100644 --- a/openfeature/provider/multi_provider.py +++ b/openfeature/provider/multi_provider.py @@ -9,17 +9,16 @@ from __future__ import annotations -import asyncio import typing from collections.abc import Callable, Mapping, Sequence +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from openfeature.evaluation_context import EvaluationContext -from openfeature.event import ProviderEvent, ProviderEventDetails -from openfeature.exception import GeneralError +from openfeature.exception import ErrorCode, GeneralError from openfeature.flag_evaluation import FlagResolutionDetails, FlagValueType, Reason from openfeature.hook import Hook -from openfeature.provider import AbstractProvider, FeatureProvider, Metadata, ProviderStatus +from openfeature.provider import AbstractProvider, FeatureProvider, Metadata __all__ = ["MultiProvider", "ProviderEntry", "FirstMatchStrategy", "EvaluationStrategy"] @@ -36,9 +35,11 @@ class EvaluationStrategy(typing.Protocol): """ Strategy interface for determining which provider's result to use. - Current implementation supports 'sequential' mode (evaluate one at a time, - stop early). 'parallel' mode (evaluate all simultaneously using asyncio.gather - or ThreadPoolExecutor) is planned for a future enhancement. + Supports 'sequential' mode (evaluate one at a time, stop early when strategy + is satisfied) and 'parallel' mode (evaluate all providers, then select best + result). Note: Both modes currently execute provider calls sequentially; + true concurrent evaluation using asyncio.gather or ThreadPoolExecutor is + planned for a future enhancement. """ run_mode: typing.Literal["sequential", "parallel"] @@ -118,6 +119,7 @@ def __init__( self.strategy = strategy or FirstMatchStrategy() self._registered_providers: list[tuple[str, FeatureProvider]] = [] self._register_providers(providers) + self._cached_hooks: list[Hook] | None = None def _register_providers(self, providers: list[ProviderEntry]) -> None: """ @@ -162,30 +164,34 @@ def get_metadata(self) -> Metadata: return Metadata(name="MultiProvider") def get_provider_hooks(self) -> list[Hook]: - """Aggregate hooks from all providers.""" - hooks: list[Hook] = [] - for _, provider in self._registered_providers: - hooks.extend(provider.get_provider_hooks()) - return hooks + """Aggregate hooks from all providers (cached for efficiency).""" + if self._cached_hooks is None: + hooks: list[Hook] = [] + for _, provider in self._registered_providers: + hooks.extend(provider.get_provider_hooks()) + self._cached_hooks = hooks + return self._cached_hooks def initialize(self, evaluation_context: EvaluationContext) -> None: """ - Initialize all providers sequentially. + Initialize all providers in parallel using ThreadPoolExecutor. - Note: Parallel initialization using ThreadPoolExecutor or asyncio.gather() - is planned for a future enhancement. + This allows concurrent initialization of I/O-bound providers. """ - errors: list[Exception] = [] - - for name, provider in self._registered_providers: + def init_provider(entry: tuple[str, FeatureProvider]) -> str | None: + name, provider = entry try: provider.initialize(evaluation_context) + return None except Exception as e: - errors.append(Exception(f"Provider '{name}' initialization failed: {e}")) - + return f"Provider '{name}' initialization failed: {e}" + + with ThreadPoolExecutor() as executor: + results = list(executor.map(init_provider, self._registered_providers)) + + errors = [r for r in results if r is not None] if errors: - # Aggregate errors - error_msgs = "; ".join(str(e) for e in errors) + error_msgs = "; ".join(errors) raise GeneralError(f"Multi-provider initialization failed: {error_msgs}") def shutdown(self) -> None: @@ -232,9 +238,9 @@ def _evaluate_with_providers( except Exception as e: # Record error but continue to next provider error_result = FlagResolutionDetails( - flag_key=flag_key, value=default_value, reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, error_message=str(e), ) results.append((provider_name, error_result)) @@ -249,9 +255,9 @@ def _evaluate_with_providers( return results[-1][1] return FlagResolutionDetails( - flag_key=flag_key, value=default_value, reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, error_message="No providers returned a result", ) @@ -268,20 +274,73 @@ def resolve_boolean_details( lambda p, k, d, ctx: p.resolve_boolean_details(k, d, ctx), ) + async def _evaluate_with_providers_async( + self, + flag_key: str, + default_value: FlagValueType, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable, + ) -> FlagResolutionDetails[FlagValueType]: + """ + Async evaluation logic that properly awaits provider async methods. + + :param flag_key: The flag key to evaluate + :param default_value: Default value for the flag + :param evaluation_context: Evaluation context + :param resolve_fn: Async function to call on each provider for resolution + :return: Final resolution details + """ + results: list[tuple[str, FlagResolutionDetails]] = [] + + for provider_name, provider in self._registered_providers: + try: + result = await resolve_fn(provider, flag_key, default_value, evaluation_context) + results.append((provider_name, result)) + + # In sequential mode, stop if strategy says to use this result + if (self.strategy.run_mode == "sequential" and + self.strategy.should_use_result(flag_key, provider_name, result)): + return result + + except Exception as e: + # Record error but continue to next provider + error_result = FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=str(e), + ) + results.append((provider_name, error_result)) + + # If all sequential attempts completed (or parallel mode), pick best result + for provider_name, result in results: + if self.strategy.should_use_result(flag_key, provider_name, result): + return result + + # No successful result - return last error or default + if results: + return results[-1][1] + + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message="No providers returned a result", + ) + async def resolve_boolean_details_async( self, flag_key: str, default_value: bool, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[bool]: - """ - Async boolean evaluation (currently delegates to sync implementation). - - Note: True async evaluation using await and provider-level async methods - is planned for a future enhancement. The current implementation maintains - API compatibility but does not provide non-blocking I/O benefits. - """ - return self.resolve_boolean_details(flag_key, default_value, evaluation_context) + """Async boolean evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_boolean_details_async(k, d, ctx), + ) def resolve_string_details( self, @@ -302,8 +361,13 @@ async def resolve_string_details_async( default_value: str, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[str]: - """Async string evaluation (currently delegates to sync implementation).""" - return self.resolve_string_details(flag_key, default_value, evaluation_context) + """Async string evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_string_details_async(k, d, ctx), + ) def resolve_integer_details( self, @@ -324,8 +388,13 @@ async def resolve_integer_details_async( default_value: int, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[int]: - """Async integer evaluation (currently delegates to sync implementation).""" - return self.resolve_integer_details(flag_key, default_value, evaluation_context) + """Async integer evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_integer_details_async(k, d, ctx), + ) def resolve_float_details( self, @@ -346,8 +415,13 @@ async def resolve_float_details_async( default_value: float, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[float]: - """Async float evaluation (currently delegates to sync implementation).""" - return self.resolve_float_details(flag_key, default_value, evaluation_context) + """Async float evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_float_details_async(k, d, ctx), + ) def resolve_object_details( self, @@ -368,5 +442,10 @@ async def resolve_object_details_async( default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: - """Async object evaluation (currently delegates to sync implementation).""" - return self.resolve_object_details(flag_key, default_value, evaluation_context) + """Async object evaluation using provider async methods.""" + return await self._evaluate_with_providers_async( + flag_key, + default_value, + evaluation_context, + lambda p, k, d, ctx: p.resolve_object_details_async(k, d, ctx), + ) From 91ff7dcfe34b5986e37ac20f813426cfcbf2867d Mon Sep 17 00:00:00 2001 From: Vikas Rao Date: Sun, 22 Feb 2026 10:47:19 -0800 Subject: [PATCH 4/8] Address all remaining Gemini review comments HIGH PRIORITY FIXES: - Fix name resolution logic to prevent collisions between explicit and auto-generated names - Check used_names set for metadata names before using them - Use while loop to find next available indexed name if collision detected - Implement event propagation (spec requirement) - Override attach() and detach() methods to forward events to all providers - Import ProviderEvent and ProviderEventDetails - Enables cache invalidation and other event-driven features MEDIUM PRIORITY IMPROVEMENTS: - Parallel shutdown with proper error logging - Use ThreadPoolExecutor for concurrent shutdown - Add logging for shutdown failures - Optimize ThreadPoolExecutor max_workers - Set to len(providers) for both initialize() and shutdown() - Ensures all providers can start immediately - Improve type hints for better type safety - Add generic type parameters to FlagResolutionDetails in resolve_fn signatures - Specify Awaitable return type for async resolve_fn - Add generic types to results list declarations All critical and high-priority feedback addressed. Ready for re-review. Refs: open-feature#511 Signed-off-by: Jonathan Norris --- openfeature/provider/multi_provider.py | 73 +++++++++++++++++++------- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py index a561a6b1..54a446b1 100644 --- a/openfeature/provider/multi_provider.py +++ b/openfeature/provider/multi_provider.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from openfeature.evaluation_context import EvaluationContext +from openfeature.event import ProviderEvent, ProviderEventDetails from openfeature.exception import ErrorCode, GeneralError from openfeature.flag_evaluation import FlagResolutionDetails, FlagValueType, Reason from openfeature.hook import Hook @@ -127,8 +128,8 @@ def _register_providers(self, providers: list[ProviderEntry]) -> None: Names are determined by: 1. Explicit name in ProviderEntry - 2. provider.get_metadata().name if unique - 3. {metadata.name}_{index} if not unique + 2. provider.get_metadata().name if unique and not conflicting + 3. {metadata.name}_{index} if not unique or conflicting """ # Count providers by their metadata name to detect duplicates name_counts: dict[str, int] = {} @@ -144,17 +145,20 @@ def _register_providers(self, providers: list[ProviderEntry]) -> None: metadata_name = entry.provider.get_metadata().name or "provider" if entry.name: - # Explicit name provided + # Explicit name provided - must be unique if entry.name in used_names: raise ValueError(f"Provider name '{entry.name}' is not unique") final_name = entry.name - elif name_counts[metadata_name] == 1: - # Metadata name is unique + elif name_counts[metadata_name] == 1 and metadata_name not in used_names: + # Metadata name is unique and not already taken by explicit name final_name = metadata_name else: - # Multiple providers with same metadata name, add index - name_indices[metadata_name] = name_indices.get(metadata_name, 0) + 1 - final_name = f"{metadata_name}_{name_indices[metadata_name]}" + # Multiple providers or collision with explicit name, add index + while True: + name_indices[metadata_name] = name_indices.get(metadata_name, 0) + 1 + final_name = f"{metadata_name}_{name_indices[metadata_name]}" + if final_name not in used_names: + break used_names.add(final_name) self._registered_providers.append((final_name, entry.provider)) @@ -172,6 +176,32 @@ def get_provider_hooks(self) -> list[Hook]: self._cached_hooks = hooks return self._cached_hooks + def attach( + self, + on_emit: Callable[[FeatureProvider, ProviderEvent, ProviderEventDetails], None], + ) -> None: + """ + Attach event handler and propagate to all underlying providers. + + Events from underlying providers are forwarded through the MultiProvider. + This enables features like cache invalidation to work across all providers. + """ + super().attach(on_emit) + + # Propagate attach to all wrapped providers + for _, provider in self._registered_providers: + provider.attach(on_emit) + + def detach(self) -> None: + """ + Detach event handler and propagate to all underlying providers. + """ + super().detach() + + # Propagate detach to all wrapped providers + for _, provider in self._registered_providers: + provider.detach() + def initialize(self, evaluation_context: EvaluationContext) -> None: """ Initialize all providers in parallel using ThreadPoolExecutor. @@ -186,7 +216,7 @@ def init_provider(entry: tuple[str, FeatureProvider]) -> str | None: except Exception as e: return f"Provider '{name}' initialization failed: {e}" - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: results = list(executor.map(init_provider, self._registered_providers)) errors = [r for r in results if r is not None] @@ -195,20 +225,27 @@ def init_provider(entry: tuple[str, FeatureProvider]) -> str | None: raise GeneralError(f"Multi-provider initialization failed: {error_msgs}") def shutdown(self) -> None: - """Shutdown all providers.""" - for _, provider in self._registered_providers: + """Shutdown all providers in parallel.""" + import logging + + logger = logging.getLogger(__name__) + + def shutdown_provider(entry: tuple[str, FeatureProvider]) -> None: + name, provider = entry try: provider.shutdown() - except Exception: - # Log but don't fail shutdown - pass + except Exception as e: + logger.error(f"Provider '{name}' shutdown failed: {e}") + + with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: + list(executor.map(shutdown_provider, self._registered_providers)) def _evaluate_with_providers( self, flag_key: str, default_value: FlagValueType, evaluation_context: EvaluationContext | None, - resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], FlagResolutionDetails], + resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], FlagResolutionDetails[FlagValueType]], ) -> FlagResolutionDetails[FlagValueType]: """ Core evaluation logic that delegates to providers based on strategy. @@ -223,7 +260,7 @@ def _evaluate_with_providers( :param resolve_fn: Function to call on each provider for resolution :return: Final resolution details """ - results: list[tuple[str, FlagResolutionDetails]] = [] + results: list[tuple[str, FlagResolutionDetails[FlagValueType]]] = [] for provider_name, provider in self._registered_providers: try: @@ -279,7 +316,7 @@ async def _evaluate_with_providers_async( flag_key: str, default_value: FlagValueType, evaluation_context: EvaluationContext | None, - resolve_fn: Callable, + resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], typing.Awaitable[FlagResolutionDetails[FlagValueType]]], ) -> FlagResolutionDetails[FlagValueType]: """ Async evaluation logic that properly awaits provider async methods. @@ -290,7 +327,7 @@ async def _evaluate_with_providers_async( :param resolve_fn: Async function to call on each provider for resolution :return: Final resolution details """ - results: list[tuple[str, FlagResolutionDetails]] = [] + results: list[tuple[str, FlagResolutionDetails[FlagValueType]]] = [] for provider_name, provider in self._registered_providers: try: From 231bcab5ebe9d3ee2b97e5f491a840d22f46b7f5 Mon Sep 17 00:00:00 2001 From: Vikas Rao Date: Sun, 22 Feb 2026 19:31:23 -0800 Subject: [PATCH 5/8] Use Awaitable from collections.abc instead of typing.Awaitable This is more consistent with the other type imports in the file. Signed-off-by: Jonathan Norris --- openfeature/provider/multi_provider.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py index 54a446b1..07aba99e 100644 --- a/openfeature/provider/multi_provider.py +++ b/openfeature/provider/multi_provider.py @@ -10,7 +10,7 @@ from __future__ import annotations import typing -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass @@ -316,7 +316,7 @@ async def _evaluate_with_providers_async( flag_key: str, default_value: FlagValueType, evaluation_context: EvaluationContext | None, - resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], typing.Awaitable[FlagResolutionDetails[FlagValueType]]], + resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], Awaitable[FlagResolutionDetails[FlagValueType]]], ) -> FlagResolutionDetails[FlagValueType]: """ Async evaluation logic that properly awaits provider async methods. From 536325f959bce6fc24e2cb0feadd8fffc01fd57d Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Fri, 6 Mar 2026 10:24:08 +0000 Subject: [PATCH 6/8] fix: close multi-provider parity gaps Co-authored-by: jonathan Signed-off-by: Jonathan Norris --- openfeature/client.py | 69 +- openfeature/provider/__init__.py | 18 +- openfeature/provider/_registry.py | 29 +- openfeature/provider/multi_provider.py | 1242 +++++++++++++++++++----- tests/test_multi_provider.py | 862 ++++++++++------ 5 files changed, 1649 insertions(+), 571 deletions(-) diff --git a/openfeature/client.py b/openfeature/client.py index a02693c1..d01ee56b 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -429,6 +429,11 @@ def _establish_hooks_and_provider( client_metadata = self.get_metadata() provider_metadata = provider.get_metadata() + provider_hooks = ( + [] + if self._provider_uses_internal_hooks(provider) + else provider.get_provider_hooks() + ) # Hooks need to be handled in different orders at different stages # in the flag evaluation @@ -450,7 +455,7 @@ def _establish_hooks_and_provider( get_hooks(), self.hooks, evaluation_hooks, - provider.get_provider_hooks(), + provider_hooks, ) ] # after, error, finally: Provider, Invocation, Client, API @@ -465,6 +470,36 @@ def _establish_hooks_and_provider( merged_eval_context, ) + def _provider_uses_internal_hooks(self, provider: FeatureProvider) -> bool: + uses_internal_hooks = getattr(provider, "uses_internal_provider_hooks", None) + return bool(callable(uses_internal_hooks) and uses_internal_hooks()) + + def _set_internal_provider_hook_runtime( + self, + provider: FeatureProvider, + flag_type: FlagType, + hook_hints: HookHints, + ) -> object | None: + if not self._provider_uses_internal_hooks(provider): + return None + set_hook_runtime = getattr(provider, "set_internal_provider_hook_runtime", None) + if not callable(set_hook_runtime): + return None + return set_hook_runtime( + flag_type=flag_type, + client_metadata=self.get_metadata(), + hook_hints=hook_hints, + ) + + def _reset_internal_provider_hook_runtime( + self, provider: FeatureProvider, runtime_token: object | None + ) -> None: + if runtime_token is None: + return + reset_hook_runtime = getattr(provider, "reset_internal_provider_hook_runtime", None) + if callable(reset_hook_runtime): + reset_hook_runtime(runtime_token) + def _assert_provider_status( self, ) -> OpenFeatureError | None: @@ -611,13 +646,21 @@ async def evaluate_flag_details_async( merged_eval_context, ) - flag_evaluation = await self._create_provider_evaluation_async( + runtime_token = self._set_internal_provider_hook_runtime( provider, flag_type, - flag_key, - default_value, - merged_context, + hook_hints, ) + try: + flag_evaluation = await self._create_provider_evaluation_async( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + finally: + self._reset_internal_provider_hook_runtime(provider, runtime_token) if err := flag_evaluation.get_exception(): error_hooks( flag_type, err, reversed_merged_hooks_and_context, hook_hints @@ -787,13 +830,21 @@ def evaluate_flag_details( merged_eval_context, ) - flag_evaluation = self._create_provider_evaluation( + runtime_token = self._set_internal_provider_hook_runtime( provider, flag_type, - flag_key, - default_value, - merged_context, + hook_hints, ) + try: + flag_evaluation = self._create_provider_evaluation( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + finally: + self._reset_internal_provider_hook_runtime(provider, runtime_token) if err := flag_evaluation.get_exception(): error_hooks( flag_type, err, reversed_merged_hooks_and_context, hook_hints diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 55e00263..b022bbc7 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -11,21 +11,17 @@ from openfeature.hook import Hook from .metadata import Metadata -from .multi_provider import ( - EvaluationStrategy, - FirstMatchStrategy, - MultiProvider, - ProviderEntry, -) if typing.TYPE_CHECKING: from openfeature.flag_evaluation import FlagValueType __all__ = [ "AbstractProvider", + "ComparisonStrategy", "EvaluationStrategy", "FeatureProvider", "FirstMatchStrategy", + "FirstSuccessfulStrategy", "Metadata", "MultiProvider", "ProviderEntry", @@ -262,3 +258,13 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None: def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: if hasattr(self, "_on_emit"): self._on_emit(self, event, details) + + +from .multi_provider import ( # noqa: E402 + ComparisonStrategy, + EvaluationStrategy, + FirstMatchStrategy, + FirstSuccessfulStrategy, + MultiProvider, + ProviderEntry, +) diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index bf8fa9a8..1944dd9c 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -80,23 +80,25 @@ def _initialize_provider(self, provider: FeatureProvider) -> None: try: if hasattr(provider, "initialize"): provider.initialize(self._get_evaluation_context()) - self.dispatch_event( - provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() - ) + if self.get_provider_status(provider) == ProviderStatus.NOT_READY: + self.dispatch_event( + provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() + ) except Exception as err: error_code = ( err.error_code if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL ) - self.dispatch_event( - provider, - ProviderEvent.PROVIDER_ERROR, - ProviderEventDetails( - message=f"Provider initialization failed: {err}", - error_code=error_code, - ), - ) + if self.get_provider_status(provider) == ProviderStatus.NOT_READY: + self.dispatch_event( + provider, + ProviderEvent.PROVIDER_ERROR, + ProviderEventDetails( + message=f"Provider initialization failed: {err}", + error_code=error_code, + ), + ) def _shutdown_provider(self, provider: FeatureProvider) -> None: try: @@ -115,6 +117,11 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None: provider.detach() def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: + provider_status_getter = getattr(provider, "get_status", None) + if callable(provider_status_getter): + status = provider_status_getter() + if isinstance(status, ProviderStatus): + return status return self._provider_status.get(provider, ProviderStatus.NOT_READY) def dispatch_event( diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py index 07aba99e..daff1cdf 100644 --- a/openfeature/provider/multi_provider.py +++ b/openfeature/provider/multi_provider.py @@ -1,14 +1,9 @@ -""" -Multi-Provider implementation for OpenFeature Python SDK. - -This provider wraps multiple underlying providers, allowing a single client -to interact with multiple flag sources simultaneously. - -See: https://openfeature.dev/specification/appendix-a/#multi-provider -""" - from __future__ import annotations +import asyncio +import contextvars +import logging +import threading import typing from collections.abc import Awaitable, Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor @@ -16,353 +11,1004 @@ from openfeature.evaluation_context import EvaluationContext from openfeature.event import ProviderEvent, ProviderEventDetails -from openfeature.exception import ErrorCode, GeneralError -from openfeature.flag_evaluation import FlagResolutionDetails, FlagValueType, Reason -from openfeature.hook import Hook -from openfeature.provider import AbstractProvider, FeatureProvider, Metadata +from openfeature.exception import ErrorCode, GeneralError, OpenFeatureError +from openfeature.flag_evaluation import ( + FlagEvaluationDetails, + FlagResolutionDetails, + FlagType, + FlagValueType, + Reason, +) +from openfeature.hook import Hook, HookContext, HookHints +from openfeature.hook._hook_support import ( + after_all_hooks, + after_hooks, + before_hooks, + error_hooks, +) +from openfeature.provider import ( + AbstractProvider, + FeatureProvider, + Metadata, + ProviderStatus, +) -__all__ = ["MultiProvider", "ProviderEntry", "FirstMatchStrategy", "EvaluationStrategy"] +__all__ = [ + "ComparisonStrategy", + "EvaluationStrategy", + "FirstMatchStrategy", + "FirstSuccessfulStrategy", + "MultiProvider", + "ProviderEntry", +] +logger = logging.getLogger("openfeature") -@dataclass -class ProviderEntry: - """Configuration for a provider in the Multi-Provider.""" +T = typing.TypeVar("T", bound=FlagValueType) +RunMode: typing.TypeAlias = typing.Literal["sequential", "parallel"] +ComparisonMismatchHandler: typing.TypeAlias = Callable[ + [str, Mapping[str, FlagResolutionDetails[FlagValueType]]], None +] + +@dataclass(frozen=True) +class ProviderEntry: provider: FeatureProvider name: str | None = None +@dataclass(frozen=True) +class _ProviderEvaluation(typing.Generic[T]): + provider_name: str + provider: FeatureProvider + result: FlagResolutionDetails[T] + + +@dataclass(frozen=True) +class _ProviderHookRuntime: + flag_type: FlagType + client_metadata: typing.Any + hook_hints: HookHints + + class EvaluationStrategy(typing.Protocol): - """ - Strategy interface for determining which provider's result to use. - - Supports 'sequential' mode (evaluate one at a time, stop early when strategy - is satisfied) and 'parallel' mode (evaluate all providers, then select best - result). Note: Both modes currently execute provider calls sequentially; - true concurrent evaluation using asyncio.gather or ThreadPoolExecutor is - planned for a future enhancement. - """ - - run_mode: typing.Literal["sequential", "parallel"] + run_mode: RunMode def should_use_result( self, flag_key: str, provider_name: str, - result: FlagResolutionDetails, - ) -> bool: - """ - Determine if this result should be used (and stop evaluation if sequential). - - :param flag_key: The flag being evaluated - :param provider_name: Name of the provider that returned this result - :param result: The resolution details from the provider - :return: True if this result should be used as the final result - """ - ... + result: FlagResolutionDetails[FlagValueType], + ) -> bool: ... + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: ... + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: ... + + +def _is_success(result: FlagResolutionDetails[FlagValueType]) -> bool: + return result.error_code is None and result.reason != Reason.ERROR + + +def _validate_run_mode(run_mode: RunMode) -> RunMode: + if run_mode not in ("sequential", "parallel"): + raise ValueError(f"Unsupported run_mode '{run_mode}'") + return run_mode + + +def _format_result_error( + provider_name: str, result: FlagResolutionDetails[FlagValueType] +) -> str: + error_code = result.error_code.value if result.error_code else ErrorCode.GENERAL.value + error_message = result.error_message or "Unknown error" + return f"{provider_name}: {error_code} ({error_message})" + + +def _build_aggregated_error( + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + prefix: str, +) -> FlagResolutionDetails[FlagValueType]: + if not evaluations: + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=f"{prefix} for flag '{flag_key}': no providers returned a result", + ) + + errors_text = "; ".join( + _format_result_error(evaluation.provider_name, evaluation.result) + for evaluation in evaluations + ) + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=f"{prefix} for flag '{flag_key}': {errors_text}", + ) class FirstMatchStrategy: - """ - Uses the first successful result from providers (in order). - - In sequential mode, stops at the first non-error result. - In parallel mode, picks the first successful result from the ordered list. - """ + def __init__(self, run_mode: RunMode = "sequential") -> None: + self.run_mode = _validate_run_mode(run_mode) + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + return _is_success(result) + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + return result.error_code == ErrorCode.FLAG_NOT_FOUND + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + for evaluation in evaluations: + if self.should_use_result( + flag_key, evaluation.provider_name, evaluation.result + ): + return evaluation.result + if not self.should_continue( + flag_key, evaluation.provider_name, evaluation.result + ): + return evaluation.result + if evaluations: + return evaluations[-1].result + return _build_aggregated_error( + flag_key, + default_value, + evaluations, + "Multi-provider evaluation failed", + ) + + +class FirstSuccessfulStrategy: + def __init__(self, run_mode: RunMode = "sequential") -> None: + self.run_mode = _validate_run_mode(run_mode) + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + return _is_success(result) + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + del result + return True + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + for evaluation in evaluations: + if _is_success(evaluation.result): + return evaluation.result + return _build_aggregated_error( + flag_key, + default_value, + evaluations, + "All providers failed", + ) + + +class ComparisonStrategy: + run_mode: RunMode = "parallel" + + def __init__( + self, + fallback_provider: str | None = None, + on_mismatch: ComparisonMismatchHandler | None = None, + ) -> None: + self.fallback_provider = fallback_provider + self.on_mismatch = on_mismatch - run_mode: typing.Literal["sequential", "parallel"] = "sequential" + def validate_provider_names(self, provider_names: Sequence[str]) -> None: + if ( + self.fallback_provider is not None + and self.fallback_provider not in provider_names + ): + raise ValueError( + f"Fallback provider '{self.fallback_provider}' is not registered" + ) def should_use_result( self, flag_key: str, provider_name: str, - result: FlagResolutionDetails, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + del result + return False + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], ) -> bool: - """Use the first result that doesn't have an error.""" - return result.reason != Reason.ERROR + del flag_key + del provider_name + del result + return True + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + failed_evaluations = [ + evaluation for evaluation in evaluations if not _is_success(evaluation.result) + ] + if failed_evaluations: + return _build_aggregated_error( + flag_key, + default_value, + failed_evaluations, + "Comparison strategy received provider errors", + ) + + fallback_evaluation = self._select_fallback_evaluation(evaluations) + fallback_value = fallback_evaluation.result.value + has_mismatch = any( + evaluation.result.value != fallback_value for evaluation in evaluations + ) + if has_mismatch and self.on_mismatch is not None: + mismatch_results = { + evaluation.provider_name: evaluation.result for evaluation in evaluations + } + try: + self.on_mismatch(flag_key, mismatch_results) + except Exception: + logger.exception( + "Comparison strategy mismatch callback failed for flag '%s'", + flag_key, + ) + return fallback_evaluation.result + + def _select_fallback_evaluation( + self, evaluations: list[_ProviderEvaluation[FlagValueType]] + ) -> _ProviderEvaluation[FlagValueType]: + if not evaluations: + raise ValueError("ComparisonStrategy requires at least one provider") + if self.fallback_provider is None: + return evaluations[0] + for evaluation in evaluations: + if evaluation.provider_name == self.fallback_provider: + return evaluation + raise ValueError( + f"Fallback provider '{self.fallback_provider}' is not registered" + ) class MultiProvider(AbstractProvider): - """ - A provider that aggregates multiple underlying providers. - - Evaluations are delegated to underlying providers based on the configured - strategy (default: FirstMatchStrategy in sequential mode). - - Example: - provider_a = SomeProvider() - provider_b = AnotherProvider() - - multi = MultiProvider([ - ProviderEntry(provider_a, name="primary"), - ProviderEntry(provider_b, name="fallback") - ]) - - api.set_provider(multi) - """ + _status_precedence: tuple[ProviderStatus, ...] = ( + ProviderStatus.FATAL, + ProviderStatus.NOT_READY, + ProviderStatus.ERROR, + ProviderStatus.STALE, + ProviderStatus.READY, + ) def __init__( self, providers: list[ProviderEntry], strategy: EvaluationStrategy | None = None, - ): - """ - Initialize the Multi-Provider. - - :param providers: List of ProviderEntry objects defining the providers - :param strategy: Evaluation strategy (defaults to FirstMatchStrategy) - """ + ) -> None: super().__init__() - if not providers: raise ValueError("At least one provider must be provided") - + self.strategy = strategy or FirstMatchStrategy() - self._registered_providers: list[tuple[str, FeatureProvider]] = [] + self._registeredProviders: list[tuple[str, FeatureProvider]] = [] + self._provider_names: dict[FeatureProvider, str] = {} + self._provider_statuses: dict[str, ProviderStatus] = {} + self._aggregate_status = ProviderStatus.NOT_READY + self._statusLock = threading.Lock() + self._hookRuntime: contextvars.ContextVar[_ProviderHookRuntime | None] = ( + contextvars.ContextVar( + f"multiProviderHookRuntime:{id(self)}", + default=None, + ) + ) self._register_providers(providers) - self._cached_hooks: list[Hook] | None = None + self._provider_statuses = { + provider_name: ProviderStatus.NOT_READY + for provider_name, _ in self._registeredProviders + } + validate_provider_names = getattr(self.strategy, "validate_provider_names", None) + if callable(validate_provider_names): + validate_provider_names( + [provider_name for provider_name, _ in self._registeredProviders] + ) + + def uses_internal_provider_hooks(self) -> bool: + return True + + def set_internal_provider_hook_runtime( + self, + flag_type: FlagType, + client_metadata: typing.Any, + hook_hints: HookHints, + ) -> contextvars.Token[_ProviderHookRuntime | None]: + return self._hookRuntime.set( + _ProviderHookRuntime( + flag_type=flag_type, + client_metadata=client_metadata, + hook_hints=hook_hints, + ) + ) + + def reset_internal_provider_hook_runtime( + self, token: contextvars.Token[_ProviderHookRuntime | None] + ) -> None: + self._hookRuntime.reset(token) + + def get_status(self) -> ProviderStatus: + with self._statusLock: + return self._aggregate_status def _register_providers(self, providers: list[ProviderEntry]) -> None: - """ - Register providers with unique names. - - Names are determined by: - 1. Explicit name in ProviderEntry - 2. provider.get_metadata().name if unique and not conflicting - 3. {metadata.name}_{index} if not unique or conflicting - """ - # Count providers by their metadata name to detect duplicates name_counts: dict[str, int] = {} for entry in providers: metadata_name = entry.provider.get_metadata().name or "provider" name_counts[metadata_name] = name_counts.get(metadata_name, 0) + 1 - # Track used names to prevent conflicts used_names: set[str] = set() - name_indices: dict[str, int] = {} + name_indexes: dict[str, int] = {} for entry in providers: metadata_name = entry.provider.get_metadata().name or "provider" - if entry.name: - # Explicit name provided - must be unique if entry.name in used_names: raise ValueError(f"Provider name '{entry.name}' is not unique") - final_name = entry.name + provider_name = entry.name elif name_counts[metadata_name] == 1 and metadata_name not in used_names: - # Metadata name is unique and not already taken by explicit name - final_name = metadata_name + provider_name = metadata_name else: - # Multiple providers or collision with explicit name, add index while True: - name_indices[metadata_name] = name_indices.get(metadata_name, 0) + 1 - final_name = f"{metadata_name}_{name_indices[metadata_name]}" - if final_name not in used_names: + name_indexes[metadata_name] = name_indexes.get(metadata_name, 0) + 1 + provider_name = f"{metadata_name}_{name_indexes[metadata_name]}" + if provider_name not in used_names: break - - used_names.add(final_name) - self._registered_providers.append((final_name, entry.provider)) + + used_names.add(provider_name) + self._registeredProviders.append((provider_name, entry.provider)) + self._provider_names[entry.provider] = provider_name def get_metadata(self) -> Metadata: - """Return metadata including all wrapped provider metadata.""" return Metadata(name="MultiProvider") def get_provider_hooks(self) -> list[Hook]: - """Aggregate hooks from all providers (cached for efficiency).""" - if self._cached_hooks is None: - hooks: list[Hook] = [] - for _, provider in self._registered_providers: - hooks.extend(provider.get_provider_hooks()) - self._cached_hooks = hooks - return self._cached_hooks + return [] def attach( self, on_emit: Callable[[FeatureProvider, ProviderEvent, ProviderEventDetails], None], ) -> None: - """ - Attach event handler and propagate to all underlying providers. - - Events from underlying providers are forwarded through the MultiProvider. - This enables features like cache invalidation to work across all providers. - """ super().attach(on_emit) - - # Propagate attach to all wrapped providers - for _, provider in self._registered_providers: - provider.attach(on_emit) + for _, provider in self._registeredProviders: + provider.attach(self._handle_provider_event) def detach(self) -> None: - """ - Detach event handler and propagate to all underlying providers. - """ - super().detach() - - # Propagate detach to all wrapped providers - for _, provider in self._registered_providers: + for _, provider in self._registeredProviders: provider.detach() + super().detach() def initialize(self, evaluation_context: EvaluationContext) -> None: - """ - Initialize all providers in parallel using ThreadPoolExecutor. - - This allows concurrent initialization of I/O-bound providers. - """ - def init_provider(entry: tuple[str, FeatureProvider]) -> str | None: - name, provider = entry + def initialize_provider( + entry: tuple[str, FeatureProvider], + ) -> tuple[str, Exception | None]: + provider_name, provider = entry try: provider.initialize(evaluation_context) - return None - except Exception as e: - return f"Provider '{name}' initialization failed: {e}" + return provider_name, None + except Exception as err: + return provider_name, err + + with ThreadPoolExecutor(max_workers=len(self._registeredProviders)) as executor: + init_results = list(executor.map(initialize_provider, self._registeredProviders)) + + error_messages: list[str] = [] + event_details = ProviderEventDetails() + for provider_name, err in init_results: + if err is None: + self._mark_provider_ready(provider_name) + continue + provider_status = self._status_from_exception(err) + self._set_provider_status(provider_name, provider_status) + error_messages.append( + f"Provider '{provider_name}' initialization failed: {self._error_message_from_exception(err)}" + ) + event_details = self._details_from_exception(err, provider_name) - with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: - results = list(executor.map(init_provider, self._registered_providers)) + self._refresh_aggregate_status(event_details) - errors = [r for r in results if r is not None] - if errors: - error_msgs = "; ".join(errors) - raise GeneralError(f"Multi-provider initialization failed: {error_msgs}") + if error_messages: + raise GeneralError(f"Multi-provider initialization failed: {'; '.join(error_messages)}") def shutdown(self) -> None: - """Shutdown all providers in parallel.""" - import logging - - logger = logging.getLogger(__name__) - + for _, provider in self._registeredProviders: + provider.detach() + def shutdown_provider(entry: tuple[str, FeatureProvider]) -> None: - name, provider = entry + provider_name, provider = entry try: provider.shutdown() - except Exception as e: - logger.error(f"Provider '{name}' shutdown failed: {e}") + except Exception: + logger.exception("Provider '%s' shutdown failed", provider_name) - with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: - list(executor.map(shutdown_provider, self._registered_providers)) + with ThreadPoolExecutor(max_workers=len(self._registeredProviders)) as executor: + list(executor.map(shutdown_provider, self._registeredProviders)) - def _evaluate_with_providers( + with self._statusLock: + self._provider_statuses = { + provider_name: ProviderStatus.NOT_READY + for provider_name, _ in self._registeredProviders + } + self._aggregate_status = ProviderStatus.NOT_READY + + def _handle_provider_event( + self, + provider: FeatureProvider, + event: ProviderEvent, + details: ProviderEventDetails, + ) -> None: + provider_name = self._provider_names.get(provider) + if provider_name is None: + return + if event == ProviderEvent.PROVIDER_CONFIGURATION_CHANGED: + self.emit( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + self._with_provider_metadata(details, provider_name), + ) + return + if event == ProviderEvent.PROVIDER_READY: + self._set_provider_status(provider_name, ProviderStatus.READY) + elif event == ProviderEvent.PROVIDER_STALE: + self._set_provider_status(provider_name, ProviderStatus.STALE) + elif event == ProviderEvent.PROVIDER_ERROR: + self._set_provider_status( + provider_name, + self._status_from_event_details(details), + ) + self._refresh_aggregate_status(self._with_provider_metadata(details, provider_name)) + + def _set_provider_status( + self, provider_name: str, provider_status: ProviderStatus + ) -> None: + with self._statusLock: + self._provider_statuses[provider_name] = provider_status + + def _mark_provider_ready(self, provider_name: str) -> None: + with self._statusLock: + if self._provider_statuses.get(provider_name) == ProviderStatus.NOT_READY: + self._provider_statuses[provider_name] = ProviderStatus.READY + + def _calculate_aggregate_status(self) -> ProviderStatus: + statuses = tuple(self._provider_statuses.values()) + if not statuses: + return ProviderStatus.NOT_READY + for status in self._status_precedence: + if status in statuses: + return status + return ProviderStatus.NOT_READY + + def _refresh_aggregate_status(self, details: ProviderEventDetails) -> None: + event_to_emit: ProviderEvent | None = None + event_details = details + with self._statusLock: + previous_status = self._aggregate_status + aggregate_status = self._calculate_aggregate_status() + if previous_status == aggregate_status: + return + self._aggregate_status = aggregate_status + event_to_emit = self._event_from_status(aggregate_status) + event_details = self._details_for_status(aggregate_status, details) + if event_to_emit is not None: + self.emit(event_to_emit, event_details) + + def _event_from_status(self, provider_status: ProviderStatus) -> ProviderEvent | None: + if provider_status == ProviderStatus.READY: + return ProviderEvent.PROVIDER_READY + if provider_status == ProviderStatus.STALE: + return ProviderEvent.PROVIDER_STALE + if provider_status in (ProviderStatus.ERROR, ProviderStatus.FATAL): + return ProviderEvent.PROVIDER_ERROR + return None + + def _details_for_status( + self, provider_status: ProviderStatus, details: ProviderEventDetails + ) -> ProviderEventDetails: + error_code = details.error_code + if provider_status == ProviderStatus.FATAL: + error_code = ErrorCode.PROVIDER_FATAL + elif provider_status == ProviderStatus.ERROR and error_code is None: + error_code = ErrorCode.GENERAL + return ProviderEventDetails( + flags_changed=details.flags_changed, + message=details.message, + error_code=error_code, + metadata=dict(details.metadata), + ) + + def _with_provider_metadata( + self, details: ProviderEventDetails, provider_name: str + ) -> ProviderEventDetails: + metadata = dict(details.metadata) + metadata["provider_name"] = provider_name + return ProviderEventDetails( + flags_changed=details.flags_changed, + message=details.message, + error_code=details.error_code, + metadata=metadata, + ) + + def _status_from_event_details( + self, details: ProviderEventDetails + ) -> ProviderStatus: + if details.error_code == ErrorCode.PROVIDER_FATAL: + return ProviderStatus.FATAL + return ProviderStatus.ERROR + + def _status_from_exception(self, err: Exception) -> ProviderStatus: + if ( + isinstance(err, OpenFeatureError) + and err.error_code == ErrorCode.PROVIDER_FATAL + ): + return ProviderStatus.FATAL + return ProviderStatus.ERROR + + def _details_from_exception( + self, err: Exception, provider_name: str + ) -> ProviderEventDetails: + error_code = ( + err.error_code + if isinstance(err, OpenFeatureError) + else ErrorCode.GENERAL + ) + error_message = self._error_message_from_exception(err) + return ProviderEventDetails( + message=f"Provider '{provider_name}' failed: {error_message}", + error_code=error_code, + metadata={"provider_name": provider_name}, + ) + + def _error_message_from_exception(self, err: Exception) -> str: + if isinstance(err, OpenFeatureError) and err.error_message: + return err.error_message + return str(err) + + def _resolution_from_exception( + self, default_value: T, err: Exception + ) -> FlagResolutionDetails[T]: + error_code = ( + err.error_code + if isinstance(err, OpenFeatureError) + else ErrorCode.GENERAL + ) + error_message = self._error_message_from_exception(err) + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=error_code, + error_message=error_message, + ) + + def _create_provider_hook_contexts( self, + provider: FeatureProvider, + flag_type: FlagType, flag_key: str, default_value: FlagValueType, + evaluation_context: EvaluationContext, + client_metadata: typing.Any, + ) -> list[tuple[Hook, HookContext]]: + provider_metadata = provider.get_metadata() + return [ + ( + hook, + HookContext( + flag_key=flag_key, + flag_type=flag_type, + default_value=default_value, + evaluation_context=evaluation_context, + client_metadata=client_metadata, + provider_metadata=provider_metadata, + hook_data={}, + ), + ) + for hook in provider.get_provider_hooks() + ] + + def _evaluate_provider_sync( # noqa: PLR0913 + self, + provider_name: str, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: T, evaluation_context: EvaluationContext | None, - resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], FlagResolutionDetails[FlagValueType]], - ) -> FlagResolutionDetails[FlagValueType]: - """ - Core evaluation logic that delegates to providers based on strategy. - - Current implementation evaluates providers sequentially regardless of - strategy.run_mode. True concurrent evaluation for 'parallel' mode is - planned for a future enhancement. - - :param flag_key: The flag key to evaluate - :param default_value: Default value for the flag - :param evaluation_context: Evaluation context - :param resolve_fn: Function to call on each provider for resolution - :return: Final resolution details - """ - results: list[tuple[str, FlagResolutionDetails[FlagValueType]]] = [] - - for provider_name, provider in self._registered_providers: + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + FlagResolutionDetails[T], + ], + ) -> _ProviderEvaluation[T]: + runtime = self._hookRuntime.get() + if runtime is None or not provider.get_provider_hooks(): try: - result = resolve_fn(provider, flag_key, default_value, evaluation_context) - results.append((provider_name, result)) - - # In sequential mode, stop if strategy says to use this result - if (self.strategy.run_mode == "sequential" and - self.strategy.should_use_result(flag_key, provider_name, result)): - return result - - except Exception as e: - # Record error but continue to next provider - error_result = FlagResolutionDetails( - value=default_value, - reason=Reason.ERROR, - error_code=ErrorCode.GENERAL, - error_message=str(e), + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolve_fn(provider, flag_key, default_value, evaluation_context), ) - results.append((provider_name, error_result)) - - # If all sequential attempts completed (or parallel mode), pick best result - for provider_name, result in results: - if self.strategy.should_use_result(flag_key, provider_name, result): - return result - - # No successful result - return last error or default - if results: - return results[-1][1] - - return FlagResolutionDetails( - value=default_value, - reason=Reason.ERROR, - error_code=ErrorCode.GENERAL, - error_message="No providers returned a result", + except Exception as err: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + + provider_context = evaluation_context or EvaluationContext() + hook_contexts = self._create_provider_hook_contexts( + provider, + flag_type, + flag_key, + default_value, + provider_context, + runtime.client_metadata, ) + reversed_hook_contexts = list(reversed(hook_contexts)) + flag_evaluation = FlagEvaluationDetails(flag_key=flag_key, value=default_value) + try: + before_context = before_hooks(flag_type, hook_contexts, runtime.hook_hints) + resolved_context = provider_context.merge(before_context) + resolution = resolve_fn(provider, flag_key, default_value, resolved_context) + flag_evaluation = resolution.to_flag_evaluation_details(flag_key) + if err := flag_evaluation.get_exception(): + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + after_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + except Exception as err: + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + finally: + after_all_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) - def resolve_boolean_details( + async def _evaluate_provider_async( # noqa: PLR0913 self, + provider_name: str, + provider: FeatureProvider, + flag_type: FlagType, flag_key: str, - default_value: bool, - evaluation_context: EvaluationContext | None = None, - ) -> FlagResolutionDetails[bool]: - return self._evaluate_with_providers( + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + Awaitable[FlagResolutionDetails[T]], + ], + ) -> _ProviderEvaluation[T]: + runtime = self._hookRuntime.get() + if runtime is None or not provider.get_provider_hooks(): + try: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=await resolve_fn( + provider, flag_key, default_value, evaluation_context + ), + ) + except Exception as err: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + + provider_context = evaluation_context or EvaluationContext() + hook_contexts = self._create_provider_hook_contexts( + provider, + flag_type, flag_key, default_value, - evaluation_context, - lambda p, k, d, ctx: p.resolve_boolean_details(k, d, ctx), + provider_context, + runtime.client_metadata, + ) + reversed_hook_contexts = list(reversed(hook_contexts)) + flag_evaluation = FlagEvaluationDetails(flag_key=flag_key, value=default_value) + try: + before_context = before_hooks(flag_type, hook_contexts, runtime.hook_hints) + resolved_context = provider_context.merge(before_context) + resolution = await resolve_fn(provider, flag_key, default_value, resolved_context) + flag_evaluation = resolution.to_flag_evaluation_details(flag_key) + if err := flag_evaluation.get_exception(): + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + after_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + except Exception as err: + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + finally: + after_all_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + + def _evaluate_with_providers( + self, + flag_type: FlagType, + flag_key: str, + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + FlagResolutionDetails[T], + ], + ) -> FlagResolutionDetails[T]: + if self.strategy.run_mode == "parallel": + with ThreadPoolExecutor(max_workers=len(self._registeredProviders)) as executor: + futures = [ + executor.submit( + self._evaluate_provider_sync, + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + for provider_name, provider in self._registeredProviders + ] + evaluations = [future.result() for future in futures] + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), + ) + + evaluations: list[_ProviderEvaluation[T]] = [] + for provider_name, provider in self._registeredProviders: + evaluation = self._evaluate_provider_sync( + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + evaluations.append(evaluation) + if self.strategy.should_use_result( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + return evaluation.result + if not self.strategy.should_continue( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + break + + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), ) async def _evaluate_with_providers_async( self, + flag_type: FlagType, flag_key: str, - default_value: FlagValueType, + default_value: T, evaluation_context: EvaluationContext | None, - resolve_fn: Callable[[FeatureProvider, str, FlagValueType, EvaluationContext | None], Awaitable[FlagResolutionDetails[FlagValueType]]], - ) -> FlagResolutionDetails[FlagValueType]: - """ - Async evaluation logic that properly awaits provider async methods. - - :param flag_key: The flag key to evaluate - :param default_value: Default value for the flag - :param evaluation_context: Evaluation context - :param resolve_fn: Async function to call on each provider for resolution - :return: Final resolution details - """ - results: list[tuple[str, FlagResolutionDetails[FlagValueType]]] = [] - - for provider_name, provider in self._registered_providers: - try: - result = await resolve_fn(provider, flag_key, default_value, evaluation_context) - results.append((provider_name, result)) - - # In sequential mode, stop if strategy says to use this result - if (self.strategy.run_mode == "sequential" and - self.strategy.should_use_result(flag_key, provider_name, result)): - return result - - except Exception as e: - # Record error but continue to next provider - error_result = FlagResolutionDetails( - value=default_value, - reason=Reason.ERROR, - error_code=ErrorCode.GENERAL, - error_message=str(e), + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + Awaitable[FlagResolutionDetails[T]], + ], + ) -> FlagResolutionDetails[T]: + if self.strategy.run_mode == "parallel": + tasks = [ + asyncio.create_task( + self._evaluate_provider_async( + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) ) - results.append((provider_name, error_result)) - - # If all sequential attempts completed (or parallel mode), pick best result - for provider_name, result in results: - if self.strategy.should_use_result(flag_key, provider_name, result): - return result - - # No successful result - return last error or default - if results: - return results[-1][1] - - return FlagResolutionDetails( - value=default_value, - reason=Reason.ERROR, - error_code=ErrorCode.GENERAL, - error_message="No providers returned a result", + for provider_name, provider in self._registeredProviders + ] + evaluations = await asyncio.gather(*tasks) + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + list(evaluations), + ), + ), + ) + + evaluations: list[_ProviderEvaluation[T]] = [] + for provider_name, provider in self._registeredProviders: + evaluation = await self._evaluate_provider_async( + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + evaluations.append(evaluation) + if self.strategy.should_use_result( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + return evaluation.result + if not self.strategy.should_continue( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + break + + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), + ) + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + return self._evaluate_with_providers( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_boolean_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) async def resolve_boolean_details_async( @@ -371,12 +1017,18 @@ async def resolve_boolean_details_async( default_value: bool, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[bool]: - """Async boolean evaluation using provider async methods.""" return await self._evaluate_with_providers_async( + FlagType.BOOLEAN, flag_key, default_value, evaluation_context, - lambda p, k, d, ctx: p.resolve_boolean_details_async(k, d, ctx), + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_boolean_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) def resolve_string_details( @@ -386,10 +1038,17 @@ def resolve_string_details( evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[str]: return self._evaluate_with_providers( + FlagType.STRING, flag_key, default_value, evaluation_context, - lambda p, k, d, ctx: p.resolve_string_details(k, d, ctx), + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_string_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) async def resolve_string_details_async( @@ -398,12 +1057,18 @@ async def resolve_string_details_async( default_value: str, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[str]: - """Async string evaluation using provider async methods.""" return await self._evaluate_with_providers_async( + FlagType.STRING, flag_key, default_value, evaluation_context, - lambda p, k, d, ctx: p.resolve_string_details_async(k, d, ctx), + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_string_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) def resolve_integer_details( @@ -413,10 +1078,17 @@ def resolve_integer_details( evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[int]: return self._evaluate_with_providers( + FlagType.INTEGER, flag_key, default_value, evaluation_context, - lambda p, k, d, ctx: p.resolve_integer_details(k, d, ctx), + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_integer_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) async def resolve_integer_details_async( @@ -425,12 +1097,18 @@ async def resolve_integer_details_async( default_value: int, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[int]: - """Async integer evaluation using provider async methods.""" return await self._evaluate_with_providers_async( + FlagType.INTEGER, flag_key, default_value, evaluation_context, - lambda p, k, d, ctx: p.resolve_integer_details_async(k, d, ctx), + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_integer_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) def resolve_float_details( @@ -440,10 +1118,17 @@ def resolve_float_details( evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[float]: return self._evaluate_with_providers( + FlagType.FLOAT, flag_key, default_value, evaluation_context, - lambda p, k, d, ctx: p.resolve_float_details(k, d, ctx), + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_float_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) async def resolve_float_details_async( @@ -452,12 +1137,18 @@ async def resolve_float_details_async( default_value: float, evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[float]: - """Async float evaluation using provider async methods.""" return await self._evaluate_with_providers_async( + FlagType.FLOAT, flag_key, default_value, evaluation_context, - lambda p, k, d, ctx: p.resolve_float_details_async(k, d, ctx), + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_float_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) def resolve_object_details( @@ -467,10 +1158,17 @@ def resolve_object_details( evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: return self._evaluate_with_providers( + FlagType.OBJECT, flag_key, default_value, evaluation_context, - lambda p, k, d, ctx: p.resolve_object_details(k, d, ctx), + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_object_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) async def resolve_object_details_async( @@ -479,10 +1177,16 @@ async def resolve_object_details_async( default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], evaluation_context: EvaluationContext | None = None, ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: - """Async object evaluation using provider async methods.""" return await self._evaluate_with_providers_async( + FlagType.OBJECT, flag_key, default_value, evaluation_context, - lambda p, k, d, ctx: p.resolve_object_details_async(k, d, ctx), + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_object_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), ) diff --git a/tests/test_multi_provider.py b/tests/test_multi_provider.py index 2ba7759a..aa6cea26 100644 --- a/tests/test_multi_provider.py +++ b/tests/test_multi_provider.py @@ -1,297 +1,607 @@ +import asyncio +import threading +from unittest.mock import MagicMock + import pytest from openfeature import api from openfeature.evaluation_context import EvaluationContext -from openfeature.exception import GeneralError -from openfeature.flag_evaluation import FlagResolutionDetails, Reason -from openfeature.provider import Metadata -from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider -from openfeature.provider.multi_provider import ( +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import ErrorCode, GeneralError +from openfeature.flag_evaluation import ( + FlagEvaluationDetails, + FlagResolutionDetails, + Reason, +) +from openfeature.hook import Hook, HookContext, HookHints +from openfeature.provider import ( + AbstractProvider, + ComparisonStrategy, FirstMatchStrategy, + FirstSuccessfulStrategy, + Metadata, MultiProvider, ProviderEntry, + ProviderStatus, ) -from openfeature.provider.no_op_provider import NoOpProvider + + +class BooleanProvider(AbstractProvider): + def __init__( + self, + name: str, + boolean_result: FlagResolutionDetails[bool] | None = None, + boolean_exception: Exception | None = None, + hook_list: list[Hook] | None = None, + sync_blocker: "SyncBlocker | None" = None, + async_blocker: "AsyncBlocker | None" = None, + ) -> None: + super().__init__() + self.name = name + self.booleanResult = boolean_result + self.booleanException = boolean_exception + self.hookList = hook_list or [] + self.sync_blocker = sync_blocker + self.async_blocker = async_blocker + self.resolveCount = 0 + self.seenContexts: list[dict[str, object]] = [] + + def get_metadata(self) -> Metadata: + return Metadata(name=self.name) + + def get_provider_hooks(self) -> list[Hook]: + return self.hookList + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + del flag_key + self.resolveCount += 1 + self.seenContexts.append(dict((evaluation_context or EvaluationContext()).attributes)) + if self.sync_blocker is not None: + self.sync_blocker.wait() + if self.booleanException is not None: + raise self.booleanException + if self.booleanResult is not None: + return self.booleanResult + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + del flag_key + self.resolveCount += 1 + self.seenContexts.append(dict((evaluation_context or EvaluationContext()).attributes)) + if self.async_blocker is not None: + await self.async_blocker.wait() + if self.booleanException is not None: + raise self.booleanException + if self.booleanResult is not None: + return self.booleanResult + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_object_details( + self, + flag_key: str, + default_value: dict | list, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[dict | list]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + +class RecordingHook(Hook): + def __init__(self, hook_name: str) -> None: + self.hookName = hook_name + self.events: list[str] = [] + + def before( + self, hook_context: HookContext, hints: HookHints + ) -> EvaluationContext | None: + del hook_context + del hints + self.events.append("before") + return EvaluationContext(attributes={"hookOwner": self.hookName}) + + def after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[object], + hints: HookHints, + ) -> None: + del hook_context + del details + del hints + self.events.append("after") + + def error( + self, hook_context: HookContext, exception: Exception, hints: HookHints + ) -> None: + del hook_context + del exception + del hints + self.events.append("error") + + def finally_after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[object], + hints: HookHints, + ) -> None: + del hook_context + del details + del hints + self.events.append("finally") + + +class SyncBlocker: + def __init__(self, expected_count: int) -> None: + self.expectedCount = expected_count + self.enteredCount = 0 + self.enteredEvent = threading.Event() + self.releaseEvent = threading.Event() + self.lock = threading.Lock() + + def wait(self) -> None: + with self.lock: + self.enteredCount += 1 + if self.enteredCount == self.expectedCount: + self.enteredEvent.set() + assert self.releaseEvent.wait(timeout=2) + + +class AsyncBlocker: + def __init__(self, expected_count: int) -> None: + self.expectedCount = expected_count + self.enteredCount = 0 + self.enteredEvent = asyncio.Event() + self.releaseEvent = asyncio.Event() + self.lock = asyncio.Lock() + + async def wait(self) -> None: + async with self.lock: + self.enteredCount += 1 + if self.enteredCount == self.expectedCount: + self.enteredEvent.set() + await asyncio.wait_for(self.releaseEvent.wait(), timeout=2) def test_multi_provider_requires_at_least_one_provider(): - # Given/When/Then with pytest.raises(ValueError, match="At least one provider must be provided"): MultiProvider([]) -def test_multi_provider_uses_explicit_names(): - # Given - provider_a = NoOpProvider() - provider_b = NoOpProvider() - - # When - multi = MultiProvider([ - ProviderEntry(provider_a, name="first"), - ProviderEntry(provider_b, name="second"), - ]) - - # Then - assert len(multi._registered_providers) == 2 - assert multi._registered_providers[0][0] == "first" - assert multi._registered_providers[1][0] == "second" - - -def test_multi_provider_generates_unique_names_when_metadata_conflicts(): - # Given - provider_a = NoOpProvider() - provider_b = NoOpProvider() - - # When - both have same metadata name "NoOpProvider" - multi = MultiProvider([ - ProviderEntry(provider_a), - ProviderEntry(provider_b), - ]) - - # Then - names are auto-indexed - assert len(multi._registered_providers) == 2 - names = [name for name, _ in multi._registered_providers] - assert names == ["NoOpProvider_1", "NoOpProvider_2"] - - def test_multi_provider_rejects_duplicate_explicit_names(): - # Given - provider_a = NoOpProvider() - provider_b = NoOpProvider() - - # When/Then + first_provider = BooleanProvider("provider") + second_provider = BooleanProvider("provider") + with pytest.raises(ValueError, match="Provider name 'duplicate' is not unique"): - MultiProvider([ - ProviderEntry(provider_a, name="duplicate"), - ProviderEntry(provider_b, name="duplicate"), - ]) - - -def test_multi_provider_first_match_strategy_sequential(): - # Given - flags_a = { - "flag1": InMemoryFlag("off", {"on": True, "off": False}), - } - flags_b = { - "flag1": InMemoryFlag("on", {"on": True, "off": False}), - "flag2": InMemoryFlag("on", {"on": True, "off": False}), - } - - provider_a = InMemoryProvider(flags_a) - provider_b = InMemoryProvider(flags_b) - - multi = MultiProvider([ - ProviderEntry(provider_a, name="primary"), - ProviderEntry(provider_b, name="fallback"), - ], strategy=FirstMatchStrategy()) - - # When - flag1 exists in both, should use first (primary) - result = multi.resolve_boolean_details("flag1", False) - - # Then - assert result.value == False # primary provider returns "off" variant - assert result.reason != Reason.ERROR - - -def test_multi_provider_fallback_to_second_provider(): - # Given - flags_a = {} # primary has no flags - flags_b = { - "flag1": InMemoryFlag("on", {"on": True, "off": False}), - } - - provider_a = InMemoryProvider(flags_a) - provider_b = InMemoryProvider(flags_b) - - multi = MultiProvider([ - ProviderEntry(provider_a, name="primary"), - ProviderEntry(provider_b, name="fallback"), - ]) - - # When - flag1 doesn't exist in primary, should fallback - result = multi.resolve_boolean_details("flag1", False) - - # Then - assert result.value == True # fallback provider has the flag - assert result.reason != Reason.ERROR - - -def test_multi_provider_all_types_work(): - # Given - flags = { - "bool-flag": InMemoryFlag("on", {"on": True, "off": False}), - "string-flag": InMemoryFlag("greeting", {"greeting": "hello", "farewell": "goodbye"}), - "int-flag": InMemoryFlag("big", {"small": 10, "big": 100}), - "float-flag": InMemoryFlag("pi", {"pi": 3.14, "e": 2.71}), - "object-flag": InMemoryFlag("full", { - "full": {"name": "test", "value": 42}, - "empty": {}, - }), - } - - provider = InMemoryProvider(flags) - multi = MultiProvider([ProviderEntry(provider)]) - - # When/Then - bool_result = multi.resolve_boolean_details("bool-flag", False) - assert bool_result.value == True - - string_result = multi.resolve_string_details("string-flag", "default") - assert string_result.value == "hello" - - int_result = multi.resolve_integer_details("int-flag", 0) - assert int_result.value == 100 - - float_result = multi.resolve_float_details("float-flag", 0.0) - assert float_result.value == 3.14 - - object_result = multi.resolve_object_details("object-flag", {}) - assert object_result.value == {"name": "test", "value": 42} - - -def test_multi_provider_initialize_all_providers(): - # Given - provider_a = NoOpProvider() - provider_b = NoOpProvider() - - # Track if initialize was called - provider_a.initialize = lambda ctx: None - provider_b.initialize = lambda ctx: None - - a_initialized = False - b_initialized = False - - def track_a_init(ctx): - nonlocal a_initialized - a_initialized = True - - def track_b_init(ctx): - nonlocal b_initialized - b_initialized = True - - provider_a.initialize = track_a_init - provider_b.initialize = track_b_init - - multi = MultiProvider([ - ProviderEntry(provider_a), - ProviderEntry(provider_b), - ]) - - # When - multi.initialize(EvaluationContext()) - - # Then - assert a_initialized - assert b_initialized - - -def test_multi_provider_initialization_failures_are_aggregated(): - # Given - provider_a = NoOpProvider() - provider_b = NoOpProvider() - - def fail_init(ctx): - raise Exception("Init failed") - - provider_a.initialize = fail_init - provider_b.initialize = fail_init - - multi = MultiProvider([ - ProviderEntry(provider_a, name="a"), - ProviderEntry(provider_b, name="b"), - ]) - - # When/Then - with pytest.raises(GeneralError, match="Multi-provider initialization failed"): - multi.initialize(EvaluationContext()) - - -def test_multi_provider_returns_error_when_no_providers_have_flag(): - # Given - provider_a = InMemoryProvider({}) - provider_b = InMemoryProvider({}) - - multi = MultiProvider([ - ProviderEntry(provider_a), - ProviderEntry(provider_b), - ]) - - # When - result = multi.resolve_boolean_details("nonexistent", False) - - # Then - assert result.value == False # default value - assert result.reason == Reason.ERROR + MultiProvider( + [ + ProviderEntry(first_provider, name="duplicate"), + ProviderEntry(second_provider, name="duplicate"), + ] + ) + + +def test_comparison_strategy_rejects_unknown_fallback_provider(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + + with pytest.raises(ValueError, match="Fallback provider 'missing' is not registered"): + MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=ComparisonStrategy(fallback_provider="missing"), + ) + + +def test_first_match_uses_fallback_after_flag_not_found(): + missing_result = FlagResolutionDetails( + value=False, + reason=Reason.ERROR, + error_code=ErrorCode.FLAG_NOT_FOUND, + error_message="missing", + ) + first_provider = BooleanProvider("first", boolean_result=missing_result) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.value is True + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +def test_first_match_stops_on_non_flag_not_found_error(): + error_result = FlagResolutionDetails( + value=False, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message="boom", + ) + first_provider = BooleanProvider("first", boolean_result=error_result) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.error_code == ErrorCode.GENERAL + assert second_provider.resolveCount == 0 + + +def test_first_successful_skips_general_errors(): + first_provider = BooleanProvider("first", boolean_exception=GeneralError("broken")) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.value is True + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +def test_first_successful_aggregates_errors_when_all_providers_fail(): + first_provider = BooleanProvider("first", boolean_exception=GeneralError("first")) + second_provider = BooleanProvider("second", boolean_exception=GeneralError("second")) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.error_code == ErrorCode.GENERAL + assert "first: GENERAL (first)" in result.error_message + assert "second: GENERAL (second)" in result.error_message + + +def test_comparison_strategy_returns_fallback_value_and_calls_on_mismatch(): + mismatch_spy = MagicMock() + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=ComparisonStrategy( + fallback_provider="second", + on_mismatch=mismatch_spy, + ), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.value is True + mismatch_spy.assert_called_once() + + +def test_comparison_strategy_aggregates_provider_errors(): + first_provider = BooleanProvider("first", boolean_exception=GeneralError("first")) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=ComparisonStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.error_code == ErrorCode.GENERAL + assert "first: GENERAL (first)" in result.error_message + + +def test_multi_provider_runs_sync_parallel_evaluation(): + sync_blocker = SyncBlocker(expected_count=2) + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + sync_blocker=sync_blocker, + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + sync_blocker=sync_blocker, + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(run_mode="parallel"), + ) + + result_holder: list[FlagResolutionDetails[bool]] = [] + + def evaluate() -> None: + result_holder.append(multi_provider.resolve_boolean_details("flagKey", False)) + + worker_thread = threading.Thread(target=evaluate) + worker_thread.start() + + assert sync_blocker.enteredEvent.wait(timeout=2) + sync_blocker.releaseEvent.set() + worker_thread.join(timeout=2) + + assert result_holder[0].value is False + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 @pytest.mark.asyncio -async def test_multi_provider_async_methods_work(): - # Given - flags = { - "async-flag": InMemoryFlag("on", {"on": True, "off": False}), - } - provider = InMemoryProvider(flags) - multi = MultiProvider([ProviderEntry(provider)]) - - # When - result = await multi.resolve_boolean_details_async("async-flag", False) - - # Then - assert result.value == True - assert result.reason != Reason.ERROR - - -def test_multi_provider_can_be_used_with_api(): - # Given - api.clear_providers() - flags = { - "api-flag": InMemoryFlag("on", {"on": True, "off": False}), - } - provider = InMemoryProvider(flags) - multi = MultiProvider([ProviderEntry(provider)]) - - # When - api.set_provider(multi) +async def test_multi_provider_runs_async_parallel_evaluation(): + async_blocker = AsyncBlocker(expected_count=2) + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + async_blocker=async_blocker, + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + async_blocker=async_blocker, + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(run_mode="parallel"), + ) + + evaluation_task = asyncio.create_task( + multi_provider.resolve_boolean_details_async("flagKey", False) + ) + + await asyncio.wait_for(async_blocker.enteredEvent.wait(), timeout=2) + async_blocker.releaseEvent.set() + result = await asyncio.wait_for(evaluation_task, timeout=2) + + assert result.value is False + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +def test_multi_provider_isolates_provider_hooks_and_runs_lifecycle(): + first_hook = RecordingHook("first") + second_hook = RecordingHook("second") + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails( + value=False, + reason=Reason.ERROR, + error_code=ErrorCode.FLAG_NOT_FOUND, + error_message="missing", + ), + hook_list=[first_hook], + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + hook_list=[second_hook], + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + api.set_provider(multi_provider) + client = api.get_client() + result = client.get_boolean_details( + "flagKey", + False, + evaluation_context=EvaluationContext(attributes={"base": "value"}), + ) + + assert result.value is True + assert first_hook.events == ["before", "error", "finally"] + assert second_hook.events == ["before", "after", "finally"] + assert first_provider.seenContexts[0]["base"] == "value" + assert first_provider.seenContexts[0]["hookOwner"] == "first" + assert second_provider.seenContexts[0]["base"] == "value" + assert second_provider.seenContexts[0]["hookOwner"] == "second" + + +def test_multi_provider_does_not_run_unused_provider_hooks(): + first_hook = RecordingHook("first") + second_hook = RecordingHook("second") + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + hook_list=[first_hook], + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + hook_list=[second_hook], + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + api.set_provider(multi_provider) client = api.get_client() - value = client.get_boolean_value("api-flag", False) - - # Then - assert value == True - - -def test_multi_provider_metadata(): - # Given - multi = MultiProvider([ProviderEntry(NoOpProvider())]) - - # When - metadata = multi.get_metadata() - - # Then - assert metadata.name == "MultiProvider" - - -def test_multi_provider_aggregates_hooks(): - # Given - from unittest.mock import MagicMock - - provider_a = NoOpProvider() - provider_b = NoOpProvider() - - hook_a = MagicMock() - hook_b = MagicMock() - - provider_a.get_provider_hooks = lambda: [hook_a] - provider_b.get_provider_hooks = lambda: [hook_b] - - multi = MultiProvider([ - ProviderEntry(provider_a), - ProviderEntry(provider_b), - ]) - - # When - hooks = multi.get_provider_hooks() - - # Then - assert len(hooks) == 2 - assert hook_a in hooks - assert hook_b in hooks + result = client.get_boolean_details("flagKey", False) + + assert result.value is True + assert first_hook.events == ["before", "after", "finally"] + assert second_hook.events == [] + + +def test_multi_provider_aggregates_status_and_deduplicates_events(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ] + ) + + api.set_provider(multi_provider) + client = api.get_client() + spy = MagicMock() + client.add_handler(ProviderEvent.PROVIDER_READY, spy.provider_ready) + client.add_handler(ProviderEvent.PROVIDER_ERROR, spy.provider_error) + client.add_handler(ProviderEvent.PROVIDER_STALE, spy.provider_stale) + spy.provider_ready.reset_mock() + + first_provider.emit_provider_stale(ProviderEventDetails(message="stale")) + assert client.get_provider_status() == ProviderStatus.STALE + assert spy.provider_stale.call_count == 1 + + second_provider.emit_provider_stale(ProviderEventDetails(message="still stale")) + assert client.get_provider_status() == ProviderStatus.STALE + assert spy.provider_stale.call_count == 1 + + first_provider.emit_provider_error( + ProviderEventDetails(error_code=ErrorCode.GENERAL, message="error") + ) + assert client.get_provider_status() == ProviderStatus.ERROR + assert spy.provider_error.call_count == 1 + + second_provider.emit_provider_error( + ProviderEventDetails(error_code=ErrorCode.PROVIDER_FATAL, message="fatal") + ) + assert client.get_provider_status() == ProviderStatus.FATAL + assert spy.provider_error.call_count == 2 + + second_provider.emit_provider_ready(ProviderEventDetails()) + assert client.get_provider_status() == ProviderStatus.ERROR + assert spy.provider_error.call_count == 3 + + first_provider.emit_provider_ready(ProviderEventDetails()) + assert client.get_provider_status() == ProviderStatus.READY + assert spy.provider_ready.call_count == 1 + + +def test_multi_provider_forwards_configuration_changed_events(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ] + ) + + api.set_provider(multi_provider) + client = api.get_client() + spy = MagicMock() + client.add_handler( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + spy.provider_configuration_changed, + ) + + first_provider.emit_provider_configuration_changed(ProviderEventDetails(message="one")) + second_provider.emit_provider_configuration_changed(ProviderEventDetails(message="two")) + + assert spy.provider_configuration_changed.call_count == 2 + + +def test_multi_provider_reports_not_ready_after_shutdown(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ] + ) + + api.set_provider(multi_provider) + client = api.get_client() + + api.shutdown() + + assert client.get_provider_status() == ProviderStatus.NOT_READY From 521d9bb01985b6dfb2f4765a87a87356fc1747ec Mon Sep 17 00:00:00 2001 From: Jonathan Norris Date: Mon, 16 Mar 2026 14:57:39 -0400 Subject: [PATCH 7/8] fix: address multi-provider review issues - Fix ContextVar propagation to ThreadPoolExecutor workers (Python <3.12) - Fix _refresh_aggregate_status dropping events during partial init failure - Add shouldEvaluateThisProvider check to skip NOT_READY/FATAL providers - Fix ComparisonStrategy to return first provider result on no-mismatch - Add InternalHookProvider protocol replacing fragile duck-typing - Scope get_status override in registry to InternalHookProvider only - Rename camelCase instance variables to snake_case Signed-off-by: Jonathan Norris --- openfeature/client.py | 20 ++-- openfeature/provider/__init__.py | 27 +++++ openfeature/provider/_registry.py | 16 +-- openfeature/provider/multi_provider.py | 136 ++++++++++++++++--------- 4 files changed, 137 insertions(+), 62 deletions(-) diff --git a/openfeature/client.py b/openfeature/client.py index d01ee56b..bc8b9bf7 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -30,7 +30,7 @@ before_hooks, error_hooks, ) -from openfeature.provider import FeatureProvider, ProviderStatus +from openfeature.provider import FeatureProvider, InternalHookProvider, ProviderStatus from openfeature.provider._registry import provider_registry from openfeature.transaction_context import get_transaction_context @@ -471,8 +471,10 @@ def _establish_hooks_and_provider( ) def _provider_uses_internal_hooks(self, provider: FeatureProvider) -> bool: - uses_internal_hooks = getattr(provider, "uses_internal_provider_hooks", None) - return bool(callable(uses_internal_hooks) and uses_internal_hooks()) + return ( + isinstance(provider, InternalHookProvider) + and provider.uses_internal_provider_hooks() + ) def _set_internal_provider_hook_runtime( self, @@ -480,12 +482,11 @@ def _set_internal_provider_hook_runtime( flag_type: FlagType, hook_hints: HookHints, ) -> object | None: - if not self._provider_uses_internal_hooks(provider): + if not isinstance(provider, InternalHookProvider): return None - set_hook_runtime = getattr(provider, "set_internal_provider_hook_runtime", None) - if not callable(set_hook_runtime): + if not provider.uses_internal_provider_hooks(): return None - return set_hook_runtime( + return provider.set_internal_provider_hook_runtime( flag_type=flag_type, client_metadata=self.get_metadata(), hook_hints=hook_hints, @@ -496,9 +497,8 @@ def _reset_internal_provider_hook_runtime( ) -> None: if runtime_token is None: return - reset_hook_runtime = getattr(provider, "reset_internal_provider_hook_runtime", None) - if callable(reset_hook_runtime): - reset_hook_runtime(runtime_token) + if isinstance(provider, InternalHookProvider): + provider.reset_internal_provider_hook_runtime(runtime_token) def _assert_provider_status( self, diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index b022bbc7..4a000790 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -22,6 +22,7 @@ "FeatureProvider", "FirstMatchStrategy", "FirstSuccessfulStrategy", + "InternalHookProvider", "Metadata", "MultiProvider", "ProviderEntry", @@ -128,6 +129,32 @@ async def resolve_object_details_async( ]: ... +@typing.runtime_checkable +class InternalHookProvider(typing.Protocol): + """Protocol for providers that manage their own provider hook execution. + + Providers implementing this protocol (e.g. MultiProvider) handle provider + hook lifecycle internally. The client will skip its own provider hook + invocations and instead delegate to the provider via the set/reset methods. + + The registry will also use get_status() from this protocol instead of its + own internal status tracking for providers that implement it. + """ + + def uses_internal_provider_hooks(self) -> bool: ... + + def set_internal_provider_hook_runtime( + self, + flag_type: typing.Any, + client_metadata: typing.Any, + hook_hints: typing.Any, + ) -> typing.Any: ... + + def reset_internal_provider_hook_runtime(self, token: typing.Any) -> None: ... + + def get_status(self) -> ProviderStatus: ... + + class AbstractProvider(FeatureProvider): def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: # this makes sure to invoke the parent of `FeatureProvider` -> `object` diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index 1944dd9c..cb4083a3 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -5,7 +5,7 @@ ProviderEventDetails, ) from openfeature.exception import ErrorCode, GeneralError, OpenFeatureError -from openfeature.provider import FeatureProvider, ProviderStatus +from openfeature.provider import FeatureProvider, InternalHookProvider, ProviderStatus from openfeature.provider.no_op_provider import NoOpProvider @@ -80,6 +80,9 @@ def _initialize_provider(self, provider: FeatureProvider) -> None: try: if hasattr(provider, "initialize"): provider.initialize(self._get_evaluation_context()) + # InternalHookProvider (e.g. MultiProvider) emits its own events + # during initialize(), so only dispatch PROVIDER_READY if the + # provider hasn't already transitioned away from NOT_READY. if self.get_provider_status(provider) == ProviderStatus.NOT_READY: self.dispatch_event( provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() @@ -90,6 +93,8 @@ def _initialize_provider(self, provider: FeatureProvider) -> None: if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL ) + # Same guard: skip if the provider already emitted its own error + # event and transitioned out of NOT_READY. if self.get_provider_status(provider) == ProviderStatus.NOT_READY: self.dispatch_event( provider, @@ -117,11 +122,10 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None: provider.detach() def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: - provider_status_getter = getattr(provider, "get_status", None) - if callable(provider_status_getter): - status = provider_status_getter() - if isinstance(status, ProviderStatus): - return status + # Only InternalHookProvider implementations (e.g. MultiProvider) manage + # their own status. For all other providers, use the registry's tracking. + if isinstance(provider, InternalHookProvider): + return provider.get_status() return self._provider_status.get(provider, ProviderStatus.NOT_READY) def dispatch_event( diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py index daff1cdf..965a8337 100644 --- a/openfeature/provider/multi_provider.py +++ b/openfeature/provider/multi_provider.py @@ -290,23 +290,29 @@ def determine_final_result( "Comparison strategy received provider errors", ) + # The first provider's result is the "final resolution" (used on agreement). + # The fallback provider's result is used on mismatch (per JS SDK reference). + final_evaluation = evaluations[0] fallback_evaluation = self._select_fallback_evaluation(evaluations) - fallback_value = fallback_evaluation.result.value + reference_value = final_evaluation.result.value has_mismatch = any( - evaluation.result.value != fallback_value for evaluation in evaluations + evaluation.result.value != reference_value for evaluation in evaluations ) - if has_mismatch and self.on_mismatch is not None: - mismatch_results = { - evaluation.provider_name: evaluation.result for evaluation in evaluations - } - try: - self.on_mismatch(flag_key, mismatch_results) - except Exception: - logger.exception( - "Comparison strategy mismatch callback failed for flag '%s'", - flag_key, - ) - return fallback_evaluation.result + if has_mismatch: + if self.on_mismatch is not None: + mismatch_results = { + evaluation.provider_name: evaluation.result + for evaluation in evaluations + } + try: + self.on_mismatch(flag_key, mismatch_results) + except Exception: + logger.exception( + "Comparison strategy mismatch callback failed for flag '%s'", + flag_key, + ) + return fallback_evaluation.result + return final_evaluation.result def _select_fallback_evaluation( self, evaluations: list[_ProviderEvaluation[FlagValueType]] @@ -342,26 +348,27 @@ def __init__( raise ValueError("At least one provider must be provided") self.strategy = strategy or FirstMatchStrategy() - self._registeredProviders: list[tuple[str, FeatureProvider]] = [] + self._registered_providers: list[tuple[str, FeatureProvider]] = [] self._provider_names: dict[FeatureProvider, str] = {} self._provider_statuses: dict[str, ProviderStatus] = {} self._aggregate_status = ProviderStatus.NOT_READY - self._statusLock = threading.Lock() - self._hookRuntime: contextvars.ContextVar[_ProviderHookRuntime | None] = ( + self._initialized = False + self._status_lock = threading.Lock() + self._hook_runtime: contextvars.ContextVar[_ProviderHookRuntime | None] = ( contextvars.ContextVar( - f"multiProviderHookRuntime:{id(self)}", + f"multi_provider_hook_runtime:{id(self)}", default=None, ) ) self._register_providers(providers) self._provider_statuses = { provider_name: ProviderStatus.NOT_READY - for provider_name, _ in self._registeredProviders + for provider_name, _ in self._registered_providers } validate_provider_names = getattr(self.strategy, "validate_provider_names", None) if callable(validate_provider_names): validate_provider_names( - [provider_name for provider_name, _ in self._registeredProviders] + [provider_name for provider_name, _ in self._registered_providers] ) def uses_internal_provider_hooks(self) -> bool: @@ -373,7 +380,7 @@ def set_internal_provider_hook_runtime( client_metadata: typing.Any, hook_hints: HookHints, ) -> contextvars.Token[_ProviderHookRuntime | None]: - return self._hookRuntime.set( + return self._hook_runtime.set( _ProviderHookRuntime( flag_type=flag_type, client_metadata=client_metadata, @@ -384,10 +391,10 @@ def set_internal_provider_hook_runtime( def reset_internal_provider_hook_runtime( self, token: contextvars.Token[_ProviderHookRuntime | None] ) -> None: - self._hookRuntime.reset(token) + self._hook_runtime.reset(token) def get_status(self) -> ProviderStatus: - with self._statusLock: + with self._status_lock: return self._aggregate_status def _register_providers(self, providers: list[ProviderEntry]) -> None: @@ -415,7 +422,7 @@ def _register_providers(self, providers: list[ProviderEntry]) -> None: break used_names.add(provider_name) - self._registeredProviders.append((provider_name, entry.provider)) + self._registered_providers.append((provider_name, entry.provider)) self._provider_names[entry.provider] = provider_name def get_metadata(self) -> Metadata: @@ -429,11 +436,11 @@ def attach( on_emit: Callable[[FeatureProvider, ProviderEvent, ProviderEventDetails], None], ) -> None: super().attach(on_emit) - for _, provider in self._registeredProviders: + for _, provider in self._registered_providers: provider.attach(self._handle_provider_event) def detach(self) -> None: - for _, provider in self._registeredProviders: + for _, provider in self._registered_providers: provider.detach() super().detach() @@ -448,8 +455,8 @@ def initialize_provider( except Exception as err: return provider_name, err - with ThreadPoolExecutor(max_workers=len(self._registeredProviders)) as executor: - init_results = list(executor.map(initialize_provider, self._registeredProviders)) + with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: + init_results = list(executor.map(initialize_provider, self._registered_providers)) error_messages: list[str] = [] event_details = ProviderEventDetails() @@ -464,13 +471,14 @@ def initialize_provider( ) event_details = self._details_from_exception(err, provider_name) - self._refresh_aggregate_status(event_details) + self._initialized = True + self._refresh_aggregate_status(event_details, force=True) if error_messages: raise GeneralError(f"Multi-provider initialization failed: {'; '.join(error_messages)}") def shutdown(self) -> None: - for _, provider in self._registeredProviders: + for _, provider in self._registered_providers: provider.detach() def shutdown_provider(entry: tuple[str, FeatureProvider]) -> None: @@ -480,13 +488,13 @@ def shutdown_provider(entry: tuple[str, FeatureProvider]) -> None: except Exception: logger.exception("Provider '%s' shutdown failed", provider_name) - with ThreadPoolExecutor(max_workers=len(self._registeredProviders)) as executor: - list(executor.map(shutdown_provider, self._registeredProviders)) + with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: + list(executor.map(shutdown_provider, self._registered_providers)) - with self._statusLock: + with self._status_lock: self._provider_statuses = { provider_name: ProviderStatus.NOT_READY - for provider_name, _ in self._registeredProviders + for provider_name, _ in self._registered_providers } self._aggregate_status = ProviderStatus.NOT_READY @@ -519,14 +527,29 @@ def _handle_provider_event( def _set_provider_status( self, provider_name: str, provider_status: ProviderStatus ) -> None: - with self._statusLock: + with self._status_lock: self._provider_statuses[provider_name] = provider_status def _mark_provider_ready(self, provider_name: str) -> None: - with self._statusLock: + with self._status_lock: if self._provider_statuses.get(provider_name) == ProviderStatus.NOT_READY: self._provider_statuses[provider_name] = ProviderStatus.READY + def _should_evaluate_provider(self, provider_name: str) -> bool: + """Check if a provider should be evaluated based on its status. + + Providers with NOT_READY or FATAL status are skipped, matching the + JS SDK reference behavior (shouldEvaluateThisProvider). + + Before initialize() has been called, all providers are eligible since + status tracking is not yet meaningful. + """ + if not self._initialized: + return True + with self._status_lock: + status = self._provider_statuses.get(provider_name, ProviderStatus.NOT_READY) + return status not in (ProviderStatus.NOT_READY, ProviderStatus.FATAL) + def _calculate_aggregate_status(self) -> ProviderStatus: statuses = tuple(self._provider_statuses.values()) if not statuses: @@ -536,13 +559,17 @@ def _calculate_aggregate_status(self) -> ProviderStatus: return status return ProviderStatus.NOT_READY - def _refresh_aggregate_status(self, details: ProviderEventDetails) -> None: + def _refresh_aggregate_status( + self, + details: ProviderEventDetails, + force: bool = False, + ) -> None: event_to_emit: ProviderEvent | None = None event_details = details - with self._statusLock: + with self._status_lock: previous_status = self._aggregate_status aggregate_status = self._calculate_aggregate_status() - if previous_status == aggregate_status: + if previous_status == aggregate_status and not force: return self._aggregate_status = aggregate_status event_to_emit = self._event_from_status(aggregate_status) @@ -676,7 +703,7 @@ def _evaluate_provider_sync( # noqa: PLR0913 FlagResolutionDetails[T], ], ) -> _ProviderEvaluation[T]: - runtime = self._hookRuntime.get() + runtime = self._hook_runtime.get() if runtime is None or not provider.get_provider_hooks(): try: return _ProviderEvaluation( @@ -763,7 +790,7 @@ async def _evaluate_provider_async( # noqa: PLR0913 Awaitable[FlagResolutionDetails[T]], ], ) -> _ProviderEvaluation[T]: - runtime = self._hookRuntime.get() + runtime = self._hook_runtime.get() if runtime is None or not provider.get_provider_hooks(): try: return _ProviderEvaluation( @@ -850,10 +877,21 @@ def _evaluate_with_providers( FlagResolutionDetails[T], ], ) -> FlagResolutionDetails[T]: + eligible_providers = [ + (name, provider) + for name, provider in self._registered_providers + if self._should_evaluate_provider(name) + ] + if self.strategy.run_mode == "parallel": - with ThreadPoolExecutor(max_workers=len(self._registeredProviders)) as executor: + # Each worker thread gets its own copy of the current context so + # that ContextVars (e.g. _hook_runtime) are propagated correctly. + # ThreadPoolExecutor does not automatically copy context on + # Python < 3.12, and a single Context.run() is not reentrant. + with ThreadPoolExecutor(max_workers=len(eligible_providers) or 1) as executor: futures = [ executor.submit( + contextvars.copy_context().run, self._evaluate_provider_sync, provider_name, provider, @@ -863,7 +901,7 @@ def _evaluate_with_providers( evaluation_context, resolve_fn, ) - for provider_name, provider in self._registeredProviders + for provider_name, provider in eligible_providers ] evaluations = [future.result() for future in futures] return typing.cast( @@ -879,7 +917,7 @@ def _evaluate_with_providers( ) evaluations: list[_ProviderEvaluation[T]] = [] - for provider_name, provider in self._registeredProviders: + for provider_name, provider in eligible_providers: evaluation = self._evaluate_provider_sync( provider_name, provider, @@ -926,6 +964,12 @@ async def _evaluate_with_providers_async( Awaitable[FlagResolutionDetails[T]], ], ) -> FlagResolutionDetails[T]: + eligible_providers = [ + (name, provider) + for name, provider in self._registered_providers + if self._should_evaluate_provider(name) + ] + if self.strategy.run_mode == "parallel": tasks = [ asyncio.create_task( @@ -939,7 +983,7 @@ async def _evaluate_with_providers_async( resolve_fn, ) ) - for provider_name, provider in self._registeredProviders + for provider_name, provider in eligible_providers ] evaluations = await asyncio.gather(*tasks) return typing.cast( @@ -955,7 +999,7 @@ async def _evaluate_with_providers_async( ) evaluations: list[_ProviderEvaluation[T]] = [] - for provider_name, provider in self._registeredProviders: + for provider_name, provider in eligible_providers: evaluation = await self._evaluate_provider_async( provider_name, provider, From d6fda15cf950995d52633edb1ecb2065150426b3 Mon Sep 17 00:00:00 2001 From: Jonathan Norris Date: Mon, 16 Mar 2026 15:05:42 -0400 Subject: [PATCH 8/8] fix: resolve CI failures in multi-provider - Add _is_internal_hook_provider class marker to avoid Mock false positives with runtime_checkable Protocol isinstance checks - Fix mypy no-redef errors by hoisting evaluations declaration before branch - Fix mypy no-any-return by assigning to typed local before returning - Fix mypy attr-defined by using _as_internal_hook_provider narrowing helper - Apply ruff formatting fixes Signed-off-by: Jonathan Norris --- openfeature/client.py | 29 +++++++---- openfeature/provider/__init__.py | 6 +++ openfeature/provider/_registry.py | 7 ++- openfeature/provider/multi_provider.py | 72 ++++++++++++++++++-------- tests/test_multi_provider.py | 24 ++++++--- uv.lock | 10 ++-- 6 files changed, 103 insertions(+), 45 deletions(-) diff --git a/openfeature/client.py b/openfeature/client.py index bc8b9bf7..09adb211 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -470,11 +470,19 @@ def _establish_hooks_and_provider( merged_eval_context, ) + def _as_internal_hook_provider( + self, provider: FeatureProvider + ) -> InternalHookProvider | None: + """Return the provider as InternalHookProvider if it opts in, else None.""" + if getattr(provider, "_is_internal_hook_provider", False) and isinstance( + provider, InternalHookProvider + ): + return provider + return None + def _provider_uses_internal_hooks(self, provider: FeatureProvider) -> bool: - return ( - isinstance(provider, InternalHookProvider) - and provider.uses_internal_provider_hooks() - ) + ihp = self._as_internal_hook_provider(provider) + return ihp is not None and ihp.uses_internal_provider_hooks() def _set_internal_provider_hook_runtime( self, @@ -482,23 +490,24 @@ def _set_internal_provider_hook_runtime( flag_type: FlagType, hook_hints: HookHints, ) -> object | None: - if not isinstance(provider, InternalHookProvider): - return None - if not provider.uses_internal_provider_hooks(): + ihp = self._as_internal_hook_provider(provider) + if ihp is None or not ihp.uses_internal_provider_hooks(): return None - return provider.set_internal_provider_hook_runtime( + result: object | None = ihp.set_internal_provider_hook_runtime( flag_type=flag_type, client_metadata=self.get_metadata(), hook_hints=hook_hints, ) + return result def _reset_internal_provider_hook_runtime( self, provider: FeatureProvider, runtime_token: object | None ) -> None: if runtime_token is None: return - if isinstance(provider, InternalHookProvider): - provider.reset_internal_provider_hook_runtime(runtime_token) + ihp = self._as_internal_hook_provider(provider) + if ihp is not None: + ihp.reset_internal_provider_hook_runtime(runtime_token) def _assert_provider_status( self, diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 4a000790..8bcd7721 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -139,8 +139,14 @@ class InternalHookProvider(typing.Protocol): The registry will also use get_status() from this protocol instead of its own internal status tracking for providers that implement it. + + Implementations must set ``_is_internal_hook_provider = True`` as a class + attribute. This marker is checked alongside ``isinstance`` to avoid false + positives from duck-typed objects (e.g. ``Mock``). """ + _is_internal_hook_provider: typing.ClassVar[bool] + def uses_internal_provider_hooks(self) -> bool: ... def set_internal_provider_hook_runtime( diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index cb4083a3..59fede6f 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -124,7 +124,12 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None: def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: # Only InternalHookProvider implementations (e.g. MultiProvider) manage # their own status. For all other providers, use the registry's tracking. - if isinstance(provider, InternalHookProvider): + # We check _is_internal_hook_provider (a concrete class attribute) in + # addition to isinstance, because runtime_checkable Protocols match any + # object that has the right method names — including Mock objects. + if getattr(provider, "_is_internal_hook_provider", False) and isinstance( + provider, InternalHookProvider + ): return provider.get_status() return self._provider_status.get(provider, ProviderStatus.NOT_READY) diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py index 965a8337..0cc241c9 100644 --- a/openfeature/provider/multi_provider.py +++ b/openfeature/provider/multi_provider.py @@ -109,7 +109,9 @@ def _validate_run_mode(run_mode: RunMode) -> RunMode: def _format_result_error( provider_name: str, result: FlagResolutionDetails[FlagValueType] ) -> str: - error_code = result.error_code.value if result.error_code else ErrorCode.GENERAL.value + error_code = ( + result.error_code.value if result.error_code else ErrorCode.GENERAL.value + ) error_message = result.error_message or "Unknown error" return f"{provider_name}: {error_code} ({error_message})" @@ -280,7 +282,9 @@ def determine_final_result( evaluations: list[_ProviderEvaluation[FlagValueType]], ) -> FlagResolutionDetails[FlagValueType]: failed_evaluations = [ - evaluation for evaluation in evaluations if not _is_success(evaluation.result) + evaluation + for evaluation in evaluations + if not _is_success(evaluation.result) ] if failed_evaluations: return _build_aggregated_error( @@ -338,6 +342,8 @@ class MultiProvider(AbstractProvider): ProviderStatus.READY, ) + _is_internal_hook_provider: typing.ClassVar[bool] = True + def __init__( self, providers: list[ProviderEntry], @@ -365,7 +371,9 @@ def __init__( provider_name: ProviderStatus.NOT_READY for provider_name, _ in self._registered_providers } - validate_provider_names = getattr(self.strategy, "validate_provider_names", None) + validate_provider_names = getattr( + self.strategy, "validate_provider_names", None + ) if callable(validate_provider_names): validate_provider_names( [provider_name for provider_name, _ in self._registered_providers] @@ -455,8 +463,12 @@ def initialize_provider( except Exception as err: return provider_name, err - with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: - init_results = list(executor.map(initialize_provider, self._registered_providers)) + with ThreadPoolExecutor( + max_workers=len(self._registered_providers) + ) as executor: + init_results = list( + executor.map(initialize_provider, self._registered_providers) + ) error_messages: list[str] = [] event_details = ProviderEventDetails() @@ -475,7 +487,9 @@ def initialize_provider( self._refresh_aggregate_status(event_details, force=True) if error_messages: - raise GeneralError(f"Multi-provider initialization failed: {'; '.join(error_messages)}") + raise GeneralError( + f"Multi-provider initialization failed: {'; '.join(error_messages)}" + ) def shutdown(self) -> None: for _, provider in self._registered_providers: @@ -488,7 +502,9 @@ def shutdown_provider(entry: tuple[str, FeatureProvider]) -> None: except Exception: logger.exception("Provider '%s' shutdown failed", provider_name) - with ThreadPoolExecutor(max_workers=len(self._registered_providers)) as executor: + with ThreadPoolExecutor( + max_workers=len(self._registered_providers) + ) as executor: list(executor.map(shutdown_provider, self._registered_providers)) with self._status_lock: @@ -522,7 +538,9 @@ def _handle_provider_event( provider_name, self._status_from_event_details(details), ) - self._refresh_aggregate_status(self._with_provider_metadata(details, provider_name)) + self._refresh_aggregate_status( + self._with_provider_metadata(details, provider_name) + ) def _set_provider_status( self, provider_name: str, provider_status: ProviderStatus @@ -547,7 +565,9 @@ def _should_evaluate_provider(self, provider_name: str) -> bool: if not self._initialized: return True with self._status_lock: - status = self._provider_statuses.get(provider_name, ProviderStatus.NOT_READY) + status = self._provider_statuses.get( + provider_name, ProviderStatus.NOT_READY + ) return status not in (ProviderStatus.NOT_READY, ProviderStatus.FATAL) def _calculate_aggregate_status(self) -> ProviderStatus: @@ -577,7 +597,9 @@ def _refresh_aggregate_status( if event_to_emit is not None: self.emit(event_to_emit, event_details) - def _event_from_status(self, provider_status: ProviderStatus) -> ProviderEvent | None: + def _event_from_status( + self, provider_status: ProviderStatus + ) -> ProviderEvent | None: if provider_status == ProviderStatus.READY: return ProviderEvent.PROVIDER_READY if provider_status == ProviderStatus.STALE: @@ -632,9 +654,7 @@ def _details_from_exception( self, err: Exception, provider_name: str ) -> ProviderEventDetails: error_code = ( - err.error_code - if isinstance(err, OpenFeatureError) - else ErrorCode.GENERAL + err.error_code if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL ) error_message = self._error_message_from_exception(err) return ProviderEventDetails( @@ -652,9 +672,7 @@ def _resolution_from_exception( self, default_value: T, err: Exception ) -> FlagResolutionDetails[T]: error_code = ( - err.error_code - if isinstance(err, OpenFeatureError) - else ErrorCode.GENERAL + err.error_code if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL ) error_message = self._error_message_from_exception(err) return FlagResolutionDetails( @@ -709,7 +727,9 @@ def _evaluate_provider_sync( # noqa: PLR0913 return _ProviderEvaluation( provider_name=provider_name, provider=provider, - result=resolve_fn(provider, flag_key, default_value, evaluation_context), + result=resolve_fn( + provider, flag_key, default_value, evaluation_context + ), ) except Exception as err: return _ProviderEvaluation( @@ -821,7 +841,9 @@ async def _evaluate_provider_async( # noqa: PLR0913 try: before_context = before_hooks(flag_type, hook_contexts, runtime.hook_hints) resolved_context = provider_context.merge(before_context) - resolution = await resolve_fn(provider, flag_key, default_value, resolved_context) + resolution = await resolve_fn( + provider, flag_key, default_value, resolved_context + ) flag_evaluation = resolution.to_flag_evaluation_details(flag_key) if err := flag_evaluation.get_exception(): error_hooks( @@ -883,12 +905,16 @@ def _evaluate_with_providers( if self._should_evaluate_provider(name) ] + evaluations: list[_ProviderEvaluation[T]] = [] + if self.strategy.run_mode == "parallel": # Each worker thread gets its own copy of the current context so # that ContextVars (e.g. _hook_runtime) are propagated correctly. # ThreadPoolExecutor does not automatically copy context on # Python < 3.12, and a single Context.run() is not reentrant. - with ThreadPoolExecutor(max_workers=len(eligible_providers) or 1) as executor: + with ThreadPoolExecutor( + max_workers=len(eligible_providers) or 1 + ) as executor: futures = [ executor.submit( contextvars.copy_context().run, @@ -916,7 +942,6 @@ def _evaluate_with_providers( ), ) - evaluations: list[_ProviderEvaluation[T]] = [] for provider_name, provider in eligible_providers: evaluation = self._evaluate_provider_sync( provider_name, @@ -970,6 +995,8 @@ async def _evaluate_with_providers_async( if self._should_evaluate_provider(name) ] + evaluations: list[_ProviderEvaluation[T]] = [] + if self.strategy.run_mode == "parallel": tasks = [ asyncio.create_task( @@ -985,7 +1012,7 @@ async def _evaluate_with_providers_async( ) for provider_name, provider in eligible_providers ] - evaluations = await asyncio.gather(*tasks) + evaluations = list(await asyncio.gather(*tasks)) return typing.cast( FlagResolutionDetails[T], self.strategy.determine_final_result( @@ -993,12 +1020,11 @@ async def _evaluate_with_providers_async( default_value, typing.cast( list[_ProviderEvaluation[FlagValueType]], - list(evaluations), + evaluations, ), ), ) - evaluations: list[_ProviderEvaluation[T]] = [] for provider_name, provider in eligible_providers: evaluation = await self._evaluate_provider_async( provider_name, diff --git a/tests/test_multi_provider.py b/tests/test_multi_provider.py index aa6cea26..a7959666 100644 --- a/tests/test_multi_provider.py +++ b/tests/test_multi_provider.py @@ -60,7 +60,9 @@ def resolve_boolean_details( ) -> FlagResolutionDetails[bool]: del flag_key self.resolveCount += 1 - self.seenContexts.append(dict((evaluation_context or EvaluationContext()).attributes)) + self.seenContexts.append( + dict((evaluation_context or EvaluationContext()).attributes) + ) if self.sync_blocker is not None: self.sync_blocker.wait() if self.booleanException is not None: @@ -77,7 +79,9 @@ async def resolve_boolean_details_async( ) -> FlagResolutionDetails[bool]: del flag_key self.resolveCount += 1 - self.seenContexts.append(dict((evaluation_context or EvaluationContext()).attributes)) + self.seenContexts.append( + dict((evaluation_context or EvaluationContext()).attributes) + ) if self.async_blocker is not None: await self.async_blocker.wait() if self.booleanException is not None: @@ -225,7 +229,9 @@ def test_comparison_strategy_rejects_unknown_fallback_provider(): first_provider = BooleanProvider("first") second_provider = BooleanProvider("second") - with pytest.raises(ValueError, match="Fallback provider 'missing' is not registered"): + with pytest.raises( + ValueError, match="Fallback provider 'missing' is not registered" + ): MultiProvider( [ ProviderEntry(first_provider, name="first"), @@ -311,7 +317,9 @@ def test_first_successful_skips_general_errors(): def test_first_successful_aggregates_errors_when_all_providers_fail(): first_provider = BooleanProvider("first", boolean_exception=GeneralError("first")) - second_provider = BooleanProvider("second", boolean_exception=GeneralError("second")) + second_provider = BooleanProvider( + "second", boolean_exception=GeneralError("second") + ) multi_provider = MultiProvider( [ ProviderEntry(first_provider, name="first"), @@ -583,8 +591,12 @@ def test_multi_provider_forwards_configuration_changed_events(): spy.provider_configuration_changed, ) - first_provider.emit_provider_configuration_changed(ProviderEventDetails(message="one")) - second_provider.emit_provider_configuration_changed(ProviderEventDetails(message="two")) + first_provider.emit_provider_configuration_changed( + ProviderEventDetails(message="one") + ) + second_provider.emit_provider_configuration_changed( + ProviderEventDetails(message="two") + ) assert spy.provider_configuration_changed.call_count == 2 diff --git a/uv.lock b/uv.lock index 5b5ed9ab..cd187edc 100644 --- a/uv.lock +++ b/uv.lock @@ -259,12 +259,12 @@ dev = [ [package.metadata.requires-dev] dev = [ - { name = "behave" }, - { name = "coverage", extras = ["toml"], specifier = ">=6.5" }, - { name = "poethepoet", specifier = ">=0.40.0" }, + { name = "behave", specifier = ">=1.3.0,<2.0.0" }, + { name = "coverage", extras = ["toml"], specifier = ">=7.10.0,<8.0.0" }, + { name = "poethepoet", specifier = ">=0.40.0,<1.0.0" }, { name = "pre-commit" }, - { name = "pytest", specifier = ">=9.0.0" }, - { name = "pytest-asyncio", specifier = ">=1.3.0" }, + { name = "pytest", specifier = ">=9.0.0,<10.0.0" }, + { name = "pytest-asyncio", specifier = ">=1.3.0,<2.0.0" }, ] [[package]]