Megatron-LM源码系列(六):Distributed-Optimizer分布式优化器实现Part1

1. 使用说明

在megatron中指定--use-distributed-optimizer就能开启分布式优化器, 参数定义在megatron/arguments.py中。分布式优化器的思路是将训练中的优化器状态均匀地分布到不同数据并行的rank结点上,相当于开启ZERO-1的训练。

    group.add_argument('--use-distributed-optimizer', action='store_true',
                       help='Use distributed optimizer.')

在使用--use-distributed-optimizer, 同时会check两个参数 args.DDP_impl == 'local'(默认开启)和args.use_contiguous_buffers_in_local_ddp(默认开启)。

    # If we use the distributed optimizer, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is on.
    if args.use_distributed_optimizer:
        assert args.DDP_impl == 'local'
        assert args.use_contiguous_buffers_in_local_ddp

分布式优化器节省的理论显存值依赖参数类型和梯度类型,以下是每一个parameter对应占用的理论字节数(d表示数据并行的size大小,也就是一个数据并行中的卡数, 等于 T P × P P TP \times PP TP×PP ):

训练数据类型Non-distributed optim(单位Byte)Distributed optim(单位Byte)
float16 param, float16 grads204 + 16/d
float16 param, fp32 grads186 + 12/d
fp32 param, fp32 grads168 + 8/d

2. 实现介绍

  • Distributed-Optimizer分布式优化器的主要实现是通过连续的grad buffer来进行的,grad buffer中用于模型状态和优化器状态之间进行parameter参数和grad梯度的通信。grad buffer中使用reduce-scatter和all-gather进行通信。

  • 数据流如下:
    在这里插入图片描述

    1. 在每个dp的rank上计算完grad后,组成待更新的grad buffer数组
    2. 更新的时候通过reduce-scatter将grad buffer切分到各个rank上
    3. 在每个rank上完成优化器的step操作
    4. 最后将所有结果执行allgather操作得到更新后的grad buffer。
  • 以fp16类型grad为例,grad buffer分片说明如下:
    在这里插入图片描述

    • 一共有4个参数,分别用绿/黄/蓝/红表示;总参数大小为16个fp16类型数据
    • 按DP中rank的个数对总数据均匀切分
    • 如果参数过大,每个rank可能会只包含部分参数的数据,所以要考虑参数的偏移
    • 每个DP rank中的每个param参数都对应有3个偏移,一个是world_index表示总的数据偏移,一个是local_index表示在当前rank中的数据偏移,一个是param_index相对于param来说,表示当前rank结点存的数据的偏移。
    • 以黄色参数Param1为例,在rank0存了Param1的一个元素,rank1存了Param1的4个元素;world_index来说rank0上黄色部分的元素是总数据的[3,4], rank1上黄色部分的4个元素是总数据的[4,8]; local_index来说在rank0上表示[3,4],rank1表示当前结点全部的4个元素,范围也就是[0,4];param_index来说,对于rank0上的Param1的param_index就是[0,1],在rank2上的param_index就是[1,5];
  • 关键步骤详解:

    1. 上图中每个方块看成是一个grad buffer中的一个fp16类型元素,在反向结束以后,grad buffer中有16个fp16类型的元素
    2. 在每一个DP rank上调用reduce-scatter操作
    3. 每个DP rank的grad buffer中都有4个fp16类型元素经过了reduce-scatter操作更新,没更新的12个fp16类型元素等待后续垃圾回收
    4. 每个DP rank从grad buffer中拷贝更新后的4个fp16类型元素到fp32类型的main grad buffer中,准备开始后续的更新操作,例如
      • DP rank0拷贝[0:4]个元素
      • DP rank1拷贝[4:8]个元素
      • DP rank2拷贝[8:12]个元素
      • DP rank3拷贝[12:16]个元素
    5. 执行Optimizer.step(), step()操作必须通过fp32类型来进行计算
    6. 每个DP rank从main grad buffer中拷贝step()更新后的4个fp32类型元素到fp16类型的grad buffer中
    7. 执行allgather操作, 这样每个grad buffer就都是最新更新后的数据了
    8. 基于grad buffer来更新各个模型的fp16类型的参数
    9. 开始进行下一轮的更新

3. 源码实现

3.1 程序入口

  • 初始化的入口在文件megatron/training.pyget_model函数中,在创建LocalDDP的实例中会传入args.use_contiguous_buffers_in_local_ddp
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
    ...
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            ...
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
    ...
  • 训练的入口定义在train_step函数中, 基本流程如下:
def train_step(forward_step_func, data_iterator,
               model, optimizer, opt_param_scheduler):
    ...
    
    # 清除grad
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
        for partition in model:
            partition.zero_grad_buffer()
    optimizer.zero_grad()
    ...
    
    # 执行前反向计算
    losses_reduced = forward_backward_func(...)
    ...
    
    # 对梯度执行Reduce-Scatter操作
    optimizer.reduce_model_grads(args, timers)
    ...
    
    # 更新梯度
    timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
    timers('optimizer').stop()
    ...
    
    # 对更新后的param执行gather操作
    if update_successful:
        optimizer.gather_model_params(args, timers)
    ...
    
    # 通过scheduler更新学习率
    if update_successful:
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        opt_param_scheduler.step(increment=increment)
        skipped_iter = 0
    else:
        skipped_iter = 1
    ...

3.2 grad buffer初始化(DistributedDataParallel类)

  • grad buffer初始化是在类DistributedDataParallel的init函数中, 源码定义在megatron/optimizer/distrib_optimizer.py文件中。
class DistributedDataParallel(DistributedDataParallelBase):
    def __init__(self, module,
                 accumulate_allreduce_grads_in_fp32,
                 use_contiguous_buffers):
  • 创建grad buffer和index map
            self._grad_buffers = {}
            self._grad_buffer_param_index_map = {}
            data_parallel_world_size = mpu.get_data_parallel_world_size()
  • 按类型分别计算每个类型元素的个数,使用type_num_elements map进行存储,key是元素类型,value是类型出现的元素个数
            # First calculate total number of elements per type.
            type_num_elements = {}
            for param in self.module.parameters():
                if param.requires_grad:
                    dtype = _get_buffer_type(param)
                    type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
                                               + param.data.nelement()
  • 实际开始分配grad buffer, 为了支持被DP并行数正好切分,需要先对每个类型出现的个数进行padding操作;然后通过MemoryBuffer进行存储的分配
            # Allocate the buffer.
            for dtype, num_elements in type_num_elements.items():

                # If using distributed optimizer, pad memory buffer to be
                # multiple of data_parallel_world_size. (This padding is done
                # due to a constraint with the reduce_scatter op, which requires
                # all tensors have equal size. See: optimizer.py.)
                num_elements_padded = data_parallel_world_size * \
                    int(math.ceil(num_elements / data_parallel_world_size))

                # Allocate grad buffer.
                self._grad_buffers[dtype] = MemoryBuffer(num_elements,
                                                         num_elements_padded,
                                                         dtype)
  • 从grad buffer中给每一个param参数分配对应的main_grad空间,在分配main_grad时根据每个param参数的类型从对应的self._grad_buffers[dtype]中得到跟param.data.shape一样的tensor,这里的tensor与grad buffer共享存储。同时grad buffer的分配是按倒序来分配的,比如self.module.parameters()中有三个参数分别是[p1, p2, p3], 在grad buffer中存储则是[p3_grad, p2_grad, p1_grad]_grad_buffer_param_index_map用来记录每个param的梯度在grad buffer中存储的起始和结束位置。
            ...
            # Assume the back prop order is reverse the params order,
            # store the start index for the gradients.
            for param in self.module.parameters():
                if param.requires_grad:
                    dtype = _get_buffer_type(param)
                    type_num_elements[dtype] -= param.data.nelement()
                    # get的第二个参数是start_index,这里的start_index是从grad_buffer从大到小来算的
                    param.main_grad = self._grad_buffers[dtype].get(
                        param.data.shape, type_num_elements[dtype])
                    if dtype not in self._grad_buffer_param_index_map:
                        self._grad_buffer_param_index_map[dtype] = {}
                    self._grad_buffer_param_index_map[dtype][param] = (
                        type_num_elements[dtype],
                        type_num_elements[dtype] + param.data.nelement(),
                    )
  • 遍历每一个参数,对于每一个参数的grad_fn的下一个function累加grad_acc函数进行改写,由于param本身没有grad_fn,通过trick方式使用param.expand_as给param加上了grad_fn函数。
            ...
            # Backward hook.
            # Accumalation function for the gradients. We need
            # to store them so they don't go out of scope.
            self.grad_accs = []
            # Loop over all the parameters in the model.
            for param in self.module.parameters():
                if param.requires_grad:
                    # 使用expand_as使param具有grad_fn.
                    param_tmp = param.expand_as(param)
                    # 获取梯度累加函数,并注册hook改写
                    grad_acc = param_tmp.grad_fn.next_functions[0][0]
                    grad_acc.register_hook(self._make_param_hook(param))
                    self.grad_accs.append(grad_acc)
    
    def _make_param_hook(self, param):
        """Create the all-reduce hook for backprop."""
        # Hook used for back-prop.
        def param_hook(*unused):
            # Add the gradient to the buffer.
            if param.grad is not None:
                # The gradient function of linear layers is fused with GEMMs
                param.main_grad.add_(param.grad.data)
                # Now we can deallocate grad memory.
                param.grad = None
        return param_hook

4. 参考