高效强化学习训练 - 优化slime中的权重同步

本文也在我的知乎专栏中发布,知乎链接

作者

何标

朱子霖

李冀

1. 什么是slime?

slime 是一个强化学习大规模训练框架,提供以下核心能力:


什么是slime?
slime主要由三个核心模块组成2

2. 什么是权重同步?

什么是权重同步?

在LLM强化学习中,权重同步是指将更新好的训练端的模型权重传输给到推理端的过程,以确保推理工作节点始终使用最新的模型参数。

为什么需要权重同步?

在LLM的强化学习(如PPO、GRPO等)中:

  1. 训练引擎在每个optimizer.step()后更新策略模型权重。
  2. 推理引擎生成rollout、采样动作,但它需要使用最新的策略模型权重以与训练保持一致。
  3. 这两个组件通常分别运行在不同的进程和不同的框架(如Megatron/FSDP vs. SGLang/vLLM),因此需要显式同步

注意:这篇文章专门关注同卡(Colocate)模式,我们在整个权重更新过程中使用update_weights_from_tensor api。在分离(Disaggregate)模式下,slime使用update_weights_from_distributed api,通常通过NVLink/InfiniBand互连传输权重。

3. 权重同步在slime中如何工作?

权重同步在slime中如何工作?

在slime的同卡(Colocate)模式下,Megatron的工作进程和SGLang的工作进程共同位于相同的物理GPU上。为了实现零拷贝权重传输,Megatron不发送数据本身,而是通过将Tensor序列化成CudaIpcHandlers再将其发送给SGLang的工作进程,而SGLang可以直接通过这些CudaIpcHandlers来访问权重数据进行映射,这样可以极大的提高传输效率。

以下是详细的5步工作流程:

  1. 收集分布式Tensor:从Megatron训练进程中的PP/TP/EP/ETP等级的分布式工作节点收集,并gather成完整的Tensor。代码
  2. 序列化为CUDA IPC:将Tensor转换为CudaIpcHandlers并将其聚合成一个个大约为512MB的bucket tensor中。代码
  3. API通信:通过update_weights_from_tensor api将序列化好的CudaIpcHandlers发送到SGLang Server。代码
  4. 分发到工作节点:SGLang Server将CudaIpcHandlers分发到SGLang在各个GPU Rank上启动好的TP Worker进程。代码
  5. 重构和加载:TP Worker将CudaIpcHandlers反序列化并进行映射,指向Megatron之前聚合好的同一片GPU地址,从而将Megatron的权重加载到SGLang中。代码


为什么采用基于服务器的架构?

  1. 保证训推一致。 因为线上任务肯定是用的 server based。所以 RL 这里用完全相同的配置,可以
    • 避免训出来的模型上线或者单独评测的指标不匹配
    • 可以充分复用 sglang 对 server 做的测试和性能优化
  2. 为了减少用户自定义 rollout 时的心智负担
    • 通过 server based + router,让写 rollout 就像是常规打线上服务。这样比较好配合让算法老师自定义 rollout funnction 的思路
    • 可以把 router address 对外暴露,从而让外部的 agent 环境可以随意调用 slime 内部的 sglang server,从而实现纯异步训练

4. 我们的工作:将QWen3-30B-A3B模型的权重更新时间从60秒优化到7秒

我们的优化之旅

注意:上图是根据这个Github Issue3里的所有PR做完之后往回捋出来的,以便更容易理解逻辑,实际上,我们没有按照上图所示的改进顺序进行,因为实际工作场景中自然是按照从易到难实现,而不是根据物理传输过程中的顺序。

4.0 GPU上的跨进程数据传输:CUDA IPC Handler

在进程间传输大型模型权重时,我们肯定想要避免将整个模型序列化成Base64这种方式然后传输,尤其在同卡情况下,这样传输效率太低,内存和延迟都会爆炸。

不太现实的传统方法

传统方法 vs CUDA IPC

利用CUDA IPC Handler同卡零拷贝传输

CUDA IPC如何工作

主要优势:

  1. 零拷贝传输 通过内存映射来传输数据,避免在进程间传送大量的数据
  2. 最小CPU内存开销:CUDA IPC Handler非常小 vs 序列化数据的GB级别

这其实只是我们的baseline实现,虽然比直接传数据要快得多,但仍然花了60秒,显然有很多优化空间。

4.1 优化Megatron Worker中Tensor聚合过程:从60秒到50秒

第一个瓶颈来自于聚合分散在不同Megatron Worker中的Tensor,在此之前先浅浅介绍一下不同的并行策略(TP/PP/EP)下的聚合通信方式。这里简单介绍一下,对于后续的优化会有帮助。

按并行类型划分的通信策略

并行方式 通信方式 原因
张量并行 (TP) all_gather 每个rank有部分Tensor → 收集所有部分以重构完整Tensor
流水线并行 (PP) broadcast 源rank有完整层 → 分发到其他PP Rank
专家并行 (EP) broadcast 源rank有完整专家 → 分发到其他专家组

我们采取的优化很简单,就是采用异步收集Tensor来打满带宽,在下面的例子中,我们以TP Tensor的all_gather为例。

解决方案:异步TENSOR收集/广播

def async_tensor_gathering():
    # 阶段1:同时启动所有异步操作
    handles = []
    for param in tensor_parallel_params:
        handle = dist.all_gather(
            param_partitions, param.data, 
            group=tp_group, async_op=True  # 关键:非阻塞
        )
        handles.append(handle)
    
    # 阶段2:等待所有操作完成
    for handle in handles:
        handle.wait()  # 通过批量等待最大化并行性
    
    # 阶段3:所有通信完成后处理所有结果
    gathered_params = []
    for info, direct_param, handle, param_partitions, partition_dim in gather_tasks:
        param = torch.cat(param_partitions, dim=partition_dim)
        gathered_params.append(param)

    return gathered_params

性能影响:

代码参考:slime/backends/megatron_utils/update_weight_utils.py

相关PRhttps://github.com/THUDM/slime/pull/135

4.2 通过Tensor分桶优化SGLang服务器调用:从50秒到30秒

下一个瓶颈是对SGLang服务器的API调用数量。在基础实现里,我们对每个Tensor进行单独的HTTP请求造成了显著的开销。 这在Dense Model里问题不是很大,因为相对来说Tensor的数量较少,而MOE Model经常会有上万个Tensor需要传播,因此这个问题比较严重。

问题:太多小的API调用

# 低效:每个Tensor一个API调用
for name, tensor in named_tensors.items():
    response = requests.post(
        f"http://{server_host}/update_weights_from_tensor",
        json={"tensor_name": name, "tensor_data": serialize(tensor)}
    )

解决方案:Tensor分桶

优化方案是在传输前将参数智能地分组为最优大小的bucket。以下是slime的样例代码:

def get_param_info_buckets(args, model) -> list[list[ParamInfo]]:
    param_infos = get_param_infos(args, model)
    param_info_buckets = [[]]
    buffer_size = 0
    
    for info in param_infos:
        param_size = info.size * tp_size

        # 当超过大小限制时创建新桶
        if buffer_size + param_size > args.update_weight_buffer_size:
            param_info_buckets.append([])
            buffer_size = 0
            
        param_info_buckets[-1].append(info)
        buffer_size += param_size
    
    return param_info_buckets

self.param_info_buckets = get_param_info_buckets(args, model)

# 发送桶而不是单个Tensor
for param_infos in tqdm(self.param_info_buckets, disable=rank != 0, desc="Update weights"):
    self._update_bucket_weights_from_tensor(param_infos)

注意:通过多次实验,我们发现512MB是在内存和延迟之间平衡的最佳bucket大小。当然这个参数可以直接在slime的参数中调整,我们试过1GB,2GB的速度也不错,用户可以自己稍微尝试一下。

性能影响:

代码参考

4.3 合并多个Tensor成一个Tensor:减少CUDA IPC开销:从30秒到20秒

即使有了Tensor分桶,我们仍然面临一个重要瓶颈:CUDA IPC Handler Open/Close开销。每个Tensor都需要自己的IPC Handler创建和清理,导致上万个频繁的操作。目前这个过程过于频繁,已经成为整个同步过程中的瓶颈。

问题:太多CUDA IPC操作

太多CUDA IPC操作

性能分析

上面的flame chart揭示了我们权重同步过程中的真正瓶颈。以下是详细分解:

阶段 持续时间 百分比 主要活动
IPC句柄打开 22ms 54% CUDA IPC句柄创建和内存映射
加载权重 8ms 19% 实际权重加载和Tensor重构
IPC句柄关闭 11ms 27% CUDA IPC清理和资源释放
总计 41ms 100% SGLang中完整的权重更新周期

关键发现81%的时间花费在CUDA IPC操作(打开+关闭)上,而只有19%用于实际权重加载。这解释了为什么合并多个Tensor可以提供如此显著的改进。

扁平化后

扁平化Tensor后的性能

阶段 持续时间 百分比 改进
IPC句柄打开 3ms 15% 快86%
重建 5ms 25% Tensor重构的新阶段
加载权重 12ms 60% 轻微变化
IPC句柄关闭 200μs 1% 快98%
总计 20ms 100% 相比合并前减少了51%

关键成就:通过合并多个Tensor,我们将IPC操作从总时间的81%减少到16%,而权重加载在60%时成为主导阶段 - 这正是我们想要的!

有关如何实现合并多个Tensor等技术细节,请参考以下PR:

相关PR:

4.4 加载权重优化:最终性能提升:从20秒到7秒

在优化IPC开销后,我们还发现了权重加载过程本身的其他瓶颈,特别是对于MoE模型。

关键优化:

1. 参数字典缓存

# 之前:每次权重更新时昂贵的模型遍历
params_dict = dict(self.named_parameters())

# 之后:缓存参数字典
if not hasattr(self, "_cached_params_dict"):
    self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict

2. 重复的Expert Map GPU Device Sync优化

# 避免专家映射的重复GPU到CPU同步
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
    # 将专家映射移动到GPU一次并缓存
    self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")

3. 重复的CUDA Device查询优化

# 缓存CUDA设备查询以避免重复的昂贵调用
@lru_cache(maxsize=8)
def get_device(device_id: Optional[int] = None) -> str:
    # 缓存的设备查找消除了重复的torch.cuda.is_available()调用

性能影响:

相关PR:

5. 未来优化

目前 slime 可以做到 7s 完成训推一体下 Qwen3 30B-A3B 模型 bf16 权重的参数同步。100s 完成 GLM4.5 355B-A32B 的 fp8 blockwise 量化 + 参数更新。

但还有不少的优化空间,欢迎社区的小伙伴联系我们一起继续优化。下面是一些可能的优化方向:

6. 致谢

7. 参考文献