From 2dea196991d1f8e10ad489db1a55c2c50fd5279d Mon Sep 17 00:00:00 2001 From: vizier-team Date: Fri, 4 Apr 2025 14:16:57 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 744067296 --- vizier/_src/algorithms/designers/gp/yjt.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vizier/_src/algorithms/designers/gp/yjt.py b/vizier/_src/algorithms/designers/gp/yjt.py index 08b7cb831..e5b2437c3 100644 --- a/vizier/_src/algorithms/designers/gp/yjt.py +++ b/vizier/_src/algorithms/designers/gp/yjt.py @@ -79,16 +79,24 @@ def optimal_transformation( lambdas = preprocessing.PowerTransformer( method, standardize=False).fit(data).lambdas_.astype(dtype) - logging.info('Optimal lambda was: %s', lambdas) + logging.info('Optimal lambda was: %s, %s', lambdas, lambdas.dtype) if dimension == 1: # Make it a scalar, so we don't end up with batch_shape = [1] in the # bijector. lambdas = lambdas.item() if method == 'yeo-johnson': - warp = tfsb.YeoJohnson(lambdas) + # Cast the default values of `rho` and `shift` to the same dtype as `data` + # to avoid dtype mismatch errors. + warp = tfsb.YeoJohnson( + lambdas, rho=np.asarray(2.0, dtype=dtype), shift=np.asarray(1.0, dtype) + ) elif method == 'box-cox': - warp = tfsb.YeoJohnson(lambdas, shift=.0) + # Cast the default values of `rho` and `shift` to the same dtype as `data` + # to avoid dtype mismatch errors. + warp = tfsb.YeoJohnson( + lambdas, rho=np.asarray(2.0, dtype), shift=np.asarray(0.0, dtype) + ) else: raise ValueError(f'Unknown method: {method}')