-
Notifications
You must be signed in to change notification settings - Fork 5
Clarification about Hyperparameters and CoLar stability #12
Description
Thanks for the awesome work. I reviewed the code and the different training settings, and I trained a Llama 1b-instruct using CoT to replicate some CoLar results. I used GSM8K datasets; my CoT model has around 40% val acc.
However, I tried to use the math Colar SFT parameters for math you shared here.
https://huggingface.co/AlbertTan/CoLaR/blob/main/logs/colar/qsa-math/colar-math-sft-llama/hparams.yaml
Unfortunately, CoLar is unstable during training and stays at random guess performance even after 10 epochs.
Could you provide some guidelines on setting the parameters of the latent head and loss relative weights?
These are my current parameters.
latent_cot_config:
ce_weight: 1
embed_modeling_weight: 1
embed_modeling_loss: mse # {nll, mse}
entropy_weight: -1e-6
pred_embed_forward_weight: 1
max_compression_factor: 5
pred_compressed_cot: True
sqrt_mean: True
GSM8k
CoT checkpoint as initialization.
Thank you in advance.