diff --git a/AGENTS.md b/AGENTS.md index cb4732c..21fc361 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -61,4 +61,16 @@ Strong success criteria let you loop independently. Weak criteria ("make it work --- -**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes. \ No newline at end of file +**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes. + +## 5. Specifics for this project. + +All code interactions are done through uv. Please use `uv run` to run python. + +### Running ruff, mypy, and pytest + +All three tools are pre-configured in `pyproject.toml` and can be run without extra arguments: + +- **ruff**: `uv run ruff check` (excludes `docs/`, configured in `[tool.ruff]`) +- **mypy**: `uv run mypy` (targets `src/`, configured in `[tool.mypy]`) +- **pytest**: `uv run pytest` (targets `test/`, configured in `[tool.pytest.ini_options]`) diff --git a/protocol_corrections_architecture.md b/protocol_corrections_architecture.md new file mode 100644 index 0000000..70726b9 --- /dev/null +++ b/protocol_corrections_architecture.md @@ -0,0 +1,619 @@ +# Protocol Corrections Architecture + +> **For implementors:** Read this file alongside +> `CQEDToolbox/src/cqedtoolbox/protocols/operations/single_qubit/res_spec.py`, +> which is the canonical reference implementation. The doc describes the API; +> `res_spec.py` shows a complete, real migration including `CorrectionParameter` +> subclasses with `_qick_getter`/`_qick_setter` bodies, `Correction` subclasses +> with full state tracking, and `__init__` wiring. + +## Background + +The protocol system (`src/labcore/protocols/base.py`) orchestrates multi-step lab +measurements. Each `ProtocolOperation` runs a fixed workflow: + +``` +measure() → load_data() → analyze() → evaluate() → correct() +``` + +Before this change, `evaluate()` did two things: assessed results **and** mutated +hardware parameters. The retry mechanism was blunt — just re-run the same operation +with the same settings. + +## What Changed + +### 1. Separated concerns across `evaluate()` and `correct()` + +| Method | Responsibility | +|---|---| +| `evaluate()` | **Pure assessment.** Returns named check results + overall status. No side effects. | +| `correct()` | **Only place parameters are changed.** Applies found values on success, corrective actions on retry. | + +`correct()` is always called inside `execute()` after `evaluate()`. Its return value +(an `EvaluateResult`) is what the protocol executor sees. + +### 2. New types + +#### `CheckResult` +```python +@dataclass +class CheckResult: + name: str # e.g. "snr_check", "peak_exists" + passed: bool + description: str # e.g. "SNR=1.5, threshold=2.0" +``` + +#### `EvaluateResult` +```python +@dataclass +class EvaluateResult: + status: OperationStatus # SUCCESS / RETRY / FAILURE + checks: list[CheckResult] = [] # named check outcomes +``` +Return type for both `evaluate()` and `correct()`. + +#### `Correction` +```python +class Correction: + name: str = "" + description: str = "" + triggered_by: str = "" # name of the CheckResult that triggers this + + def can_apply(self) -> bool: + """Return False when strategy is exhausted → correct() escalates to FAILURE.""" + return True + + def apply(self) -> None: + """Apply the correction in-place. Called before the next retry attempt.""" + raise NotImplementedError + + def report_output(self) -> str: + """Optional. Return a human-readable description of what apply() just changed. + Called by correct() after apply() and appended to the correction log line.""" + return "" +``` + +Subclass this for each corrective strategy. One **instance per operation**, created +in `__init__` and reused across retries so stateful strategies (e.g. stepping +through a frequency list) work correctly. + +`report_output()` is called after `apply()` and appended inline to the correction +log entry as `| **Change:** `. Implement it to surface the actual values +that changed (before → after), which makes the HTML report self-explanatory. + +**Example:** +```python +class FrequencySweepCorrection(Correction): + name = "scan_next_frequency_window" + description = "Step through candidate frequency windows until a peak is found" + triggered_by = "peak_exists" + + def __init__(self, freq_center_param, windows: list[float]): + self.freq_center_param = freq_center_param + self.windows = windows + self._idx = 0 + self._last_change: str = "" + + def can_apply(self) -> bool: + return self._idx < len(self.windows) + + def apply(self) -> None: + old = self.freq_center_param() + new = self.windows[self._idx] + self.freq_center_param(new) + self._idx += 1 + self._last_change = f"center: {old * 1e-9:.4f} → {new * 1e-9:.4f} GHz" + + def report_output(self) -> str: + return self._last_change +``` + +#### `CorrectionParameter` +```python +class CorrectionParameter(ProtocolParameterBase): + is_correction: ClassVar[bool] = True + # Skips hardware params validation in __post_init__ + # Otherwise identical to ProtocolParameterBase — same callable interface, + # same platform-specific getter/setter pattern for unit differences. +``` + +Used for parameters that control correction strategy (window sizes, step counts, +noise tolerances) rather than actual hardware state. Subclass exactly like +`ProtocolParameterBase`. + +**Important:** `CorrectionParameter` subclasses **must** be decorated with `@dataclass` +so that `name` and `description` fields with `init=False` defaults are resolved +correctly. Without `@dataclass` the fields are not processed and the class will not +instantiate correctly. + +```python +@dataclass # required +class MyThreshold(CorrectionParameter): + name: str = field(default="my_threshold", init=False) + description: str = field(default="...", init=False) + + def _qick_getter(self): return self.params.corrections.my_op.threshold() + def _qick_setter(self, v): self.params.corrections.my_op.threshold(v) +``` + +--- + +## Registration API + +Operations can use a registration-based path (covers most cases) or override +`evaluate()` / `correct()` directly for complex logic. + +### Registering checks + +```python +# In __init__: +self._register_check( + name="snr_check", + check_func=self._check_snr, + correction=self._snr_correction, # single Correction, or list[Correction], or None +) +self._register_check( + name="peak_exists", + check_func=self._check_peak, + correction=[self._freq_correction, self._fallback_correction], # fallback chain +) +``` + +The `correction` argument accepts: +- `None` — no correction; failed check → immediate FAILURE +- A single `Correction` instance — normalized to a list of one internally +- A `list[Correction]` — tried in order on each retry; first where `can_apply()` is True is used + +**Default `evaluate()`** runs all registered checks: +- All pass → `EvaluateResult(SUCCESS, checks)` +- Any fail → `EvaluateResult(RETRY, checks)` + +**Default `correct()`**: +- Appends a check summary table to `report_output` +- If the operation has any `figure_paths`, appends `figure_paths[-1]` to `report_output` immediately after the table (so the plot appears below the check results in the HTML report). No override needed — this is automatic. +- On RETRY: for each failed check, finds the **first** registered `Correction` where `can_apply()` is True: + - No corrections registered → returns `EvaluateResult(FAILURE, checks)` + - All corrections exhausted → returns `EvaluateResult(FAILURE, checks)` + - Otherwise → calls `apply()`, then `report_output()`, and logs both to `report_output` +- On SUCCESS: applies all registered success updates (see below) +- On FAILURE: no-op + +### Registering success updates + +```python +# In __init__: +self._register_success_update( + param=self.frequency, + value_func=lambda: self.peak_freq, # called lazily at correct() time +) +``` + +On SUCCESS, `correct()` calls each registered `value_func`, writes the result to `param`, +records a `ParamImprovement`, and appends a line to `report_output`. Multiple updates are +applied in registration order. + +`value_func` is called lazily so it can safely reference attributes set during `analyze()` +(e.g. `self.fit_result`). + +`self.improvements` is reset to `[]` at the start of each `execute()` call, so it always +reflects only the current attempt. + +### Registering correction parameters + +```python +# In __init__: +self._register_correction_params( + window_size=WindowSizeParam(params), + max_steps=MaxStepsParam(params), +) +``` + +Stored in `self.correction_params`. Excluded from `verify_all_parameters()` (no +hardware to check). Accessible as attributes: `self.window_size()`. + +--- + +## Complete operation pattern + +```python +class FindResonatorOperation(ProtocolOperation): + SNR_THRESHOLD = 2.0 + + def __init__(self, params=None): + super().__init__() + self._register_inputs(center=ResonatorCenter(params)) + self._register_outputs(frequency=ResonatorFrequency(params)) + + # Correction strategies — persist across retries + self._freq_sweep = FrequencySweepCorrection( + freq_center_param=self.center, + windows=[5.0e9, 5.5e9, 6.0e9, 6.5e9], + ) + self._fallback_sweep = WideSweepCorrection(self.center) + self._increase_avg = IncreaseAveragingCorrection(self.averages) + + # Register checks → corrections (list = fallback chain) + self._register_check("peak_exists", self._check_peak, + [self._freq_sweep, self._fallback_sweep]) + self._register_check("snr_check", self._check_snr, self._increase_avg) + + # On success, write the found frequency automatically + self._register_success_update(self.frequency, lambda: self.peak_freq) + + # Correction strategy parameters (platform-aware knobs) + self._register_correction_params( + window_size=FrequencyWindowSize(params), + ) + + self.peak_freq: float | None = None + self.snr: float | None = None + + # --- platform-specific measurement (implement for QICK / OPX) --- + def _measure_dummy(self) -> Path: ... + def _load_data_dummy(self) -> None: ... + + def analyze(self) -> None: + # detect peaks, compute SNR — no param mutations here + ... + + # --- checks (pure assessment) --- + def _check_peak(self) -> CheckResult: + passed = self.peak_freq is not None + return CheckResult("peak_exists", passed, + f"{'peak at ' + str(self.peak_freq) if passed else 'no peak detected'}") + + def _check_snr(self) -> CheckResult: + snr = self.snr or 0.0 + passed = snr >= self.SNR_THRESHOLD + return CheckResult("snr_check", passed, + f"SNR={snr:.2f}, threshold={self.SNR_THRESHOLD}") + + # No correct() override needed — base class handles: + # RETRY → applies first applicable correction per failed check + # SUCCESS → writes self.frequency via _register_success_update + # + # Override correct() only for custom report messages or additional logic. +``` + +If extra reporting is needed on SUCCESS, override `correct()` and call `super()` first: + +```python +def correct(self, result: EvaluateResult) -> EvaluateResult: + result = super().correct(result) # check table + corrections + success updates + if result.status == OperationStatus.SUCCESS: + self.report_output.append( + f"Resonator found at {self.peak_freq:.3e} Hz (SNR={self.snr:.2f})\n" + ) + return result +``` + +### Custom report layouts with multiple figures + +The default `correct()` auto-appends `figure_paths[-1]` immediately after the check table. +This works for simple operations with one plot. For operations that produce several named +figures (e.g. colorbar, per-trace plots, summary plot) and need a specific report order, +pop the named figures out of `figure_paths` **before** calling `super()`, then clear the +list so the auto-append has nothing to fire on. + +```python +def correct(self, result: EvaluateResult) -> EvaluateResult: + # Pull named figures out before super() can auto-append the last one. + # figure_paths order after analyze(): [0]=colorbar, [1..N-1]=traces, [-2]=snr_plot, [-1]=summary + colorbar = self.figure_paths.pop(0) if len(self.figure_paths) >= 3 else None + summary_plot = self.figure_paths.pop(-1) if self.figure_paths else None + snr_plot = self.figure_paths.pop(-1) if self.figure_paths else None + trace_figures = list(self.figure_paths) + self.figure_paths.clear() # prevent auto-append + + # Build header and main plots first + self.report_output.extend([ + "## My Operation\n...\n", + "**Colorbar:**\n", colorbar, + "**Summary:**\n", summary_plot, + "**SNR plot:**\n", snr_plot, + ]) + + result = super().correct(result) # adds check table; no auto-figure since list is empty + + if result.status == OperationStatus.SUCCESS: + self.report_output.append("### Per-trace results\n") + for fig in trace_figures: + self.report_output.append(fig) + + return result +``` + +Key points: +- `report_output` is a plain `list`. Append strings (markdown) or `Path` objects (figures) in any order. +- Pop figures in reverse order from the end to avoid index shifting. +- Call `super().correct()` after building the preamble so the check table appears below the plots. + +--- + +## `SuperOperationBase` changes + +- Sub-operations call their own `correct()` internally (inside `execute()`). +- `SuperOperationBase.execute()` now returns `EvaluateResult`. +- `SuperOperationBase` has its own `correct()` — default is a no-op. Override for + super-level parameter changes. + +--- + +## Exported symbols (`protocols/__init__.py`) + +New exports added: +- `CheckResult` +- `Correction` +- `CorrectionParameter` +- `EvaluateResult` + +--- + +## Dummy package additions + +| File | Addition | +|---|---| +| `parameters.py` | `_DummyCorrectionParameterBase(CorrectionParameter)` — in-memory correction params | +| All 6 operation files | `evaluate()` returns `EvaluateResult`; parameter updates moved to `correct()` | +| `dummy_protocol.py` | `DummySuperOperation.evaluate()` returns `EvaluateResult` | + +--- + +## `_DummyCorrectionParameterBase` pattern + +```python +@dataclass +class _DummyCorrectionParameterBase(CorrectionParameter): + def __post_init__(self): + super().__post_init__() + self._value: float = 0.0 + + def _dummy_getter(self) -> float: + return self._value + + def _dummy_setter(self, v: float) -> None: + self._value = v + +# Concrete correction parameter: +@dataclass +class ResonatorWindowSize(_DummyCorrectionParameterBase): + name: str = field(default="resonator_window_size", init=False) + description: str = field(default="Frequency search window width (Hz)", init=False) +``` + +--- + +## Migrating an existing operation + +Follow these steps to convert an operation that has an old-style `evaluate()` that both assesses and mutates parameters. + +### Step 1 — Split `evaluate()` into check methods + +Each condition that was tested in `evaluate()` becomes a `_check_*` method returning a `CheckResult`. Keep it pure — no side effects. + +```python +# Before +def evaluate(self): + if self.snr < THRESHOLD: + return OperationStatus.FAILURE + self.readout_freq(self.fit_result.params["f_0"].value) + return OperationStatus.SUCCESS + +# After +def _check_snr(self) -> CheckResult: + passed = self.snr >= self.snr_threshold() + return CheckResult("snr_check", passed, f"SNR={self.snr:.3f}, threshold={self.snr_threshold():.3f}") +``` + +### Step 2 — Define `Correction` classes for each failure mode + +Each way you'd retry the measurement becomes a `Correction` subclass. Make it stateful so it steps through options across retries. Implement `report_output()` to log what changed. + +```python +class MyCorrection(Correction): + name = "my_correction" + description = "Short description of what this does" + + def __init__(self, param): + self.param = param + self._count = 0 + self._last_change = "" + + def can_apply(self) -> bool: + return self._count < MAX + + def apply(self) -> None: + old = self.param() + new = compute_new_value(old, self._count) + self.param(new) + self._count += 1 + self._last_change = f"{old} → {new}" + + def report_output(self) -> str: + return self._last_change +``` + +### Step 3 — Define `CorrectionParameter` classes for configurable knobs + +Any threshold or limit that should be adjustable from the parameter manager becomes a `CorrectionParameter`. Always add `@dataclass`. Parameters live under `params.corrections..` by convention. + +```python +@dataclass +class MyThreshold(CorrectionParameter): + name: str = field(default="my_op_threshold", init=False) + description: str = field(default="...", init=False) + + def _qick_getter(self): return self.params.corrections.my_op.threshold() + def _qick_setter(self, v): self.params.corrections.my_op.threshold(v) +``` + +Then add the parameter to the instrument server. Connect via the instrumentserver client and call +`add_parameter` for each correction parameter: + +```python +from instrumentserver.client.proxy import Client + +c = Client() +params = c.get_instrument("parameter_manager") + +params.add_parameter("corrections.my_op.threshold", initial_value=2.0, unit="") +params.add_parameter("corrections.my_op.max_steps", initial_value=3, unit="") +``` + +The `unit` argument is required; use `unit=""` for dimensionless quantities. + +### Step 4 — Wire everything up in `__init__` + +```python +def __init__(self, params): + super().__init__() + self._register_inputs(...) + self._register_outputs(...) + + # 1. Register correction parameters first (they become self.* attributes) + self._register_correction_params( + my_threshold=MyThreshold(params), + ) + + # 2. Create correction instances (can now reference self.my_threshold) + self._my_correction = MyCorrection(self.some_param, self.my_threshold) + + # 3. Register checks with their correction fallback chains + self._register_check("snr_check", self._check_snr, self._my_correction) + + # 4. Register what to write on success + self._register_success_update( + self.output_param, + lambda: self.fit_result.params["f_0"].value, + ) +``` + +### Step 5 — Move parameter mutations out of `evaluate()` + +Delete the old `evaluate()` — the default implementation now handles everything via +registered checks. Delete any `correct()` override that only applied found values — +`_register_success_update` handles that too. + +Only keep a `correct()` override if you need custom report messages on SUCCESS/FAILURE: + +```python +def correct(self, result: EvaluateResult) -> EvaluateResult: + result = super().correct(result) # always call super first + if result.status == OperationStatus.SUCCESS: + self.report_output.append(f"Found frequency: {self.fit_result.params['f_0'].value:.6e} Hz\n") + return result +``` + +--- + +## Multi-level correction hierarchy + +When a single check has multiple levels of corrective action (fast cheap fixes first, +slow expensive fixes last), pass a list to `_register_check`. The list is a **fallback +chain**: the first correction where `can_apply()` is True is used on each retry. + +The pattern used in `ResonatorSpectroscopy`: + +``` +Level 0: WindowShiftCorrection — shift measurement window ±1, ±2, ... spans +Level 1: IncreaseSamplingRate — increase frequency steps × factor, reset window +Level 2: IncreaseAveraging — increase repetitions × factor, reset window +``` + +Key design points: +- **Level 0** is tried first on every retry until exhausted (all ±n shifts attempted). +- **Level 1** fires only when Level 0 is exhausted. Its `apply()` increases steps + AND calls `window_correction.reset()` so Level 0 starts over with the new settings. +- **Level 2** fires only when Level 1 is also exhausted. Same reset pattern. +- Higher-level corrections hold a reference to lower-level ones and call `reset()` on them. + +```python +self._window_shift = WindowShiftCorrection( + self.start_frequency, self.end_frequency, self.max_window_shifts +) +self._increase_sampling = IncreaseSamplingRateCorrection( + self.steps, self._window_shift, self.sampling_factor, self.max_sampling_increases +) +self._increase_averaging = IncreaseAveragingCorrection( + self.repetitions, self._window_shift, self.averaging_factor, self.max_averaging_increases +) + +self._register_check( + "quality_check", + self._check_quality, + [self._window_shift, self._increase_sampling, self._increase_averaging], +) +``` + +With defaults of `max_window_shifts=3`, `max_sampling_increases=2`, `max_averaging_increases=2`, +this gives up to 30 retries before FAILURE (6 window shifts × 5 sampling/averaging levels). + +--- + +## Multi-criteria quality checks + +A single `CheckResult` can combine multiple independent criteria. The convention is to +build a list of failures and join them into the description so the report is specific: + +```python +def _check_quality(self) -> CheckResult: + threshold = self.snr_threshold() + snr_passed = self.snr >= threshold + + max_error = self.max_fit_param_error() # e.g. 1.0 = 100% + bad_params = [] + for pname, param in self.fit_result.params.items(): + if param.stderr is None: + bad_params.append(f"{pname}(no stderr)") + elif param.value == 0 or abs(param.stderr / param.value) > max_error: + pct = abs(param.stderr / param.value) * 100 if param.value != 0 else float("inf") + bad_params.append(f"{pname}({pct:.0f}%)") + + passed = snr_passed and len(bad_params) == 0 + parts = [f"SNR={self.snr:.3f} (threshold={threshold:.3f})"] + if bad_params: + parts.append(f"high-error params: {', '.join(bad_params)}") + return CheckResult("quality_check", passed, "; ".join(parts)) +``` + +The `max_fit_param_error` is a `CorrectionParameter` stored at +`params.corrections.res_spec.max_fit_param_error` (default `1.0`). Certain fit +parameters with known large uncertainties can be excluded by name before the loop. + +--- + +## What is NOT yet done + +- No new `CorrectionParameter` subclasses in the dummy package (the base class is + there; concrete examples should be added alongside real operations). +- The `_assemble_report()` HTML does not yet have a dedicated "Correction + Parameters" section — check tables appear in `report_output` via the default + `correct()`, but `correction_params` values are not rendered separately. +- Dummy operations have not yet been updated to use `_register_success_update` — + they still override `correct()` manually. That update is deferred. + +--- + +## Files changed + +### Initial corrections architecture +``` +src/labcore/protocols/base.py +src/labcore/protocols/__init__.py +src/labcore/testing/protocol_dummy/parameters.py +src/labcore/testing/protocol_dummy/gaussian.py +src/labcore/testing/protocol_dummy/cosine.py +src/labcore/testing/protocol_dummy/linear.py +src/labcore/testing/protocol_dummy/exponential.py +src/labcore/testing/protocol_dummy/exponential_decay.py +src/labcore/testing/protocol_dummy/exponentially_decaying_sine.py +src/labcore/testing/protocol_dummy/dummy_protocol.py +test/pytest/test_protocols.py +test/pytest/test_protocols_realistic.py +``` + +### Gap fixes (registration-based success updates + fallback corrections) +``` +src/labcore/protocols/base.py +test/pytest/test_protocols.py +``` diff --git a/src/labcore/analysis/fitfuncs/generic.py b/src/labcore/analysis/fitfuncs/generic.py index c14b86f..f3905d5 100644 --- a/src/labcore/analysis/fitfuncs/generic.py +++ b/src/labcore/analysis/fitfuncs/generic.py @@ -166,3 +166,23 @@ def guess( A = data[i_max] - of sigma = np.abs((coordinates[-1] - coordinates[0])) / 20 return dict(x0=x0, sigma=sigma, A=A, of=of) + + +class Lorentzian(Fit): + @staticmethod + def model( + coordinates: np.ndarray, x0: float, gamma: float, A: float, of: float + ) -> np.ndarray: + return A * (gamma**2) / ((coordinates - x0) ** 2 + gamma**2) + of + + @staticmethod + def guess( + coordinates: Union[Tuple[np.ndarray, ...], np.ndarray], data: np.ndarray + ) -> Dict[str, float]: + of = np.mean(data) + dev = data - of + i_max = np.argmax(np.abs(dev)) + x0 = coordinates[i_max] + A = data[i_max] - of + gamma = np.abs((coordinates[-1] - coordinates[0])) / 20 + return dict(x0=x0, gamma=gamma, A=A, of=of) diff --git a/src/labcore/protocols/__init__.py b/src/labcore/protocols/__init__.py index 50ceb59..29390b2 100644 --- a/src/labcore/protocols/__init__.py +++ b/src/labcore/protocols/__init__.py @@ -4,9 +4,21 @@ from labcore.protocols.base import ( BranchBase as BranchBase, ) +from labcore.protocols.base import ( + CheckResult as CheckResult, +) from labcore.protocols.base import ( Condition as Condition, ) +from labcore.protocols.base import ( + Correction as Correction, +) +from labcore.protocols.base import ( + CorrectionParameter as CorrectionParameter, +) +from labcore.protocols.base import ( + EvaluateResult as EvaluateResult, +) from labcore.protocols.base import ( OperationStatus as OperationStatus, ) diff --git a/src/labcore/protocols/base.py b/src/labcore/protocols/base.py index 44fdefe..c102c19 100644 --- a/src/labcore/protocols/base.py +++ b/src/labcore/protocols/base.py @@ -2,10 +2,10 @@ import base64 import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum, auto from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, ClassVar import markdown import numpy as np @@ -179,13 +179,114 @@ class ParamImprovement: param: ProtocolParameterBase +@dataclass +class CheckResult: + """Result of a single named check in evaluate().""" + + name: str + passed: bool + description: str + + +# TODO: Add a reset mechanism +class Correction: + """ + Base class for stateful correction strategies. + + Subclass this and attach an instance to the operation in ``__init__``. The + instance persists across retry attempts so that stateful strategies (e.g. + stepping through a list of frequency windows) work correctly. + + Example:: + + class FrequencySweepCorrection(Correction): + name = "scan_next_frequency_window" + description = "Step through candidate frequency windows" + triggered_by = "peak_exists" + + def __init__(self, freq_param, windows: list[float]): + self.freq_param = freq_param + self.windows = windows + self._idx = 0 + + def can_apply(self) -> bool: + return self._idx < len(self.windows) + + def apply(self) -> None: + self.freq_param(self.windows[self._idx]) + self._idx += 1 + """ + + name: str = "" + description: str = "" + triggered_by: str = "" + + def can_apply(self) -> bool: + """Return False when this strategy is exhausted, forcing FAILURE.""" + return True + + def apply(self) -> None: + """Apply the correction. Called before the next retry attempt.""" + raise NotImplementedError( + f"Correction '{self.__class__.__name__}' must implement apply()" + ) + + def report_output(self) -> str: + """Return a string describing what apply() just changed. Optional.""" + return "" + + +@dataclass +class EvaluateResult: + """Return type for ProtocolOperation.evaluate() and correct().""" + + status: OperationStatus + checks: list[CheckResult] = field(default_factory=list) + + +class CorrectionParameter(ProtocolParameterBase): + """ + Platform-aware correction strategy knob. + + Use this instead of ProtocolParameterBase for parameters that control + correction strategies (e.g. window size, step count, noise tolerance). + May have platform-specific units or scaling — implement the platform + getter/setter methods for any unit conversions needed. + + Subclass exactly like ProtocolParameterBase. + """ + + is_correction: ClassVar[bool] = True + + def __post_init__(self) -> None: + if self.platform_type is None: + self.platform_type = PLATFORMTYPE + + +@dataclass +class _RegisteredCheck: + """Internal: maps a check function to its optional correction strategies.""" + + name: str + check_func: Callable[[], CheckResult] + corrections: list[Correction] = field(default_factory=list) + + +@dataclass +class _RegisteredSuccessUpdate: + """Internal: param to write and callable that produces the new value, applied on SUCCESS.""" + + param: ProtocolParameterBase + value_func: Callable[[], Any] + + # TODO: How do we handle different saving for different scenarios? For example: # For the lab we use run_measurement, for something like the lccf it will be something different. # In the same way that if some other lab wants to run this, they might want to automatically save other stuff. class ProtocolOperation: """ """ - DEFAULT_MAX_ATTEMPTS = 3 # Default max retry attempts for operations + DEFAULT_MAX_ATTEMPTS = 100 # Default max retry attempts for operations def __init__(self) -> None: global PLATFORMTYPE @@ -214,6 +315,11 @@ def __init__(self) -> None: self.current_attempt: int = 0 self.total_attempts_made: int = 0 + # Check and correction registration + self._registered_checks: list[_RegisteredCheck] = [] + self._registered_success_updates: list[_RegisteredSuccessUpdate] = [] + self.correction_params: dict[str, CorrectionParameter] = {} + def _register_inputs(self, **kwargs: ProtocolParameterBase) -> None: """Register input parameters as both attributes and in the dictionary""" for name, param in kwargs.items(): @@ -226,6 +332,64 @@ def _register_outputs(self, **kwargs: ProtocolParameterBase) -> None: setattr(self, name, param) self.output_params[name] = param + def _register_correction_params(self, **kwargs: CorrectionParameter) -> None: + """Register correction strategy parameters as attributes and in the dictionary.""" + for name, param in kwargs.items(): + setattr(self, name, param) + self.correction_params[name] = param + + def _register_check( + self, + name: str, + check_func: Callable[[], CheckResult], + correction: Correction | list[Correction] | None = None, + ) -> None: + """ + Register a named check and its optional correction strategy (or strategies). + + When evaluate() uses the default implementation, it runs all registered + checks and returns SUCCESS if all pass, RETRY if any fail. When + correct() uses the default implementation, it calls the first applicable + correction for each failed check. + + Args: + name: Unique identifier for this check (used in reports and routing). + check_func: Callable returning a CheckResult — pure assessment, no side effects. + correction: Optional Correction instance or list of instances to try in order + when this check fails. The first one where can_apply() returns True is used. + """ + if correction is None: + corrections: list[Correction] = [] + elif isinstance(correction, list): + corrections = correction + else: + corrections = [correction] + self._registered_checks.append(_RegisteredCheck(name, check_func, corrections)) + + def _register_success_update( + self, + param: ProtocolParameterBase, + value_func: Callable[[], Any], + ) -> None: + """Register a param to write when the operation succeeds. + + value_func is called lazily at correct() time, so it can safely + reference instance attributes set during analyze() (e.g. self.fit_result). + Multiple updates are applied in registration order. + """ + self._registered_success_updates.append( + _RegisteredSuccessUpdate(param, value_func) + ) + + def _correction_for_check(self, check_name: str) -> Correction | None: + """Return the first applicable Correction for the given check name, or None.""" + for rc in self._registered_checks: + if rc.name == check_name: + for c in rc.corrections: + if c.can_apply(): + return c + return None + def _measure_qick(self) -> Path: raise NotImplementedError("QICK measurement not implemented") @@ -317,44 +481,132 @@ def load_data(self) -> bool: return self._verify_shape() - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: """ - Evaluate operation results and recommend next action. + Assess operation results and return named check outcomes. - Subclasses must implement custom logic based on their domain knowledge. + Default implementation runs all checks registered via _register_check(). + Returns SUCCESS if all checks pass, RETRY if any fail. + + Override for complex logic (conditional checks, custom status rules). + Overrides should still return an EvaluateResult with CheckResult objects + so that reports and the default correct() work correctly. Returns: - OperationStatus.SUCCESS: Proceed to next operation - OperationStatus.RETRY: Retry this operation (if attempts remain) - OperationStatus.FAILURE: Stop protocol execution + EvaluateResult containing the status and individual check outcomes. + """ + if self._registered_checks: + checks = [rc.check_func() for rc in self._registered_checks] + all_passed = all(c.passed for c in checks) + status = OperationStatus.SUCCESS if all_passed else OperationStatus.RETRY + return EvaluateResult(status, checks) + raise NotImplementedError( + "Subclasses must either register checks with _register_check() or implement evaluate()" + ) + + def correct(self, result: EvaluateResult) -> EvaluateResult: """ - raise NotImplementedError("Subclasses must implement evaluate()") + Apply parameter changes based on the evaluation result. - def execute(self) -> OperationStatus: + This is the only place where parameter values should be modified. + + Default implementation: + - Appends a check summary table to the report. + - On RETRY: applies registered corrections for each failed check. + If any correction's can_apply() returns False, escalates to FAILURE. + - On SUCCESS / FAILURE: no-op (override to apply found values on SUCCESS). + + Override to apply found values on SUCCESS or implement complex correction + logic. Call super().correct(result) first to get check reporting and + default RETRY correction handling, then handle SUCCESS. + + Args: + result: The EvaluateResult returned by evaluate(). + + Returns: + Possibly modified EvaluateResult (e.g. RETRY escalated to FAILURE + when a correction is exhausted). """ - Execute the full operation workflow: measure -> load_data -> analyze -> evaluate. + if result.checks: + rows = "\n".join( + f"| {c.name} | {'✓ PASS' if c.passed else '✗ FAIL'} | {c.description} |" + for c in result.checks + ) + table = ( + f"| Check | Result | Details |\n|-------|--------|----------|\n{rows}\n" + ) + self.report_output.append(table) + if self.figure_paths: + self.report_output.append(self.figure_paths[-1].resolve()) + + if result.status == OperationStatus.RETRY: + for check in result.checks: + if not check.passed: + correction = self._correction_for_check(check.name) + if correction is None: + # Distinguish: no corrections registered vs. all exhausted + registered = next( + ( + rc.corrections + for rc in self._registered_checks + if rc.name == check.name + ), + [], + ) + if registered: + msg = ( + f"**Correction exhausted:** all strategies for check " + f"`{check.name}` have been exhausted — operation failed.\n" + ) + else: + msg = f"**No correction registered for failed check `{check.name}` — operation failed.**\n" + logger.warning(msg.strip()) + self.report_output.append(msg) + return EvaluateResult(OperationStatus.FAILURE, result.checks) + correction.apply() + msg = f"**Correction applied:** `{correction.name}` — {correction.description}" + change = correction.report_output() + if change: + msg += f" | **Change:** {change}" + msg += "\n" + logger.info(msg.strip()) + self.report_output.append(msg) + + if ( + result.status == OperationStatus.SUCCESS + and self._registered_success_updates + ): + for upd in self._registered_success_updates: + old = upd.param() + new = upd.value_func() + upd.param(new) + self.improvements.append(ParamImprovement(old, new, upd.param)) + self.report_output.append( + f"{upd.param.name} updated: {old} → {new:.3f}\n" + ) - This method increments attempt counters and adds repetition headers to reports. + return result + + def execute(self) -> EvaluateResult: + """ + Execute the full operation workflow: measure -> load_data -> analyze -> evaluate -> correct. Returns: - OperationStatus from evaluate() method + EvaluateResult from correct(), which may differ from evaluate()'s result + (e.g. RETRY escalated to FAILURE when a correction is exhausted). """ - # Increment attempt counter + self.improvements = [] self.current_attempt += 1 self.total_attempts_made += 1 - # Add repetition header to report if this is a retry if self.current_attempt > 1: - repetition_header = f"### ATTEMPT {self.current_attempt}\n\n" - self.report_output.append(repetition_header) + self.report_output.append(f"### ATTEMPT {self.current_attempt}\n\n") - # Execute the four-step workflow self.measure() self.load_data() self.analyze() - status = self.evaluate() - - return status + eval_result = self.evaluate() + return self.correct(eval_result) class SuperOperationBase(ProtocolOperation): @@ -415,7 +667,7 @@ def _validate_operations(self) -> None: f"Only ProtocolOperation instances are allowed." ) - def execute(self) -> OperationStatus: + def execute(self) -> EvaluateResult: """ Execute all sub-operations in sequence and aggregate their reports. @@ -423,75 +675,63 @@ def execute(self) -> OperationStatus: all sub-operations instead of calling measure/load_data/analyze. Returns: - OperationStatus from the evaluate() method + EvaluateResult from correct() after all sub-operations complete. """ - # Validate operations before executing self._validate_operations() - # Increment attempt counter self.current_attempt += 1 self.total_attempts_made += 1 - # Add retry header if needed if self.current_attempt > 1: - repetition_header = f"### ATTEMPT {self.current_attempt}\n\n" - self.report_output.append(repetition_header) + self.report_output.append(f"### ATTEMPT {self.current_attempt}\n\n") - # Add SuperOperation header to report - header = f"## {self.name}\n\n" - self.report_output.append(header) + self.report_output.append(f"## {self.name}\n\n") - # Execute each sub-operation for i, op in enumerate(self.operations): logger.info( f" [{self.name}] Executing sub-operation {i + 1}/{len(self.operations)}: {op.name}" ) - # Execute the operation (measure -> load_data -> analyze -> evaluate) try: - status = op.execute() + result = op.execute() except Exception as e: logger.error( f" [{self.name}] Exception in sub-operation {op.name}: {e}" ) - # If a sub-operation fails, the SuperOperation fails - return OperationStatus.FAILURE + return EvaluateResult(OperationStatus.FAILURE) - # Aggregate the sub-operation's report output if op.report_output: - # Add sub-operation section header self.report_output.append(f"### {op.name}\n\n") - # Add all report items from the sub-operation self.report_output.extend(op.report_output) self.report_output.append("\n") - # Check sub-operation status - if status == OperationStatus.FAILURE: + if result.status == OperationStatus.FAILURE: logger.error( f" [{self.name}] Sub-operation {op.name} failed critically" ) - return OperationStatus.FAILURE - elif status == OperationStatus.RETRY: + return EvaluateResult(OperationStatus.FAILURE) + elif result.status == OperationStatus.RETRY: logger.warning( f" [{self.name}] Sub-operation {op.name} requested retry" ) - # Don't immediately fail - let evaluate() decide - # But we could track this for evaluation logic - elif status == OperationStatus.SUCCESS: + elif result.status == OperationStatus.SUCCESS: logger.info(f" [{self.name}] Sub-operation {op.name} succeeded") - # Aggregate figure paths if hasattr(op, "figure_paths"): self.figure_paths.extend(op.figure_paths) - - # Aggregate improvements if hasattr(op, "improvements"): self.improvements.extend(op.improvements) - # Call the subclass's evaluate() method to determine overall status - status = self.evaluate() + eval_result = self.evaluate() + return self.correct(eval_result) - return status + def correct(self, result: EvaluateResult) -> EvaluateResult: + """ + Apply super-operation level corrections. Override to implement + super-level parameter changes. Default: no-op (sub-operations handle + their own corrections inside their own execute() calls). + """ + return result def measure(self) -> Path: """Not used in SuperOperation - operations handle their own measurement""" @@ -515,6 +755,7 @@ def analyze(self) -> None: ) +# TODO: remove condition from the protocol. This simply should be all the checks. Have this reflect in the new report as well. class ProtocolBase: def __init__(self, report_path: Path = Path("")): @@ -596,13 +837,22 @@ def verify_all_parameters(self) -> bool: for op in all_ops: for param_name, param in op.input_params.items(): try: - param() # Use callable syntax to verify parameter access + val = param() + param(val) except Exception as e: failures[param.name] = e for param_name, param in op.output_params.items(): try: - param() # Use callable syntax to verify parameter access + val = param() + param(val) + except Exception as e: + failures[param.name] = e + + for param_name, param in op.correction_params.items(): + try: + val = param() + param(val) except Exception as e: failures[param.name] = e @@ -881,7 +1131,6 @@ def _assemble_report(self) -> Path: .section-content {{ overflow: hidden; transition: max-height 0.3s ease, opacity 0.3s ease; - max-height: 10000px; opacity: 1; }} @@ -891,18 +1140,31 @@ def _assemble_report(self) -> Path: }} @@ -930,43 +1192,39 @@ def _execute_operation(self, op: ProtocolOperation) -> bool: """ max_attempts = op.max_attempts - # Reset attempt counter for this operation op.current_attempt = 0 while op.current_attempt < max_attempts: - # Execute operation (it will increment current_attempt internally) try: - status = op.execute() + result = op.execute() except Exception as e: logger.error(f" Exception during {op.name}: {e}") return False - # Handle status - if status == OperationStatus.SUCCESS: + if result.status == OperationStatus.SUCCESS: logger.info(f" SUCCESS: {op.name} succeeded") return True - elif status == OperationStatus.RETRY: + elif result.status == OperationStatus.RETRY: if op.current_attempt < max_attempts: logger.warning( f" RETRY: {op.name} requesting retry (attempt {op.current_attempt}/{max_attempts})" ) - continue # Retry + continue else: logger.error( f" FAILURE: {op.name} exhausted {max_attempts} attempts" ) return False - elif status == OperationStatus.FAILURE: + elif result.status == OperationStatus.FAILURE: logger.error(f" FAILURE: {op.name} failed critically") return False else: - logger.error(f" Unknown status: {status}") + logger.error(f" Unknown status: {result.status}") return False - # Should not reach here return False def _execute_branch( diff --git a/src/labcore/testing/protocol_dummy/cosine.py b/src/labcore/testing/protocol_dummy/cosine.py index 14d97eb..0a973ff 100644 --- a/src/labcore/testing/protocol_dummy/cosine.py +++ b/src/labcore/testing/protocol_dummy/cosine.py @@ -12,7 +12,7 @@ from labcore.measurement import Sweep from labcore.measurement.record import dependent, independent, recording from labcore.measurement.storage import run_and_save_sweep -from labcore.protocols.base import OperationStatus, ParamImprovement, ProtocolOperation +from labcore.protocols.base import CheckResult, ProtocolOperation from labcore.testing.protocol_dummy.parameters import ( CosineAmplitude, CosineFrequency, @@ -44,6 +44,11 @@ def __init__(self, params: Any = None) -> None: self.condition = f"Success if the SNR of the Cosine fit is bigger than the current threshold of {self.SNR_THRESHOLD}" + self._register_check("snr_check", self._check_snr) + self._register_success_update( + self.amplitude, lambda: cast(FitResult, self.fit_result).params["A"].value + ) + self.independents = {"x_values": []} self.dependents = {"y_values": []} @@ -132,55 +137,28 @@ def analyze(self) -> None: image_path = ds._new_file_path(ds.savefolders[1], self.name, suffix="png") self.figure_paths.append(image_path) - def evaluate(self) -> OperationStatus: - """ - Evaluate if the fit was successful based on SNR threshold. - If successful, update the amplitude output parameter with the fitted amplitude value. - """ - header = ( - f"## Cosine - Amplitude Fit\n" - f"Generated fake Cosine data and fitted it to extract amplitude.\n" - f"Data Path: `{self.data_loc}`\n" - f"Plot:\n" - ) - plot_image = self.figure_paths[0].resolve() - - assert self.snr is not None - assert self.fit_result is not None - if self.snr >= self.SNR_THRESHOLD: - logger.info( - f"SNR of {self.snr} is bigger than threshold of {self.SNR_THRESHOLD}. Applying new values" + self.report_output.append( + f"## Cosine - Amplitude Fit\n" + f"Generated fake Cosine data and fitted it to extract amplitude.\n" + f"Data Path: `{self.data_loc}`\n" + f"Plot:\n" ) - - old_value = self.amplitude() - new_value = self.fit_result.params["A"].value - - logger.info( - f"Updating {self.amplitude.name} from {old_value} to {new_value}" + self.report_output.append(image_path.resolve()) + self.report_output.append( + f"**Fit Report:**\n```\n{self.fit_result.lmfit_result.fit_report()}\n```\n" ) - self.amplitude(new_value) - self.improvements = [ParamImprovement(old_value, new_value, self.amplitude)] - - msg_2 = ( - f"Fit was **SUCCESSFUL** with an SNR of {self.snr:.3f}.\n" - f"{self.amplitude.name} updated: {old_value} -> {new_value:.3f}\n\n" - f"**Fit Report:**\n```\n{str(self.fit_result.lmfit_result.fit_report())}\n```\n\n" + def _check_snr(self) -> CheckResult: + snr = self.snr or 0.0 + passed = snr >= self.SNR_THRESHOLD + if passed: + self.report_output.append( + f"Fit was **SUCCESSFUL** with an SNR of {snr:.3f}.\n" ) - - self.report_output = [header, plot_image, msg_2] - - return OperationStatus.SUCCESS - - logger.info( - f"SNR of {self.snr} is smaller than threshold of {self.SNR_THRESHOLD}. Evaluation failed" - ) - - msg_2 = ( - f"Fit was **UNSUCCESSFUL** with an SNR of {self.snr:.3f}.\n" - f"NO value has been changed.\n" - f"Fit Report:\n\n```\n{str(self.fit_result.lmfit_result.fit_report())}\n```\n" + else: + self.report_output.append( + f"Fit was **UNSUCCESSFUL** with an SNR of {snr:.3f}. NO value has been changed.\n" + ) + return CheckResult( + "snr_check", passed, f"SNR={snr:.3f}, threshold={self.SNR_THRESHOLD}" ) - self.report_output = [header, plot_image, msg_2] - - return OperationStatus.FAILURE diff --git a/src/labcore/testing/protocol_dummy/dummy_protocol.py b/src/labcore/testing/protocol_dummy/dummy_protocol.py index 7a74bec..5bf8d8d 100644 --- a/src/labcore/testing/protocol_dummy/dummy_protocol.py +++ b/src/labcore/testing/protocol_dummy/dummy_protocol.py @@ -5,6 +5,7 @@ from labcore.protocols.base import ( BranchBase, Condition, + EvaluateResult, OperationStatus, ProtocolBase, SuperOperationBase, @@ -47,7 +48,7 @@ def __init__(self, params: Any = None) -> None: # Configure retry behavior self.max_attempts = 3 # Will retry up to 3 times total - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: """ Evaluate the overall success of all sub-operations. Uses same retry testing mechanism as GaussianProtocol. @@ -61,10 +62,10 @@ def evaluate(self) -> OperationStatus: logger.info( f"[{self.name}] At {self.total_attempts_made} attempts, requesting retry for testing" ) - return OperationStatus.RETRY + return EvaluateResult(OperationStatus.RETRY) logger.info(f"[{self.name}] Reached 3 attempts, returning SUCCESS") - return OperationStatus.SUCCESS + return EvaluateResult(OperationStatus.SUCCESS) class DummyProtocol(ProtocolBase): diff --git a/src/labcore/testing/protocol_dummy/exponential.py b/src/labcore/testing/protocol_dummy/exponential.py index 8fe1dfc..3a0b399 100644 --- a/src/labcore/testing/protocol_dummy/exponential.py +++ b/src/labcore/testing/protocol_dummy/exponential.py @@ -12,7 +12,12 @@ from labcore.measurement.record import dependent, independent, recording from labcore.measurement.storage import run_and_save_sweep from labcore.measurement.sweep import Sweep -from labcore.protocols.base import OperationStatus, ParamImprovement, ProtocolOperation +from labcore.protocols.base import ( + EvaluateResult, + OperationStatus, + ParamImprovement, + ProtocolOperation, +) from labcore.testing.protocol_dummy.parameters import ExponentialA, ExponentialB plt.switch_backend("agg") @@ -122,7 +127,7 @@ def analyze(self) -> None: image_path = ds._new_file_path(ds.savefolders[1], self.name, suffix="png") self.figure_paths.append(image_path) - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: """ Evaluate if the fit was successful based on SNR threshold. If successful, update the 'a' output parameter with the fitted coefficient value. @@ -158,7 +163,7 @@ def evaluate(self) -> OperationStatus: self.report_output = [header, plot_image, msg_2] - return OperationStatus.SUCCESS + return EvaluateResult(OperationStatus.SUCCESS) logger.info( f"SNR of {self.snr} is smaller than threshold of {self.SNR_THRESHOLD}. Evaluation failed" @@ -171,4 +176,4 @@ def evaluate(self) -> OperationStatus: ) self.report_output = [header, plot_image, msg_2] - return OperationStatus.FAILURE + return EvaluateResult(OperationStatus.FAILURE) diff --git a/src/labcore/testing/protocol_dummy/exponential_decay.py b/src/labcore/testing/protocol_dummy/exponential_decay.py index 2aeccb2..2edacf4 100644 --- a/src/labcore/testing/protocol_dummy/exponential_decay.py +++ b/src/labcore/testing/protocol_dummy/exponential_decay.py @@ -12,7 +12,12 @@ from labcore.measurement.record import dependent, independent, recording from labcore.measurement.storage import run_and_save_sweep from labcore.measurement.sweep import Sweep -from labcore.protocols.base import OperationStatus, ParamImprovement, ProtocolOperation +from labcore.protocols.base import ( + EvaluateResult, + OperationStatus, + ParamImprovement, + ProtocolOperation, +) from labcore.testing.protocol_dummy.parameters import ( ExponentialDecayAmplitude, ExponentialDecayOffset, @@ -128,7 +133,7 @@ def analyze(self) -> None: image_path = ds._new_file_path(ds.savefolders[1], self.name, suffix="png") self.figure_paths.append(image_path) - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: """ Evaluate if the fit was successful based on SNR threshold. If successful, update the amplitude output parameter with the fitted amplitude value. @@ -166,7 +171,7 @@ def evaluate(self) -> OperationStatus: self.report_output = [header, plot_image, msg_2] - return OperationStatus.SUCCESS + return EvaluateResult(OperationStatus.SUCCESS) logger.info( f"SNR of {self.snr} is smaller than threshold of {self.SNR_THRESHOLD}. Evaluation failed" @@ -179,4 +184,4 @@ def evaluate(self) -> OperationStatus: ) self.report_output = [header, plot_image, msg_2] - return OperationStatus.FAILURE + return EvaluateResult(OperationStatus.FAILURE) diff --git a/src/labcore/testing/protocol_dummy/exponentially_decaying_sine.py b/src/labcore/testing/protocol_dummy/exponentially_decaying_sine.py index 9180ea2..cc3534f 100644 --- a/src/labcore/testing/protocol_dummy/exponentially_decaying_sine.py +++ b/src/labcore/testing/protocol_dummy/exponentially_decaying_sine.py @@ -12,7 +12,12 @@ from labcore.measurement.record import dependent, independent, recording from labcore.measurement.storage import run_and_save_sweep from labcore.measurement.sweep import Sweep -from labcore.protocols.base import OperationStatus, ParamImprovement, ProtocolOperation +from labcore.protocols.base import ( + EvaluateResult, + OperationStatus, + ParamImprovement, + ProtocolOperation, +) from labcore.testing.protocol_dummy.parameters import ( ExponentiallyDecayingSineAmplitude, ExponentiallyDecayingSineFrequency, @@ -144,7 +149,7 @@ def analyze(self) -> None: image_path = ds._new_file_path(ds.savefolders[1], self.name, suffix="png") self.figure_paths.append(image_path) - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: """ Evaluate if the fit was successful based on SNR threshold. If successful, update the amplitude output parameter with the fitted amplitude value. @@ -182,7 +187,7 @@ def evaluate(self) -> OperationStatus: self.report_output = [header, plot_image, msg_2] - return OperationStatus.SUCCESS + return EvaluateResult(OperationStatus.SUCCESS) logger.info( f"SNR of {self.snr} is smaller than threshold of {self.SNR_THRESHOLD}. Evaluation failed" @@ -195,4 +200,4 @@ def evaluate(self) -> OperationStatus: ) self.report_output = [header, plot_image, msg_2] - return OperationStatus.FAILURE + return EvaluateResult(OperationStatus.FAILURE) diff --git a/src/labcore/testing/protocol_dummy/gaussian.py b/src/labcore/testing/protocol_dummy/gaussian.py index ce871ba..9a8b239 100644 --- a/src/labcore/testing/protocol_dummy/gaussian.py +++ b/src/labcore/testing/protocol_dummy/gaussian.py @@ -12,7 +12,12 @@ from labcore.measurement.record import dependent, independent, recording from labcore.measurement.storage import run_and_save_sweep from labcore.measurement.sweep import Sweep -from labcore.protocols.base import OperationStatus, ParamImprovement, ProtocolOperation +from labcore.protocols.base import ( + EvaluateResult, + OperationStatus, + ParamImprovement, + ProtocolOperation, +) from labcore.testing.protocol_dummy.parameters import ( GaussianAmplitude, GaussianCenter, @@ -130,7 +135,7 @@ def analyze(self) -> None: image_path = ds._new_file_path(ds.savefolders[1], self.name, suffix="png") self.figure_paths.append(image_path) - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: """ Evaluate if the fit was successful based on SNR threshold. If successful, update the amplitude output parameter with the fitted amplitude value. @@ -173,9 +178,9 @@ def evaluate(self) -> OperationStatus: if self.total_attempts_made != 3: msg_3 = f"Protocol at {self.total_attempts_made} repetitions, repeating for testing." self.report_output.append(msg_3) - return OperationStatus.RETRY + return EvaluateResult(OperationStatus.RETRY) - return OperationStatus.SUCCESS + return EvaluateResult(OperationStatus.SUCCESS) logger.info( f"SNR of {self.snr} is smaller than threshold of {self.SNR_THRESHOLD}. Evaluation failed" @@ -188,4 +193,4 @@ def evaluate(self) -> OperationStatus: ) self.report_output = [header, plot_image, msg_2] - return OperationStatus.FAILURE + return EvaluateResult(OperationStatus.FAILURE) diff --git a/src/labcore/testing/protocol_dummy/linear.py b/src/labcore/testing/protocol_dummy/linear.py index 2ac6ab5..307cd3b 100644 --- a/src/labcore/testing/protocol_dummy/linear.py +++ b/src/labcore/testing/protocol_dummy/linear.py @@ -12,7 +12,12 @@ from labcore.measurement.record import dependent, independent, recording from labcore.measurement.storage import run_and_save_sweep from labcore.measurement.sweep import Sweep -from labcore.protocols.base import OperationStatus, ParamImprovement, ProtocolOperation +from labcore.protocols.base import ( + EvaluateResult, + OperationStatus, + ParamImprovement, + ProtocolOperation, +) from labcore.testing.protocol_dummy.parameters import LinearOffset, LinearSlope plt.switch_backend("agg") @@ -120,7 +125,7 @@ def analyze(self) -> None: image_path = ds._new_file_path(ds.savefolders[1], self.name, suffix="png") self.figure_paths.append(image_path) - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: """ Evaluate if the fit was successful based on SNR threshold. If successful, update the slope output parameter with the fitted slope value. @@ -156,7 +161,7 @@ def evaluate(self) -> OperationStatus: self.report_output = [header, plot_image, msg_2] - return OperationStatus.SUCCESS + return EvaluateResult(OperationStatus.SUCCESS) logger.info( f"SNR of {self.snr} is smaller than threshold of {self.SNR_THRESHOLD}. Evaluation failed" @@ -169,4 +174,4 @@ def evaluate(self) -> OperationStatus: ) self.report_output = [header, plot_image, msg_2] - return OperationStatus.FAILURE + return EvaluateResult(OperationStatus.FAILURE) diff --git a/test/pytest/test_protocols.py b/test/pytest/test_protocols.py index 1e18bed..6e58efc 100644 --- a/test/pytest/test_protocols.py +++ b/test/pytest/test_protocols.py @@ -17,7 +17,11 @@ import labcore.protocols.base as proto_base from labcore.protocols.base import ( BranchBase, + CheckResult, Condition, + Correction, + CorrectionParameter, + EvaluateResult, OperationStatus, ParamImprovement, PlatformTypes, @@ -80,9 +84,9 @@ def _load_data_dummy(self): def analyze(self): log.append("analyze") - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: log.append("evaluate") - return status + return EvaluateResult(status) return _Op(), log @@ -229,8 +233,8 @@ def test_verify_shape_cases(self, case, expected): class TestSuperOperationBase: def _make_super(self, sub_ops, evaluate_status=OperationStatus.SUCCESS): class _Super(SuperOperationBase): - def evaluate(self) -> OperationStatus: - return evaluate_status + def evaluate(self) -> EvaluateResult: + return EvaluateResult(evaluate_status) s = _Super() s.operations = sub_ops @@ -261,7 +265,7 @@ def test_execute_aggregates_sub_op_reports(self): def test_execute_returns_failure_on_sub_op_exception(self): class _BadOp(ProtocolOperation): - def execute(self) -> OperationStatus: + def execute(self) -> EvaluateResult: raise RuntimeError("boom") def _measure_dummy(self): @@ -273,33 +277,33 @@ def _load_data_dummy(self): def analyze(self): pass - def evaluate(self): - return OperationStatus.SUCCESS + def evaluate(self) -> EvaluateResult: + return EvaluateResult(OperationStatus.SUCCESS) s = self._make_super([_BadOp()]) result = s.execute() - assert result == OperationStatus.FAILURE + assert result.status == OperationStatus.FAILURE def test_execute_returns_failure_on_sub_op_failure(self): op, _ = make_simple_op(status=OperationStatus.FAILURE) s = self._make_super([op]) result = s.execute() - assert result == OperationStatus.FAILURE + assert result.status == OperationStatus.FAILURE def test_execute_calls_evaluate_at_end(self): called = [] class _Super(SuperOperationBase): - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: called.append(True) - return OperationStatus.SUCCESS + return EvaluateResult(OperationStatus.SUCCESS) op, _ = make_simple_op() s = _Super() s.operations = [op] result = s.execute() assert called == [True] - assert result == OperationStatus.SUCCESS + assert result.status == OperationStatus.SUCCESS # =========================================================================== @@ -450,14 +454,14 @@ def _load_data_dummy(self): def analyze(self): pass - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: old = self.result() new = old + 5 self.improvements.append( ParamImprovement(old_value=old, new_value=new, param=self.result) ) self.result(new) - return OperationStatus.SUCCESS + return EvaluateResult(OperationStatus.SUCCESS) op = _Op() proto = make_protocol([op], report_path=tmp_path) @@ -483,11 +487,11 @@ def _load_data_dummy(self): def analyze(self): pass - def evaluate(self) -> OperationStatus: + def evaluate(self) -> EvaluateResult: attempt["count"] += 1 if attempt["count"] < 3: - return OperationStatus.RETRY - return OperationStatus.SUCCESS + return EvaluateResult(OperationStatus.RETRY) + return EvaluateResult(OperationStatus.SUCCESS) op = _Op() op.max_attempts = 3 @@ -510,8 +514,8 @@ def _load_data_dummy(self): def analyze(self): pass - def evaluate(self) -> OperationStatus: - return OperationStatus.RETRY + def evaluate(self) -> EvaluateResult: + return EvaluateResult(OperationStatus.RETRY) op = _Op() op.max_attempts = 2 @@ -520,3 +524,275 @@ def evaluate(self) -> OperationStatus: assert proto.success is False assert op.total_attempts_made == 2 + + +# =========================================================================== +# 8. Success update registration +# =========================================================================== + + +def make_op_with_check(status: OperationStatus, check_name: str = "test_check"): + """Return an operation whose evaluate() returns a single named check result.""" + + class _Op(ProtocolOperation): + def _measure_dummy(self): + return Path(".") + + def _load_data_dummy(self): + pass + + def analyze(self): + pass + + def evaluate(self) -> EvaluateResult: + passed = status == OperationStatus.SUCCESS + return EvaluateResult(status, [CheckResult(check_name, passed, "stub")]) + + return _Op() + + +class TestSuccessUpdateRegistration: + def test_update_applied_on_success(self): + op = make_op_with_check(OperationStatus.SUCCESS) + param, store = make_param({"value": 0.0}) + op._register_success_update(param, lambda: 99.0) + op.correct(op.evaluate()) + assert store["value"] == 99.0 + assert len(op.improvements) == 1 + assert op.improvements[0].new_value == 99.0 + + def test_update_not_applied_on_retry(self): + op = make_op_with_check(OperationStatus.RETRY) + param, store = make_param({"value": 0.0}) + op._register_success_update(param, lambda: 99.0) + op._register_check("test_check", lambda: CheckResult("test_check", False, "")) + op.correct(op.evaluate()) + assert store["value"] == 0.0 + assert op.improvements == [] + + def test_multiple_updates_all_applied(self): + op = make_op_with_check(OperationStatus.SUCCESS) + param1, store1 = make_param({"value": 0.0}) + param2, store2 = make_param({"value": 0.0}) + op._register_success_update(param1, lambda: 1.0) + op._register_success_update(param2, lambda: 2.0) + op.correct(op.evaluate()) + assert store1["value"] == 1.0 + assert store2["value"] == 2.0 + assert len(op.improvements) == 2 + + def test_report_contains_param_name(self): + op = make_op_with_check(OperationStatus.SUCCESS) + param, _ = make_param() + op._register_success_update(param, lambda: 5.0) + op.correct(op.evaluate()) + combined = " ".join(str(s) for s in op.report_output) + assert param.name in combined + + +# =========================================================================== +# 9. Multiple / fallback corrections per check +# =========================================================================== + + +class _TrackingCorrection(Correction): + """Correction that records apply() calls and has a configurable can_apply().""" + + def __init__(self, can: bool = True): + self._can = can + self.applied = 0 + + def can_apply(self) -> bool: + return self._can + + def apply(self) -> None: + self.applied += 1 + + +def make_op_with_failing_check(check_name: str = "test_check"): + """Operation whose evaluate() always returns RETRY with a single failed check.""" + + class _Op(ProtocolOperation): + def _measure_dummy(self): + return Path(".") + + def _load_data_dummy(self): + pass + + def analyze(self): + pass + + def evaluate(self) -> EvaluateResult: + return EvaluateResult( + OperationStatus.RETRY, + [CheckResult(check_name, False, "stub")], + ) + + return _Op() + + +class TestMultipleFallbackCorrections: + def test_first_correction_applied_when_both_can_apply(self): + op = make_op_with_failing_check() + c1 = _TrackingCorrection(can=True) + c2 = _TrackingCorrection(can=True) + op._register_check( + "test_check", lambda: CheckResult("test_check", False, ""), [c1, c2] + ) + op.correct(op.evaluate()) + assert c1.applied == 1 + assert c2.applied == 0 + + def test_fallback_to_second_when_first_exhausted(self): + op = make_op_with_failing_check() + c1 = _TrackingCorrection(can=False) + c2 = _TrackingCorrection(can=True) + op._register_check( + "test_check", lambda: CheckResult("test_check", False, ""), [c1, c2] + ) + op.correct(op.evaluate()) + assert c1.applied == 0 + assert c2.applied == 1 + + def test_failure_when_all_exhausted(self): + op = make_op_with_failing_check() + c1 = _TrackingCorrection(can=False) + c2 = _TrackingCorrection(can=False) + op._register_check( + "test_check", lambda: CheckResult("test_check", False, ""), [c1, c2] + ) + result = op.correct(op.evaluate()) + assert result.status == OperationStatus.FAILURE + + def test_single_correction_backward_compat(self): + op = make_op_with_failing_check() + c = _TrackingCorrection() + op._register_check( + "test_check", lambda: CheckResult("test_check", False, ""), c + ) + assert op._registered_checks[0].corrections == [c] + + def test_list_stored_in_order(self): + op = make_op_with_failing_check() + c1 = _TrackingCorrection() + c2 = _TrackingCorrection() + op._register_check( + "test_check", lambda: CheckResult("test_check", False, ""), [c1, c2] + ) + assert op._registered_checks[0].corrections == [c1, c2] + + +# =========================================================================== +# 10. Default evaluate() using registered checks +# =========================================================================== + + +def make_op_with_registered_checks(passing: dict[str, bool]): + """Operation that uses _register_check() and relies on default evaluate().""" + + class _Op(ProtocolOperation): + def __init__(self): + super().__init__() + for name, should_pass in passing.items(): + self._register_check( + name, + lambda p=should_pass, n=name: CheckResult(n, p, f"stub:{p}"), + ) + + def _measure_dummy(self): + return Path(".") + + def _load_data_dummy(self): + pass + + def analyze(self): + pass + + return _Op() + + +class TestDefaultEvaluate: + def test_all_checks_pass_returns_success(self): + op = make_op_with_registered_checks({"a": True, "b": True}) + result = op.evaluate() + assert result.status == OperationStatus.SUCCESS + assert len(result.checks) == 2 + assert all(c.passed for c in result.checks) + + def test_any_check_fails_returns_retry(self): + op = make_op_with_registered_checks({"a": True, "b": False}) + result = op.evaluate() + assert result.status == OperationStatus.RETRY + + def test_check_names_match_registered(self): + op = make_op_with_registered_checks({"snr": True, "peak": False}) + result = op.evaluate() + names = [c.name for c in result.checks] + assert names == ["snr", "peak"] + + def test_no_correction_registered_escalates_to_failure(self): + """Failed check with correction=None → correct() returns FAILURE.""" + op = make_op_with_registered_checks({"peak": False}) + result = op.correct(op.evaluate()) + assert result.status == OperationStatus.FAILURE + + def test_check_table_appended_to_report(self): + op = make_op_with_registered_checks({"snr": True}) + op.correct(op.evaluate()) + combined = " ".join(op.report_output) + assert "snr" in combined + + def test_improvements_reset_on_each_execute(self): + store = {"value": 0.0} + op = make_op_with_registered_checks({"ok": True}) + param, _ = make_param(store) + op._register_success_update(param, lambda: 1.0) + op.execute() + assert len(op.improvements) == 1 + op.execute() + assert len(op.improvements) == 1 # reset, not accumulated + + +# =========================================================================== +# 11. CorrectionParameter +# =========================================================================== + + +def make_correction_param(): + @dataclass + class _CParam(CorrectionParameter): + name: str = field(default="window_size", init=False) + description: str = field(default="search window width", init=False) + + def _dummy_getter(self): + return self._value + + def _dummy_setter(self, v): + self._value = v + + def __post_init__(self): + super().__post_init__() + self._value = 0.0 + + return _CParam(params=None) + + +class TestCorrectionParameter: + def test_getter_setter(self): + p = make_correction_param() + p(42.0) + assert p() == 42.0 + + def test_registered_as_attribute(self): + op, _ = make_simple_op() + p = make_correction_param() + op._register_correction_params(win=p) + assert op.win is p + assert op.correction_params["win"] is p + + def test_included_in_verify_all_parameters(self, tmp_path): + """CorrectionParameter should be checked in verify_all_parameters().""" + op, _ = make_simple_op() + op._register_correction_params(win=make_correction_param()) + proto = make_protocol([op], report_path=tmp_path) + assert proto.verify_all_parameters() is True