6. Stage3 - 前后向过程

Hook 章节( 节 5 ) 我们看到对每个 Module 的前后向过程都进行了 hook ,在 Module::forward 前后都会执行特定的动作, 也就是在这里需要完成被分割参数的还原与重新释放。 并且关键的动作都在类 PartitionedParameterCoordinator 中, 本章我们研究前向过程。

6.1. 参数还原

理论上,前后向过程前都要还原参数, 先把 hook 的函数贴过来,

@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
    see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)

    global FWD_MODULE_STACK
    FWD_MODULE_STACK.append(sub_module)

    param_coordinator = self.get_param_coordinator(training=sub_module.training)
    param_coordinator.trace_prologue(sub_module)
    if param_coordinator.is_record_trace():
        param_coordinator.record_module(sub_module)
    # 真正的参数聚合动作是在这里
    # param_coordinator 的类型是 deepspeed.runtime.zero.PartitionedParameterCoordinator
    param_coordinator.fetch_sub_module(sub_module, forward=True)

    see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)

@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
    assert sub_module.training, "backward pass is invalid for module in evaluation mode"
    param_coordinator = self.get_param_coordinator(training=True)
    param_coordinator.trace_prologue(sub_module)
    if param_coordinator.is_record_trace():
        param_coordinator.record_module(sub_module)
    param_coordinator.fetch_sub_module(sub_module, forward=False)

前后向之前都是要进入 param_coordinator.fetch_sub_module(sub_module, forward=True) ,通过入参 forward=True 区分前向还是后向,接下跟踪进去看看

@instrument_w_nvtx
@torch.no_grad()
def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
    """This method does the following (in order):
    1. kick off fetch for parameters in immediately required sub module
    2. kick off fetch for next few parameters we will need later (prefetch)
    3. block on parameters in immediately required sub module
    """
    if logger.isEnabledFor(logging.DEBUG):
        debug_rank0(
            f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} "
            + str({
                "avail": f"{self.__n_available_params:.1e}",
                "queue_sz": f"{len(self.__param_queue or [])}",
                "inflight": [p.ds_id for p in self.__inflight_param_registry],
            }))
    # 只有当前层(module)的直属参数,不会递归到下层
    params_to_fetch = frozenset(iter_params(current_submodule))
    # 统计一下需要聚合的参数数量
    fetch_numel = sum(
        [p.partition_numel() for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
    if fetch_numel > 0:
        # 判断前向还是后向
        event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT
        # 这一行只是打日志
        self._dump_param_ids(event_name, current_submodule.id,
                             [p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
        self.__profiler.start_event(event_name)
        # kick off all gather for params in the immediately required submodule
        # for param in params_to_fetch:
        if logger.isEnabledFor(logging.DEBUG):
            for param in params_to_fetch:
                debug_rank0(f"-fetch: {param.ds_summary()}")
        # 启动参数聚合,真正的参数还原逻辑,接下来继续跟踪进去
        self.__all_gather_params(params_to_fetch, forward)
        self.__profiler.stop_event(event_name, fetch_numel)
    # 注意,参数聚合是一个异步操作,并不会立刻完成。
    # 以下就是等待完成
    wait_numel = 0
    wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT
    self.__profiler.start_event(wait_event_name)
    # wait for parameters in the immediately needed submodule to become available
    for param in params_to_fetch:
        param.ds_active_sub_modules.add(current_submodule.id)
        if logger.isEnabledFor(logging.DEBUG):
            debug_rank0(f"-wait: {param.ds_summary()}")
        if param in self.__inflight_param_registry:
            wait_numel += param.partition_numel()
            # 有关stream的详细介绍 https://zhuanlan.zhihu.com/p/369367933
            with get_accelerator().stream(self.__allgather_stream):
                # 这里主要是通过event来确保 __allgather_stream 中的操作依次完成
                # 没完成就会在这空转,等待完成。
                # query() 判断此event之前的所有操作是否完成
                while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query():
                    self.__ongoing_fetch_events.popleft()
                if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events:
                    self.__ongoing_fetch_events.popleft().synchronize()
                # 等待 当前参数 param 完成 allgather 聚合,
                # todo 这里有点奇怪 self.__inflight_param_registry.pop(param) 得到的handler 包含了全部参数
                #   这里 wait 会一次性 wait 全部参数?
                # handler 有两种:
                #  1. AllGatherHandle  仅处理一个参数
                #  2. AllGatherCoalescedHandle  批处理多个参数
                self.__inflight_param_registry.pop(param).wait()
                if not get_accelerator().is_synchronized_device():
                    # 如果底层计算引擎不是同步的,是异步的,这里特指cuda
                    # 创建一个 event

                    event = get_accelerator().Event()
                    event.record()
                    self.__ongoing_fetch_events.append(event)
        # 判断参数是否完成聚合,并达到可用状态
        assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
    if not get_accelerator().is_synchronized_device():
        # 等待 __allgather_stream 通道完成
        get_accelerator().current_stream().wait_stream(self.__allgather_stream)
    self.__profiler.stop_event(wait_event_name, wait_numel)

    # kick off parameter prefetches for upcoming modules
    # don't prefetch if we dont have a completed model trace
    if self.is_complete_trace():
        # go through the parameters we need for the current module and pop them
        # off the fetch queue so that they aren't prefetched later.
        # if params have already been popped off the fetch queue by earlier
        # prefetches we won't look for them here
        discarded_from_prefetch_queue = set()
        params_not_already_fetched = set(
            filter(lambda p: self.__most_recent_step_id_param_fetched_for[p] < self.__step_id, params_to_fetch))
        while self.__param_queue and len(discarded_from_prefetch_queue) < len(params_not_already_fetched):
            param_in_trace = self.__param_queue.popleft()
            self.__most_recent_step_id_param_fetched_for[
                param_in_trace.param] = param_in_trace.step_id_last_used_at
            discarded_from_prefetch_queue.add(param_in_trace.param)

        if discarded_from_prefetch_queue != params_not_already_fetched:
            raise RuntimeError(
                f"tracing error at step {self.__step_id}: \n"
                f"module id: {current_submodule.id}, training: {current_submodule.training}\n"
                f"expected the next {len(params_not_already_fetched)} parameters in the "
                f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n"
                f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}.")

        def _is_currently_on_nvme(param):
            if param.nvme_swapper is None:
                return False

            return param.ds_tensor.final_location == OffloadDeviceEnum.nvme \
                and param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE

        # kick off all gather for params in the next few submodules (prefetch)
        if self.__prefetch_bucket_sz > 0:
            max_params_to_prefetch = min(self.__max_n_available_params - self.__n_available_params,
                                         self.__prefetch_bucket_sz)
            params_to_prefetch = set()
            numel_prefetching = 0
            while self.__param_queue and numel_prefetching < max_params_to_prefetch:
                param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft()

                if _is_currently_on_nvme(param_in_trace.param):
                    # nvme prefetch is handled elsewhere. Need to break here to preserve fetch order
                    self.__param_queue.appendleft(param_in_trace)
                    break

                do_prefetch = param_in_trace.param.ds_status == ZeroParamStatus.NOT_AVAILABLE
                if param_in_trace.param in params_to_prefetch:
                    # Avoid duplicates
                    do_prefetch = False

                self.__most_recent_step_id_param_fetched_for[param_in_trace.param] = \
                    max(self.__most_recent_step_id_param_fetched_for[param_in_trace.param],
                        param_in_trace.step_id_last_used_at)

                if do_prefetch:
                    params_to_prefetch.add(param_in_trace.param)
                    numel_prefetching += param_in_trace.param.ds_numel

            if numel_prefetching > 0:
                event_name = __class__.FORWARD_PREFETCH_SUBMIT if forward else __class__.BACKWARD_PREFETCH_SUBMIT
                self.__profiler.start_event(event_name)
                if logger.isEnabledFor(logging.DEBUG):
                    for param in params_to_prefetch:
                        debug_rank0(f"-prefetch: {param.ds_summary()}")
                self.__all_gather_params(params_to_prefetch, forward)
                self.__profiler.stop_event(event_name, numel_prefetching)

            if self.__prefetch_nvme:
                self.__prefetch_nvme_param_partitions()

    self.__step_id += 1

6.1.1. __all_gather_params

@instrument_w_nvtx
def __all_gather_params(self, params: Set[Parameter], forward: bool) -> None:
    quantized_params = []
    nonquantized_params = []
    for param in params:
        if hasattr(param.ds_tensor, 'ds_quant_scale'):
            quantized_params.append(param)
        else:
            nonquantized_params.append(param)
    if quantized_params:
        self.__all_gather_params_(quantized_params, forward, quantize=True)
    if nonquantized_params:
        self.__all_gather_params_(nonquantized_params, forward, quantize=self.zero_quantized_weights)

根据参数是否启动了量化,分别进行处理,

def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: bool = False) -> None:
    """for each partitioned parameter, kick off an async allgather and store
    the work handle for the in flight parameters."""
    partitioned_params = []  # 需要聚合的参数集合
    all_gather_numel = 0
    for param in params:
        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:  # 状态为不可用的参数才需要
            partitioned_params.append(param)
            all_gather_numel += param.ds_numel

    if partitioned_params:
        # partitioned_params
        self.__n_available_params += all_gather_numel
        # 使用 __allgather_stream 通道,
        # GPU是可以并行计算的,可以简单理解为:
        #   1. 同一个stream内的操作是串行的
        #   2. 不同stream 内的操作是并行的
        # 这里 allgather 的操作全部在 allgather stream 上执行
        with get_accelerator().stream(self.__allgather_stream):
            # 前向还是后向
            event_name = __class__.FORWARD_ALL_GATHER if forward else __class__.BACKWARD_ALL_GATHER

            self.__profiler.start_event(event_name)  # 启动性能检测器
            # 调用 parameter 的方法 all_gather_coalesced 进行参数聚合
            # 注意这个方法可以接收一个参数列表批量进行。这代码写的很随性。
            handle = partitioned_params[0].all_gather_coalesced(partitioned_params,
                                                                forward=forward,
                                                                quantize=quantize)
            self.__profiler.stop_event(event_name, all_gather_numel)

        for param in partitioned_params:
            #  ZeroParamStatus.INFLIGHT 表示参数正在同步聚合中,
            assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary()
            self.__inflight_param_registry[param] = handle

        # Release swap buffers for persisted params on nvme since they will never be partitioned or evicted from GPU
        swap_persisted_params = [
            p for p in partitioned_params if p.ds_persist and p.ds_tensor.final_location == OffloadDeviceEnum.nvme
        ]
        if swap_persisted_params:
            swap_persisted_params[0].nvme_swapper.remove_partition_and_release_buffers(swap_persisted_params)

我们发现最终又跳到了每个参数的方法 all_gather_coalesced 中, 这个方法是前面参数分割章节( 节 4 ) 中赋予的,具体在 节 4.3.1

@instrument_w_nvtx
def all_gather_coalesced(params: Iterable[Parameter],
                         forward: bool = True,
                         safe_mode: bool = False,
                         quantize: bool = False) -> AllGatherCoalescedHandle:

    # fetches from nvme if the partition is not available and in nvme
    self._ensure_availability_of_partitioned_params(params)

    if self.num_partitions == 1:
        return _no_gather_coalesced(params)

    for param in params:
        if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
            raise RuntimeError(param.ds_summary())
        param.ds_status = ZeroParamStatus.INFLIGHT

    # use appropriate all gather process group
    ds_process_group = self.ds_process_group
    rank_in_group = self.rank
    world_size = self.dp_world_size
    use_secondary_tensor = False
    if self.zero_param_process_group and not forward:
        ds_process_group = self.zero_param_process_group  # intragroup
        rank_in_group = self.rank_in_group
        world_size = self.num_ranks_in_param_group

    # pprint(dir(ds_process_group))
    # ensure that each rank has params in same order. the allgather
    # is done by flattening the parameter list into a single tensor that
    # can be allgathered in a single call - this means that if each rank
    # gives a list of the same parameters in a different order we will
    # silently get incorrect parameter values, and have very difficult
    # to debug correctness issues.
    params = sorted(params, key=lambda p: p.ds_id)

    if logger.isEnabledFor(logging.DEBUG):
        debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}")

    if safe_mode:
        # ensure that same list (with same ordering) of parameters are
        # being allgathered across all ranks, otherwise could mix
        # data between tensors.
        assert_ints_same_as_other_ranks([p.ds_id for p in params])
        # ensure that tensors from each rank agree on the same ds_numel
        # otherwise could mix data between tensors.
        assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params])

    if len(params) == 1:
        # have an opportunity to avoid some intermediate memory allocations
        param, = params
        buffer_size = math.ceil(param.ds_numel / world_size) * world_size
        if not forward and param.ds_secondary_tensor is not None:
            buffer_size = param.ds_secondary_tensor.shape[
                              0] * world_size  # make sure out is appropriately sized

        param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor
        param_buffer = torch.empty(
            buffer_size,
            dtype=param_ds_tensor.dtype if not quantize else torch.int8,
            device=get_accelerator().current_device_name(),
            requires_grad=False,
        )
        if not quantize:
            handles = _dist_allgather_fn(
                param_ds_tensor.to(get_accelerator().current_device_name()),
                param_buffer,
                ds_process_group,
            )
            param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
            return AllGatherHandle(handles, param)
        else:
            if hasattr(param_ds_tensor, "ds_quant_scale"):
                scales = param_ds_tensor.ds_quant_scale
                quantized_param = param_ds_tensor.data
            else:
                quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor)
            handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device_name()),
                                        param_buffer, ds_process_group)

            quant_scale_buffer = torch.empty(
                scales.numel() * world_size,
                dtype=scales.dtype,
                device=get_accelerator().current_device_name(),
                requires_grad=False,
            )
            quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device_name()),
                                              quant_scale_buffer, ds_process_group)
            quant_info = QuantizationInfo()
            quant_info.quantized_param = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(
                param.device)
            quant_info.backend = self.quantizer_module
            quant_info.quant_handle = quant_handle
            quant_info.scale_buffer = quant_scale_buffer
            return AllGatherHandle(handle, param, quantization=quant_info)

    else:
        partition_sz = sum(p.ds_tensor.ds_numel for p in params)

        if params[0].ds_secondary_tensor is not None and not forward:
            partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

        flat_tensor = torch.empty(partition_sz * world_size,
                                  dtype=get_only_unique_item(p.ds_tensor.dtype
                                                             for p in params) if not quantize else torch.int8,
                                  device=get_accelerator().current_device_name(),
                                  requires_grad=False)
        if not quantize:
            partitions: List[Parameter] = []
            for i in range(world_size):
                partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))

            if params[0].ds_secondary_tensor is not None and not forward:
                use_secondary_tensor = True
                instrument_w_nvtx(torch.cat)(
                    [p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
                    out=partitions[rank_in_group])
            else:
                instrument_w_nvtx(
                    torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
                               out=partitions[rank_in_group])
            handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
            # Fix get_partition_dp_group(params[0]))

            return AllGatherCoalescedHandle(
                allgather_handle=handle,
                params=params,
                partitions=partitions,
                world_size=world_size,
                use_secondary_tensor=use_secondary_tensor,
                forward=forward,
            )
        else:
            if params[0].ds_secondary_tensor is not None and not forward:
                use_secondary_tensor = True
                if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"):
                    quantized_param = instrument_w_nvtx(torch.cat)([
                        p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params
                    ])
                    scales = instrument_w_nvtx(torch.cat)([
                        p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name())
                        for p in params
                    ])
                else:
                    quantized_param, scales = self.quantizer_module.quantize(
                        instrument_w_nvtx(torch.cat)([
                            p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params
                        ]))
            else:
                if hasattr(params[0].ds_tensor, "ds_quant_scale"):
                    quantized_param = instrument_w_nvtx(torch.cat)(
                        [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params])
                    scales = instrument_w_nvtx(torch.cat)([
                        p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) for p in params
                    ])
                else:
                    quantized_param, scales = self.quantizer_module.quantize(
                        instrument_w_nvtx(torch.cat)(
                            [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params]))
            quant_scale_buffer = torch.empty(
                scales.numel() * world_size,
                dtype=torch.float32,
                device=get_accelerator().current_device_name(),
                requires_grad=False,
            )
            handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
            quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
            quant_info = QuantizationInfo()
            quant_info.quantized_param = flat_tensor
            quant_info.backend = self.quantizer_module
            quant_info.quant_handle = quant_handle
            quant_info.scale_buffer = quant_scale_buffer
            quant_info.partition_sz = partition_sz
            quant_info.world_size = world_size
            return AllGatherCoalescedHandle(
                allgather_handle=handle,
                params=params,
                partitions=None,
                world_size=world_size,
                use_secondary_tensor=use_secondary_tensor,
                forward=forward,
                quantization=quant_info,
            )

6.2. 参数重新分割

同样先把直接 hook 的函数贴过来,

@torch.no_grad()
def post_sub_module_forward_function(self, sub_module):
    see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
                     force=False)
    # 重新释放参数
    param_coordinator = self.get_param_coordinator(training=sub_module.training)
    # 具体操作都在 deepspeed.runtime.zero.PartitionedParameterCoordinator::release_sub_module
    param_coordinator.release_sub_module(sub_module, backward=False)

    see_memory_usage(f"After sub module function {sub_module.__class__.__name__}  {sub_module.id} after release",
                     force=False)
@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
    assert sub_module.training, "backward pass is invalid for module in evaluation mode"
    see_memory_usage(
        f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
        force=False)

    self.get_param_coordinator(training=True).release_sub_module(sub_module, backward=True)

    see_memory_usage(
        f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
        force=False)

显然核心的逻辑在 param_coordinator.release_sub_module(sub_module, backward=False) ,接下来跟进去分析。

6.2.1. release_param

@instrument_w_nvtx
@torch.no_grad()
def release_sub_module(self, submodule: Module, backward: bool) -> None:
    """release the parameters of a sub module, assuming they meet conditions to
    be released."""
    params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
        p.ds_id for p in iter_params(submodule)))
    for param in iter_params(submodule):
        param.ds_active_sub_modules.discard(submodule.id)
        if param.ds_id in params_to_release and not param.is_external_param:
            self.__release_param(param, backward)

@instrument_w_nvtx
def __release_param(self, param: Parameter, backward: bool) -> None:
    if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
        if logger.isEnabledFor(logging.DEBUG):
            debug_rank0(f"-release: {param.ds_summary()}")
        param.partition(backward=backward)
        self.__n_available_params -= param.ds_numel

看到最终调用 param.partition(backward=backward) 对参数有重新进行了一遍分割流程。 这又回到了 节 4.3.2