LLM并行训练4-megascale论文学习

llm,megascale · 浏览次数 : 0

小编点评

这段文字似乎是对深度学习模型训练中的一些优化技术的描述,包括模型的并行化、优化器选择、以及数据加载等方面。以下是对这些技术的简要总结: 1. **算法优化与并行注意力机制**: * 串行版本和并行版本的差异:串行版本中,注意力机制后面跟着一层MLP;而并行版本中,MLP和注意力机制被合并在一起。 * 为什么可以这样合并:Palm论文中指出,在62B模型的实验中,将MLP融合到注意力机制中并不会降低性能。这可能是通过堆叠不同大小的窗口来捕获句子中的信息,从而减少计算量。 2. **AdamW优化器**: * 权重衰减项的处理:AdamW优化器将权重衰减项从梯度的计算中分离出来,并直接加在最后的权重更新步骤上。 * 新的截断函数:为了防止批量大小过大时动量出现极端值影响反向传播(BP),引入了一个截断函数。 * 批量大小扩展:论文中提到可以将批量大小增大4倍以加速训练。 3. **3D并行优化与序列并行**: * 张量并行优化:通过将张量分割到不同的设备上进行并行计算,以平摊计算和内存开销。 * 序列并行:通过将数据流分割到不同的设备上进行并行处理,以减少通信延迟。 * AllGather和Reduce-scatter操作:用于在不同设备之间同步数据和梯度。 4. **流水线优化**: * 交错式1F1B架构:在每个节点的前向传播(FP)之后,通过AllGather操作将计算结果发送到其他节点进行求和,然后在反向传播(BP)时使用Reduce-scatter操作。 * 缩短等待时间:通过将AllGather的接收和发送拆分,并优先处理接收,可以缩短等待时间,提高效率。 5. **数据加载优化**: * 同步梯度与预加载:在批量梯度同步后,可以直接释放前向相关的数据,并预加载下一轮训练所需的数据。 * 数据重用:避免在单机上重复读取相同的数据,从而减少I/O开销。 6. **网络通信优化**: * 集群容错与错误检测:通过心跳机制和中心节点的自动诊断,确保集群状态的正常,并在异常时进行自动恢复。 * Checkpoint保存与读取:通过将参数复制到内存并在另一个线程中将参数写入HDFS,可以隐藏HDFS读取的延迟。 7. **LLM状态恢复**: * 节点重入与状态替换:在节点重入后,替换原来的rank_id,以便继续训练。 * 状态监控:使用CUDA事件的时间线可视化工具来监控模型的状态。 8. **流水线并行**: * 层级并行:根据节点的rank在启动时预先分配好层次结构。 * 节点重入:节点重入后,替换原来的rank_id。 9. **其他注意事项**: * 梯度累积:在warm-up/cool-down过程中,所有节点必须等待通信完成后才能进行计算。 * 通信隐藏:在实际应用中,通信通常会被隐藏起来,以避免性能下降。 请注意,这些技术可能涉及多个论文和技术的结合,如AdamW、LAMB、AllGather、Reduce-scatter、流水线并行等。具体的实现细节和效果可能需要查阅相关论文和代码库。

正文

算法优化

并行注意力机制

\[串行版本: y = x + MLP(LayerNorm(x + Attention(LayerNorm(x)))) \]

\[并行版本: y = x + MLP(LayerNorm(x)) + Attention(LayerNorm(x)))) \]

乍一看确实不是等价的, attention那块的后置mlp去哪了..这个其实没有理论证明, Palm论文里提到把mlp融合到attention里实验62B模型上性能没有下降. 主要对应的是下图网络结构的并行化改造.

image-20240629132922611

滑动窗口Attention

image-20240629163320388

通过堆叠不同大小的窗口来捕获句子中的信息,所需要的计算量会比直接计算整个输入文本长度的计算量要小很多

滑动窗口attention的原理参考这个文章的解释:因为模型都是多层叠加的,所以层级越高,attend的视野域就越广。如果w=3,那么第一层只能注意3个位置,但到第二层能注意到第一层输出的三个位置,换算到第一层的输入,就是5个位置。所以随着层级越高,理论上每个位置注意到的区域就越大,所能存储的信息就越接近全局attention时的状态

AdamW优化(LAMB):

adamW对比adam是把权重衰减项从梯度的计算中拿出来直接加在了最后的权重更新步骤上, 为了把权重衰减和梯度计算解耦(如果加到梯度计算里会影响到动量的滑动平均), 从而提升优化效果.

image-20240629164115223

这里做的优化是新增了一个 \(\phi\)截断函数, 主要目的是为了防止batch_size太大的时候导致优化过程中动量出现极端值影响bp. 这个方法论文里说可以把batch_size增大4倍从而加速训练.

\[W_t\leftarrow W_{t-1}-\alpha \cdot\phi(\frac{||W_{t-1}||}{||r_t+\lambda W_{t-1}||})(r_t+\lambda W_{t-1}) \]

3D并行优化

张量并行优化

image-20240629151955858

序列并行(SP)主要有2个目的: 平摊LayerNorm和Dropout的计算开销, 而且Activation占用显存也很多, 能够平摊显存消耗.

[!NOTE]

这里有个疑问: LayerNorm不是要算全局均值和方差么..这个拆分后是只算该设备内部的均值还是说需要进行额外的allReduce?

AllGather优化

序列并行(SP)后, 在进行张量并行(TP)前需要在fp的时候需要先通过gather把之前层的切片从其他节点copy汇聚过来. 如果等gather完成再跑mlp和attention就会让gpu在通信这段时间空置等待, 这里可以优化成每通信完成一个切片后, 进行这个切片的MLP列切分计算, 同时直接把gather结果送给attention并行计算, 最后再把切片计算结果concat到一起. 比如在copy完A0后, A0的前向计算就和A1的通信并行起来了, 这样就能尽量的隐藏通信

另外对矩阵做切片后再进行矩阵乘法, 计算效率要也比2个超大的矩阵乘法要高.

Reduce-Scatter优化

这块是需要把汇聚计算完成的tensor在重新进行切分发送到序列并行的节点里, 这里是把MLP的第二次行切分和attention结果加和给merge到了一起, 完成一个切片的计算后就发送出去, 同步进行下一个切片的计算使计算和通信异步进行.

流水线优化

image-20240629171330492

回顾一下交错式1F1B, 每个节点fp前需要等recv之前layer的结果, 在当前层fp完后, 通过allGather send出去计算完成的数据, 在bp的时候需要通过Reduce-scatter发送出去计算完的grad.

在warm-up/cool-down过程里, 都是必须等通信完成才能进行计算的. 为了缩短等待时间megascale把allGather的recv/send拆分开, recv优先级高于send, recv后就能直接开始计算, 不需要等send的长尾. 从而缩短等待时间.

在稳定状态的时候应该和megatron一样, 通信都会和计算异步. 实际情况里通信一般都会被隐藏掉(这里我没看懂为啥上面画的对比图是个纯串行的流程)

数据加载优化

这章的主要思想工作中经常用到就不细看了, 主要有2部分:

  1. 在bp完同步梯度的时候, 所有前向相关的数据就没用了, 就可以直接释放回池预加载下一轮fp需要的embed
  2. 避免单机内多张卡重复读相同的冗余数据(这里可能指的是embed集合么?), 先在内存里去好重再copy到显存

网络通信优化

TODO待补充..网络这块基本都忘完了.

集群容错

image-20240629155710287

错误检测

主要思想和flux-cpu有很多相似点, 主要有以下几个点

  1. 每个worker定期上报心跳给中心节点, 确保当前状态正常
  2. 状态异常时的自动化诊断(NCCL allToAll, allReduce. 同主机RDMA网卡间的连接和带宽, 网卡到GPU/MEM的连接和带宽), 完成诊断后上报给中心节点.
  3. 中心节点向k8s申请失败节点的拉黑和重分配替换

状态恢复

  • checkpoint保存: 这个看着实现方法和async_patch是一样的, 先把参数copy到内存, 模型继续训练. 同步再起一个异步线程用来把内存里的参数写到hdfs. 这样就可以把非常耗时的hdfs写入给隐藏掉.
  • checkpoint读取: 主要优化手段是在同一数据并行组里的卡, 只选一个GPU对应的训练线程读hdfs后写内存, 然后通过broadcast给这个数据并行组里的其他卡. 可以降低hdfs的读取压力.

LLM的状态恢复感觉还挺复杂的, 如果有一个节点挂了在重分配后是所有节点全部回滚到上一个checkpoint还是有更快的方法..pipeline并行应该是在根据节点rank在启动的时候就分好了层, 节点重入后要替换原来的rank_id.

状态监控

基于cuda_event的timeline可视化, 算是老熟人了. 这里的难点感觉在于超多卡的实时日志收集, 根据DP来画出卡和卡的数据流依赖关系

参考:

megascale: https://arxiv.org/abs/2402.15627

Palm(并行attention): https://public.agent-matrix.com/publish/shared/Paper/Palm.pdf

滑动窗口注意力解释: https://zhuanlan.zhihu.com/p/223430086

与LLM并行训练4-megascale论文学习相似的内容:

LLM并行训练4-megascale论文学习

算法优化 并行注意力机制 \[串行版本: y = x + MLP(LayerNorm(x + Attention(LayerNorm(x)))) \]\[并行版本: y = x + MLP(LayerNorm(x)) + Attention(LayerNorm(x)))) \]乍一看确实不是等价的,

LLM并行训练3-数据并行

前置知识 混合精度训练 在参数存储时采取fp32, 开始进行fp/bp时转成fp16运算, 拿到fp16梯度后再转回fp32更新参数. ZeRO对显存占用的估算: 模型状态: Weights(fp16)、grad(fp16) 和 MasterWeights(fp32 模型参数备份),momentum

构建RAG应用-day05: 如何评估 LLM 应用 评估并优化生成部分 评估并优化检索部分

评估 LLM 应用 1.一般评估思路 首先,你会在一到三个样本的小样本中调整 Prompt ,尝试使其在这些样本上起效。 随后,当你对系统进行进一步测试时,可能会遇到一些棘手的例子,这些例子无法通过 Prompt 或者算法解决。 最终,你会将足够多的这些例子添加到你逐步扩大的开发集中,以至于手动运行

LLM实战:当网页爬虫集成gpt3.5

本文主要是通过Scrapegraph-ai集成gpt3.5实现一个简单的网页爬取并解析的demo应用,其中涉及到gpt3.5免费申请,Scrapegraph-ai底层原理简介,demo应用源码等。

解密Prompt系列31. LLM Agent之从经验中不断学习的智能体

模型想要完成自主能力进化和自主能力获得,需要通过Self-Reflection from Past Experience来实现。那如何获得经历,把经历转化成经验,并在推理中使用呢?本章介绍三种方案

解密Prompt系列29. LLM Agent之真实世界海量API解决方案:ToolLLM & AnyTool

很早之前我们就聊过ToolFormer,Gorilla这类API调用的Agent范式,这一章我们针对真实世界中工具调用的以下几个问题,介绍微调(ToolLLM)和prompt(AnyTool)两种方案。 真实世界的API数量庞大且多样:之前的多数工具调用论文,工具数量有限,工具相对简单具体,并且往往

langchain中的LLM模型使用介绍

# 简介 构建在大语言模型基础上的应用通常有两种,第一种叫做text completion,也就是一问一答的模式,输入是text,输出也是text。这种模型下应用并不会记忆之前的问题内容,每一个问题都是最新的。通常用来做知识库。 还有一种是类似聊天机器人这种会话模式,也叫Chat models。这种

mac本地搭建ollama

mac本地搭建ollama webUI *简介:ollama-webUI是一个开源项目,简化了安装部署过程,并能直接管理各种大型语言模型(LLM)。本文将介绍如何在你的macOS上安装Ollama服务并配合webUI调用api来完成聊天。 开源地址 https://github.com/812781

微软博客上几篇 Semantic-kernel (SK)文章

自从最近微软开源Semantic-kernel (SK) 来帮助开发人员在其应用程序中使用AI大型语言模型(LLM)以来,Microsoft一直在忙于改进它,发布了有关如何使用它的新指南并发布了5篇文章介绍他的功能。 开发人员可以使用Semantic-kernel (SK) 创建自然语言提示、生成响

深入探讨Function Calling:实现外部函数调用的工作原理

引言 Function Calling 是一个允许大型语言模型(如 GPT)在生成文本的过程中调用外部函数或服务的功能。 Function Calling允许我们以 JSON 格式向 LLM 模型描述函数,并使用模型的固有推理能力来决定在生成响应之前是否调用该函数。模型本身不执行函数,而是生成包含函