Megatron-LM源码系列(六):Distributed-Optimizer分布式优化器实现Part1
1. 使用说明
在megatron中指定--use-distributed-optimizer
就能开启分布式优化器, 参数定义在megatron/arguments.py
中。分布式优化器的思路是将训练中的优化器状态均匀地分布到不同数据并行的rank结点上,相当于开启ZERO-1
的训练。
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')
在使用--use-distributed-optimizer
, 同时会check两个参数 args.DDP_impl == 'local'
(默认开启)和args.use_contiguous_buffers_in_local_ddp
(默认开启)。
# If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
if args.use_distributed_optimizer:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
分布式优化器节省的理论显存值依赖参数类型和梯度类型,以下是每一个parameter对应占用的理论字节数(d
表示数据并行的size大小,也就是一个数据并行中的卡数, 等于
T
P
×
P
P
TP \times PP
TP×PP ):
训练数据类型 | Non-distributed optim(单位Byte) | Distributed optim(单位Byte) |
---|---|---|
float16 param, float16 grads | 20 | 4 + 16/d |
float16 param, fp32 grads | 18 | 6 + 12/d |
fp32 param, fp32 grads | 16 | 8 + 8/d |
2. 实现介绍
-
Distributed-Optimizer分布式优化器的主要实现是通过连续的
grad buffer
来进行的,grad buffer
中用于模型状态和优化器状态之间进行parameter参数和grad梯度的通信。grad buffer
中使用reduce-scatter和all-gather进行通信。 -
数据流如下:
- 在每个dp的rank上计算完grad后,组成待更新的grad buffer数组
- 更新的时候通过reduce-scatter将grad buffer切分到各个rank上
- 在每个rank上完成优化器的step操作
- 最后将所有结果执行allgather操作得到更新后的grad buffer。
-
以fp16类型grad为例,grad buffer分片说明如下:
- 一共有4个参数,分别用绿/黄/蓝/红表示;总参数大小为16个fp16类型数据
- 按DP中rank的个数对总数据均匀切分
- 如果参数过大,每个rank可能会只包含部分参数的数据,所以要考虑参数的偏移
- 每个DP rank中的每个param参数都对应有3个偏移,一个是world_index表示总的数据偏移,一个是local_index表示在当前rank中的数据偏移,一个是param_index相对于param来说,表示当前rank结点存的数据的偏移。
- 以黄色参数Param1为例,在rank0存了Param1的一个元素,rank1存了Param1的4个元素;world_index来说rank0上黄色部分的元素是总数据的[3,4], rank1上黄色部分的4个元素是总数据的[4,8]; local_index来说在rank0上表示[3,4],rank1表示当前结点全部的4个元素,范围也就是[0,4];param_index来说,对于rank0上的Param1的param_index就是[0,1],在rank2上的param_index就是[1,5];
-
关键步骤详解:
- 上图中每个方块看成是一个grad buffer中的一个fp16类型元素,在反向结束以后,grad buffer中有16个fp16类型的元素
- 在每一个DP rank上调用reduce-scatter操作
- 每个DP rank的grad buffer中都有4个fp16类型元素经过了reduce-scatter操作更新,没更新的12个fp16类型元素等待后续垃圾回收
- 每个DP rank从grad buffer中拷贝更新后的4个fp16类型元素到fp32类型的main grad buffer中,准备开始后续的更新操作,例如
- DP rank0拷贝[0:4]个元素
- DP rank1拷贝[4:8]个元素
- DP rank2拷贝[8:12]个元素
- DP rank3拷贝[12:16]个元素
- 执行Optimizer.step(), step()操作必须通过fp32类型来进行计算
- 每个DP rank从main grad buffer中拷贝step()更新后的4个fp32类型元素到fp16类型的grad buffer中
- 执行allgather操作, 这样每个grad buffer就都是最新更新后的数据了
- 基于grad buffer来更新各个模型的fp16类型的参数
- 开始进行下一轮的更新
3. 源码实现
3.1 程序入口
- 初始化的入口在文件
megatron/training.py
的get_model
函数中,在创建LocalDDP
的实例中会传入args.use_contiguous_buffers_in_local_ddp
。
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
...
if wrap_with_ddp:
if args.DDP_impl == 'torch':
...
elif args.DDP_impl == 'local':
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp)
for model_module in model]
...
- 训练的入口定义在
train_step
函数中, 基本流程如下:
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler):
...
# 清除grad
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
for partition in model:
partition.zero_grad_buffer()
optimizer.zero_grad()
...
# 执行前反向计算
losses_reduced = forward_backward_func(...)
...
# 对梯度执行Reduce-Scatter操作
optimizer.reduce_model_grads(args, timers)
...
# 更新梯度
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop()
...
# 对更新后的param执行gather操作
if update_successful:
optimizer.gather_model_params(args, timers)
...
# 通过scheduler更新学习率
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
...
3.2 grad buffer初始化(DistributedDataParallel类)
grad buffer
初始化是在类DistributedDataParallel
的init函数中, 源码定义在megatron/optimizer/distrib_optimizer.py
文件中。
class DistributedDataParallel(DistributedDataParallelBase):
def __init__(self, module,
accumulate_allreduce_grads_in_fp32,
use_contiguous_buffers):
- 创建grad buffer和index map
self._grad_buffers = {}
self._grad_buffer_param_index_map = {}
data_parallel_world_size = mpu.get_data_parallel_world_size()
- 按类型分别计算每个类型元素的个数,使用type_num_elements map进行存储,key是元素类型,value是类型出现的元素个数
# First calculate total number of elements per type.
type_num_elements = {}
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
+ param.data.nelement()
- 实际开始分配grad buffer, 为了支持被DP并行数正好切分,需要先对每个类型出现的个数进行padding操作;然后通过
MemoryBuffer
进行存储的分配
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded = data_parallel_world_size * \
int(math.ceil(num_elements / data_parallel_world_size))
# Allocate grad buffer.
self._grad_buffers[dtype] = MemoryBuffer(num_elements,
num_elements_padded,
dtype)
- 从grad buffer中给每一个param参数分配对应的main_grad空间,在分配main_grad时根据每个param参数的类型从对应的
self._grad_buffers[dtype]
中得到跟param.data.shape
一样的tensor,这里的tensor与grad buffer共享存储。同时grad buffer的分配是按倒序来分配的,比如self.module.parameters()
中有三个参数分别是[p1, p2, p3]
, 在grad buffer中存储则是[p3_grad, p2_grad, p1_grad]
。_grad_buffer_param_index_map
用来记录每个param的梯度在grad buffer中存储的起始和结束位置。
...
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for param in self.module.parameters():
if param.requires_grad:
dtype = _get_buffer_type(param)
type_num_elements[dtype] -= param.data.nelement()
# get的第二个参数是start_index,这里的start_index是从grad_buffer从大到小来算的
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {}
self._grad_buffer_param_index_map[dtype][param] = (
type_num_elements[dtype],
type_num_elements[dtype] + param.data.nelement(),
)
- 遍历每一个参数,对于每一个参数的grad_fn的下一个function累加grad_acc函数进行改写,由于param本身没有
grad_fn
,通过trick方式使用param.expand_as
给param加上了grad_fn
函数。
...
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self.grad_accs = []
# Loop over all the parameters in the model.
for param in self.module.parameters():
if param.requires_grad:
# 使用expand_as使param具有grad_fn.
param_tmp = param.expand_as(param)
# 获取梯度累加函数,并注册hook改写
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def param_hook(*unused):
# Add the gradient to the buffer.
if param.grad is not None:
# The gradient function of linear layers is fused with GEMMs
param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory.
param.grad = None
return param_hook