聊聊GLM-4-9B开源模型的微调loss计算

glm,9b,loss · 浏览次数 : 0

小编点评

GitHub官方地址:GLM-4 在网上,关于微调的文章已经很多了。这些文章主要介绍了各种微调方法及其应用场景,不过很少有文章深入地讨论微调时的loss计算逻辑。 个人认为,虽然大多数人都关注如何使用微调模型,但了解其底层计算过程同样重要,尤其是在优化和调试时。 为了了解更多关于微调和loss计算的知识,可以参考以下文章: 1. 再聊多轮对话微调训练格式与长序列训练 2. 聊天GLM2与ChatGLM3微调多轮对话的设计逻辑及源码分析 3. 聊天大模型多轮对话的训练及优化 4. 微调格式:[{“messages”:[{"role”: “system”, “content”: “<system prompt text>”, “tools”: [{"name”: “<tool name>”, “args”: {“<arg name>”: “<arg value>”}}]}, {"role”: “user”, “content”: “<user prompt text>”}, {"role”: “assistant”, “content”: “<assistant response text>”}, {"role”: “user”, “content”: “<user prompt text>”}, {"role”: “assistant”, “content”: “<assistant response text>”}, {"role”: “observation”, “content”: “<observation prompt text>”}, {"role”: “assistant”, “content”: “<assistant response observation>”}, {"role”: “user”, “content”: “<user prompt text>”}, {"role”: “assistant”, “content”: “<assistant response text>”}]微调源码地址:finetune.py Loss计算代码: ```python def process_batch( batch: Mapping[str, Sequence], tokenizer: PreTrainedTokenizer, max_input_length: int, max_output_length: int, ) -> dict[str, list]: batched_conv = batch['messages'] batched_input_ids = [] batched_labels = [] # batched_conv 是一个数组 # conv 是数组内的单个 message for conv in batched_conv: input_ids = [151331, 151333] loss_masks = [False, False] # conv 是数组内的单个 message # message 是 单个role json对象 for message in conv: message = process_message(message) # 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算 loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True # 获取 input 文本的数字表示(ids) new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:] # 计算整句的 mask new_loss_masks = [loss_mask_val] * len(new_input_ids) # 拼接message中的每段json input_ids += new_input_ids # 拼接message中每段json对应的mask loss_masks += new_loss_masks # 追加结尾的 token id input_ids.append(tokenizer.eos_token_id) loss_masks = [False, *loss_masks] labels = [] for input_id, mask in zip(input_ids, loss_masks): if mask: # 添加到label,计算loss labels.append(input_id) else: # -100 不处理,即ignore_index labels.append(-100) max_length = max_input_length + max_output_length + 1 # 截断 batched_input_ids.append(input_ids[:max_length]) batched_labels.append(labels[:max_length]) return {'input_ids': batched_input_ids, 'labels': batched_labels} ``` 注释在代码中已经写明。 process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示: ```python tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config) data_manager = DataManager(data_dir, ft_config.data_config) # 数据集拆分遍历 train_dataset = data_manager.get_dataset( Split.TRAIN, functools.partial(process_batch, tokenizer=tokenizer, max_input_length=ft_config.max_input_length, max_output_length=ft_config.max_output_length), batched=True, ) print('train_dataset:', train_dataset) ``` Loss计算如下图所示: 总结:相较于之前的ChatGLM版本,GLM4开源版本的多轮对话loss计算更恰当且效率也会更高;在其他开源模型/微调框架中,例如InternLM、XTuner、Firefly等,已经支持这种loss计算。对于loss格式的类别,可参考XTuner的官方文档说明:[dataset_format.md](https://mp.weixin.qq.com/s/0mLCQfpaZr7eEonG4a4Etg)。

正文

概述

Github官方地址:GLM-4

网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。

可了解其它loss计算的文章:
再聊多轮对话微调训练格式与长序列训练
聊聊ChatGLM2与ChatGLM3微调多轮对话的设计逻辑及源码分析
聊聊大模型多轮对话的训练及优化

微调

微调格式:

[
  {
    "messages": [
      {
        "role": "system",
        "content": "<system prompt text>",
        "tools": [
          {
            "name": "<tool name>",
            "args": {
              "<arg name>": "<arg value>"
            }
          }
        ]
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      },
      {
        "role": "observation",
        "content": "<observation prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response observation>"
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      }
    ]
  }
]

微调源码地址:finetune.py
Loss计算代码:

def process_batch(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_labels = []
    # batched_conv 是一个数组
    # conv 是数组内的单个 message
    for conv in batched_conv:
        input_ids = [151331, 151333]
        loss_masks = [False, False]
        # conv 是数组内的单个 message
        # message 是 单个role json对象
        for message in conv:
            message = process_message(message)
            # 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算
            loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
            # 获取 input 文本的数字表示(ids)
            new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
            # 计算整句的 mask
            new_loss_masks = [loss_mask_val] * len(new_input_ids)
            # 拼接message中的每段json
            input_ids += new_input_ids
            # 拼接message中每段json对应的mask
            loss_masks += new_loss_masks
        # 追加结尾的 token id
        input_ids.append(tokenizer.eos_token_id)
        loss_masks = [False, *loss_masks]
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                # 添加到label,计算loss
                labels.append(input_id)
            else:
                # -100 不处理,即ignore_index
                labels.append(-100)
        max_length = max_input_length + max_output_length + 1
        # 截断
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])
    return {'input_ids': batched_input_ids, 'labels': batched_labels}

注释在代码中已经写明。process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:

tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
# 数据集拆分遍历
train_dataset = data_manager.get_dataset(
    Split.TRAIN,
    functools.partial(
        process_batch,
        tokenizer=tokenizer,
        max_input_length=ft_config.max_input_length,
        max_output_length=ft_config.max_output_length,
    ),
    batched=True,
)
print('train_dataset:', train_dataset)

Loss计算如下图所示:

总结

相比较于之前的ChatGLM版本,GLM4开源版本的多轮对话loss计算更恰当且效率也会更高;在其它的开源模型/微调框架中早已支持该种loss计算,如InternLM、XTuner、Firefly等。对于loss格式的类别,可参考XTuner的官方文档说明:dataset_format.md

原文链接:https://mp.weixin.qq.com/s/0mLCQfpaZr7eEonG4a4Etg

更多大模型相关的文章,请上个人公众号查阅:
image

与聊聊GLM-4-9B开源模型的微调loss计算相似的内容:

聊聊GLM-4-9B开源模型的微调loss计算

概述 Github官方地址:GLM-4 网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。 可了解其它loss

聊聊语言模型与知识图谱

## 语言模型 语言模型泛指:大语言模型LLM、通用模型GLM。 语言模型也是知识库。基于语言模型下的实现,比如ChatGPT,BERT,ChatGLM等等,这类知识库就像是已经人为处理好、编排好、可直接使用的知识库。 ## 知识图谱 知识图谱的定义由Google公司在2012年提出,被界定为用来提

聊聊一个差点被放弃的项目以及近期的开源计划

前言 自从 StarBlog 和 SiteDirectory 之后,我还没写新的关于开源项目的系列,最近又积累了很多想法,正好写一篇博客来总结一下。 关于差点被放弃的项目,就是最近一直在做的单点认证(IdentityServerLite) IdentityServerLite 开发这个项目的起因,是

聊聊 JSON Web Token (JWT) 和 jwcrypto 的使用

哈喽大家好,我是咸鱼。 最近写的一个 Python 项目用到了 jwcrypto 这个库,这个库是专门用来处理 JWT 的,JWT 全称是 JSON Web Token ,JSON 格式的 Token。 今天就来简单入门一下 JWT。 官方介绍:https://jwt.io/introduction

聊聊MySQL是如何处理排序的

在MySQL的查询中常常会用到 order by 和 group by 这两个关键字,它们的相同点是都会对字段进行排序,那查询语句中的排序是如何实现的呢?

聊聊 Linux iowait

哈喽大家好,我是咸鱼。 我们在使用 top 命令来查看 Linux 系统整体 CPU 使用情况的时候,往往看的是下面这一列: %Cpu(s): 0.0 us, 0.0 sy, 0.0 ni,100.0 id, 68.0 wa, 0.0 hi, 0.0 si, 0.0 st 其中,man 手册解释 w

聊聊Mybatis框架原理

好久没有写博客了。最近工作中封装了一个类似ORM框架的东西。大概的原理就是将Excel数据初始化到本地sqlite数据库后,通过json配置文件,对数据库的数据做增删改查等操作。 其实大概的思考了下,就是半ORM框架mybatis的逻辑,只是我们自己封装的简陋蛮多。想想有现成的轮子没用,反而是自己写

聊聊Spring的工厂方法与FactoryBean

概述 工厂方法是比较常见,常用的一种设计模式。FactoryBean是Spring提供的一种Bean注入IOC容器的方式。 工厂方法 在做日常开发时,一般都会避免直接new对象,而且将new的操作丢给IOC容器,但对于第三方系统的集成,我们不太好直接丢给IOC容器,此时可以通过工厂模式, 提供一个工

聊聊Spring Cloud Alibaba解决方案组件

在java的微服务解决方案中,最先出现目前应用比较多的就是spring cloud netfix系列,但是随着阿里的强劲支持,spring cloud alibaba解决方案逐渐可以替代前者,当然dubbo也是不容小觑的。之前面试几家公司应用的都是spring cloud alibaba,随着我自己

聊聊Spring Cloud Alibaba Sentinel的限流

Spring Cloud Alibaba Sentinel限流功能概览,目前先整理一版,东西有点多,想慢慢打开;后续继续更新......