From 830224ddc3d1c72ba4c6cace19285a853f17d125 Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Thu, 18 Dec 2025 14:54:50 +0100 Subject: [PATCH 1/8] Add new `GaussianProcessSurrogate.from_prior` method - construct a GP by transferring knowledge from a pre-trained prior GP - basic implementation for full mean transfer - the posterior mean of the pretrained GP are used as mean module for GP - hypereparameters are frozen and mean is evaluated at source points - interface might later be extended to other mean transfers (initialize hyperparameters) or covariance transfer - new `PriorMean` class, that implements mean of prior GP as botorch module --- baybe/surrogates/gaussian_process/core.py | 76 ++++++++++++++++--- .../gaussian_process/prior_modules.py | 55 ++++++++++++++ 2 files changed, 121 insertions(+), 10 deletions(-) create mode 100644 baybe/surrogates/gaussian_process/prior_modules.py diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index c0148aca55..3ab9e1c788 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -24,9 +24,11 @@ DefaultKernelFactory, _default_noise_factory, ) +from baybe.surrogates.gaussian_process.prior_modules import PriorMean from baybe.utils.conversion import to_string if TYPE_CHECKING: + from botorch.models import SingleTaskGP from botorch.models.gpytorch import GPyTorchModel from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform @@ -113,11 +115,57 @@ class GaussianProcessSurrogate(Surrogate): _model = field(init=False, default=None, eq=False) """The actual model.""" + # Transfer learning fields + _prior_gp = field(init=False, default=None, eq=False) + """Prior GP to extract mean/covariance from for transfer learning.""" + @staticmethod def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate: """Create a Gaussian process surrogate from one of the defined presets.""" return make_gp_from_preset(preset) + @classmethod + def _from_prior( + cls, + prior_gp: SingleTaskGP, + kernel_factory: KernelFactory | None = None, + **kwargs, + ) -> GaussianProcessSurrogate: + """Create a GP surrogate with mean function transfer learning. + + Args: + prior_gp: Fitted SingleTaskGP to use as prior + kernel_factory: Kernel factory for covariance components + **kwargs: Additional arguments for GaussianProcessSurrogate constructor + + Returns: + New GaussianProcessSurrogate instance with transfer learning + + Raises: + ValueError: If prior_gp is not fitted + """ + from copy import deepcopy + + from botorch.models import SingleTaskGP + + # Validate prior GP is fitted + if not isinstance(prior_gp, SingleTaskGP): + raise ValueError("prior_gp must be a fitted SingleTaskGP instance") + if not hasattr(prior_gp, "train_inputs") or prior_gp.train_inputs is None: + raise ValueError("Prior GP must be fitted (have train_inputs) before use") + + # Configure kernel factory (always needed since we only do mean transfer now) + if kernel_factory is None: + kernel_factory = DefaultKernelFactory() + + # Create new surrogate instance + instance = cls(kernel_or_factory=kernel_factory, **kwargs) + + # Configure for transfer learning + instance._prior_gp = deepcopy(prior_gp) + + return instance + @override def to_botorch(self) -> GPyTorchModel: return self._model @@ -152,22 +200,30 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: assert self._searchspace is not None context = _ModelContext(self._searchspace) - numerical_idxs = context.get_numerical_indices(train_x.shape[-1]) - # For GPs, we let botorch handle the scaling. See [Scaling Workaround] above. - input_transform = Normalize( - train_x.shape[-1], - bounds=context.parameter_bounds, - indices=list(numerical_idxs), - ) - outcome_transform = Standardize(train_y.shape[-1]) - # extract the batch shape of the training data batch_shape = train_x.shape[:-2] + # Configure input/output transforms + if self._prior_gp is not None and hasattr(self._prior_gp, "input_transform"): + # Use prior's transforms for consistency in transfer learning + input_transform = self._prior_gp.input_transform + outcome_transform = self._prior_gp.outcome_transform + else: + # For GPs, we let botorch handle scaling. See [Scaling Workaround] above. + input_transform = Normalize( + train_x.shape[-1], + bounds=context.parameter_bounds, + indices=numerical_idxs, + ) + outcome_transform = Standardize(train_y.shape[-1]) + # create GP mean - mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape) + if self._prior_gp is not None: + mean_module = PriorMean(self._prior_gp, batch_shape=batch_shape) + else: + mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape) # define the covariance module for the numeric dimensions base_covar_module = self.kernel_factory( diff --git a/baybe/surrogates/gaussian_process/prior_modules.py b/baybe/surrogates/gaussian_process/prior_modules.py new file mode 100644 index 0000000000..f411349cce --- /dev/null +++ b/baybe/surrogates/gaussian_process/prior_modules.py @@ -0,0 +1,55 @@ +"""Prior modules for Gaussian process transfer learning.""" + +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +import gpytorch +import torch +from botorch.models import SingleTaskGP +from torch import Tensor + + +class PriorMean(gpytorch.means.Mean): + """GPyTorch mean module using a trained GP as prior mean. + + This mean module wraps a trained Gaussian Process and uses its predictions + as the mean function for another GP. + + Args: + gp: Trained Gaussian Process to use as mean function. + batch_shape: Batch shape for the mean module. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, gp: SingleTaskGP, batch_shape: torch.Size = torch.Size(), **kwargs: Any + ) -> None: + super().__init__() + + # Deep copy and freeze the GP + self.gp: SingleTaskGP = deepcopy(gp) + self.batch_shape: torch.Size = batch_shape + + # Freeze parameters and set eval mode once + for param in self.gp.parameters(): + param.requires_grad = False + self.gp.eval() + self.gp.likelihood.eval() + + def forward(self, x: Tensor) -> Tensor: + """Compute the mean function using the wrapped GP. + + Args: + x: Input tensor for which to compute the mean. + + Returns: + Mean predictions from the wrapped GP. + """ + with torch.no_grad(), gpytorch.settings.fast_pred_var(): + mean = self.gp(x).mean.detach() + + # Handle batch dimensions + target_shape = torch.broadcast_shapes(self.batch_shape, x.shape[:-1]) + return mean.reshape(target_shape) From e88f94bf1fa22044219b43b3bbd4055f70620970 Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Thu, 18 Dec 2025 15:18:49 +0100 Subject: [PATCH 2/8] Make constructor method public --- baybe/surrogates/gaussian_process/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 3ab9e1c788..4bd0a6dace 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -125,7 +125,7 @@ def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate: return make_gp_from_preset(preset) @classmethod - def _from_prior( + def from_prior( cls, prior_gp: SingleTaskGP, kernel_factory: KernelFactory | None = None, From cdf71a9d7a56bb515c582f50cfa43d6d0d2535ee Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Thu, 18 Dec 2025 15:24:56 +0100 Subject: [PATCH 3/8] Set source GP to eval mode for predictions --- baybe/surrogates/gaussian_process/prior_modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/baybe/surrogates/gaussian_process/prior_modules.py b/baybe/surrogates/gaussian_process/prior_modules.py index f411349cce..d718b3c7a6 100644 --- a/baybe/surrogates/gaussian_process/prior_modules.py +++ b/baybe/surrogates/gaussian_process/prior_modules.py @@ -35,8 +35,6 @@ def __init__( # Freeze parameters and set eval mode once for param in self.gp.parameters(): param.requires_grad = False - self.gp.eval() - self.gp.likelihood.eval() def forward(self, x: Tensor) -> Tensor: """Compute the mean function using the wrapped GP. @@ -47,6 +45,8 @@ def forward(self, x: Tensor) -> Tensor: Returns: Mean predictions from the wrapped GP. """ + self.gp.eval() + self.gp.likelihood.eval() with torch.no_grad(), gpytorch.settings.fast_pred_var(): mean = self.gp(x).mean.detach() From 1dff68569a4da0b46cb689edf72f3a9889329055 Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Thu, 8 Jan 2026 12:05:13 +0100 Subject: [PATCH 4/8] Fix typo in docstring Co-authored-by: Martin Fitzner --- baybe/surrogates/gaussian_process/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 4bd0a6dace..ccf9c01117 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -117,7 +117,7 @@ class GaussianProcessSurrogate(Surrogate): # Transfer learning fields _prior_gp = field(init=False, default=None, eq=False) - """Prior GP to extract mean/covariance from for transfer learning.""" + """Prior GP to extract mean/covariance for transfer learning.""" @staticmethod def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate: From 7d17a126471883124cc3171742f74486cb487c92 Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Thu, 8 Jan 2026 12:21:35 +0100 Subject: [PATCH 5/8] Fix type annotation --- baybe/surrogates/gaussian_process/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index ccf9c01117..dc48d6fd0e 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -116,7 +116,7 @@ class GaussianProcessSurrogate(Surrogate): """The actual model.""" # Transfer learning fields - _prior_gp = field(init=False, default=None, eq=False) + _prior_gp: SingleTaskGP | None = field(init=False, default=None, eq=False) """Prior GP to extract mean/covariance for transfer learning.""" @staticmethod From eacdb28244fb6db7871717e162807c25de89ea20 Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Thu, 8 Jan 2026 12:30:28 +0100 Subject: [PATCH 6/8] Create new GP from GaussianProcessSurrogate instead of botorch instance --- baybe/surrogates/gaussian_process/core.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index dc48d6fd0e..ea12a9cc15 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -127,14 +127,14 @@ def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate: @classmethod def from_prior( cls, - prior_gp: SingleTaskGP, + prior_gp: GaussianProcessSurrogate, kernel_factory: KernelFactory | None = None, **kwargs, ) -> GaussianProcessSurrogate: """Create a GP surrogate with mean function transfer learning. Args: - prior_gp: Fitted SingleTaskGP to use as prior + prior_gp: Fitted GaussianProcessSurrogate to use as prior kernel_factory: Kernel factory for covariance components **kwargs: Additional arguments for GaussianProcessSurrogate constructor @@ -146,13 +146,13 @@ def from_prior( """ from copy import deepcopy - from botorch.models import SingleTaskGP - # Validate prior GP is fitted - if not isinstance(prior_gp, SingleTaskGP): - raise ValueError("prior_gp must be a fitted SingleTaskGP instance") - if not hasattr(prior_gp, "train_inputs") or prior_gp.train_inputs is None: - raise ValueError("Prior GP must be fitted (have train_inputs) before use") + if not isinstance(prior_gp, cls): + raise ValueError( + "prior_gp must be a fitted GaussianProcessSurrogate instance" + ) + if prior_gp._model is None: + raise ValueError("Prior GP must be fitted before use") # Configure kernel factory (always needed since we only do mean transfer now) if kernel_factory is None: @@ -161,8 +161,8 @@ def from_prior( # Create new surrogate instance instance = cls(kernel_or_factory=kernel_factory, **kwargs) - # Configure for transfer learning - instance._prior_gp = deepcopy(prior_gp) + # Configure for transfer learning - store the BoTorch model + instance._prior_gp = deepcopy(prior_gp.to_botorch()) return instance From 4cc837d30238783245fce1d4e08fed160595ca64 Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Thu, 8 Jan 2026 12:35:47 +0100 Subject: [PATCH 7/8] Add all raises to docstring --- baybe/surrogates/gaussian_process/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index ea12a9cc15..4038ec1fa4 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -142,7 +142,7 @@ def from_prior( New GaussianProcessSurrogate instance with transfer learning Raises: - ValueError: If prior_gp is not fitted + ValueError: If prior_gp is not a GaussianProcessSurrogate or is not fitted """ from copy import deepcopy From 63e385c4e6a42b51545d8e8b202440116381a214 Mon Sep 17 00:00:00 2001 From: kalama-ai Date: Thu, 8 Jan 2026 13:57:53 +0100 Subject: [PATCH 8/8] Update docstring --- baybe/surrogates/gaussian_process/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/baybe/surrogates/gaussian_process/core.py b/baybe/surrogates/gaussian_process/core.py index 4038ec1fa4..7e41049665 100644 --- a/baybe/surrogates/gaussian_process/core.py +++ b/baybe/surrogates/gaussian_process/core.py @@ -131,7 +131,10 @@ def from_prior( kernel_factory: KernelFactory | None = None, **kwargs, ) -> GaussianProcessSurrogate: - """Create a GP surrogate with mean function transfer learning. + """Create a GP surrogate using a prior GP's predictions as the mean function. + + Transfers knowledge by using the prior GP's posterior mean predictions + as the mean function for a new GP, while learning covariance from scratch. Args: prior_gp: Fitted GaussianProcessSurrogate to use as prior