From 0d3b8ee50dc41556e62637b5a2fd79a1492cfc51 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 21 Nov 2024 17:48:59 -0800 Subject: [PATCH 1/6] feat: dpo dataset new openai chat completion format Signed-off-by: Terry Kong --- examples/nlp/gpt/train_gpt_sft.py | 83 +------------- nemo_aligner/data/nlp/datasets.py | 103 ++++++++++++++++-- .../data/nlp/scripts/undo_special_tokens.py | 68 ++++++++++++ setup/requirements.txt | 1 + tests/test_datasets.py | 56 ++++++++++ 5 files changed, 226 insertions(+), 85 deletions(-) create mode 100644 nemo_aligner/data/nlp/scripts/undo_special_tokens.py diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index 371c0f5aa..676ff7a7c 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -39,7 +39,7 @@ resolve_and_create_trainer, retrieve_custom_trainer_state_dict, ) -from nemo_aligner.utils.utils import load_from_nemo +from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo """Script to start SFT training""" @@ -49,75 +49,10 @@ mp.set_start_method("spawn", force=True) -def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): - """ - This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg). - The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. - """ - OmegaConf.set_struct(gpt_cfg, True) - OmegaConf.resolve(cfg) - with open_dict(gpt_cfg): - gpt_cfg.megatron_amp_O2 = cfg.model.get("megatron_amp_O2", False) - gpt_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size - gpt_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size - gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) - gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None) - gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None) - gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None) - gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get( - "activations_checkpoint_layers_per_pipeline", None - ) - gpt_cfg.peft = cfg.model.peft - gpt_cfg.data = cfg.model.data - gpt_cfg.optim = cfg.model.optim - gpt_cfg.precision = cfg.trainer.precision - gpt_cfg.answer_only_loss = cfg.model.answer_only_loss - gpt_cfg.restore_from_path = cfg.model.restore_from_path - gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint - gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end - gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view - gpt_cfg.hidden_dropout = cfg.model.get("hidden_dropout", 0.0) - gpt_cfg.attention_dropout = cfg.model.get("attention_dropout", 0.0) - gpt_cfg.ffn_dropout = cfg.model.ffn_dropout - gpt_cfg.use_flash_attention = cfg.model.get("use_flash_attention", False) - # if TP/PP size is -1, use default TP/PP size as original model - if cfg.model.get("tensor_model_parallel_size", 1) > 0: - gpt_cfg.tensor_model_parallel_size = cfg.model.get("tensor_model_parallel_size", 1) - if cfg.model.get("pipeline_model_parallel_size", 1) > 0: - gpt_cfg.pipeline_model_parallel_size = cfg.model.get("pipeline_model_parallel_size", 1) - gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get("pipeline_model_parallel_split_rank", 0) - - if cfg.model.data.get("chat", False): - # chat model, overwrite the prompt template - prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens) - gpt_cfg.data.train_ds.prompt_template = prompt_template - gpt_cfg.data.validation_ds.prompt_template = prompt_template - - sft_cls = GPTSFTModel - gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" - - if cfg.model.get("use_flash_attention", None) is not None: - gpt_cfg.use_flash_attention = cfg.model.use_flash_attention - - if cfg.model.get("seq_len_interpolation_factor", None) is not None: - gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor - - if cfg.model.get("dist_ckpt_load_strictness", None) is not None: - gpt_cfg.dist_ckpt_load_strictness = cfg.model.dist_ckpt_load_strictness - - gpt_cfg.inference = cfg.model.get("inference", {}) - - # This is needed when modifying a hparam file directly to load `.ckpt` files. - # This is not needed to modify the cfg in `.nemo` files. - if add_cfg_to_tree: - OmegaConf.resolve(gpt_cfg) - gpt_cfg.cfg = gpt_cfg - - return gpt_cfg - - @hydra_runner(config_path="conf", config_name="gpt_sft") def main(cfg) -> None: + cfg.model = load_and_override_model_config(cfg.model.restore_from_path, cfg.model) + logging.info("\n\n************** Experiment configuration ***********") logging.info(f"\n{OmegaConf.to_yaml(cfg)}") @@ -129,17 +64,11 @@ def main(cfg) -> None: with open_dict(cfg): cfg.model.precision = cfg.trainer.precision - ptl_model, updated_cfg = load_from_nemo( - GPTSFTModel, - cfg, - trainer, - strict=True, - modify_config_fn=_modify_config, - restore_path=cfg.model.restore_from_path, - return_updated_cfg=True, + ptl_model = load_from_nemo( + GPTSFTModel, cfg, trainer, strict=True, restore_path=cfg.model.restore_from_path, return_updated_cfg=False, ) - init_peft(ptl_model, updated_cfg) + init_peft(ptl_model, cfg.model) with open_dict(cfg): # overwrite the model config with the config from the checkpoint diff --git a/nemo_aligner/data/nlp/datasets.py b/nemo_aligner/data/nlp/datasets.py index b605942b0..a07bf61a1 100644 --- a/nemo_aligner/data/nlp/datasets.py +++ b/nemo_aligner/data/nlp/datasets.py @@ -15,13 +15,19 @@ """Custom datasets for RLHF training""" import os +from typing import Dict, List import numpy as np import scipy import torch +from omegaconf import OmegaConf from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import _create_ltor_masks_and_position_ids -from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset +from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import ( + GPTSFTChatDataset, + _get_header_conversation_type_mask_role, + get_prompt_template_example, +) from nemo.core import Dataset from nemo.utils import logging @@ -344,16 +350,97 @@ def encode(self, text, append_eod=False): return text_ids, len(text_ids) + @staticmethod + def _convert_messages( + input_list: List[Dict[str, str]] + ) -> Dict: # TODO: (@adithyare) this method should live elsewhare.. + """ + args: + input_list: is a list of dicts in the openai format + for example: + [{"role": "system", "content": "you are helpful}, + {"role": "user", "content": "Why is the sky blue?"}, + {"role": "assistant", "content": "Because blablabla"}, + ...] + returns: + output_dict: a dict in nemo's format {"system": "sytem prompt", + "conversation": [], + ... + } + """ + output_dict = { + "system": "", + "conversations": [], + "mask": "User", + "type": "VALUE_TO_TEXT", + } + + # Extract the system message + num_system_msg = 0 + for msg in input_list: + if msg["role"] == "system": + output_dict["system"] = msg["content"] + num_system_msg += 1 + if num_system_msg > 1: + raise RuntimeError("Multiple system messages seen, please consolidate into a single system message.") + + # Build the conversations list + for msg in input_list: + if msg["role"] != "system": + conversation_entry = { + "from": msg["role"].capitalize(), # Capitalize 'user' and 'assistant' + "value": msg["content"], + "label": None, + } + output_dict["conversations"].append(conversation_entry) + + return output_dict + + def convert(self, messages): + """ + args: + messages: is a list of dicts in the openai format + for example: + [{"role": "system", "content": "you are helpful}, + {"role": "user", "content": "Why is the sky blue?"}, + {"role": "assistant", "content": "Because blablabla"}, + ...] + returns: + conversation: is a string formatted with the chat template + """ + if OmegaConf.select(self.cfg, "data.chat_prompt_tokens") is None: + raise RuntimeError( + "You don't have a model (model_config.yaml) which has chat_prompt_tokens, are you sure this is a Chat/Instruction model?" + ) + special_tokens = self.cfg.data.chat_prompt_tokens + nemo_source = self._convert_messages(messages) + header, conversation, data_type, mask_role = _get_header_conversation_type_mask_role( + nemo_source, special_tokens + ) + return conversation + def __getitem__(self, idx): """Returns a pair of chosen/rejected pairs, their respective lengths, and labels.""" payload = self.data[idx] - prompt, prompt_len = self.encode(payload["prompt"], append_eod=False) - chosen, chosen_len = self.encode( - payload["prompt"] + payload["chosen_response"], append_eod=self.cfg.data.get("append_eod", False) - ) - reject, reject_len = self.encode( - payload["prompt"] + payload["rejected_response"], append_eod=self.cfg.data.get("append_eod", False) - ) + + if isinstance(payload["prompt"], str): + # (@adithyare) format with hardcoded chat tokens + # will allow this for the time being. + prompt_fmtd = payload["prompt"] + chosen_fmtd = payload["prompt"] + payload["chosen_response"] + rejected_fmtd = payload["prompt"] + payload["rejected_response"] + logging.warning( + "Pre-formatting chat conversation as string with hardcoded chat tokens will be deprecated." + ) # (@adithyare) this will spam the console for now. + else: + prompt_fmtd = self.convert(payload["prompt"]) # (@adithyare) read var as "prompt formatted" + chosen_fmtd = self.convert(payload["prompt"] + [payload["chosen_response"]]) + rejected_fmtd = self.convert(payload["prompt"] + [payload["rejected_response"]]) + + prompt, prompt_len = self.encode(prompt_fmtd, append_eod=False) + chosen, chosen_len = self.encode(chosen_fmtd, append_eod=self.cfg.data.get("append_eod", False)) + reject, reject_len = self.encode(rejected_fmtd, append_eod=self.cfg.data.get("append_eod", False)) + # chosen_response_only, chosen_response_len = self.encode(payload['chosen_response']) # reject_response_only, reject_response_len = self.encode(payload['rejected_response']) chosen_labels = ([-100] * prompt_len) + chosen[prompt_len:] diff --git a/nemo_aligner/data/nlp/scripts/undo_special_tokens.py b/nemo_aligner/data/nlp/scripts/undo_special_tokens.py new file mode 100644 index 000000000..591fabd6a --- /dev/null +++ b/nemo_aligner/data/nlp/scripts/undo_special_tokens.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to remove special tokens from dpo datasets +and convert them into list of messages format""" + +import argparse +import json +import re + + +def format_conversation(input_string): + # Define roles and patterns + role_patterns = {"System": "system", "User": "user", "Assistant": "assistant"} + + # Initialize an empty output list + conversation = [] + + # Use regex to find each segment's role and content + segments = re.findall(r"([^\n]+)\n(.*?)((?=")[0] + rejected = rejected.split("\n")[0] + chosen_message = {"role": empty_assistant["role"], "content": chosen} + rejected_message = {"role": empty_assistant["role"], "content": rejected} + j_out = { + "prompt": undo_spl_prompt, + "chosen_response": chosen_message, + "rejected_response": rejected_message, + "chosen_reward": j["chosen_reward"], + "rejected_reward": j["rejected_reward"], + } + w.write(json.dumps(j_out) + "\n") diff --git a/setup/requirements.txt b/setup/requirements.txt index 198d2e07a..d074f3672 100644 --- a/setup/requirements.txt +++ b/setup/requirements.txt @@ -1,3 +1,4 @@ +Jinja2~=3.1.4 jsonlines megatron_core>=0.8 nemo_toolkit[nlp] diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 01425357b..701ff33e5 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -22,8 +22,16 @@ from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo_aligner.algorithms.dpo import dpo_custom_collate from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets +from nemo_aligner.data.nlp.scripts.undo_special_tokens import format_conversation from nemo_aligner.utils import parallel_state +try: + from jinja2 import Template + + HAS_JINJA2 = True +except: + HAS_JINJA2 = False + @pytest.fixture def llama3_tokenizer(): @@ -136,6 +144,54 @@ def test_dpo_loader(init_model_parallel, make_tmp_jsonl, llama3_tokenizer): assert num_mini_batches == 2 +@pytest.mark.run_only_on("GPU") +def test_dpo_dataset_conversion(): + prompt_str = """System\n\nUser\nDoes GPT-4 use RNN or Transformer models, and which one is better for this type of project?\nAssistant\nGPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.\nUser\nCould you explain in detail both the advantages and disadvantages from different perspectives?\nAssistant\nYes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.\nUser\ncould you add more in a table\nAssistant\n""" + + expected_oai_messages = [ + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Does GPT-4 use RNN or Transformer models, and which one is better for this type of project?", + }, + { + "role": "assistant", + "content": "GPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.", + }, + { + "role": "user", + "content": "Could you explain in detail both the advantages and disadvantages from different perspectives?", + }, + { + "role": "assistant", + "content": """Yes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.""", + }, + {"role": "user", "content": "could you add more in a table"}, + {"role": "assistant", "content": ""}, + ] + + oai_messages_prompt = format_conversation(prompt_str) + assert expected_oai_messages == oai_messages_prompt + + if HAS_JINJA2: + # (@adithyare) bonus test! convert oai style messages back into a string using Jinja + + def remove_trailing(s, t): + if s.endswith(t): + s = s[: -len(t)] + return s + + jinja_template = """{% for message in conversation %}{%- if message.role == "system" -%}System\n{{ message.content }}\n{% elif message.role == "user" -%}User\n{{ message.content }}\n{% elif message.role == "assistant" -%}Assistant\n{{ message.content }}\n{% endif %}{% endfor %}""" + jinja_template = Template(jinja_template) + prompt_str_jinja_rendered = jinja_template.render(conversation=oai_messages_prompt) + prompt_str_jinja_rendered = remove_trailing( + prompt_str_jinja_rendered, "\n" + ) # (@adithyare) jinja will add the ending of message token which we should remove to make a prompt. + assert prompt_str == prompt_str_jinja_rendered + + return True + + @pytest.mark.run_only_on("GPU") def test_dpo_loader_original(init_model_parallel, make_tmp_jsonl, llama3_tokenizer): init_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) From db3eb40550eaacfd3e4323c981c7793bb7db0c7d Mon Sep 17 00:00:00 2001 From: Adi Renduchintala Date: Tue, 3 Dec 2024 08:56:28 -0800 Subject: [PATCH 2/6] Update test_datasets.py --- tests/test_datasets.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 701ff33e5..d8c0e3b41 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -189,8 +189,6 @@ def remove_trailing(s, t): ) # (@adithyare) jinja will add the ending of message token which we should remove to make a prompt. assert prompt_str == prompt_str_jinja_rendered - return True - @pytest.mark.run_only_on("GPU") def test_dpo_loader_original(init_model_parallel, make_tmp_jsonl, llama3_tokenizer): From adb8130c1311e1cd8e07d33885141af7c17252d0 Mon Sep 17 00:00:00 2001 From: arendu Date: Tue, 3 Dec 2024 23:20:02 +0000 Subject: [PATCH 3/6] updated to use importskip Signed-off-by: arendu --- tests/test_datasets.py | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index d8c0e3b41..334ec05fe 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -25,13 +25,6 @@ from nemo_aligner.data.nlp.scripts.undo_special_tokens import format_conversation from nemo_aligner.utils import parallel_state -try: - from jinja2 import Template - - HAS_JINJA2 = True -except: - HAS_JINJA2 = False - @pytest.fixture def llama3_tokenizer(): @@ -173,21 +166,25 @@ def test_dpo_dataset_conversion(): oai_messages_prompt = format_conversation(prompt_str) assert expected_oai_messages == oai_messages_prompt - if HAS_JINJA2: - # (@adithyare) bonus test! convert oai style messages back into a string using Jinja - - def remove_trailing(s, t): - if s.endswith(t): - s = s[: -len(t)] - return s - - jinja_template = """{% for message in conversation %}{%- if message.role == "system" -%}System\n{{ message.content }}\n{% elif message.role == "user" -%}User\n{{ message.content }}\n{% elif message.role == "assistant" -%}Assistant\n{{ message.content }}\n{% endif %}{% endfor %}""" - jinja_template = Template(jinja_template) - prompt_str_jinja_rendered = jinja_template.render(conversation=oai_messages_prompt) - prompt_str_jinja_rendered = remove_trailing( - prompt_str_jinja_rendered, "\n" - ) # (@adithyare) jinja will add the ending of message token which we should remove to make a prompt. - assert prompt_str == prompt_str_jinja_rendered + # (@adithyare) bonus test! convert oai style messages back into a string using Jinja + # Attempt to import jinja2 via importorskip + jinja2 = pytest.importorskip("jinja2", reason="jinja2 library is not installed") + + # Now it's safe to use jinja2 + from jinja2 import Template + + def remove_trailing(s, t): + if s.endswith(t): + s = s[: -len(t)] + return s + + jinja_template = """{% for message in conversation %}{%- if message.role == "system" -%}System\n{{ message.content }}\n{% elif message.role == "user" -%}User\n{{ message.content }}\n{% elif message.role == "assistant" -%}Assistant\n{{ message.content }}\n{% endif %}{% endfor %}""" + jinja_template = Template(jinja_template) + prompt_str_jinja_rendered = jinja_template.render(conversation=oai_messages_prompt) + prompt_str_jinja_rendered = remove_trailing( + prompt_str_jinja_rendered, "\n" + ) # (@adithyare) jinja will add the ending of message token which we should remove to make a prompt. + assert prompt_str == prompt_str_jinja_rendered @pytest.mark.run_only_on("GPU") From 1d732adfd817f6d1c779040d4100e1145cfd3a08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 23:20:20 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: NeMo-Aligner CI --- tests/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 334ec05fe..79fb5e77d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -169,7 +169,7 @@ def test_dpo_dataset_conversion(): # (@adithyare) bonus test! convert oai style messages back into a string using Jinja # Attempt to import jinja2 via importorskip jinja2 = pytest.importorskip("jinja2", reason="jinja2 library is not installed") - + # Now it's safe to use jinja2 from jinja2 import Template From a76c29af3fb6a920bd213fd86ca7ffc256fc14d4 Mon Sep 17 00:00:00 2001 From: arendu Date: Wed, 4 Dec 2024 01:48:24 +0000 Subject: [PATCH 5/6] fix for batch size misconfiguration Signed-off-by: arendu --- examples/nlp/gpt/conf/gpt_sft.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/nlp/gpt/conf/gpt_sft.yaml b/examples/nlp/gpt/conf/gpt_sft.yaml index bdd757f31..b3dfd8ff8 100644 --- a/examples/nlp/gpt/conf/gpt_sft.yaml +++ b/examples/nlp/gpt/conf/gpt_sft.yaml @@ -87,6 +87,10 @@ model: attention_dropout: 0.0 ffn_dropout: 0.0 + # batch size needs to be specified here as well to resolve AssertionError: Issue with batch size configuration! error + global_batch_size: ${.data.train_ds.global_batch_size} + micro_batch_size: ${.data.train_ds.micro_batch_size} + steerlm2: forward_micro_batch_size: 1 # the micro batch size for the forward pass, used to compute the weights micro_batch_size: 1 # the steerlm2 training micro batch size From db1d5f12730d66dc2077ef39a30a36c22ff44693 Mon Sep 17 00:00:00 2001 From: Adi Renduchintala Date: Tue, 3 Dec 2024 17:56:05 -0800 Subject: [PATCH 6/6] Update gpt_sft.yaml removed comment --- examples/nlp/gpt/conf/gpt_sft.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/nlp/gpt/conf/gpt_sft.yaml b/examples/nlp/gpt/conf/gpt_sft.yaml index b3dfd8ff8..9946b094a 100644 --- a/examples/nlp/gpt/conf/gpt_sft.yaml +++ b/examples/nlp/gpt/conf/gpt_sft.yaml @@ -87,7 +87,6 @@ model: attention_dropout: 0.0 ffn_dropout: 0.0 - # batch size needs to be specified here as well to resolve AssertionError: Issue with batch size configuration! error global_batch_size: ${.data.train_ds.global_batch_size} micro_batch_size: ${.data.train_ds.micro_batch_size}