高效强化学习训练 - 优化slime中的权重同步
本文也在我的知乎专栏中发布,知乎链接
1. 什么是slime?
slime 是一个强化学习大规模训练框架,提供以下核心能力:
- 多功能 – 拥有完全可定制的rollout接口和灵活的训练设置(同卡或分离、同步或异步、RL或SFT)。
- 高性能 - 原生集成Megatron和SGLang进行训练和推理。
- 易维护 - 轻量级代码库,并可从Megatron预训练平滑过渡到SGLang部署。1
- 大规模验证 - 最近发布的zai-org/GLM-4.5(355B) 和 zai-org/GLM-4.5-Air(106B) 都是通过slime做的RL训练。
slime主要由三个核心模块组成2:
- 训练模块(Megatron) – 处理主要的训练过程,从数据缓冲区读取数据,并在训练后与rollout模块同步参数
- Rollout模块(SGLang + Router) – 生成新数据,包括奖励和sampling后的输出结果,并将其写入数据缓冲区
- 数据Buffer模块 – 作为桥接模块,管理prompt初始化、自定义数据和rollout生成策略。
2. 什么是权重同步?
在LLM强化学习中,权重同步是指将更新好的训练端的模型权重传输给到推理端的过程,以确保推理工作节点始终使用最新的模型参数。
为什么需要权重同步?
在LLM的强化学习(如PPO、GRPO等)中:
- 训练引擎在每个
optimizer.step()
后更新策略模型权重。 - 推理引擎生成rollout、采样动作,但它需要使用最新的策略模型权重以与训练保持一致。
- 这两个组件通常分别运行在不同的进程和不同的框架(如Megatron/FSDP vs. SGLang/vLLM),因此需要显式同步。
注意:这篇文章专门关注同卡(Colocate)模式,我们在整个权重更新过程中使用
update_weights_from_tensor
api。在分离(Disaggregate)模式下,slime使用update_weights_from_distributed
api,通常通过NVLink/InfiniBand互连传输权重。
3. 权重同步在slime中如何工作?
在slime的同卡(Colocate)模式下,Megatron的工作进程和SGLang的工作进程共同位于相同的物理GPU上。为了实现零拷贝权重传输,Megatron不发送数据本身,而是通过将Tensor序列化成CudaIpcHandlers再将其发送给SGLang的工作进程,而SGLang可以直接通过这些CudaIpcHandlers来访问权重数据进行映射,这样可以极大的提高传输效率。
以下是详细的5步工作流程:
- 收集分布式Tensor:从Megatron训练进程中的PP/TP/EP/ETP等级的分布式工作节点收集,并gather成完整的Tensor。代码
- 序列化为CUDA IPC:将Tensor转换为CudaIpcHandlers并将其聚合成一个个大约为512MB的bucket tensor中。代码
- API通信:通过
update_weights_from_tensor
api将序列化好的CudaIpcHandlers发送到SGLang Server。代码 - 分发到工作节点:SGLang Server将CudaIpcHandlers分发到SGLang在各个GPU Rank上启动好的TP Worker进程。代码
- 重构和加载:TP Worker将CudaIpcHandlers反序列化并进行映射,指向Megatron之前聚合好的同一片GPU地址,从而将Megatron的权重加载到SGLang中。代码
为什么采用基于服务器的架构?
- 保证训推一致。 因为线上任务肯定是用的 server based。所以 RL 这里用完全相同的配置,可以
- 避免训出来的模型上线或者单独评测的指标不匹配
- 可以充分复用 sglang 对 server 做的测试和性能优化
- 为了减少用户自定义 rollout 时的心智负担
- 通过 server based + router,让写 rollout 就像是常规打线上服务。这样比较好配合让算法老师自定义 rollout funnction 的思路
- 可以把 router address 对外暴露,从而让外部的 agent 环境可以随意调用 slime 内部的 sglang server,从而实现纯异步训练
4. 我们的工作:将QWen3-30B-A3B模型的权重更新时间从60秒优化到7秒
注意:上图是根据这个Github Issue3里的所有PR做完之后往回捋出来的,以便更容易理解逻辑,实际上,我们没有按照上图所示的改进顺序进行,因为实际工作场景中自然是按照从易到难实现,而不是根据物理传输过程中的顺序。
4.0 GPU上的跨进程数据传输:CUDA IPC Handler
在进程间传输大型模型权重时,我们肯定想要避免将整个模型序列化成Base64这种方式然后传输,尤其在同卡情况下,这样传输效率太低,内存和延迟都会爆炸。
不太现实的传统方法
利用CUDA IPC Handler同卡零拷贝传输
主要优势:
- 零拷贝传输 通过内存映射来传输数据,避免在进程间传送大量的数据
- 最小CPU内存开销:CUDA IPC Handler非常小 vs 序列化数据的GB级别
这其实只是我们的baseline实现,虽然比直接传数据要快得多,但仍然花了60秒,显然有很多优化空间。
4.1 优化Megatron Worker中Tensor聚合过程:从60秒到50秒
第一个瓶颈来自于聚合分散在不同Megatron Worker中的Tensor,在此之前先浅浅介绍一下不同的并行策略(TP/PP/EP)下的聚合通信方式。这里简单介绍一下,对于后续的优化会有帮助。
按并行类型划分的通信策略
并行方式 | 通信方式 | 原因 |
---|---|---|
张量并行 (TP) | all_gather | 每个rank有部分Tensor → 收集所有部分以重构完整Tensor |
流水线并行 (PP) | broadcast | 源rank有完整层 → 分发到其他PP Rank |
专家并行 (EP) | broadcast | 源rank有完整专家 → 分发到其他专家组 |
我们采取的优化很简单,就是采用异步收集Tensor来打满带宽,在下面的例子中,我们以TP Tensor的all_gather
为例。
解决方案:异步TENSOR收集/广播
def async_tensor_gathering():
# 阶段1:同时启动所有异步操作
handles = []
for param in tensor_parallel_params:
handle = dist.all_gather(
param_partitions, param.data,
group=tp_group, async_op=True # 关键:非阻塞
)
handles.append(handle)
# 阶段2:等待所有操作完成
for handle in handles:
handle.wait() # 通过批量等待最大化并行性
# 阶段3:所有通信完成后处理所有结果
gathered_params = []
for info, direct_param, handle, param_partitions, partition_dim in gather_tasks:
param = torch.cat(param_partitions, dim=partition_dim)
gathered_params.append(param)
return gathered_params
性能影响:
- 之前:顺序收集 → 60秒
- 之后:并行异步Tensor收集 → 50秒
代码参考:slime/backends/megatron_utils/update_weight_utils.py
相关PR:https://github.com/THUDM/slime/pull/135
4.2 通过Tensor分桶优化SGLang服务器调用:从50秒到30秒
下一个瓶颈是对SGLang服务器的API调用数量。在基础实现里,我们对每个Tensor进行单独的HTTP请求造成了显著的开销。 这在Dense Model里问题不是很大,因为相对来说Tensor的数量较少,而MOE Model经常会有上万个Tensor需要传播,因此这个问题比较严重。
问题:太多小的API调用
# 低效:每个Tensor一个API调用
for name, tensor in named_tensors.items():
response = requests.post(
f"http://{server_host}/update_weights_from_tensor",
json={"tensor_name": name, "tensor_data": serialize(tensor)}
)
解决方案:Tensor分桶
优化方案是在传输前将参数智能地分组为最优大小的bucket。以下是slime的样例代码:
def get_param_info_buckets(args, model) -> list[list[ParamInfo]]:
param_infos = get_param_infos(args, model)
param_info_buckets = [[]]
buffer_size = 0
for info in param_infos:
param_size = info.size * tp_size
# 当超过大小限制时创建新桶
if buffer_size + param_size > args.update_weight_buffer_size:
param_info_buckets.append([])
buffer_size = 0
param_info_buckets[-1].append(info)
buffer_size += param_size
return param_info_buckets
self.param_info_buckets = get_param_info_buckets(args, model)
# 发送桶而不是单个Tensor
for param_infos in tqdm(self.param_info_buckets, disable=rank != 0, desc="Update weights"):
self._update_bucket_weights_from_tensor(param_infos)
注意:通过多次实验,我们发现512MB是在内存和延迟之间平衡的最佳bucket大小。当然这个参数可以直接在slime的参数中调整,我们试过1GB,2GB的速度也不错,用户可以自己稍微尝试一下。
性能影响:
- 之前:上万个单独API调用 → 50秒
- 之后:几百个API调用 → 30秒
- 改进:通过最小化HTTP开销减少40%
4.3 合并多个Tensor成一个Tensor:减少CUDA IPC开销:从30秒到20秒
即使有了Tensor分桶,我们仍然面临一个重要瓶颈:CUDA IPC Handler Open/Close开销。每个Tensor都需要自己的IPC Handler创建和清理,导致上万个频繁的操作。目前这个过程过于频繁,已经成为整个同步过程中的瓶颈。
问题:太多CUDA IPC操作
性能分析
上面的flame chart揭示了我们权重同步过程中的真正瓶颈。以下是详细分解:
阶段 | 持续时间 | 百分比 | 主要活动 |
---|---|---|---|
IPC句柄打开 | 22ms | 54% | CUDA IPC句柄创建和内存映射 |
加载权重 | 8ms | 19% | 实际权重加载和Tensor重构 |
IPC句柄关闭 | 11ms | 27% | CUDA IPC清理和资源释放 |
总计 | 41ms | 100% | SGLang中完整的权重更新周期 |
关键发现:81%的时间花费在CUDA IPC操作(打开+关闭)上,而只有19%用于实际权重加载。这解释了为什么合并多个Tensor可以提供如此显著的改进。
扁平化Tensor后的性能
阶段 | 持续时间 | 百分比 | 改进 |
---|---|---|---|
IPC句柄打开 | 3ms | 15% | 快86% |
重建 | 5ms | 25% | Tensor重构的新阶段 |
加载权重 | 12ms | 60% | 轻微变化 |
IPC句柄关闭 | 200μs | 1% | 快98% |
总计 | 20ms | 100% | 相比合并前减少了51% |
关键成就:通过合并多个Tensor,我们将IPC操作从总时间的81%减少到16%,而权重加载在60%时成为主导阶段 - 这正是我们想要的!
有关如何实现合并多个Tensor等技术细节,请参考以下PR:
相关PR:
4.4 加载权重优化:最终性能提升:从20秒到7秒
在优化IPC开销后,我们还发现了权重加载过程本身的其他瓶颈,特别是对于MoE模型。
关键优化:
1. 参数字典缓存
# 之前:每次权重更新时昂贵的模型遍历
params_dict = dict(self.named_parameters())
# 之后:缓存参数字典
if not hasattr(self, "_cached_params_dict"):
self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict
2. 重复的Expert Map GPU Device Sync优化
# 避免专家映射的重复GPU到CPU同步
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
# 将专家映射移动到GPU一次并缓存
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
3. 重复的CUDA Device查询优化
# 缓存CUDA设备查询以避免重复的昂贵调用
@lru_cache(maxsize=8)
def get_device(device_id: Optional[int] = None) -> str:
# 缓存的设备查找消除了重复的torch.cuda.is_available()调用
性能影响:
- 之前:各种权重加载瓶颈 → 20秒
- 之后:优化的参数缓存和设备处理 → 7秒
- 改进:最终权重加载时间减少65%
相关PR:
5. 未来优化
目前 slime 可以做到 7s 完成训推一体下 Qwen3 30B-A3B 模型 bf16 权重的参数同步。100s 完成 GLM4.5 355B-A32B 的 fp8 blockwise 量化 + 参数更新。
但还有不少的优化空间,欢迎社区的小伙伴联系我们一起继续优化。下面是一些可能的优化方向:
- 异步收集和发送:Megatron Worker的收集和SGLang Worker的加载实际上可以异步,理论上最高能加快1倍的速度。
- 异步权重加载:非阻塞模型权重更新
- 零冗余布局:预计算推理引擎内存布局并进行零冗余拷贝,比如megatron rank 0 只传送sglang rank 0 实际需要的权重,目前还是有很大的冗余的。
6. 致谢
- slime团队: https://github.com/THUDM/slime
- SGLang团队: https://github.com/sgl-project/sglang