-
Notifications
You must be signed in to change notification settings - Fork 5
混合训练准确率下降 #13
Description
Hi @AlbertTan404 !
我下载了您提供的ckpt,并跑通了,eval的结果是

对比论文的colar-5指标,这是正常的
随后我在加载此colar-best.ckpt的基础上,尝试进行混合训练(或者说多格式对齐?) 。具体做法是 原colar的output是step -> answer。在问题不变的情况下,我为每个train.json里的样本新增了两种变体: answer -> step以及answer only。格式为
[
三个变体语义定义:
- step_answer:cot +
<think>+ answer - answer_step:answer +
<think>+ cot - answer_only:answer(严格不出现推理 token)
]
并重新写了三版对应的prompt; 注册了新的分隔符<think>
实验训练的设置如下:
[
OMP_NUM_THREADS=1 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python run.py --model=colar --dataset=qsa --devices=0 --project_root=/root/autodl-tmp/workspace/colar --artifact_root=/root/autodl-tmp/workspace --model_root=/root/autodl-tmp/workspace/models --dataset_root=/root/autodl-tmp/workspace/data/text_reasoning --load_ckpt_path=/root/autodl-tmp/workspace/ckpt/logs/colar/qsa-gsm/colar-final/checkpoints/colar_best.ckpt --do_test --log_suffix=aug433_promptfix_v2_bs8x4 dataset_name=gsm enable_mixed_train=True batch_size=8 val_batch_size=8 accumulate_grad_batches=4 max_epochs=12 num_workers=4 > /root/autodl-tmp/workspace/output/aug433_promptfix_v2_run.log 2>&1
]
对于设置,由于batchsize =32 oom了,我便用batch_size=8 + accumulate_grad_batches=4 等效 32
结果是
[
Test results: defaultdict(<class 'list'>, {'monitor': [0.0803639143705368, 0.07126610726118088, 0.06823351234197617, 0.07354056090116501, 0.07278241217136383], 'test/acc': [0.0803639143705368, 0.07126610726118088, 0.06823351234197617, 0.07354056090116501, 0.07278241217136383], 'test/n_latent_forward': [39.329795837402344, 38.274452209472656, 40.071266174316406, 39.56937026977539, 39.01819610595703], 'test/n_latent_forward_on_acc': [14.899672508239746, 14.629264831542969, 14.66464614868164, 16.517562866210938, 12.629264831542969], 'test/output_length': [4.0644426345825195, 4.033358573913574, 4.018195629119873, 3.980288028717041, 4.078847408294678]})
Test statistics with 5 replications: {'monitor': (0.07323730140924453, 0.003506395940529609), 'test/acc': (0.07323730140924453, 0.003506395940529609), 'test/n_latent_forward': (39.25261611938477, 0.5242016523301881), 'test/n_latent_forward_on_acc': (14.668082237243652, 1.0829781426128342), 'test/output_length': (4.0350264549255375, 0.030535620977975873)}
]
准确率仅为0.07,latent_forward却飙升至39
这是否是因为我epoch太少 batchsize太小,没有收敛导致的?还是说设计的 token与原分隔符'###'产生了干扰?