Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,39 @@
)


def _optimize_ard(
optimizer: optimizers.Optimizer[types.ParameterDict],
model: sp.StochasticProcessModel,
data: types.StochasticProcessModelData,
seed: jax.random.KeyArray,
) -> tuple[types.ParameterDict, Any]:
"""Perform ARD on the current model to find best model parameters."""

# Run ARD.
def setup(k: jax.random.KeyArray):
return jax.jit(
gp_bandit_utils.stochastic_process_model_setup,
static_argnames=('model',),
)(k, model=model, data=data)

constraints = sp.get_constraints(model)
loss_fn = functools.partial(
jax.jit(
gp_bandit_utils.stochastic_process_model_loss_fn,
static_argnames=('model', 'normalize'),
),
model=model,
data=data,
# For SGD, normalize the loss so we can use the same learning rate
# regardless of the number of examples (see
# `OptaxTrainWithRandomRestarts` docstring).
normalize=isinstance(optimizer, optimizers.OptaxTrainWithRandomRestarts),
)

# The ARD optimizers JIT the train step/loop internally.
return optimizer(setup, loss_fn, seed, constraints=constraints)


@attr.define(auto_attribs=False)
class VizierGPBandit(vza.Designer, vza.Predictor):
"""GP-Bandit using a Flax model.
Expand Down Expand Up @@ -342,26 +375,23 @@ def _convert_trials_to_arrays(
def _find_best_model_params(
self,
model: sp.StochasticProcessModel,
loss_fn: optimizers.LossFunction,
data: types.StochasticProcessModelData,
seed: jax.random.KeyArray,
) -> tuple[types.ParameterDict, Any]:
"""Perform ARD on the current model to find best model parameters."""
# Run ARD.
setup = functools.partial(
jax.jit(
gp_bandit_utils.stochastic_process_model_setup,
static_argnames=('model',),
),
model=model,
data=data,
)
constraints = sp.get_constraints(model)
if isinstance(self._ard_optimizer, optimizers.JaxoptLbfgsB):
optimize_ard = jax.jit(_optimize_ard, static_argnames='model')
else:
optimize_ard = _optimize_ard

logging.info('Optimizing the loss function...')
run_ard_with_profiling = profiler.record_runtime(
optimize_ard, name_prefix='VizierGPBandit', name='ard', also_log=True
)

# The ARD optimizers JIT the train step/loop internally.
return self._ard_optimizer(setup, loss_fn, seed, constraints=constraints)
logging.info('Optimizing the loss function...')
return run_ard_with_profiling(self._ard_optimizer, model, data, seed)

@profiler.record_runtime(name_prefix='VizierGPBandit', name='compute_state')
def _compute_state(
Expand Down Expand Up @@ -404,25 +434,9 @@ def _compute_state(
dimension_is_missing=dimension_is_missing,
)
model = self._build_model(features)
# TODO: Avoid retracing vmapped loss when loss function API is
# redesigned.
loss_fn = functools.partial(
jax.jit(
gp_bandit_utils.stochastic_process_model_loss_fn,
static_argnames=('model', 'normalize'),
),
model=model,
data=data,
# For SGD, normalize the loss so we can use the same learning rate
# regardless of the number of examples (see
# `OptaxTrainWithRandomRestarts` docstring).
normalize=isinstance(
self._ard_optimizer, optimizers.OptaxTrainWithRandomRestarts
),
)
self._rng, ard_rng = jax.random.split(self._rng, 2)
best_model_params, metrics = self._find_best_model_params(
model, loss_fn, data, ard_rng
model, data, ard_rng
)
# Logging for debugging purposes.
logging.info('Best model parameters: %s', best_model_params)
Expand Down
2 changes: 1 addition & 1 deletion vizier/_src/algorithms/designers/gp_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@


ensemble_ard_optimizer = optimizers.default_optimizer()
noensemble_ard_optimizer = optimizers.JaxoptScipyLbfgsB(
noensemble_ard_optimizer = optimizers.JaxoptLbfgsB(
optimizers.LbfgsBOptions(random_restarts=5, best_n=1)
)

Expand Down
18 changes: 12 additions & 6 deletions vizier/_src/jax/optimizers/jaxopt_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@ class LbfgsBOptions:
less than or equal to `random_restarts`.
"""

num_line_search_steps: int = struct.field(kw_only=True, default=20)
random_restarts: int = struct.field(kw_only=True, default=4)
num_line_search_steps: int = struct.field(
kw_only=True, default=20, pytree_node=False
)
random_restarts: int = struct.field(
kw_only=True, default=4, pytree_node=False
)
tol: float = struct.field(kw_only=True, default=1e-8)
maxiter: int = struct.field(kw_only=True, default=50)
maxiter: int = struct.field(kw_only=True, default=50, pytree_node=False)
best_n: int = struct.field(kw_only=True, default=1, pytree_node=False)

def __post_init__(self):
Expand Down Expand Up @@ -171,7 +175,7 @@ def best_n(self) -> int:
return self._options.best_n


@attr.define
@struct.dataclass
class JaxoptLbfgsB(core.Optimizer[core.Params]):
"""Jaxopt's L-BFGS-B optimizer.

Expand All @@ -193,8 +197,10 @@ class JaxoptLbfgsB(core.Optimizer[core.Params]):
_speed_test: If True, return speed test results.
"""

_options: LbfgsBOptions = attr.field(default=LbfgsBOptions())
_speed_test: bool = attr.field(kw_only=True, default=False)
_options: LbfgsBOptions = struct.field(default_factory=LbfgsBOptions)
_speed_test: bool = struct.field(
pytree_node=False, kw_only=True, default=False
)

def __call__(
self,
Expand Down