From 641b3a0bc270609bc54849aa378416e42b4ec9ab Mon Sep 17 00:00:00 2001 From: Richard Zhang Date: Fri, 13 Oct 2023 09:40:24 -0700 Subject: [PATCH] Add string parameter to metadata conversion support. PiperOrigin-RevId: 573243721 --- .../pyvizier/converters/string_converters.py | 68 +++++++++++++++++ .../converters/string_converters_test.py | 73 +++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 vizier/pyvizier/converters/string_converters.py create mode 100644 vizier/pyvizier/converters/string_converters_test.py diff --git a/vizier/pyvizier/converters/string_converters.py b/vizier/pyvizier/converters/string_converters.py new file mode 100644 index 000000000..29f3888ee --- /dev/null +++ b/vizier/pyvizier/converters/string_converters.py @@ -0,0 +1,68 @@ +# Copyright 2023 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +"""Converter utils for parameters for free-form strings.""" + +from typing import Sequence +import copy +import json +import attrs +from vizier import pyvizier as vz + +_METADATA_VERSION = '0.0.1a' +PROMPT_TUNING_NS = 'prompt_tuning' + + +@attrs.define +class PromptTuningConfig: + """Variables and utils for configuring prompt tuning.""" + + default_prompts: dict[str, str] = attrs.field(factory=dict) + + def augment_problem( + self, problem: vz.ProblemStatement + ) -> vz.ProblemStatement: + """Augments problem statement to enable for prompt tuning.""" + for k, v in self.default_prompts.items(): + problem.search_space.root.add_categorical_param(k, [v], default_value=v) + problem.metadata.ns(PROMPT_TUNING_NS)['version'] = _METADATA_VERSION + return problem + + def to_prompt_trials(self, trials: Sequence[vz.Trial]) -> Sequence[vz.Trial]: + """Convert to prompt Trial via metadata to string valued parameters.""" + prompt_trials = copy.deepcopy(trials) + for trial in prompt_trials: + prompt_values = json.loads(trial.metadata.ns(PROMPT_TUNING_NS)['values']) + for k in self.default_prompts.keys(): + if k in prompt_values: + trial.parameters[k] = prompt_values[k] + return prompt_trials + + def to_valid_suggestions( + self, suggestions: Sequence[vz.TrialSuggestion] + ) -> Sequence[vz.TrialSuggestion]: + """Returns TrialSuggestions that are valid in the augmented problem.""" + valid_suggestions = copy.deepcopy(suggestions) + for suggestion in valid_suggestions: + prompt_values = {} + for k, default_value in self.default_prompts.items(): + prompt_values[k] = suggestion.parameters[k].value + suggestion.parameters[k] = default_value + suggestion.metadata.ns(PROMPT_TUNING_NS)['values'] = json.dumps( + prompt_values + ) + suggestion.metadata.ns(PROMPT_TUNING_NS)['version'] = _METADATA_VERSION + return valid_suggestions diff --git a/vizier/pyvizier/converters/string_converters_test.py b/vizier/pyvizier/converters/string_converters_test.py new file mode 100644 index 000000000..0a276c665 --- /dev/null +++ b/vizier/pyvizier/converters/string_converters_test.py @@ -0,0 +1,73 @@ +# Copyright 2023 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json + +from vizier import pyvizier as vz +from vizier.pyvizier.converters import string_converters + +from absl.testing import absltest + + +class StringConvertersTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.default_prompts = {'prompt1': 'def1', 'prompt2': 'def2'} + self.config = string_converters.PromptTuningConfig( + default_prompts=self.default_prompts + ) + + def test_augment_problem(self): + problem = vz.ProblemStatement() + tuning_problem = self.config.augment_problem(problem) + for k, v in self.default_prompts.items(): + pconfig = tuning_problem.search_space.get(k) + self.assertCountEqual(pconfig.feasible_values, [v]) + self.assertEqual(pconfig.default_value, v) + + def test_prompt_trials(self): + trial = vz.Trial(parameters={'int': 3, 'float': 1.2, 'cat': 'test'}) + trial.metadata.ns(string_converters.PROMPT_TUNING_NS)['values'] = ( + json.dumps({'prompt1': 'test1', 'prompt2': 'test2'}) + ) + results = self.config.to_prompt_trials([trial]) + self.assertLen(results, 1) + + self.assertStartsWith(results[0].parameters['prompt1'].value, 'test1') + self.assertEqual(results[0].parameters['prompt2'].value, 'test2') + + def test_valid_suggestions(self): + problem = vz.ProblemStatement() + tuning_problem = self.config.augment_problem(problem) + suggestion = vz.TrialSuggestion( + parameters={'prompt1': 'test1', 'prompt2': 'test2'} + ) + valid_suggestions = self.config.to_valid_suggestions([suggestion]) + self.assertLen(valid_suggestions, 1) + valid_suggestion = valid_suggestions[0] + self.assertTrue( + tuning_problem.search_space.contains(valid_suggestion.parameters) + ) + + # Test the reverse conversion retrieves original parameters. + trial = valid_suggestion.to_trial().complete(vz.Measurement()) + prompt_trial = self.config.to_prompt_trials([trial])[0] + self.assertCountEqual(prompt_trial.parameters, suggestion.parameters) + + +if __name__ == '__main__': + absltest.main()