From a2deb1819515a179486efd91ddfc8cc42d3489b6 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Thu, 15 Jun 2023 13:50:41 -0700 Subject: [PATCH] JIT whole ARD optimization loop in VizierGPBandit when JaxoptLBFGSB is used. PiperOrigin-RevId: 540680400 --- vizier/_src/algorithms/designers/gp_bandit.py | 72 +++++++++++-------- .../algorithms/designers/gp_bandit_test.py | 2 +- vizier/_src/jax/optimizers/jaxopt_wrappers.py | 18 +++-- 3 files changed, 56 insertions(+), 36 deletions(-) diff --git a/vizier/_src/algorithms/designers/gp_bandit.py b/vizier/_src/algorithms/designers/gp_bandit.py index 31ee5c376..f72465cdd 100644 --- a/vizier/_src/algorithms/designers/gp_bandit.py +++ b/vizier/_src/algorithms/designers/gp_bandit.py @@ -78,6 +78,39 @@ ) +def _optimize_ard( + optimizer: optimizers.Optimizer[types.ParameterDict], + model: sp.StochasticProcessModel, + data: types.StochasticProcessModelData, + seed: jax.random.KeyArray, +) -> tuple[types.ParameterDict, Any]: + """Perform ARD on the current model to find best model parameters.""" + + # Run ARD. + def setup(k: jax.random.KeyArray): + return jax.jit( + gp_bandit_utils.stochastic_process_model_setup, + static_argnames=('model',), + )(k, model=model, data=data) + + constraints = sp.get_constraints(model) + loss_fn = functools.partial( + jax.jit( + gp_bandit_utils.stochastic_process_model_loss_fn, + static_argnames=('model', 'normalize'), + ), + model=model, + data=data, + # For SGD, normalize the loss so we can use the same learning rate + # regardless of the number of examples (see + # `OptaxTrainWithRandomRestarts` docstring). + normalize=isinstance(optimizer, optimizers.OptaxTrainWithRandomRestarts), + ) + + # The ARD optimizers JIT the train step/loop internally. + return optimizer(setup, loss_fn, seed, constraints=constraints) + + @attr.define(auto_attribs=False) class VizierGPBandit(vza.Designer, vza.Predictor): """GP-Bandit using a Flax model. @@ -342,26 +375,23 @@ def _convert_trials_to_arrays( def _find_best_model_params( self, model: sp.StochasticProcessModel, - loss_fn: optimizers.LossFunction, data: types.StochasticProcessModelData, seed: jax.random.KeyArray, ) -> tuple[types.ParameterDict, Any]: """Perform ARD on the current model to find best model parameters.""" # Run ARD. - setup = functools.partial( - jax.jit( - gp_bandit_utils.stochastic_process_model_setup, - static_argnames=('model',), - ), - model=model, - data=data, - ) - constraints = sp.get_constraints(model) + if isinstance(self._ard_optimizer, optimizers.JaxoptLbfgsB): + optimize_ard = jax.jit(_optimize_ard, static_argnames='model') + else: + optimize_ard = _optimize_ard logging.info('Optimizing the loss function...') + run_ard_with_profiling = profiler.record_runtime( + optimize_ard, name_prefix='VizierGPBandit', name='ard', also_log=True + ) - # The ARD optimizers JIT the train step/loop internally. - return self._ard_optimizer(setup, loss_fn, seed, constraints=constraints) + logging.info('Optimizing the loss function...') + return run_ard_with_profiling(self._ard_optimizer, model, data, seed) @profiler.record_runtime(name_prefix='VizierGPBandit', name='compute_state') def _compute_state( @@ -404,25 +434,9 @@ def _compute_state( dimension_is_missing=dimension_is_missing, ) model = self._build_model(features) - # TODO: Avoid retracing vmapped loss when loss function API is - # redesigned. - loss_fn = functools.partial( - jax.jit( - gp_bandit_utils.stochastic_process_model_loss_fn, - static_argnames=('model', 'normalize'), - ), - model=model, - data=data, - # For SGD, normalize the loss so we can use the same learning rate - # regardless of the number of examples (see - # `OptaxTrainWithRandomRestarts` docstring). - normalize=isinstance( - self._ard_optimizer, optimizers.OptaxTrainWithRandomRestarts - ), - ) self._rng, ard_rng = jax.random.split(self._rng, 2) best_model_params, metrics = self._find_best_model_params( - model, loss_fn, data, ard_rng + model, data, ard_rng ) # Logging for debugging purposes. logging.info('Best model parameters: %s', best_model_params) diff --git a/vizier/_src/algorithms/designers/gp_bandit_test.py b/vizier/_src/algorithms/designers/gp_bandit_test.py index 8194fccec..820b153c6 100644 --- a/vizier/_src/algorithms/designers/gp_bandit_test.py +++ b/vizier/_src/algorithms/designers/gp_bandit_test.py @@ -42,7 +42,7 @@ ensemble_ard_optimizer = optimizers.default_optimizer() -noensemble_ard_optimizer = optimizers.JaxoptScipyLbfgsB( +noensemble_ard_optimizer = optimizers.JaxoptLbfgsB( optimizers.LbfgsBOptions(random_restarts=5, best_n=1) ) diff --git a/vizier/_src/jax/optimizers/jaxopt_wrappers.py b/vizier/_src/jax/optimizers/jaxopt_wrappers.py index 346b4cdd7..21f9e32a3 100644 --- a/vizier/_src/jax/optimizers/jaxopt_wrappers.py +++ b/vizier/_src/jax/optimizers/jaxopt_wrappers.py @@ -47,10 +47,14 @@ class LbfgsBOptions: less than or equal to `random_restarts`. """ - num_line_search_steps: int = struct.field(kw_only=True, default=20) - random_restarts: int = struct.field(kw_only=True, default=4) + num_line_search_steps: int = struct.field( + kw_only=True, default=20, pytree_node=False + ) + random_restarts: int = struct.field( + kw_only=True, default=4, pytree_node=False + ) tol: float = struct.field(kw_only=True, default=1e-8) - maxiter: int = struct.field(kw_only=True, default=50) + maxiter: int = struct.field(kw_only=True, default=50, pytree_node=False) best_n: int = struct.field(kw_only=True, default=1, pytree_node=False) def __post_init__(self): @@ -171,7 +175,7 @@ def best_n(self) -> int: return self._options.best_n -@attr.define +@struct.dataclass class JaxoptLbfgsB(core.Optimizer[core.Params]): """Jaxopt's L-BFGS-B optimizer. @@ -193,8 +197,10 @@ class JaxoptLbfgsB(core.Optimizer[core.Params]): _speed_test: If True, return speed test results. """ - _options: LbfgsBOptions = attr.field(default=LbfgsBOptions()) - _speed_test: bool = attr.field(kw_only=True, default=False) + _options: LbfgsBOptions = struct.field(default_factory=LbfgsBOptions) + _speed_test: bool = struct.field( + pytree_node=False, kw_only=True, default=False + ) def __call__( self,