From 70b22d6b03cf91fe518c55cd7621c9baa1e28bd6 Mon Sep 17 00:00:00 2001 From: bkb2135 <98138173+bkb2135@users.noreply.github.com> Date: Mon, 30 Oct 2023 09:22:06 -1000 Subject: [PATCH 1/5] Add Answer Layout Criteria --- prompting/validators/criteria.py | 76 ++++++++++++++++++++++++++++++++ prompting/validators/tasks.py | 8 +++- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/prompting/validators/criteria.py b/prompting/validators/criteria.py index 23afbc3..343f002 100644 --- a/prompting/validators/criteria.py +++ b/prompting/validators/criteria.py @@ -17,6 +17,9 @@ # DEALINGS IN THE SOFTWARE. import re import torch +import json +import yaml +import ast from dataclasses import dataclass from abc import ABC, abstractmethod from typing import List @@ -94,3 +97,76 @@ def evaluate(self, completions: List[str]) -> torch.FloatTensor: def compose_text(self) -> str: return self.text.format(target_length=self.target_length, unit=self.unit.value) + +class LayoutTypeEnum(Enum): + JSON = "json" + YAML = "yaml" + DICTIONARY = "python dictionary" + NUMBEREDLIST = "numbered list" + BULLETPOINTLIST = "bullet point list" + +@dataclass +class MatchLayoutCriteria(TaskCriterion): + text: str = "Your response must be in the form of a {format_type}{w}{fields}" + penalty: float = 0.1 + format_type: LayoutTypeEnum = LayoutTypeEnum.JSON + num_fields: int = 0 + fields : str = " with {num_fields} fields" + + def is_json(text): + try: + json.loads(text) + return True + except ValueError: + return False + + def is_yaml(text): + try: + yaml.safe_load(text) + return True + except yaml.YAMLError: + return False + + def is_dictionary(text): + try: + if type(ast.literal_eval(text)) == dict: + return True + else: + return False + except (ValueError, SyntaxError): + return False + + def is_numbered_list(input_string): + pattern = r'^\d.*\n?$' + lines = input_string.split('\n') + return all(re.match(pattern, line) for line in lines) + + def is_bullet_point_list(input_string): + pattern = r'^\s*[-*+]\s.*\n?(\s*[-*+]\s.*\n?)*$' + return bool(re.match(pattern, input_string)) + + def _get_format_match(self, response : str) -> bool: + if self.format_type == LayoutTypeEnum.JSON: + return self.is_json(response) + elif self.format_type == LayoutTypeEnum.YAML: + return self.is_yaml(response) + elif self.format_type == LayoutTypeEnum.DICTIONARY: + return self.is_dictionary(response) + elif self.format_type == LayoutTypeEnum.NUMBEREDLIST: + return self.is_numbered_list(response) + elif self.format_type == LayoutTypeEnum.BULLETPOINTLIST: + return self.is_bullet_point_list(response) + else: + return False + + def evaluate(self, completions: list[str]) -> torch.FloatTensor: + penalties = torch.zeros(len(completions), dtype = torch.float32) + for idx, completion in enumerate(completions): + if not self._get_format_match(completion): + penalties[idx] = self.penalty + return penalties + + def compose_text(self) -> str: + if self.num_fields == 0: + return self.text.format(format_type = self.format_type.value, fields = "") + return self.text.format(format_type = self.format_type.value, fields = self.fields) diff --git a/prompting/validators/tasks.py b/prompting/validators/tasks.py index 2799169..023b96e 100644 --- a/prompting/validators/tasks.py +++ b/prompting/validators/tasks.py @@ -25,6 +25,8 @@ TaskCriterion, MatchLengthCriteria, TextLengthUnitEnum, + MatchLayoutCriteria, + LayoutTypeEnum, ) @@ -169,8 +171,12 @@ def create_qa_task(base_text: str, index: int) -> QuestionAnswerTask: target_length=random.randint(4, 8), unit=TextLengthUnitEnum.SENTENCES, ) + match_layout_criteria = MatchLayoutCriteria( + penalty = 0.1, + target_layout = LayoutTypeEnum.JSON + ) - criteria = [match_words_criteria, match_length_criteria] + criteria = [match_words_criteria, match_length_criteria, match_layout_criteria] return QuestionAnswerTask( base_text=base_text, From a99c62555f696a6ad005a1c17ce0b22b331b79b3 Mon Sep 17 00:00:00 2001 From: bkb2135 <98138173+bkb2135@users.noreply.github.com> Date: Tue, 31 Oct 2023 08:37:03 -1000 Subject: [PATCH 2/5] Add Layout Matching Criteria --- prompting/validators/criteria.py | 15 +++++++++++---- prompting/validators/tasks.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/prompting/validators/criteria.py b/prompting/validators/criteria.py index 343f002..f8003c4 100644 --- a/prompting/validators/criteria.py +++ b/prompting/validators/criteria.py @@ -18,7 +18,8 @@ import re import torch import json -import yaml +import random +#import yaml import ast from dataclasses import dataclass from abc import ABC, abstractmethod @@ -105,6 +106,12 @@ class LayoutTypeEnum(Enum): NUMBEREDLIST = "numbered list" BULLETPOINTLIST = "bullet point list" + def _select_random_attribute(self): + attribute_names = [attr for attr in vars(self) if not attr.startswith('_')] + random_attribute_name = random.choice(attribute_names) + random_attribute_value = getattr(self, random_attribute_name) + return random_attribute_value + @dataclass class MatchLayoutCriteria(TaskCriterion): text: str = "Your response must be in the form of a {format_type}{w}{fields}" @@ -122,9 +129,9 @@ def is_json(text): def is_yaml(text): try: - yaml.safe_load(text) + #yaml.safe_load(text) return True - except yaml.YAMLError: + except: #yaml.YAMLError: return False def is_dictionary(text): @@ -158,7 +165,7 @@ def _get_format_match(self, response : str) -> bool: return self.is_bullet_point_list(response) else: return False - + def evaluate(self, completions: list[str]) -> torch.FloatTensor: penalties = torch.zeros(len(completions), dtype = torch.float32) for idx, completion in enumerate(completions): diff --git a/prompting/validators/tasks.py b/prompting/validators/tasks.py index 023b96e..d37d91c 100644 --- a/prompting/validators/tasks.py +++ b/prompting/validators/tasks.py @@ -173,7 +173,7 @@ def create_qa_task(base_text: str, index: int) -> QuestionAnswerTask: ) match_layout_criteria = MatchLayoutCriteria( penalty = 0.1, - target_layout = LayoutTypeEnum.JSON + target_layout = LayoutTypeEnum._select_random_attribute(LayoutTypeEnum), ) criteria = [match_words_criteria, match_length_criteria, match_layout_criteria] From eaf8886513c0cc6772b13c34cd679384d9023339 Mon Sep 17 00:00:00 2001 From: bkb2135 <98138173+bkb2135@users.noreply.github.com> Date: Tue, 31 Oct 2023 08:45:48 -1000 Subject: [PATCH 3/5] Remove Yaml format --- prompting/validators/criteria.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/prompting/validators/criteria.py b/prompting/validators/criteria.py index f8003c4..84a706d 100644 --- a/prompting/validators/criteria.py +++ b/prompting/validators/criteria.py @@ -101,7 +101,6 @@ def compose_text(self) -> str: class LayoutTypeEnum(Enum): JSON = "json" - YAML = "yaml" DICTIONARY = "python dictionary" NUMBEREDLIST = "numbered list" BULLETPOINTLIST = "bullet point list" @@ -126,13 +125,6 @@ def is_json(text): return True except ValueError: return False - - def is_yaml(text): - try: - #yaml.safe_load(text) - return True - except: #yaml.YAMLError: - return False def is_dictionary(text): try: @@ -155,8 +147,6 @@ def is_bullet_point_list(input_string): def _get_format_match(self, response : str) -> bool: if self.format_type == LayoutTypeEnum.JSON: return self.is_json(response) - elif self.format_type == LayoutTypeEnum.YAML: - return self.is_yaml(response) elif self.format_type == LayoutTypeEnum.DICTIONARY: return self.is_dictionary(response) elif self.format_type == LayoutTypeEnum.NUMBEREDLIST: From ef0c7b9d20369a7611ba9daac2d8865c07e4eacc Mon Sep 17 00:00:00 2001 From: bkb2135 <98138173+bkb2135@users.noreply.github.com> Date: Tue, 31 Oct 2023 08:47:03 -1000 Subject: [PATCH 4/5] Update criteria.py --- prompting/validators/criteria.py | 1 - 1 file changed, 1 deletion(-) diff --git a/prompting/validators/criteria.py b/prompting/validators/criteria.py index 84a706d..e46381a 100644 --- a/prompting/validators/criteria.py +++ b/prompting/validators/criteria.py @@ -19,7 +19,6 @@ import torch import json import random -#import yaml import ast from dataclasses import dataclass from abc import ABC, abstractmethod From 6036ab8a3e9fc5c9ba5219822d0abd493c609bd2 Mon Sep 17 00:00:00 2001 From: bkb2135 <98138173+bkb2135@users.noreply.github.com> Date: Wed, 1 Nov 2023 11:11:00 -1000 Subject: [PATCH 5/5] Remove unused variable and add instance passing --- prompting/validators/criteria.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/prompting/validators/criteria.py b/prompting/validators/criteria.py index e46381a..030c243 100644 --- a/prompting/validators/criteria.py +++ b/prompting/validators/criteria.py @@ -112,20 +112,20 @@ def _select_random_attribute(self): @dataclass class MatchLayoutCriteria(TaskCriterion): - text: str = "Your response must be in the form of a {format_type}{w}{fields}" + text: str = "Your response must be in the form of a {format_type}{fields}" penalty: float = 0.1 format_type: LayoutTypeEnum = LayoutTypeEnum.JSON num_fields: int = 0 fields : str = " with {num_fields} fields" - def is_json(text): + def is_json(self, text): try: json.loads(text) return True except ValueError: return False - def is_dictionary(text): + def is_dictionary(self, text): try: if type(ast.literal_eval(text)) == dict: return True @@ -134,13 +134,13 @@ def is_dictionary(text): except (ValueError, SyntaxError): return False - def is_numbered_list(input_string): + def is_numbered_list(self, input_string): pattern = r'^\d.*\n?$' lines = input_string.split('\n') return all(re.match(pattern, line) for line in lines) - def is_bullet_point_list(input_string): - pattern = r'^\s*[-*+]\s.*\n?(\s*[-*+]\s.*\n?)*$' + def is_bullet_point_list(self, input_string): + pattern = r'^\s*[-*+]\s*.*\n?(\s*[-*+]\s.*\n?)*$' return bool(re.match(pattern, input_string)) def _get_format_match(self, response : str) -> bool: