chatglm2-6b-sft

chatglm2-6b, chatglm-6b微调/LORA/推理

源码地址:https://github.com/yongzhuo/ChatGLM2-SFT

踩坑

1. torch>=2.0, 否则微调会报很多错误(单纯推理可以用低版本);
2. tokenizer.encode输出为 [gMASK, sop, 真实文本token]
    64789 = {str} '[MASK]'
    64790 = {str} '[gMASK]'
    64791 = {str} '[sMASK]'
    64792 = {str} 'sop'
    64793 = {str} 'eop'
3. modeling_chatglm.py自带get_masks()的代码full_attention_mask -= padding_mask.unsqueeze(-1) - 1改为
                full_attention_mask = full_attention_mask.long() - padding_mask.unsqueeze(-1).long() - 1
4. 不支持gradient_checkpointing, 修复的话需要modeling_chatglm.py新增get_input_embeddings, set_input_embeddings;
5. modeling_chatglm.py中的ChatGLMForConditionalGeneration类forward函数中的
      if full_attention_mask is None:  前加入  batch_size, seq_length = input_ids.shape
6. get_mask(), 一直以来都对chatglm的mask/position有一些疑惑;
    def get_masks(seq, bos_token_id):
        """  code from model_chatglm.py  """
        if seq.count(bos_token_id) == 2:
            context_length = seq[2:].index(bos_token_id) + 2
        else:
            context_length = seq.index(bos_token_id)
        attention_mask = torch.ones((1, len(seq), len(seq)))
        attention_mask.tril_()
        attention_mask[..., :context_length] = 1
        # attention_mask.unsqueeze_(1)
        attention_mask = (attention_mask < 0.5).bool()
        return attention_mask
7. 严格按照官方prompt构建输入输出:
    输入:"[Round 1]\n\n问:{}\n\n答:"
    输出:"{}"
    输入id: [gMASK, BOS, 输入tokens]
    输出id: [gMASK, BOS, 输出tokens, EOS]

环境配置

transformers==4.27.1
torch>=2.0
sentencepiece
cpm_kernels
mdtex2html
accelerate
protobuf
gradio

微调样例

地址: chatglm2_6b/ft_chatglm2

配置: chatglm2_6b/ft_chatglm2/config.py
训练: python train.py
推理: python predict.py
验证: python evaluation.py
接口: python post_api.py

实验日志

微调日志(alpaca_gpt4forall)

在这里插入图片描述

推理样例(LoRA, R=8)

在这里插入图片描述

在这里插入图片描述

参考/感谢

免责申明

本项目相关资源仅供学术研究之用,使用涉及第三方代码的部分时,请严格遵循相应的开源协议。模型生成的内容受模型计算、随机性和量化精度损失等因素影响,本项目不对其准确性作出保证。对于模型输出的任何内容,本项目不承担任何法律责任,亦不对因使用相关资源和输出结果而可能产生的任何损失承担责任。

Logo

更多推荐