From 63a60a0968f1a267d2fbd60e4dad0142908546c7 Mon Sep 17 00:00:00 2001 From: jamari-morrison Date: Tue, 25 Feb 2025 01:04:04 +0000 Subject: [PATCH 1/3] gsm8k evals --- moondream/eval/gsm8k.py | 138 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 moondream/eval/gsm8k.py diff --git a/moondream/eval/gsm8k.py b/moondream/eval/gsm8k.py new file mode 100644 index 00000000..ebf9d4d0 --- /dev/null +++ b/moondream/eval/gsm8k.py @@ -0,0 +1,138 @@ +import argparse +import re +from datasets import load_dataset +from tqdm import tqdm +import torch + +from ..torch.config import MoondreamConfig +from ..torch.moondream import MoondreamModel +from ..torch.weights import load_weights_into_model + +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def parse_answer(text): + """ + Extracts the last number (integer or float) from a string. + + Args: + text (str): The input string to parse + + Returns: + float or int: The last number found in the string + None: If no number is found + """ + # Find all numbers (integers or floats) in the string + # This regex matches integers and decimal numbers (with or without leading zeros, ignores commas ) + numbers = re.findall(r'[-+]?\d*\.\d+|[-+]?\d+', text.replace(",", "")) + + if not numbers: + return None + + # Get the last number found + last_number = numbers[-1] + + # Convert to the appropriate type (int or float) + if '.' in last_number: + return float(last_number) + else: + return int(last_number) + + +def eval_gsm8k(model, debug=False): + """Evaluate the model on the GSM8K dataset.""" + # Load the GSM8K test dataset + gsm8k_test = load_dataset("openai/gsm8k", "main", split="test") + + correct = 0 + total = 0 + results = [] + + for row in tqdm(gsm8k_test, disable=debug, desc="GSM8K"): + + question = row["question"] + + + # Extract the ground truth answer from the answer field, just the number + gt_answer = row["answer"].split("####")[-1].strip() + + if gt_answer is None or not gt_answer: + if debug: + print(f"Warning: Could not parse ground truth answer from: {row['answer']}") + continue + + # Encode the question for the model + model_response = model._text_query(question)["answer"] + + model_answer = parse_answer(model_response) + + # Convert to float for comparison (handling both integers and decimals) + try: + gt_answer_float = float(gt_answer) + if model_answer is not None: + try: + model_answer_float = float(model_answer) + is_correct = abs(model_answer_float - gt_answer_float) < 1e-6 + except: + is_correct = False + print(f'failed to compute model answer float: {model_answer}, slotting in large negative.') + else: + is_correct = False + except ValueError: + is_correct = False + + if is_correct: + correct += 1 + total += 1 + + result = { + "question": question, + "ground_truth": gt_answer, + "model_response": model_response, + "model_answer": model_answer, + "correct": is_correct + } + results.append(result) + + if debug: + print(f"Question: {question}") + print(f"Ground Truth Answer: {gt_answer}") + print(f"Model Response: {model_response}") + print(f"Model Answer: {model_answer}") + print(f"Correct: {is_correct}") + print(f"Current Accuracy: {correct/total:.4f}") + print("---------") + + accuracy = correct / total if total > 0 else 0 + + return { + "accuracy": accuracy, + "correct": correct, + "total": total, + "results": results + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + + if torch.cuda.is_available(): + torch.set_default_device("cuda") + elif torch.backends.mps.is_available(): + torch.set_default_device("mps") + + config = MoondreamConfig() + model = MoondreamModel(config) + + + load_weights_into_model(args.model, model) + + # Compile omitted to make text only work + # model.compile() + + result = eval_gsm8k(model, args.debug) + + print(f"Accuracy: {result['accuracy']:.4f} ({result['correct']}/{result['total']})") \ No newline at end of file From 8ded1d64df4a90df28cf0a3e0b6c7cec0d954320 Mon Sep 17 00:00:00 2001 From: jamari-morrison Date: Tue, 25 Feb 2025 01:05:01 +0000 Subject: [PATCH 2/3] text only query support --- moondream/torch/moondream.py | 37 +++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index 6baf2d1b..12bb722f 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -314,6 +314,41 @@ def generator(): else: return {"answer": "".join(list(generator()))} + + def _text_query( + self, + question: str, + stream: bool = False, + settings: Optional[SamplingSettings] = None, + ): + if self.config.tokenizer.templates["query"] is None: + raise NotImplementedError("Model does not support querying.") + + + prompt_tokens = torch.tensor( + [ + self.config.tokenizer.templates["query"]["prefix"] + + self.tokenizer.encode(question).ids + + self.config.tokenizer.templates["query"]["suffix"] + ], + device=self.device, + ) + + max_tokens = DEFAULT_MAX_TOKENS + if settings: + max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS) + + pos = 0 + + def generator(): + for token in self._generate_text(prompt_tokens, pos, max_tokens): + yield token + + if stream: + return {"answer": generator()} + else: + return {"answer": "".join(list(generator()))} + def load_encoded_image(self, encoded_image: EncodedImage): for b, (k, v) in zip(self.text.blocks, encoded_image.caches): b.kv_cache.k_cache[:, :, : k.size(2), :] = k @@ -622,4 +657,4 @@ def detect_gaze( sum(gaze[1] for gaze in detections) / len(detections), ) - return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}} + return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}} \ No newline at end of file From c6992e780d305056827f69844a6acdcd4968ba1f Mon Sep 17 00:00:00 2001 From: jamari-morrison Date: Tue, 25 Feb 2025 01:18:27 +0000 Subject: [PATCH 3/3] apply formatting --- moondream/eval/gsm8k.py | 52 +++++++++++++++++++----------------- moondream/torch/moondream.py | 6 ++--- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/moondream/eval/gsm8k.py b/moondream/eval/gsm8k.py index ebf9d4d0..0c542e08 100644 --- a/moondream/eval/gsm8k.py +++ b/moondream/eval/gsm8k.py @@ -14,26 +14,26 @@ def parse_answer(text): """ Extracts the last number (integer or float) from a string. - + Args: text (str): The input string to parse - + Returns: float or int: The last number found in the string None: If no number is found """ # Find all numbers (integers or floats) in the string # This regex matches integers and decimal numbers (with or without leading zeros, ignores commas ) - numbers = re.findall(r'[-+]?\d*\.\d+|[-+]?\d+', text.replace(",", "")) - + numbers = re.findall(r"[-+]?\d*\.\d+|[-+]?\d+", text.replace(",", "")) + if not numbers: return None - + # Get the last number found last_number = numbers[-1] - + # Convert to the appropriate type (int or float) - if '.' in last_number: + if "." in last_number: return float(last_number) else: return int(last_number) @@ -52,20 +52,21 @@ def eval_gsm8k(model, debug=False): question = row["question"] - # Extract the ground truth answer from the answer field, just the number gt_answer = row["answer"].split("####")[-1].strip() - + if gt_answer is None or not gt_answer: if debug: - print(f"Warning: Could not parse ground truth answer from: {row['answer']}") + print( + f"Warning: Could not parse ground truth answer from: {row['answer']}" + ) continue - + # Encode the question for the model model_response = model._text_query(question)["answer"] model_answer = parse_answer(model_response) - + # Convert to float for comparison (handling both integers and decimals) try: gt_answer_float = float(gt_answer) @@ -75,25 +76,27 @@ def eval_gsm8k(model, debug=False): is_correct = abs(model_answer_float - gt_answer_float) < 1e-6 except: is_correct = False - print(f'failed to compute model answer float: {model_answer}, slotting in large negative.') + print( + f"failed to compute model answer float: {model_answer}, slotting in large negative." + ) else: is_correct = False except ValueError: is_correct = False - + if is_correct: correct += 1 total += 1 - + result = { "question": question, "ground_truth": gt_answer, "model_response": model_response, "model_answer": model_answer, - "correct": is_correct + "correct": is_correct, } results.append(result) - + if debug: print(f"Question: {question}") print(f"Ground Truth Answer: {gt_answer}") @@ -102,14 +105,14 @@ def eval_gsm8k(model, debug=False): print(f"Correct: {is_correct}") print(f"Current Accuracy: {correct/total:.4f}") print("---------") - + accuracy = correct / total if total > 0 else 0 - + return { "accuracy": accuracy, "correct": correct, "total": total, - "results": results + "results": results, } @@ -118,21 +121,20 @@ def eval_gsm8k(model, debug=False): parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() - + if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") - + config = MoondreamConfig() model = MoondreamModel(config) - load_weights_into_model(args.model, model) # Compile omitted to make text only work # model.compile() result = eval_gsm8k(model, args.debug) - - print(f"Accuracy: {result['accuracy']:.4f} ({result['correct']}/{result['total']})") \ No newline at end of file + + print(f"Accuracy: {result['accuracy']:.4f} ({result['correct']}/{result['total']})") diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py index 12bb722f..d391d021 100644 --- a/moondream/torch/moondream.py +++ b/moondream/torch/moondream.py @@ -314,7 +314,6 @@ def generator(): else: return {"answer": "".join(list(generator()))} - def _text_query( self, question: str, @@ -324,7 +323,6 @@ def _text_query( if self.config.tokenizer.templates["query"] is None: raise NotImplementedError("Model does not support querying.") - prompt_tokens = torch.tensor( [ self.config.tokenizer.templates["query"]["prefix"] @@ -337,7 +335,7 @@ def _text_query( max_tokens = DEFAULT_MAX_TOKENS if settings: max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS) - + pos = 0 def generator(): @@ -657,4 +655,4 @@ def detect_gaze( sum(gaze[1] for gaze in detections) / len(detections), ) - return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}} \ No newline at end of file + return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}