support cuda graph capture offloading module#2435
support cuda graph capture offloading module#2435lhb8125 wants to merge 22 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile Summary
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User code
participant Make as make_graphed_callables
participant G as Graphed.autograd Function
participant CS as cuda_graph_stream
participant Cur as current_stream
participant F as fwd_graph
participant B as bwd_graph
participant E as cuda_graph_event
participant Mods as TE modules
User->>Make: make_graphed_callables(func, ..., pre_warmup_hook, post_warmup_hook)
Make->>Make: warmup captures
alt pre/post hooks provided
Make->>User: pre_warmup_hook()
Make->>User: post_warmup_hook()
end
User->>G: forward(skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs)
G->>CS: wait_stream(current_stream)
G->>CS: (context) replay fwd_graph
CS->>F: replay()
G->>Cur: wait_event(cuda_graph_event)
G-->>User: return static_outputs.detach()
User->>G: backward(*grad_outputs)
G->>CS: wait_stream(current_stream)
G->>CS: (context) replay bwd_graph
CS->>B: replay()
G->>Cur: wait_event(cuda_graph_event)
G-->>User: return (None, None, None, *grad_inputs)
User->>Make: graphed_callable.backward_dw()
Make->>B: replay bwd_dw_graph
Make->>Mods: trigger_backward_dw() for modules
|
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
for more information, see https://pre-commit.ci
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
|
@buptzyb @zhongbozhu @pggPL Could you review this PR? |
| """ | ||
| return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute() | ||
|
|
||
| def trigger_backward_dw(self): |
There was a problem hiding this comment.
Please ignore this method, which will be removed after https://github.com/NVIDIA/TransformerEngine/pull/2614/files merged
| bwd_dw_graphs[graph_idx].replay() | ||
| for module in te_modules: | ||
| if hasattr(module, "trigger_backward_dw"): | ||
| module.trigger_backward_dw() |
There was a problem hiding this comment.
Please ignore this code block, which will be removed after https://github.com/NVIDIA/TransformerEngine/pull/2614/files are merged
| pool: Optional[Tuple[int, ...]] = None, | ||
| retain_graph_in_backward: bool = False, | ||
| _reuse_graph_input_output_buffers: bool = False, | ||
| pre_warmup_hook: Optional[Callable] = None, |
There was a problem hiding this comment.
Just to confirm: are the hooks used to disable and re-enable offloading for warmup? Could you point me to the implementation in MCore?
| cuda_graph_event = user_kwargs["cuda_graph_event"] | ||
| user_kwargs.pop("cuda_graph_event") | ||
| else: | ||
| cuda_graph_event = torch.cuda.Event() |
There was a problem hiding this comment.
| cuda_graph_event = torch.cuda.Event() | |
| cuda_graph_event = None |
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||
| with cuda_graph_stream: | ||
| fwd_graph.replay() | ||
| torch.cuda.current_stream().wait_event(cuda_graph_event) |
There was a problem hiding this comment.
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with cuda_graph_stream: | |
| fwd_graph.replay() | |
| torch.cuda.current_stream().wait_event(cuda_graph_event) | |
| if cuda_graph_stream != torch.cuda.current_stream(): | |
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with cuda_graph_stream: | |
| fwd_graph.replay() | |
| if cuda_graph_event is not None: | |
| torch.cuda.current_stream().wait_event(cuda_graph_event) | |
| else: | |
| torch.cuda.current_stream().wait_stream(cuda_graph_stream) | |
| else: | |
| fwd_graph.replay() |
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||
| with ctx.cuda_graph_stream: | ||
| bwd_graph.replay() | ||
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) |
There was a problem hiding this comment.
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with ctx.cuda_graph_stream: | |
| bwd_graph.replay() | |
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) | |
| if ctx.cuda_graph_stream != torch.cuda.current_stream(): | |
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with ctx.cuda_graph_stream: | |
| bwd_graph.replay() | |
| if ctx.cuda_graph_event is not None: | |
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) | |
| else: | |
| torch.cuda.current_stream().wait_stream(ctx.cuda_graph_stream) | |
| else: | |
| bwd_graph.replay() |
|
|
||
| skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] | ||
|
|
||
| if "cuda_graph_stream" in user_kwargs: |
There was a problem hiding this comment.
I suggest adding more comments to tell the developers how they can take advantage of the new cuda_graph_stream/cuda_graph_event arguments. For example:
Replay cudagraph in a dedicated cuda_graph_stream to allow overlap with work on the main stream. Meanwhile, when cuda_graph_event is given, it should be an external event captured in the cudagraph and is used to sync-back to the main stream. If no cuda_graph_event is given, the cuda_graph_stream is synced-back to the main stream. Note that a dedicated cuda_graph_stream rather than the current stream must be given if cuda_graph_event is given.
| else: | ||
| cuda_graph_stream = torch.cuda.current_stream() | ||
| if "cuda_graph_event" in user_kwargs: | ||
| cuda_graph_event = user_kwargs["cuda_graph_event"] |
There was a problem hiding this comment.
| cuda_graph_event = user_kwargs["cuda_graph_event"] | |
| assert cuda_graph_stream != torch.cuda.current_stream() | |
| cuda_graph_event = user_kwargs["cuda_graph_event"] |
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
for more information, see https://pre-commit.ci
| if col_scale is not None | ||
| else None | ||
| ) |
There was a problem hiding this comment.
Restored subviews have wrong dtype
restore_columnwise_subviews() rebuilds _columnwise_data/_columnwise_scale_inv/_columnwise_amax via columnwise_buffer.as_strided(...), but as_strided preserves the base tensor dtype (uint8). After reload, _columnwise_scale_inv (float32) and _columnwise_amax (e.g., float32) will become uint8 views, so any later kernels reading these fields will interpret bytes as elements and produce incorrect results. You likely need to reconstruct views with the original dtype (e.g., columnwise_buffer.view(orig_dtype).as_strided(...) with element offsets adjusted) rather than directly as-striding the uint8 buffer.
|
|
||
| def mark_not_offload(*tensors: torch.Tensor): | ||
| """Marks tensors to prevent them from being offloaded.""" | ||
| if NVTE_CPU_OFFLOAD_V1: | ||
| return | ||
|
|
||
| tensors, tensor_obj = prepare_for_saving(*tensors) | ||
|
|
There was a problem hiding this comment.
mark_not_offload ignored in v1
This PR removed the if NVTE_CPU_OFFLOAD_V1: return guard in mark_not_offload(). In the v1 offload path (cpu_offload_v1.py), offload eligibility is controlled via activation_offloading and the handler’s tensor_need_offloading_checker—_TE_do_not_offload is never checked—so this becomes an ineffective marker in v1 mode while still doing prepare_for_saving()/restore_from_saved() work on (possibly) QuantizedTensorStorage. That’s a behavior change under NVTE_CPU_OFFLOAD_V1=1 and can introduce unnecessary overhead / state churn. Either restore the v1 early-return or make the v1 path honor the same “do not offload” marker.
Description
This PR supports offloading modules captured by partial cuda graph in Megatron-LM.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
record_stream()pre_warmup_hookandpost_warmup_hookcuda_graph_streamandcuda_graph_eventtouser_kwargsso that the fwd&bwd replay runs at a side stream, where thecuda_graph_eventrecords on current stream after finishing computing.Checklist: