ChrisKim
Do not go gentle into that good night.
颢天

基于 Swift/Megatron 完成多模态大模型分布式训练

近期由于项目需要训练大模型,学习了基于 Swift/Megatron 的大模型分布式训练技术,发现细节不少,踩了不少坑,在此记录希望能帮到大家。

本文的最重要关键词就是分布式,毕竟模型训练谁不会呢,拿 transformers 库就能训下来了,在多机平台上如何跑起高效的训练才是难点。本文将会介绍如何在多台 NVIDIA H200 141GB 节点上,基于 ms-swift/megatron 框架完成 Qwen3-VL-8B-Thinking 的训练。

另外,本文不是零基础教程,都接触到分布式训练了,基础概念本文就不再解释了。

1 Swift 训练

在多机条件下,每个节点的环境最好做到完全一致,因此基于 Docker 容器的训练环境部署是最佳选择。当前 ms-swift 的最新版本是 4.1.2,读者可选择阅读时的最新版。

启动容器

在每台机器上运行以下指令启动容器即可,-v 的目录挂载按自己数据来选择。

sudo docker run -it \
    --gpus all \
    --shm-size=64g \
    --network=host \
    -v /data:/data \
    modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.9.1-py312-torch2.10.0-vllm0.19.0-modelscope1.35.4-swift4.1.2 \
    /bin/bash

多机并行环境变量

在完成容器启动后,接下来就是配置多节点并行需要的环境变量了。对于 ms-swift,多节点训练是使用 torchrun 来完成的,因此对 torchrun 熟悉的可以发现这些环境变量就是一模一样的。

环境变量解释示例
NNODES总节点数2
NODE_RANK当前节点编号(0 到 NNODES-1)0
MASTER_ADDR主节点 IP 地址10.112.36.7
MASTER_PORT主节点端口29500

因此,假如有 4 台机器 IP 分别为 10.112.36.7, 10.112.36.8, 10.112.36.9, 10.112.36.10,选定第一个为主节点,那么每个节点的环境变量便为:

# 10.112.36.7
export NNODES=4
export NODE_RANK=0
export MASTER_ADDR=10.112.36.7
export MASTER_PORT=29500

# 10.112.36.8
export NNODES=4
export NODE_RANK=1
export MASTER_ADDR=10.112.36.7
export MASTER_PORT=29500

# 10.112.36.9
export NNODES=4
export NODE_RANK=2
export MASTER_ADDR=10.112.36.7
export MASTER_PORT=29500

# 10.112.36.10
export NNODES=4
export NODE_RANK=3
export MASTER_ADDR=10.112.36.7
export MASTER_PORT=29500

请注意,两台机器必须要能够通过该主节点 IP 成功建立 TCP 连接,也就是说必须保证没有防火墙或者什么东西阻止其他节点访问 10.112.36.7:29500 .

启动训练

在每台机器上配置环境变量后,便可以基于 ms-swift 的命令行启动训练了,参数什么的看 ms-swift 的文档就可以了,这里不再讲解。

PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
OMP_NUM_THREADS=16 \
IMAGE_MAX_TOKEN_NUM=1280 \
swift pt \
    --tuner_type full \
    --ddp_backend nccl \
    --ddp_timeout 8640000 \
    --model /data/models/Qwen3-VL-8B-Thinking \
    --attn_impl flash_attn \
    --dataset /data/train_data/pt_*.jsonl \
    --dataset_num_proc 16 \
    --remove_unused_columns true \
    --max_length 20480 \
    --padding_free true \
    --output_dir /data/train_output/debug-qwen3-vl-8b-thinking-swift \
    --gradient_checkpointing true \
    --deepspeed zero1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --learning_rate 1e-5 \
    --report_to tensorboard \
    --logging_steps 1 \
    --num_train_epochs 1 \
    --save_steps 1000 \
    --save_total_limit 10 \
    --warmup_ratio 0.01 \
    --packing true \
    --save_only_model false \
    --dataloader_num_workers 64

以上示例应该是可以正常在 H200 上跑起来的,单卡显存 69GB/141GB。但是很有可能你已经遇到了问题,因为这里有非常多易错坑点

2 易错坑点

2.1 未开启序列打包 (packing)

最开始由于 Swift 版本兼容性的问题,--packing 参数一直开不起来,当时想着可能问题不大于是就没管了,没想到跑起来发现训练时间居然要 3000 多个小时!

原因是我的数据集虽然有 85M 条目,但只有 34B Tokens,平均长度仅 400,但最大长度有 20K,设定 max_length=20k,这让计算卡计算时绝大多数 Token 都被 padding 浪费掉了。

  • 不开序列打包:85M 条目,每条 5s,32 并行,共计 3689 小时。
  • 开启序列打包:34B Tokens,打包成约 20k 长度后有 1.7M 条目,每条 5s,32 并行,共计 147 小时。

可以看到开启序列打包后加速了 25 倍。

2.2 未缓存多模态数据集

在调试时就发现多模态数据集加载真的慢得离谱,而且每次启动训练都要等待预处理非常之久。经过调研才知道 Swift 支持预先处理数据集并缓存到本地,再通过 --cached_dataset 加载。使用以下指令预处理数据集:

OMP_NUM_THREADS=16 \
MAX_PIXELS=1048576 \
swift export \
    --model /data/models/Qwen3-VL-8B-Thinking \
    --dataset /data/train_data/pt/pt_00001.jsonl \
    --split_dataset_ratio 0 \
    --dataset_num_proc 16 \
    --to_cached_dataset true \
    --output_dir /data/train_data/pt/cached_pt_00001/

这里还包含一个坑点,一定要记得设定 OMP_NUM_THREADS 为一个比较小的合适值!不设定的默认值为 CPU 核心数量,如此大的线程数使得 CPU 资源几乎全花在了线程切换上,导致 CPU 全核 100% 但是处理速度只有 40条/s。经过修改后,CPU 总占用只有 5~10%,但速度来到了 3000~4000条/s,速度显著提升。(该问题官方文档里也写了的)

我这里有 85M 条目 34B Tokens 的多模态预训练数据,预处理耗时约 6 小时。完成预处理后,之后训练时使用 --cached_dataset 代替 --dataset 参数即可快速加载,几乎是瞬间加载。不过,--packing 操作是在加载数据集后进行处理的,每次加载仍然要等数据集打包。不过这个速度非常快,我的 85M 条目 34B Tokens 数据集只需要花 10 分钟就可以完成数据集打包。

https://assets.zouht.com/img/blog/4276-05.webp

2.3 未使用 InfiniBand 高速网卡

在初期测试时,完全没有考虑到节点互联速度这回事,因此就按上面的方式启动了训练,但经过了效率实验之后发现了奇怪的事情:

  • 1*8 卡 NVIDIA H200 速度:4.76s/it
  • 2*8 卡 NVIDIA H200 速度:20.71s/it
  • 4*8 卡 NVIDIA H200 速度:27.75s/it

要知道,以上示例跑的是纯 Deepspeed 也就相当于 DDP,那么该实验结果的意思便是:1 台需要 5 秒处理 8 条数据,2 台需要 20 秒处理 16 条数据,4 台需要 27 秒处理 32 条数据。并行越多反而越慢了!

此时我才意识到,单机和多机最大的区别就是互联速度,在单机条件下,PCIe 5.0 可以提供 512Gbps (64GB/s) 的互联带宽,若使用 NVLink 甚至可以达到 7200Gbps (900GB/s)。但当我查看我使用的网卡时才发现,这四台节点的以太网卡只有 25Gbps (3.1GB/s) 的速率,和 PCIe 都有高达 20 倍的差距

使用指令 sar -n DEV 1 查看网络设备占用情况,果真如此,25Gbps 以太网卡都跑到满载了(下图网卡为空载,仅为示例,当时训练时显示 80% 多)。

https://assets.zouht.com/img/blog/4276-01.webp

另外我发现,机上实际上是有 400Gbps 的 InfiniBand 高速网卡的,使用 ibstat 指令可以查看当前机器的 InfiniBand 网卡情况,发现有八个 mlx5_* 开头的 400Gbps 端口。

https://assets.zouht.com/img/blog/4276-02.webp

解决方案

要使用 InfiniBand 网卡,需要使用以下指令:

sudo docker run -it \
    --privileged \
    --gpus all \
    --shm-size=64g \
    --network=host \
    -v /dev/infiniband:/dev/inifiniband \
    -v /data:/data \
    modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.9.1-py312-torch2.10.0-vllm0.19.0-modelscope1.35.4-swift4.1.2 \
    /bin/bash

首先是 --privileged 参数,它提供了容器 CAP_IPC_LOCK 权限,因为 InfiniBand 这种 RDMA 网卡需要直接通过 DMA 读写物理内存。也可以通过 --cap-add=IPC_LOCK 参数实现。然后就是通过 -v /dev/infiniband:/dev/inifiniband 把设备映射到容器内,这个好理解。

完成以上配置后,在训练时还得添加以下环境变量:

环境变量解释示例
NCCL_DEBUG输出日志,调通后就可以删了INFO
NCCL_IB_DISABLE是否禁用 InfiniBand0
NCCL_IB_HCAInfiniBand 使用的 RDMA 通信接口名mlx5
NCCL_SOCKET_IFNAMEInfiniBand 使用的 IP 通信接口名bond0
  • NCCL_IB_HCA 可以直接填写 ibstat 指令输出的网卡前缀,这样 NCCL 会自动寻找并选择可用网卡,你也可以选择指定好端口例如 mlx5_0,mlx5_2
  • NCCL_SOCKET_IFNAME 这里需要填写有 IP 地址,可以互联的以太网接口,也就是后面 MASTER_ADDR 那个 IP 地址所在的网卡名。

综上,训练指令为(每个节点仍然要配置 NNODES 等环境变量,和第 1 章所述相同)

NCCL_DEBUG=INFO \
NCCL_IB_DISABLE=0 \
NCCL_IB_HCA=mlx5 \
NCCL_SOCKET_IFNAME=bond0 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
OMP_NUM_THREADS=16 \
IMAGE_MAX_TOKEN_NUM=1280 \
swift pt \
    ......

此时再观察 sar -n DEV 1 就可以发现以太网卡几乎没有负载了,训练时间也正常了。

  • 1*8 卡 NVIDIA H200 速度:4.76s/it
  • 2*8 卡 NVIDIA H200 速度:5.36s/it
  • 3*8 卡 NVIDIA H200 速度:6.22s/it
  • 4*8 卡 NVIDIA H200 速度:[待测试]
https://assets.zouht.com/img/blog/4276-03.webp

2.4 设定的 shm-size 过小

该问题在小规模测试时一般不会触发,但在正式训练时很可能就爆了,报错为 uable to allocate shared memory(shm),如图:

https://assets.zouht.com/img/blog/4276-04.webp

这种情况就是启动 Docker 容器时设定的 shm-size 过小,提高即可,我这里提升到 512GB 即可解决该问题。

同时也可以选择直接与宿主机共享 IPC,使用 --ipc=host 参数,让容器直接使用宿主机所有的共享内存空间,不再受限于 Docker 内部的配额。

3 Megatron 训练

如果把 Swift 训练调通且速度正常之后,Megatron 的训练基本上就水到渠成了。Megatron 框架相对于 Swift 提供了更加高级的并行方式,得以训练更大的模型,同时能够更高效利用硬件。

Megatron 训练同样使用 Swift 框架,因为 Swift 在近期已完成对 Megatron 的适配,因此使用和第 1 章同样的指令启动容器。

转换模型

不同的是,Megatron 需要使用自己独有的模型格式,而不是 Hugginface Transformers 的格式,因此训练前首先需要转换模型:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
swift export \
    --model /data/models/Qwen3-VL-8B-Thinking \
    --to_mcore true \
    --torch_dtype bfloat16 \
    --output_dir /data/models/Qwen3-VL-8B-Thinking-mcore \
    --test_convert_precision true

注意,不同版本 Megatron 转出来的模型可能不通用,在期间我升级了一次 Docker 镜像版本,就会发现之前转换的模型训练会报错。

启动训练

然后就可以开始训练了,多机并行的环境变量同第 1 章所述,训练的启动指令修改为如下:

NCCL_IB_DISABLE=0 \
NCCL_IB_HCA=mlx5 \
NCCL_SOCKET_IFNAME=bond0 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
OMP_NUM_THREADS=16 \
IMAGE_MAX_TOKEN_NUM=1280 \
megatron pt \
    --micro_batch_size 1 \
    --global_batch_size 4 \
    --recompute_granularity full \
    --recompute_method uniform \
    --recompute_num_layers 1 \
    --num_train_epochs 1 \
    --cross_entropy_loss_fusion true \
    --attention_backend flash \
    --optimizer adam \
    --optimizer_cpu_offload false \
    --seed 42 \
    --cached_dataset /data/train_data/pt/cached_pt_*/train \
    --load_from_cache_file true \
    --dataloader_num_workers 64 \
    --dataset_num_proc 16 \
    --max_length 20480 \
    --packing true \
    --group_by_length false \
    --lr 1e-5 \
    --lr_warmup_fraction 0.01 \
    --min_lr 1e-6 \
    --weight_decay 0.1 \
    --clip_grad 1.0 \
    --output_dir /data/train_output/debug-qwen3-vl-8b-thinking-megatron \
    --mcore_model /data/models/Qwen3-VL-8B-Thinking-mcore \
    --save_steps 10000 \
    --finetune true \
    --save_total_limit 10 \
    --ddp_backend nccl \
    --ddp_timeout 864000 \
    --use_distributed_optimizer true \
    --tensor_model_parallel_size 4 \
    --pipeline_model_parallel_size 1 \
    --sequence_parallel true \
    --context_parallel_size 1 \
    --overlap_grad_reduce true \
    --overlap_param_gather true \
    --report_to tensorboard \
    --logging_steps 1

因为 8B 的模型太小了,所以我这里也没开什么特别复杂的并行,就是开了个 TP=4。但每张卡的显存占用得到了极为显著的下降,速度也得到了显著的提升。

双机速度从之前的 5.36s/it 提升到 1.05s/it,显存占用从之前的 100GB 下降到了 18GB。经过测试,原来的训练 micro_batch 最多开到 2,而现在的显存余量足以开到 32.

https://assets.zouht.com/img/blog/4276-06.webp

如果读者想要了解各种并行方式(DP, TP, SP, CP, PP, EP),可以查看我的笔记:

本文链接:https://www.zouht.com/4276.html
本文使用:CC BY-NC-SA 4.0 许可
# # # # # # #
首页      技术      基于 Swift/Megatron 完成多模态大模型分布式训练

发表回复

textsms
account_circle
email

颢天

基于 Swift/Megatron 完成多模态大模型分布式训练
近期由于项目需要训练大模型,学习了基于 Swift/Megatron 的大模型分布式训练技术,发现细节不少,踩了不少坑,在此记录希望能帮到大家。 本文的最重要关键词就是分布式,毕竟模型训练…
扫描二维码继续阅读
2026-04-29