.. _ch_deepspeed_hook: ########################################## Stage3 - hook 注册 ########################################## 参数分割之后,在执行前向、后向之前,需要先把参数再还原回来。 同理,在执行前向后向之后,还要释放掉各自不需要的参数。 这里利用 ``pytorch`` 的 ``hook`` 功能在上述四个关键节点插入相关的动作。 ``pytorch`` 的 ``Module`` 类型提供了一系列 ``register_xxx_hook`` 方法来实现 ``hook`` 功能。 ``deepspeed`` 的 ``hook`` 动作都在类 ``DeepSpeedZeRoOffload`` 中实现, 具体的在方法 ``DeepSpeedZeRoOffload::setup_zero_stage3_hooks`` 中。 .. code-block:: python def setup_zero_stage3_hooks(self): """ 注册 stage3 相关的hook函数 Returns: """ self.hierarchy = 0 #reset step if in inference mode @instrument_w_nvtx def _end_of_forward_hook(module, *args): if not torch._C.is_grad_enabled(): self.get_param_coordinator(training=False).reset_step() #likely one of them should be enough but just to be safe # 注册各种 钩子 hook , # 包括 pre_forward、pre_backward、post_forward、post_backward self._register_hooks_recursively(self.module) self.module.register_forward_hook(_end_of_forward_hook) # Add top module to stack trace global FWD_MODULE_STACK FWD_MODULE_STACK.append(self.module) def _register_hooks_recursively(self, module, count=[0]): """真正执行hook操作""" my_count = count[0] module.id = my_count #print(f"{module.__class__} : {module.id}") for child in module.children(): count[0] = count[0] + 1 self._register_hooks_recursively(child, count=count) @instrument_w_nvtx def _pre_forward_module_hook(module, *args): self.pre_sub_module_forward_function(module) @instrument_w_nvtx def _post_forward_module_hook(module, input, output): global FWD_MODULE_STACK FWD_MODULE_STACK.pop() if output is None: output = [] elif not isinstance(output, (list, tuple)): if torch.is_tensor(output): output = [output] else: #print(f'got UNKNOWN type {type(output)}') outputs = [] output = output if isinstance(output, dict) else vars(output) for name, val in output.items(): if not name.startswith('__') and torch.is_tensor(val): outputs.append(val) output = outputs for item in filter(lambda item: is_zero_param(item) or hasattr(item, 'ds_param_alias'), output): key = id(item) if hasattr(item, 'ds_id') else id(item.ds_param_alias) actual_external_param = item if hasattr(item, 'ds_id') else item.ds_param_alias if not any(key in m._external_params for m in FWD_MODULE_STACK): actual_external_param.is_external_param = True module_to_register = FWD_MODULE_STACK[-1] register_external_parameter(module_to_register, actual_external_param) print_rank_0( f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {actual_external_param.ds_id}.', force=False) # It's possible that the parameter was already external to the completed module. If so, remove it the # registration as it will be covered by the outer module instead. if key in module._external_params: print_rank_0( f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {actual_external_param.ds_id}', force=False) unregister_external_parameter(module, actual_external_param) actual_external_param.all_gather() self.post_sub_module_forward_function(module) def _pre_backward_module_hook(module, inputs, output): @instrument_w_nvtx def _run_before_backward_function(sub_module): # some models (e.g. Albert) may run multiple forwards on the same layer in a loop # before doing backwards, so each backward will need a pre-fetch - using reference # counting to support this scenario #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") if sub_module.applied_pre_backward_ref_cnt > 0: self.pre_sub_module_backward_function(sub_module) sub_module.applied_pre_backward_ref_cnt -= 1 #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") return _apply_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output) #This is an alternate to doing _post_backward_module_hook #it uses tensor.register_hook instead of using torch.autograd.Function def _alternate_post_backward_module_hook(module, inputs): module.ds_grads_remaining = 0 #print(f"Before Forward {module.__class__.__name__}") def _run_after_backward_hook(*unused): module.ds_grads_remaining = module.ds_grads_remaining - 1 if module.ds_grads_remaining == 0: #print(f"After backward {module.__class__.__name__}") self.post_sub_module_backward_function(module) def _run_before_forward_function(input): if input.requires_grad: module.ds_grads_remaining += 1 return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function, _run_after_backward_hook, inputs) def _post_backward_module_hook(module, inputs): module.ds_grads_remaining = 0 @instrument_w_nvtx def _run_after_backward_function(sub_module): if sub_module.ds_grads_remaining == 0: self.post_sub_module_backward_function(sub_module) return _apply_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs) # Pre forward hook self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook)) # Post forward hook self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook)) # Pre backward hook self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook)) # post backward hook self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook)) **前向过程之前 pre_forward** 显然,在执行前向过程之前,我们需要 **把被分割的参数还原回来**,这里自然通过 ``AllGather`` 通信还原本层的参数。 这里通过 ``module.register_forward_pre_hook(_pre_forward_module_hook)`` 进行注册, 顺着 ``_pre_forward_module_hook`` 跟踪下去: .. code-block:: python @torch.no_grad() def pre_sub_module_forward_function(self, sub_module): see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False) global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) param_coordinator = self.get_param_coordinator(training=sub_module.training) param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) # 真正的参数聚合动作是在这里 # param_coordinator 的类型是 deepspeed.runtime.zero.PartitionedParameterCoordinator param_coordinator.fetch_sub_module(sub_module, forward=True) see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False) **前向过程之后 post_forward** .. code-block:: python @torch.no_grad() def post_sub_module_forward_function(self, sub_module): see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) # 重新释放参数 param_coordinator = self.get_param_coordinator(training=sub_module.training) # 具体操作都在 deepspeed.runtime.zero.PartitionedParameterCoordinator::release_sub_module param_coordinator.release_sub_module(sub_module, backward=False) see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) **后向过程之前 pre_backward** .. code-block:: python @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): assert sub_module.training, "backward pass is invalid for module in evaluation mode" param_coordinator = self.get_param_coordinator(training=True) param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) param_coordinator.fetch_sub_module(sub_module, forward=False) **后向过程之后 post_backward** .. code-block:: python @torch.no_grad() def post_sub_module_backward_function(self, sub_module): assert sub_module.training, "backward pass is invalid for module in evaluation mode" see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) self.get_param_coordinator(training=True).release_sub_module(sub_module, backward=True) see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) 可以看到每一个节点的实现其实都在 ``deepspeed.runtime.zero.PartitionedParameterCoordinator`` 里面, 最终会跳转这个类里去执行。