Conversation
69ebbfa to
0144c44
Compare
5cda8d9 to
a12c11f
Compare
a12c11f to
7dbfde7
Compare
0bc8e21 to
9d285ce
Compare
9d285ce to
ffffafe
Compare
| ) | ||
| from baybe.surrogates.gaussian_process.components.mean import MeanFactoryProtocol | ||
| from baybe.surrogates.gaussian_process.presets import ( | ||
| GaussianProcessPreset, |
There was a problem hiding this comment.
GaussianProcessPreset no longer exists but is used in this file.
There was a problem hiding this comment.
Not sure what you mean 🤔 GaussianProcessPreset still exists and is available via this path. Can you elaborate?
7da186f to
be99459
Compare
73059a9 to
118e019
Compare
There was a problem hiding this comment.
should we better call this selectors.py instead of selector.py?
There was a problem hiding this comment.
Pull request overview
This PR introduces kernel dimension control by deriving GPyTorch kernel active_dims/ard_num_dims from a SearchSpace, adds parameter (sub)selection utilities, and updates GP surrogate behavior around multi-task kernels (including a deprecation guard for custom kernels).
Changes:
- Replace
Kernel.to_gpytorch(**kwargs)dimension arguments withKernel.to_gpytorch(searchspace)and addBasicKernel.parameter_namesto control which parameters a kernel operates on. - Add parameter selector primitives and update default GP kernel preset to use an explicit ICM-style (base-kernel × task-kernel) construction for multi-task search spaces.
- Update tests and changelog to reflect the new kernel translation API and the new multi-task custom-kernel deprecation behavior.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
baybe/kernels/base.py |
Adds searchspace-driven dimension resolution and whitelisting for BayBE-only attrs (e.g., parameter_names) during GPyTorch translation. |
baybe/parameters/selector.py |
Introduces a parameter selector protocol + implementations to support convenient parameter subselection. |
baybe/surrogates/gaussian_process/components/kernel.py |
Adds KernelFactory base (with parameter selector support) and introduces ICMKernelFactory for multi-task kernel composition. |
baybe/surrogates/gaussian_process/presets/baybe.py |
Makes the default kernel preset multi-task aware by dispatching to an ICM kernel factory and defining a default task kernel factory. |
baybe/surrogates/gaussian_process/presets/edbo.py / edbo_smoothed.py |
Wires kernel factories to pass selected parameter_names into base kernels; updates likelihood factory typing. |
baybe/surrogates/gaussian_process/core.py |
Updates surrogate to call kernel.to_gpytorch(searchspace=...) and adds a DeprecationError gate for custom kernels in multi-task contexts. |
baybe/surrogates/gaussian_process/components/{generic,mean,likelihood}.py |
Renames the component factory protocol type and aligns runtime typing. |
baybe/settings.py |
Validates the new BAYBE_DISABLE_CUSTOM_KERNEL_WARNING env var as a boolean-like string. |
tests/test_kernels.py |
Updates kernel translation tests to build a SearchSpace and verify derived active_dims/ard_num_dims. |
tests/hypothesis_strategies/kernels.py |
Extends kernel strategies to optionally generate kernels with parameter_names. |
tests/test_deprecations.py |
Adds a test covering the new multi-task custom-kernel deprecation behavior. |
CHANGELOG.md |
Documents the new kernel API, parameter subselection, and the multi-task custom-kernel deprecation behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Exception: initial values are not used during construction but are set | ||
| # on the created object (see code at the end of the method). | ||
| missing = set(unmatched) - set(kernel_attrs) | ||
| missing = set(unmatched) - set(kernel_attrs) - self._whitelisted_attributes |
There was a problem hiding this comment.
missing is computed from unmatched (the loop-local from the last base class iteration) instead of the aggregated unmatched_attrs. This makes the sanity check incomplete and can miss BayBE attrs that don’t map to any GPyTorch constructor arg. Use unmatched_attrs when computing missing (or otherwise ensure the check considers all unmatched attributes across base classes).
| missing = set(unmatched) - set(kernel_attrs) - self._whitelisted_attributes | |
| missing = set(unmatched_attrs) - set(kernel_attrs) - self._whitelisted_attributes |
| # passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)` | ||
| # fails if we explicitly pass `ard_num_dims=None`. |
There was a problem hiding this comment.
The comment about filtering kwargs to avoid passing ard_num_dims=None / batch_shape=None is now out of sync with the implementation: kw only contains active_dims, while ard_num_dims is always passed explicitly later. Please update/remove the comment or restore the original filtering pattern so the code and rationale stay consistent.
| # passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)` | |
| # fails if we explicitly pass `ard_num_dims=None`. | |
| # passed or not. Explicitly passing `None` for a kwarg can change behavior | |
| # compared to omitting it, so we filter out any `None` values here. |
| @define | ||
| class ParameterSelector(ParameterSelectorProtocol): | ||
| """Base class for parameter selectors.""" | ||
|
|
||
| exclude: bool = field(default=False, validator=instance_of(bool), kw_only=True) | ||
| """Boolean flag indicating whether invert the selection criterion.""" | ||
|
|
||
| @abstractmethod | ||
| def _is_match(self, parameter: Parameter) -> bool: | ||
| """Determine if a parameter meets the selection criterion.""" |
There was a problem hiding this comment.
ParameterSelector defines _is_match with @abstractmethod but the class doesn’t inherit from abc.ABC / use ABCMeta, so the abstractness won’t be enforced and ParameterSelector() can be instantiated accidentally. Inherit from ABC (e.g., class ParameterSelector(ABC, ParameterSelectorProtocol):) to make the abstract contract effective.
| base_kernel_factory: KernelFactoryProtocol = field(alias="base_kernel_or_factory") | ||
| """The factory for the base kernel operating on numerical input features.""" | ||
|
|
||
| task_kernel_factory: KernelFactoryProtocol = field(alias="task_kernel_or_factory") | ||
| """The factory for the task kernel operating on the task indices.""" | ||
|
|
There was a problem hiding this comment.
ICMKernelFactory’s base_kernel_factory / task_kernel_factory fields are typed as factories and have no converter, so passing a plain Kernel (as suggested by the deprecation error message’s guidance to “include the task kernel explicitly”) will fail at runtime when the kernel is called. Consider adding the same to_component_factory converter pattern used by GaussianProcessSurrogate.kernel_factory so these fields accept either a kernel object or a factory.
671e81c to
118e019
Compare
DevPR, parent is #745
parameter_namesICMKernelFactoryclass