diff --git a/openfeature/client.py b/openfeature/client.py index a02693c1..09adb211 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -30,7 +30,7 @@ before_hooks, error_hooks, ) -from openfeature.provider import FeatureProvider, ProviderStatus +from openfeature.provider import FeatureProvider, InternalHookProvider, ProviderStatus from openfeature.provider._registry import provider_registry from openfeature.transaction_context import get_transaction_context @@ -429,6 +429,11 @@ def _establish_hooks_and_provider( client_metadata = self.get_metadata() provider_metadata = provider.get_metadata() + provider_hooks = ( + [] + if self._provider_uses_internal_hooks(provider) + else provider.get_provider_hooks() + ) # Hooks need to be handled in different orders at different stages # in the flag evaluation @@ -450,7 +455,7 @@ def _establish_hooks_and_provider( get_hooks(), self.hooks, evaluation_hooks, - provider.get_provider_hooks(), + provider_hooks, ) ] # after, error, finally: Provider, Invocation, Client, API @@ -465,6 +470,45 @@ def _establish_hooks_and_provider( merged_eval_context, ) + def _as_internal_hook_provider( + self, provider: FeatureProvider + ) -> InternalHookProvider | None: + """Return the provider as InternalHookProvider if it opts in, else None.""" + if getattr(provider, "_is_internal_hook_provider", False) and isinstance( + provider, InternalHookProvider + ): + return provider + return None + + def _provider_uses_internal_hooks(self, provider: FeatureProvider) -> bool: + ihp = self._as_internal_hook_provider(provider) + return ihp is not None and ihp.uses_internal_provider_hooks() + + def _set_internal_provider_hook_runtime( + self, + provider: FeatureProvider, + flag_type: FlagType, + hook_hints: HookHints, + ) -> object | None: + ihp = self._as_internal_hook_provider(provider) + if ihp is None or not ihp.uses_internal_provider_hooks(): + return None + result: object | None = ihp.set_internal_provider_hook_runtime( + flag_type=flag_type, + client_metadata=self.get_metadata(), + hook_hints=hook_hints, + ) + return result + + def _reset_internal_provider_hook_runtime( + self, provider: FeatureProvider, runtime_token: object | None + ) -> None: + if runtime_token is None: + return + ihp = self._as_internal_hook_provider(provider) + if ihp is not None: + ihp.reset_internal_provider_hook_runtime(runtime_token) + def _assert_provider_status( self, ) -> OpenFeatureError | None: @@ -611,13 +655,21 @@ async def evaluate_flag_details_async( merged_eval_context, ) - flag_evaluation = await self._create_provider_evaluation_async( + runtime_token = self._set_internal_provider_hook_runtime( provider, flag_type, - flag_key, - default_value, - merged_context, + hook_hints, ) + try: + flag_evaluation = await self._create_provider_evaluation_async( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + finally: + self._reset_internal_provider_hook_runtime(provider, runtime_token) if err := flag_evaluation.get_exception(): error_hooks( flag_type, err, reversed_merged_hooks_and_context, hook_hints @@ -787,13 +839,21 @@ def evaluate_flag_details( merged_eval_context, ) - flag_evaluation = self._create_provider_evaluation( + runtime_token = self._set_internal_provider_hook_runtime( provider, flag_type, - flag_key, - default_value, - merged_context, + hook_hints, ) + try: + flag_evaluation = self._create_provider_evaluation( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + finally: + self._reset_internal_provider_hook_runtime(provider, runtime_token) if err := flag_evaluation.get_exception(): error_hooks( flag_type, err, reversed_merged_hooks_and_context, hook_hints diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index aea5069f..8bcd7721 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -15,7 +15,19 @@ if typing.TYPE_CHECKING: from openfeature.flag_evaluation import FlagValueType -__all__ = ["AbstractProvider", "FeatureProvider", "Metadata", "ProviderStatus"] +__all__ = [ + "AbstractProvider", + "ComparisonStrategy", + "EvaluationStrategy", + "FeatureProvider", + "FirstMatchStrategy", + "FirstSuccessfulStrategy", + "InternalHookProvider", + "Metadata", + "MultiProvider", + "ProviderEntry", + "ProviderStatus", +] class ProviderStatus(Enum): @@ -117,6 +129,38 @@ async def resolve_object_details_async( ]: ... +@typing.runtime_checkable +class InternalHookProvider(typing.Protocol): + """Protocol for providers that manage their own provider hook execution. + + Providers implementing this protocol (e.g. MultiProvider) handle provider + hook lifecycle internally. The client will skip its own provider hook + invocations and instead delegate to the provider via the set/reset methods. + + The registry will also use get_status() from this protocol instead of its + own internal status tracking for providers that implement it. + + Implementations must set ``_is_internal_hook_provider = True`` as a class + attribute. This marker is checked alongside ``isinstance`` to avoid false + positives from duck-typed objects (e.g. ``Mock``). + """ + + _is_internal_hook_provider: typing.ClassVar[bool] + + def uses_internal_provider_hooks(self) -> bool: ... + + def set_internal_provider_hook_runtime( + self, + flag_type: typing.Any, + client_metadata: typing.Any, + hook_hints: typing.Any, + ) -> typing.Any: ... + + def reset_internal_provider_hook_runtime(self, token: typing.Any) -> None: ... + + def get_status(self) -> ProviderStatus: ... + + class AbstractProvider(FeatureProvider): def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: # this makes sure to invoke the parent of `FeatureProvider` -> `object` @@ -247,3 +291,13 @@ def emit_provider_stale(self, details: ProviderEventDetails) -> None: def emit(self, event: ProviderEvent, details: ProviderEventDetails) -> None: if hasattr(self, "_on_emit"): self._on_emit(self, event, details) + + +from .multi_provider import ( # noqa: E402 + ComparisonStrategy, + EvaluationStrategy, + FirstMatchStrategy, + FirstSuccessfulStrategy, + MultiProvider, + ProviderEntry, +) diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index bf8fa9a8..59fede6f 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -5,7 +5,7 @@ ProviderEventDetails, ) from openfeature.exception import ErrorCode, GeneralError, OpenFeatureError -from openfeature.provider import FeatureProvider, ProviderStatus +from openfeature.provider import FeatureProvider, InternalHookProvider, ProviderStatus from openfeature.provider.no_op_provider import NoOpProvider @@ -80,23 +80,30 @@ def _initialize_provider(self, provider: FeatureProvider) -> None: try: if hasattr(provider, "initialize"): provider.initialize(self._get_evaluation_context()) - self.dispatch_event( - provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() - ) + # InternalHookProvider (e.g. MultiProvider) emits its own events + # during initialize(), so only dispatch PROVIDER_READY if the + # provider hasn't already transitioned away from NOT_READY. + if self.get_provider_status(provider) == ProviderStatus.NOT_READY: + self.dispatch_event( + provider, ProviderEvent.PROVIDER_READY, ProviderEventDetails() + ) except Exception as err: error_code = ( err.error_code if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL ) - self.dispatch_event( - provider, - ProviderEvent.PROVIDER_ERROR, - ProviderEventDetails( - message=f"Provider initialization failed: {err}", - error_code=error_code, - ), - ) + # Same guard: skip if the provider already emitted its own error + # event and transitioned out of NOT_READY. + if self.get_provider_status(provider) == ProviderStatus.NOT_READY: + self.dispatch_event( + provider, + ProviderEvent.PROVIDER_ERROR, + ProviderEventDetails( + message=f"Provider initialization failed: {err}", + error_code=error_code, + ), + ) def _shutdown_provider(self, provider: FeatureProvider) -> None: try: @@ -115,6 +122,15 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None: provider.detach() def get_provider_status(self, provider: FeatureProvider) -> ProviderStatus: + # Only InternalHookProvider implementations (e.g. MultiProvider) manage + # their own status. For all other providers, use the registry's tracking. + # We check _is_internal_hook_provider (a concrete class attribute) in + # addition to isinstance, because runtime_checkable Protocols match any + # object that has the right method names — including Mock objects. + if getattr(provider, "_is_internal_hook_provider", False) and isinstance( + provider, InternalHookProvider + ): + return provider.get_status() return self._provider_status.get(provider, ProviderStatus.NOT_READY) def dispatch_event( diff --git a/openfeature/provider/multi_provider.py b/openfeature/provider/multi_provider.py new file mode 100644 index 00000000..0cc241c9 --- /dev/null +++ b/openfeature/provider/multi_provider.py @@ -0,0 +1,1262 @@ +from __future__ import annotations + +import asyncio +import contextvars +import logging +import threading +import typing +from collections.abc import Awaitable, Callable, Mapping, Sequence +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass + +from openfeature.evaluation_context import EvaluationContext +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import ErrorCode, GeneralError, OpenFeatureError +from openfeature.flag_evaluation import ( + FlagEvaluationDetails, + FlagResolutionDetails, + FlagType, + FlagValueType, + Reason, +) +from openfeature.hook import Hook, HookContext, HookHints +from openfeature.hook._hook_support import ( + after_all_hooks, + after_hooks, + before_hooks, + error_hooks, +) +from openfeature.provider import ( + AbstractProvider, + FeatureProvider, + Metadata, + ProviderStatus, +) + +__all__ = [ + "ComparisonStrategy", + "EvaluationStrategy", + "FirstMatchStrategy", + "FirstSuccessfulStrategy", + "MultiProvider", + "ProviderEntry", +] + +logger = logging.getLogger("openfeature") + +T = typing.TypeVar("T", bound=FlagValueType) +RunMode: typing.TypeAlias = typing.Literal["sequential", "parallel"] +ComparisonMismatchHandler: typing.TypeAlias = Callable[ + [str, Mapping[str, FlagResolutionDetails[FlagValueType]]], None +] + + +@dataclass(frozen=True) +class ProviderEntry: + provider: FeatureProvider + name: str | None = None + + +@dataclass(frozen=True) +class _ProviderEvaluation(typing.Generic[T]): + provider_name: str + provider: FeatureProvider + result: FlagResolutionDetails[T] + + +@dataclass(frozen=True) +class _ProviderHookRuntime: + flag_type: FlagType + client_metadata: typing.Any + hook_hints: HookHints + + +class EvaluationStrategy(typing.Protocol): + run_mode: RunMode + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: ... + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: ... + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: ... + + +def _is_success(result: FlagResolutionDetails[FlagValueType]) -> bool: + return result.error_code is None and result.reason != Reason.ERROR + + +def _validate_run_mode(run_mode: RunMode) -> RunMode: + if run_mode not in ("sequential", "parallel"): + raise ValueError(f"Unsupported run_mode '{run_mode}'") + return run_mode + + +def _format_result_error( + provider_name: str, result: FlagResolutionDetails[FlagValueType] +) -> str: + error_code = ( + result.error_code.value if result.error_code else ErrorCode.GENERAL.value + ) + error_message = result.error_message or "Unknown error" + return f"{provider_name}: {error_code} ({error_message})" + + +def _build_aggregated_error( + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + prefix: str, +) -> FlagResolutionDetails[FlagValueType]: + if not evaluations: + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=f"{prefix} for flag '{flag_key}': no providers returned a result", + ) + + errors_text = "; ".join( + _format_result_error(evaluation.provider_name, evaluation.result) + for evaluation in evaluations + ) + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=f"{prefix} for flag '{flag_key}': {errors_text}", + ) + + +class FirstMatchStrategy: + def __init__(self, run_mode: RunMode = "sequential") -> None: + self.run_mode = _validate_run_mode(run_mode) + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + return _is_success(result) + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + return result.error_code == ErrorCode.FLAG_NOT_FOUND + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + for evaluation in evaluations: + if self.should_use_result( + flag_key, evaluation.provider_name, evaluation.result + ): + return evaluation.result + if not self.should_continue( + flag_key, evaluation.provider_name, evaluation.result + ): + return evaluation.result + if evaluations: + return evaluations[-1].result + return _build_aggregated_error( + flag_key, + default_value, + evaluations, + "Multi-provider evaluation failed", + ) + + +class FirstSuccessfulStrategy: + def __init__(self, run_mode: RunMode = "sequential") -> None: + self.run_mode = _validate_run_mode(run_mode) + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + return _is_success(result) + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + del result + return True + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + for evaluation in evaluations: + if _is_success(evaluation.result): + return evaluation.result + return _build_aggregated_error( + flag_key, + default_value, + evaluations, + "All providers failed", + ) + + +class ComparisonStrategy: + run_mode: RunMode = "parallel" + + def __init__( + self, + fallback_provider: str | None = None, + on_mismatch: ComparisonMismatchHandler | None = None, + ) -> None: + self.fallback_provider = fallback_provider + self.on_mismatch = on_mismatch + + def validate_provider_names(self, provider_names: Sequence[str]) -> None: + if ( + self.fallback_provider is not None + and self.fallback_provider not in provider_names + ): + raise ValueError( + f"Fallback provider '{self.fallback_provider}' is not registered" + ) + + def should_use_result( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + del result + return False + + def should_continue( + self, + flag_key: str, + provider_name: str, + result: FlagResolutionDetails[FlagValueType], + ) -> bool: + del flag_key + del provider_name + del result + return True + + def determine_final_result( + self, + flag_key: str, + default_value: FlagValueType, + evaluations: list[_ProviderEvaluation[FlagValueType]], + ) -> FlagResolutionDetails[FlagValueType]: + failed_evaluations = [ + evaluation + for evaluation in evaluations + if not _is_success(evaluation.result) + ] + if failed_evaluations: + return _build_aggregated_error( + flag_key, + default_value, + failed_evaluations, + "Comparison strategy received provider errors", + ) + + # The first provider's result is the "final resolution" (used on agreement). + # The fallback provider's result is used on mismatch (per JS SDK reference). + final_evaluation = evaluations[0] + fallback_evaluation = self._select_fallback_evaluation(evaluations) + reference_value = final_evaluation.result.value + has_mismatch = any( + evaluation.result.value != reference_value for evaluation in evaluations + ) + if has_mismatch: + if self.on_mismatch is not None: + mismatch_results = { + evaluation.provider_name: evaluation.result + for evaluation in evaluations + } + try: + self.on_mismatch(flag_key, mismatch_results) + except Exception: + logger.exception( + "Comparison strategy mismatch callback failed for flag '%s'", + flag_key, + ) + return fallback_evaluation.result + return final_evaluation.result + + def _select_fallback_evaluation( + self, evaluations: list[_ProviderEvaluation[FlagValueType]] + ) -> _ProviderEvaluation[FlagValueType]: + if not evaluations: + raise ValueError("ComparisonStrategy requires at least one provider") + if self.fallback_provider is None: + return evaluations[0] + for evaluation in evaluations: + if evaluation.provider_name == self.fallback_provider: + return evaluation + raise ValueError( + f"Fallback provider '{self.fallback_provider}' is not registered" + ) + + +class MultiProvider(AbstractProvider): + _status_precedence: tuple[ProviderStatus, ...] = ( + ProviderStatus.FATAL, + ProviderStatus.NOT_READY, + ProviderStatus.ERROR, + ProviderStatus.STALE, + ProviderStatus.READY, + ) + + _is_internal_hook_provider: typing.ClassVar[bool] = True + + def __init__( + self, + providers: list[ProviderEntry], + strategy: EvaluationStrategy | None = None, + ) -> None: + super().__init__() + if not providers: + raise ValueError("At least one provider must be provided") + + self.strategy = strategy or FirstMatchStrategy() + self._registered_providers: list[tuple[str, FeatureProvider]] = [] + self._provider_names: dict[FeatureProvider, str] = {} + self._provider_statuses: dict[str, ProviderStatus] = {} + self._aggregate_status = ProviderStatus.NOT_READY + self._initialized = False + self._status_lock = threading.Lock() + self._hook_runtime: contextvars.ContextVar[_ProviderHookRuntime | None] = ( + contextvars.ContextVar( + f"multi_provider_hook_runtime:{id(self)}", + default=None, + ) + ) + self._register_providers(providers) + self._provider_statuses = { + provider_name: ProviderStatus.NOT_READY + for provider_name, _ in self._registered_providers + } + validate_provider_names = getattr( + self.strategy, "validate_provider_names", None + ) + if callable(validate_provider_names): + validate_provider_names( + [provider_name for provider_name, _ in self._registered_providers] + ) + + def uses_internal_provider_hooks(self) -> bool: + return True + + def set_internal_provider_hook_runtime( + self, + flag_type: FlagType, + client_metadata: typing.Any, + hook_hints: HookHints, + ) -> contextvars.Token[_ProviderHookRuntime | None]: + return self._hook_runtime.set( + _ProviderHookRuntime( + flag_type=flag_type, + client_metadata=client_metadata, + hook_hints=hook_hints, + ) + ) + + def reset_internal_provider_hook_runtime( + self, token: contextvars.Token[_ProviderHookRuntime | None] + ) -> None: + self._hook_runtime.reset(token) + + def get_status(self) -> ProviderStatus: + with self._status_lock: + return self._aggregate_status + + def _register_providers(self, providers: list[ProviderEntry]) -> None: + name_counts: dict[str, int] = {} + for entry in providers: + metadata_name = entry.provider.get_metadata().name or "provider" + name_counts[metadata_name] = name_counts.get(metadata_name, 0) + 1 + + used_names: set[str] = set() + name_indexes: dict[str, int] = {} + + for entry in providers: + metadata_name = entry.provider.get_metadata().name or "provider" + if entry.name: + if entry.name in used_names: + raise ValueError(f"Provider name '{entry.name}' is not unique") + provider_name = entry.name + elif name_counts[metadata_name] == 1 and metadata_name not in used_names: + provider_name = metadata_name + else: + while True: + name_indexes[metadata_name] = name_indexes.get(metadata_name, 0) + 1 + provider_name = f"{metadata_name}_{name_indexes[metadata_name]}" + if provider_name not in used_names: + break + + used_names.add(provider_name) + self._registered_providers.append((provider_name, entry.provider)) + self._provider_names[entry.provider] = provider_name + + def get_metadata(self) -> Metadata: + return Metadata(name="MultiProvider") + + def get_provider_hooks(self) -> list[Hook]: + return [] + + def attach( + self, + on_emit: Callable[[FeatureProvider, ProviderEvent, ProviderEventDetails], None], + ) -> None: + super().attach(on_emit) + for _, provider in self._registered_providers: + provider.attach(self._handle_provider_event) + + def detach(self) -> None: + for _, provider in self._registered_providers: + provider.detach() + super().detach() + + def initialize(self, evaluation_context: EvaluationContext) -> None: + def initialize_provider( + entry: tuple[str, FeatureProvider], + ) -> tuple[str, Exception | None]: + provider_name, provider = entry + try: + provider.initialize(evaluation_context) + return provider_name, None + except Exception as err: + return provider_name, err + + with ThreadPoolExecutor( + max_workers=len(self._registered_providers) + ) as executor: + init_results = list( + executor.map(initialize_provider, self._registered_providers) + ) + + error_messages: list[str] = [] + event_details = ProviderEventDetails() + for provider_name, err in init_results: + if err is None: + self._mark_provider_ready(provider_name) + continue + provider_status = self._status_from_exception(err) + self._set_provider_status(provider_name, provider_status) + error_messages.append( + f"Provider '{provider_name}' initialization failed: {self._error_message_from_exception(err)}" + ) + event_details = self._details_from_exception(err, provider_name) + + self._initialized = True + self._refresh_aggregate_status(event_details, force=True) + + if error_messages: + raise GeneralError( + f"Multi-provider initialization failed: {'; '.join(error_messages)}" + ) + + def shutdown(self) -> None: + for _, provider in self._registered_providers: + provider.detach() + + def shutdown_provider(entry: tuple[str, FeatureProvider]) -> None: + provider_name, provider = entry + try: + provider.shutdown() + except Exception: + logger.exception("Provider '%s' shutdown failed", provider_name) + + with ThreadPoolExecutor( + max_workers=len(self._registered_providers) + ) as executor: + list(executor.map(shutdown_provider, self._registered_providers)) + + with self._status_lock: + self._provider_statuses = { + provider_name: ProviderStatus.NOT_READY + for provider_name, _ in self._registered_providers + } + self._aggregate_status = ProviderStatus.NOT_READY + + def _handle_provider_event( + self, + provider: FeatureProvider, + event: ProviderEvent, + details: ProviderEventDetails, + ) -> None: + provider_name = self._provider_names.get(provider) + if provider_name is None: + return + if event == ProviderEvent.PROVIDER_CONFIGURATION_CHANGED: + self.emit( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + self._with_provider_metadata(details, provider_name), + ) + return + if event == ProviderEvent.PROVIDER_READY: + self._set_provider_status(provider_name, ProviderStatus.READY) + elif event == ProviderEvent.PROVIDER_STALE: + self._set_provider_status(provider_name, ProviderStatus.STALE) + elif event == ProviderEvent.PROVIDER_ERROR: + self._set_provider_status( + provider_name, + self._status_from_event_details(details), + ) + self._refresh_aggregate_status( + self._with_provider_metadata(details, provider_name) + ) + + def _set_provider_status( + self, provider_name: str, provider_status: ProviderStatus + ) -> None: + with self._status_lock: + self._provider_statuses[provider_name] = provider_status + + def _mark_provider_ready(self, provider_name: str) -> None: + with self._status_lock: + if self._provider_statuses.get(provider_name) == ProviderStatus.NOT_READY: + self._provider_statuses[provider_name] = ProviderStatus.READY + + def _should_evaluate_provider(self, provider_name: str) -> bool: + """Check if a provider should be evaluated based on its status. + + Providers with NOT_READY or FATAL status are skipped, matching the + JS SDK reference behavior (shouldEvaluateThisProvider). + + Before initialize() has been called, all providers are eligible since + status tracking is not yet meaningful. + """ + if not self._initialized: + return True + with self._status_lock: + status = self._provider_statuses.get( + provider_name, ProviderStatus.NOT_READY + ) + return status not in (ProviderStatus.NOT_READY, ProviderStatus.FATAL) + + def _calculate_aggregate_status(self) -> ProviderStatus: + statuses = tuple(self._provider_statuses.values()) + if not statuses: + return ProviderStatus.NOT_READY + for status in self._status_precedence: + if status in statuses: + return status + return ProviderStatus.NOT_READY + + def _refresh_aggregate_status( + self, + details: ProviderEventDetails, + force: bool = False, + ) -> None: + event_to_emit: ProviderEvent | None = None + event_details = details + with self._status_lock: + previous_status = self._aggregate_status + aggregate_status = self._calculate_aggregate_status() + if previous_status == aggregate_status and not force: + return + self._aggregate_status = aggregate_status + event_to_emit = self._event_from_status(aggregate_status) + event_details = self._details_for_status(aggregate_status, details) + if event_to_emit is not None: + self.emit(event_to_emit, event_details) + + def _event_from_status( + self, provider_status: ProviderStatus + ) -> ProviderEvent | None: + if provider_status == ProviderStatus.READY: + return ProviderEvent.PROVIDER_READY + if provider_status == ProviderStatus.STALE: + return ProviderEvent.PROVIDER_STALE + if provider_status in (ProviderStatus.ERROR, ProviderStatus.FATAL): + return ProviderEvent.PROVIDER_ERROR + return None + + def _details_for_status( + self, provider_status: ProviderStatus, details: ProviderEventDetails + ) -> ProviderEventDetails: + error_code = details.error_code + if provider_status == ProviderStatus.FATAL: + error_code = ErrorCode.PROVIDER_FATAL + elif provider_status == ProviderStatus.ERROR and error_code is None: + error_code = ErrorCode.GENERAL + return ProviderEventDetails( + flags_changed=details.flags_changed, + message=details.message, + error_code=error_code, + metadata=dict(details.metadata), + ) + + def _with_provider_metadata( + self, details: ProviderEventDetails, provider_name: str + ) -> ProviderEventDetails: + metadata = dict(details.metadata) + metadata["provider_name"] = provider_name + return ProviderEventDetails( + flags_changed=details.flags_changed, + message=details.message, + error_code=details.error_code, + metadata=metadata, + ) + + def _status_from_event_details( + self, details: ProviderEventDetails + ) -> ProviderStatus: + if details.error_code == ErrorCode.PROVIDER_FATAL: + return ProviderStatus.FATAL + return ProviderStatus.ERROR + + def _status_from_exception(self, err: Exception) -> ProviderStatus: + if ( + isinstance(err, OpenFeatureError) + and err.error_code == ErrorCode.PROVIDER_FATAL + ): + return ProviderStatus.FATAL + return ProviderStatus.ERROR + + def _details_from_exception( + self, err: Exception, provider_name: str + ) -> ProviderEventDetails: + error_code = ( + err.error_code if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL + ) + error_message = self._error_message_from_exception(err) + return ProviderEventDetails( + message=f"Provider '{provider_name}' failed: {error_message}", + error_code=error_code, + metadata={"provider_name": provider_name}, + ) + + def _error_message_from_exception(self, err: Exception) -> str: + if isinstance(err, OpenFeatureError) and err.error_message: + return err.error_message + return str(err) + + def _resolution_from_exception( + self, default_value: T, err: Exception + ) -> FlagResolutionDetails[T]: + error_code = ( + err.error_code if isinstance(err, OpenFeatureError) else ErrorCode.GENERAL + ) + error_message = self._error_message_from_exception(err) + return FlagResolutionDetails( + value=default_value, + reason=Reason.ERROR, + error_code=error_code, + error_message=error_message, + ) + + def _create_provider_hook_contexts( + self, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: FlagValueType, + evaluation_context: EvaluationContext, + client_metadata: typing.Any, + ) -> list[tuple[Hook, HookContext]]: + provider_metadata = provider.get_metadata() + return [ + ( + hook, + HookContext( + flag_key=flag_key, + flag_type=flag_type, + default_value=default_value, + evaluation_context=evaluation_context, + client_metadata=client_metadata, + provider_metadata=provider_metadata, + hook_data={}, + ), + ) + for hook in provider.get_provider_hooks() + ] + + def _evaluate_provider_sync( # noqa: PLR0913 + self, + provider_name: str, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + FlagResolutionDetails[T], + ], + ) -> _ProviderEvaluation[T]: + runtime = self._hook_runtime.get() + if runtime is None or not provider.get_provider_hooks(): + try: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolve_fn( + provider, flag_key, default_value, evaluation_context + ), + ) + except Exception as err: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + + provider_context = evaluation_context or EvaluationContext() + hook_contexts = self._create_provider_hook_contexts( + provider, + flag_type, + flag_key, + default_value, + provider_context, + runtime.client_metadata, + ) + reversed_hook_contexts = list(reversed(hook_contexts)) + flag_evaluation = FlagEvaluationDetails(flag_key=flag_key, value=default_value) + try: + before_context = before_hooks(flag_type, hook_contexts, runtime.hook_hints) + resolved_context = provider_context.merge(before_context) + resolution = resolve_fn(provider, flag_key, default_value, resolved_context) + flag_evaluation = resolution.to_flag_evaluation_details(flag_key) + if err := flag_evaluation.get_exception(): + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + after_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + except Exception as err: + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + finally: + after_all_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + + async def _evaluate_provider_async( # noqa: PLR0913 + self, + provider_name: str, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + Awaitable[FlagResolutionDetails[T]], + ], + ) -> _ProviderEvaluation[T]: + runtime = self._hook_runtime.get() + if runtime is None or not provider.get_provider_hooks(): + try: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=await resolve_fn( + provider, flag_key, default_value, evaluation_context + ), + ) + except Exception as err: + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + + provider_context = evaluation_context or EvaluationContext() + hook_contexts = self._create_provider_hook_contexts( + provider, + flag_type, + flag_key, + default_value, + provider_context, + runtime.client_metadata, + ) + reversed_hook_contexts = list(reversed(hook_contexts)) + flag_evaluation = FlagEvaluationDetails(flag_key=flag_key, value=default_value) + try: + before_context = before_hooks(flag_type, hook_contexts, runtime.hook_hints) + resolved_context = provider_context.merge(before_context) + resolution = await resolve_fn( + provider, flag_key, default_value, resolved_context + ) + flag_evaluation = resolution.to_flag_evaluation_details(flag_key) + if err := flag_evaluation.get_exception(): + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + after_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=resolution, + ) + except Exception as err: + error_hooks( + flag_type, + err, + reversed_hook_contexts, + runtime.hook_hints, + ) + return _ProviderEvaluation( + provider_name=provider_name, + provider=provider, + result=self._resolution_from_exception(default_value, err), + ) + finally: + after_all_hooks( + flag_type, + flag_evaluation, + reversed_hook_contexts, + runtime.hook_hints, + ) + + def _evaluate_with_providers( + self, + flag_type: FlagType, + flag_key: str, + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + FlagResolutionDetails[T], + ], + ) -> FlagResolutionDetails[T]: + eligible_providers = [ + (name, provider) + for name, provider in self._registered_providers + if self._should_evaluate_provider(name) + ] + + evaluations: list[_ProviderEvaluation[T]] = [] + + if self.strategy.run_mode == "parallel": + # Each worker thread gets its own copy of the current context so + # that ContextVars (e.g. _hook_runtime) are propagated correctly. + # ThreadPoolExecutor does not automatically copy context on + # Python < 3.12, and a single Context.run() is not reentrant. + with ThreadPoolExecutor( + max_workers=len(eligible_providers) or 1 + ) as executor: + futures = [ + executor.submit( + contextvars.copy_context().run, + self._evaluate_provider_sync, + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + for provider_name, provider in eligible_providers + ] + evaluations = [future.result() for future in futures] + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), + ) + + for provider_name, provider in eligible_providers: + evaluation = self._evaluate_provider_sync( + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + evaluations.append(evaluation) + if self.strategy.should_use_result( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + return evaluation.result + if not self.strategy.should_continue( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + break + + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), + ) + + async def _evaluate_with_providers_async( + self, + flag_type: FlagType, + flag_key: str, + default_value: T, + evaluation_context: EvaluationContext | None, + resolve_fn: Callable[ + [FeatureProvider, str, T, EvaluationContext | None], + Awaitable[FlagResolutionDetails[T]], + ], + ) -> FlagResolutionDetails[T]: + eligible_providers = [ + (name, provider) + for name, provider in self._registered_providers + if self._should_evaluate_provider(name) + ] + + evaluations: list[_ProviderEvaluation[T]] = [] + + if self.strategy.run_mode == "parallel": + tasks = [ + asyncio.create_task( + self._evaluate_provider_async( + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + ) + for provider_name, provider in eligible_providers + ] + evaluations = list(await asyncio.gather(*tasks)) + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), + ) + + for provider_name, provider in eligible_providers: + evaluation = await self._evaluate_provider_async( + provider_name, + provider, + flag_type, + flag_key, + default_value, + evaluation_context, + resolve_fn, + ) + evaluations.append(evaluation) + if self.strategy.should_use_result( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + return evaluation.result + if not self.strategy.should_continue( + flag_key, + provider_name, + typing.cast(FlagResolutionDetails[FlagValueType], evaluation.result), + ): + break + + return typing.cast( + FlagResolutionDetails[T], + self.strategy.determine_final_result( + flag_key, + default_value, + typing.cast( + list[_ProviderEvaluation[FlagValueType]], + evaluations, + ), + ), + ) + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + return self._evaluate_with_providers( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_boolean_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + return await self._evaluate_with_providers_async( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_boolean_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + return self._evaluate_with_providers( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_string_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + return await self._evaluate_with_providers_async( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_string_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + return self._evaluate_with_providers( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_integer_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + return await self._evaluate_with_providers_async( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_integer_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + return self._evaluate_with_providers( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_float_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + return await self._evaluate_with_providers_async( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_float_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + def resolve_object_details( + self, + flag_key: str, + default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + return self._evaluate_with_providers( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_object_details( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) + + async def resolve_object_details_async( + self, + flag_key: str, + default_value: Sequence[FlagValueType] | Mapping[str, FlagValueType], + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[Sequence[FlagValueType] | Mapping[str, FlagValueType]]: + return await self._evaluate_with_providers_async( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + lambda provider, resolved_flag_key, resolved_default_value, resolved_context: ( + provider.resolve_object_details_async( + resolved_flag_key, + resolved_default_value, + resolved_context, + ) + ), + ) diff --git a/tests/test_multi_provider.py b/tests/test_multi_provider.py new file mode 100644 index 00000000..a7959666 --- /dev/null +++ b/tests/test_multi_provider.py @@ -0,0 +1,619 @@ +import asyncio +import threading +from unittest.mock import MagicMock + +import pytest + +from openfeature import api +from openfeature.evaluation_context import EvaluationContext +from openfeature.event import ProviderEvent, ProviderEventDetails +from openfeature.exception import ErrorCode, GeneralError +from openfeature.flag_evaluation import ( + FlagEvaluationDetails, + FlagResolutionDetails, + Reason, +) +from openfeature.hook import Hook, HookContext, HookHints +from openfeature.provider import ( + AbstractProvider, + ComparisonStrategy, + FirstMatchStrategy, + FirstSuccessfulStrategy, + Metadata, + MultiProvider, + ProviderEntry, + ProviderStatus, +) + + +class BooleanProvider(AbstractProvider): + def __init__( + self, + name: str, + boolean_result: FlagResolutionDetails[bool] | None = None, + boolean_exception: Exception | None = None, + hook_list: list[Hook] | None = None, + sync_blocker: "SyncBlocker | None" = None, + async_blocker: "AsyncBlocker | None" = None, + ) -> None: + super().__init__() + self.name = name + self.booleanResult = boolean_result + self.booleanException = boolean_exception + self.hookList = hook_list or [] + self.sync_blocker = sync_blocker + self.async_blocker = async_blocker + self.resolveCount = 0 + self.seenContexts: list[dict[str, object]] = [] + + def get_metadata(self) -> Metadata: + return Metadata(name=self.name) + + def get_provider_hooks(self) -> list[Hook]: + return self.hookList + + def resolve_boolean_details( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + del flag_key + self.resolveCount += 1 + self.seenContexts.append( + dict((evaluation_context or EvaluationContext()).attributes) + ) + if self.sync_blocker is not None: + self.sync_blocker.wait() + if self.booleanException is not None: + raise self.booleanException + if self.booleanResult is not None: + return self.booleanResult + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[bool]: + del flag_key + self.resolveCount += 1 + self.seenContexts.append( + dict((evaluation_context or EvaluationContext()).attributes) + ) + if self.async_blocker is not None: + await self.async_blocker.wait() + if self.booleanException is not None: + raise self.booleanException + if self.booleanResult is not None: + return self.booleanResult + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_string_details( + self, + flag_key: str, + default_value: str, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[str]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_integer_details( + self, + flag_key: str, + default_value: int, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[int]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_float_details( + self, + flag_key: str, + default_value: float, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[float]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + def resolve_object_details( + self, + flag_key: str, + default_value: dict | list, + evaluation_context: EvaluationContext | None = None, + ) -> FlagResolutionDetails[dict | list]: + del flag_key + del evaluation_context + return FlagResolutionDetails(value=default_value, reason=Reason.DEFAULT) + + +class RecordingHook(Hook): + def __init__(self, hook_name: str) -> None: + self.hookName = hook_name + self.events: list[str] = [] + + def before( + self, hook_context: HookContext, hints: HookHints + ) -> EvaluationContext | None: + del hook_context + del hints + self.events.append("before") + return EvaluationContext(attributes={"hookOwner": self.hookName}) + + def after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[object], + hints: HookHints, + ) -> None: + del hook_context + del details + del hints + self.events.append("after") + + def error( + self, hook_context: HookContext, exception: Exception, hints: HookHints + ) -> None: + del hook_context + del exception + del hints + self.events.append("error") + + def finally_after( + self, + hook_context: HookContext, + details: FlagEvaluationDetails[object], + hints: HookHints, + ) -> None: + del hook_context + del details + del hints + self.events.append("finally") + + +class SyncBlocker: + def __init__(self, expected_count: int) -> None: + self.expectedCount = expected_count + self.enteredCount = 0 + self.enteredEvent = threading.Event() + self.releaseEvent = threading.Event() + self.lock = threading.Lock() + + def wait(self) -> None: + with self.lock: + self.enteredCount += 1 + if self.enteredCount == self.expectedCount: + self.enteredEvent.set() + assert self.releaseEvent.wait(timeout=2) + + +class AsyncBlocker: + def __init__(self, expected_count: int) -> None: + self.expectedCount = expected_count + self.enteredCount = 0 + self.enteredEvent = asyncio.Event() + self.releaseEvent = asyncio.Event() + self.lock = asyncio.Lock() + + async def wait(self) -> None: + async with self.lock: + self.enteredCount += 1 + if self.enteredCount == self.expectedCount: + self.enteredEvent.set() + await asyncio.wait_for(self.releaseEvent.wait(), timeout=2) + + +def test_multi_provider_requires_at_least_one_provider(): + with pytest.raises(ValueError, match="At least one provider must be provided"): + MultiProvider([]) + + +def test_multi_provider_rejects_duplicate_explicit_names(): + first_provider = BooleanProvider("provider") + second_provider = BooleanProvider("provider") + + with pytest.raises(ValueError, match="Provider name 'duplicate' is not unique"): + MultiProvider( + [ + ProviderEntry(first_provider, name="duplicate"), + ProviderEntry(second_provider, name="duplicate"), + ] + ) + + +def test_comparison_strategy_rejects_unknown_fallback_provider(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + + with pytest.raises( + ValueError, match="Fallback provider 'missing' is not registered" + ): + MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=ComparisonStrategy(fallback_provider="missing"), + ) + + +def test_first_match_uses_fallback_after_flag_not_found(): + missing_result = FlagResolutionDetails( + value=False, + reason=Reason.ERROR, + error_code=ErrorCode.FLAG_NOT_FOUND, + error_message="missing", + ) + first_provider = BooleanProvider("first", boolean_result=missing_result) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.value is True + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +def test_first_match_stops_on_non_flag_not_found_error(): + error_result = FlagResolutionDetails( + value=False, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message="boom", + ) + first_provider = BooleanProvider("first", boolean_result=error_result) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.error_code == ErrorCode.GENERAL + assert second_provider.resolveCount == 0 + + +def test_first_successful_skips_general_errors(): + first_provider = BooleanProvider("first", boolean_exception=GeneralError("broken")) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.value is True + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +def test_first_successful_aggregates_errors_when_all_providers_fail(): + first_provider = BooleanProvider("first", boolean_exception=GeneralError("first")) + second_provider = BooleanProvider( + "second", boolean_exception=GeneralError("second") + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.error_code == ErrorCode.GENERAL + assert "first: GENERAL (first)" in result.error_message + assert "second: GENERAL (second)" in result.error_message + + +def test_comparison_strategy_returns_fallback_value_and_calls_on_mismatch(): + mismatch_spy = MagicMock() + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=ComparisonStrategy( + fallback_provider="second", + on_mismatch=mismatch_spy, + ), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.value is True + mismatch_spy.assert_called_once() + + +def test_comparison_strategy_aggregates_provider_errors(): + first_provider = BooleanProvider("first", boolean_exception=GeneralError("first")) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=ComparisonStrategy(), + ) + + result = multi_provider.resolve_boolean_details("flagKey", False) + + assert result.error_code == ErrorCode.GENERAL + assert "first: GENERAL (first)" in result.error_message + + +def test_multi_provider_runs_sync_parallel_evaluation(): + sync_blocker = SyncBlocker(expected_count=2) + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + sync_blocker=sync_blocker, + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + sync_blocker=sync_blocker, + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(run_mode="parallel"), + ) + + result_holder: list[FlagResolutionDetails[bool]] = [] + + def evaluate() -> None: + result_holder.append(multi_provider.resolve_boolean_details("flagKey", False)) + + worker_thread = threading.Thread(target=evaluate) + worker_thread.start() + + assert sync_blocker.enteredEvent.wait(timeout=2) + sync_blocker.releaseEvent.set() + worker_thread.join(timeout=2) + + assert result_holder[0].value is False + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +@pytest.mark.asyncio +async def test_multi_provider_runs_async_parallel_evaluation(): + async_blocker = AsyncBlocker(expected_count=2) + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + async_blocker=async_blocker, + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + async_blocker=async_blocker, + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstSuccessfulStrategy(run_mode="parallel"), + ) + + evaluation_task = asyncio.create_task( + multi_provider.resolve_boolean_details_async("flagKey", False) + ) + + await asyncio.wait_for(async_blocker.enteredEvent.wait(), timeout=2) + async_blocker.releaseEvent.set() + result = await asyncio.wait_for(evaluation_task, timeout=2) + + assert result.value is False + assert first_provider.resolveCount == 1 + assert second_provider.resolveCount == 1 + + +def test_multi_provider_isolates_provider_hooks_and_runs_lifecycle(): + first_hook = RecordingHook("first") + second_hook = RecordingHook("second") + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails( + value=False, + reason=Reason.ERROR, + error_code=ErrorCode.FLAG_NOT_FOUND, + error_message="missing", + ), + hook_list=[first_hook], + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + hook_list=[second_hook], + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + api.set_provider(multi_provider) + client = api.get_client() + result = client.get_boolean_details( + "flagKey", + False, + evaluation_context=EvaluationContext(attributes={"base": "value"}), + ) + + assert result.value is True + assert first_hook.events == ["before", "error", "finally"] + assert second_hook.events == ["before", "after", "finally"] + assert first_provider.seenContexts[0]["base"] == "value" + assert first_provider.seenContexts[0]["hookOwner"] == "first" + assert second_provider.seenContexts[0]["base"] == "value" + assert second_provider.seenContexts[0]["hookOwner"] == "second" + + +def test_multi_provider_does_not_run_unused_provider_hooks(): + first_hook = RecordingHook("first") + second_hook = RecordingHook("second") + first_provider = BooleanProvider( + "first", + boolean_result=FlagResolutionDetails(value=True, reason=Reason.STATIC), + hook_list=[first_hook], + ) + second_provider = BooleanProvider( + "second", + boolean_result=FlagResolutionDetails(value=False, reason=Reason.STATIC), + hook_list=[second_hook], + ) + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ], + strategy=FirstMatchStrategy(), + ) + + api.set_provider(multi_provider) + client = api.get_client() + result = client.get_boolean_details("flagKey", False) + + assert result.value is True + assert first_hook.events == ["before", "after", "finally"] + assert second_hook.events == [] + + +def test_multi_provider_aggregates_status_and_deduplicates_events(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ] + ) + + api.set_provider(multi_provider) + client = api.get_client() + spy = MagicMock() + client.add_handler(ProviderEvent.PROVIDER_READY, spy.provider_ready) + client.add_handler(ProviderEvent.PROVIDER_ERROR, spy.provider_error) + client.add_handler(ProviderEvent.PROVIDER_STALE, spy.provider_stale) + spy.provider_ready.reset_mock() + + first_provider.emit_provider_stale(ProviderEventDetails(message="stale")) + assert client.get_provider_status() == ProviderStatus.STALE + assert spy.provider_stale.call_count == 1 + + second_provider.emit_provider_stale(ProviderEventDetails(message="still stale")) + assert client.get_provider_status() == ProviderStatus.STALE + assert spy.provider_stale.call_count == 1 + + first_provider.emit_provider_error( + ProviderEventDetails(error_code=ErrorCode.GENERAL, message="error") + ) + assert client.get_provider_status() == ProviderStatus.ERROR + assert spy.provider_error.call_count == 1 + + second_provider.emit_provider_error( + ProviderEventDetails(error_code=ErrorCode.PROVIDER_FATAL, message="fatal") + ) + assert client.get_provider_status() == ProviderStatus.FATAL + assert spy.provider_error.call_count == 2 + + second_provider.emit_provider_ready(ProviderEventDetails()) + assert client.get_provider_status() == ProviderStatus.ERROR + assert spy.provider_error.call_count == 3 + + first_provider.emit_provider_ready(ProviderEventDetails()) + assert client.get_provider_status() == ProviderStatus.READY + assert spy.provider_ready.call_count == 1 + + +def test_multi_provider_forwards_configuration_changed_events(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ] + ) + + api.set_provider(multi_provider) + client = api.get_client() + spy = MagicMock() + client.add_handler( + ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, + spy.provider_configuration_changed, + ) + + first_provider.emit_provider_configuration_changed( + ProviderEventDetails(message="one") + ) + second_provider.emit_provider_configuration_changed( + ProviderEventDetails(message="two") + ) + + assert spy.provider_configuration_changed.call_count == 2 + + +def test_multi_provider_reports_not_ready_after_shutdown(): + first_provider = BooleanProvider("first") + second_provider = BooleanProvider("second") + multi_provider = MultiProvider( + [ + ProviderEntry(first_provider, name="first"), + ProviderEntry(second_provider, name="second"), + ] + ) + + api.set_provider(multi_provider) + client = api.get_client() + + api.shutdown() + + assert client.get_provider_status() == ProviderStatus.NOT_READY diff --git a/uv.lock b/uv.lock index 5b5ed9ab..cd187edc 100644 --- a/uv.lock +++ b/uv.lock @@ -259,12 +259,12 @@ dev = [ [package.metadata.requires-dev] dev = [ - { name = "behave" }, - { name = "coverage", extras = ["toml"], specifier = ">=6.5" }, - { name = "poethepoet", specifier = ">=0.40.0" }, + { name = "behave", specifier = ">=1.3.0,<2.0.0" }, + { name = "coverage", extras = ["toml"], specifier = ">=7.10.0,<8.0.0" }, + { name = "poethepoet", specifier = ">=0.40.0,<1.0.0" }, { name = "pre-commit" }, - { name = "pytest", specifier = ">=9.0.0" }, - { name = "pytest-asyncio", specifier = ">=1.3.0" }, + { name = "pytest", specifier = ">=9.0.0,<10.0.0" }, + { name = "pytest-asyncio", specifier = ">=1.3.0,<2.0.0" }, ] [[package]]