Skip to content

support cuda graph capture offloading module#2435

Open
lhb8125 wants to merge 22 commits intoNVIDIA:mainfrom
lhb8125:hongbinl/offload_activation_cuda_graph
Open

support cuda graph capture offloading module#2435
lhb8125 wants to merge 22 commits intoNVIDIA:mainfrom
lhb8125:hongbinl/offload_activation_cuda_graph

Conversation

@lhb8125
Copy link
Contributor

@lhb8125 lhb8125 commented Dec 1, 2025

Description

This PR supports offloading modules captured by partial cuda graph in Megatron-LM.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Megatron's offloading relys on cpu_offload_v1 but we need to reuse mark_not_offload() so that the weights won't be offloaded.
  • Do not offload the output of core_attention
  • Refine the allocation strategy of fp8&fp4 tensors because the previous impl allocates tensor by from_blob(), which is not compatitable with record_stream()
  • Add pre_warmup_hook and post_warmup_hook
  • Passing cuda_graph_stream and cuda_graph_event to user_kwargs so that the fwd&bwd replay runs at a side stream, where the cuda_graph_event records on current stream after finishing computing.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
@lhb8125 lhb8125 marked this pull request as draft December 1, 2025 05:34
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 1, 2025

Greptile Overview

Greptile Summary

  • Extends make_graphed_callables to support pre/post warmup hooks and to replay forward/backward graphs on a provided side stream/event, plus an extra backward_dw hook trigger for TE modules.
  • Updates quantization bulk-allocation to keep backing uint8 buffers and changes split_quantize to return both quantized outputs and these buffers for offload/stream-recording use cases.
  • Adjusts GroupedLinear CPU-offload behavior to offload the shared columnwise buffer rather than individual quantized tensor views and attempts to restore views after reload.
  • Adds mark_not_offload on attention outputs to prevent them from being offloaded in CPU-offload workflows.

Confidence Score: 2/5

  • This PR likely needs fixes before merging due to a correctness issue in restored quantized tensor subviews and a behavior change in CPU offload v1 marking.
  • Core feature work is clear, but restore_columnwise_subviews() currently reconstructs float/typed internal tensors as uint8 views, which will break any downstream computations using those fields. Additionally, removing the v1 guard in mark_not_offload() changes semantics under NVTE_CPU_OFFLOAD_V1 and can introduce overhead/state churn while not affecting v1 offload decisions.
  • transformer_engine/pytorch/quantized_tensor.py, transformer_engine/pytorch/module/grouped_linear.py, transformer_engine/pytorch/cpu_offload.py

Important Files Changed

Filename Overview
transformer_engine/pytorch/cpu_offload.py Removes the NVTE_CPU_OFFLOAD_V1 early-return guard in mark_not_offload(), changing behavior under v1 and potentially causing unnecessary quantized-tensor pack/unpack overhead.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds mark_not_offload() on attention outputs to avoid offloading; interacts with CPU offload v1 path where offloading is controlled via mark_activation_offload().
transformer_engine/pytorch/graph.py Adds pre/post warmup hooks, routes graph replay through a provided stream/event, and triggers TE module backward_dw hooks after bwd_dw graph replay; changes autograd signature to include cuda_graph_stream/event.
transformer_engine/pytorch/csrc/extensions/cast.cpp Changes split_quantize to return (outputs, buffer_list) and switches bulk-allocation views to be based on allocated uint8 buffers (kept for offload/record_stream).
transformer_engine/pytorch/csrc/extensions.h Updates split_quantize C++ API signature to return a tuple of (quantized outputs, backing buffers).
transformer_engine/pytorch/module/grouped_linear.py Reworks grouped_linear quantized activation offload by offloading the shared buffer and restoring columnwise subviews in backward; changes usage flags during CPU offload.
transformer_engine/pytorch/module/base.py Adds trigger_backward_dw() to execute delayed wgrad accumulation/reduce hooks, used from graphed callable attribute functions.
transformer_engine/pytorch/quantized_tensor.py Adds helpers to capture columnwise subview offsets and restore subviews after buffer reload using as_strided; used to support offloading of shared buffers.
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py Updates test to unpack new split_quantize tuple return (ignoring buffer list).
tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py Updates test to unpack new split_quantize tuple return (ignoring buffer list).

Sequence Diagram

sequenceDiagram
    participant User as User code
    participant Make as make_graphed_callables
    participant G as Graphed.autograd Function
    participant CS as cuda_graph_stream
    participant Cur as current_stream
    participant F as fwd_graph
    participant B as bwd_graph
    participant E as cuda_graph_event
    participant Mods as TE modules

    User->>Make: make_graphed_callables(func, ..., pre_warmup_hook, post_warmup_hook)
    Make->>Make: warmup captures
    alt pre/post hooks provided
        Make->>User: pre_warmup_hook()
        Make->>User: post_warmup_hook()
    end

    User->>G: forward(skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs)
    G->>CS: wait_stream(current_stream)
    G->>CS: (context) replay fwd_graph
    CS->>F: replay()
    G->>Cur: wait_event(cuda_graph_event)
    G-->>User: return static_outputs.detach()

    User->>G: backward(*grad_outputs)
    G->>CS: wait_stream(current_stream)
    G->>CS: (context) replay bwd_graph
    CS->>B: replay()
    G->>Cur: wait_event(cuda_graph_event)
    G-->>User: return (None, None, None, *grad_inputs)

    User->>Make: graphed_callable.backward_dw()
    Make->>B: replay bwd_dw_graph
    Make->>Mods: trigger_backward_dw() for modules
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

lhb8125 and others added 17 commits December 8, 2025 06:35
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: root <root@eos0046.eos.clusters.nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
@lhb8125
Copy link
Contributor Author

lhb8125 commented Feb 5, 2026

/te-ci pytorch L1

@lhb8125 lhb8125 marked this pull request as ready for review February 5, 2026 03:37
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@lhb8125
Copy link
Contributor Author

lhb8125 commented Feb 5, 2026

@buptzyb @zhongbozhu @pggPL Could you review this PR?

"""
return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()

def trigger_backward_dw(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ignore this method, which will be removed after https://github.com/NVIDIA/TransformerEngine/pull/2614/files merged

bwd_dw_graphs[graph_idx].replay()
for module in te_modules:
if hasattr(module, "trigger_backward_dw"):
module.trigger_backward_dw()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ignore this code block, which will be removed after https://github.com/NVIDIA/TransformerEngine/pull/2614/files are merged

pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False,
_reuse_graph_input_output_buffers: bool = False,
pre_warmup_hook: Optional[Callable] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm: are the hooks used to disable and re-enable offloading for warmup? Could you point me to the implementation in MCore?

cuda_graph_event = user_kwargs["cuda_graph_event"]
user_kwargs.pop("cuda_graph_event")
else:
cuda_graph_event = torch.cuda.Event()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cuda_graph_event = torch.cuda.Event()
cuda_graph_event = None

Comment on lines +801 to +804
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
if cuda_graph_stream != torch.cuda.current_stream():
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
if cuda_graph_event is not None:
torch.cuda.current_stream().wait_event(cuda_graph_event)
else:
torch.cuda.current_stream().wait_stream(cuda_graph_stream)
else:
fwd_graph.replay()

Comment on lines +821 to +824
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
if ctx.cuda_graph_stream != torch.cuda.current_stream():
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
if ctx.cuda_graph_event is not None:
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
else:
torch.cuda.current_stream().wait_stream(ctx.cuda_graph_stream)
else:
bwd_graph.replay()


skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]

if "cuda_graph_stream" in user_kwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding more comments to tell the developers how they can take advantage of the new cuda_graph_stream/cuda_graph_event arguments. For example:

Replay cudagraph in a dedicated cuda_graph_stream to allow overlap with work on the main stream. Meanwhile, when cuda_graph_event is given, it should be an external event captured in the cudagraph and is used to sync-back to the main stream. If no cuda_graph_event is given, the cuda_graph_stream is synced-back to the main stream. Note that a dedicated cuda_graph_stream rather than the current stream must be given if cuda_graph_event is given.

else:
cuda_graph_stream = torch.cuda.current_stream()
if "cuda_graph_event" in user_kwargs:
cuda_graph_event = user_kwargs["cuda_graph_event"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cuda_graph_event = user_kwargs["cuda_graph_event"]
assert cuda_graph_stream != torch.cuda.current_stream()
cuda_graph_event = user_kwargs["cuda_graph_event"]

Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +212 to +214
if col_scale is not None
else None
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored subviews have wrong dtype

restore_columnwise_subviews() rebuilds _columnwise_data/_columnwise_scale_inv/_columnwise_amax via columnwise_buffer.as_strided(...), but as_strided preserves the base tensor dtype (uint8). After reload, _columnwise_scale_inv (float32) and _columnwise_amax (e.g., float32) will become uint8 views, so any later kernels reading these fields will interpret bytes as elements and produce incorrect results. You likely need to reconstruct views with the original dtype (e.g., columnwise_buffer.view(orig_dtype).as_strided(...) with element offsets adjusted) rather than directly as-striding the uint8 buffer.

Comment on lines 45 to 50

def mark_not_offload(*tensors: torch.Tensor):
"""Marks tensors to prevent them from being offloaded."""
if NVTE_CPU_OFFLOAD_V1:
return

tensors, tensor_obj = prepare_for_saving(*tensors)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mark_not_offload ignored in v1

This PR removed the if NVTE_CPU_OFFLOAD_V1: return guard in mark_not_offload(). In the v1 offload path (cpu_offload_v1.py), offload eligibility is controlled via activation_offloading and the handler’s tensor_need_offloading_checker_TE_do_not_offload is never checked—so this becomes an ineffective marker in v1 mode while still doing prepare_for_saving()/restore_from_saved() work on (possibly) QuantizedTensorStorage. That’s a behavior change under NVTE_CPU_OFFLOAD_V1=1 and can introduce unnecessary overhead / state churn. Either restore the v1 early-return or make the v1 path honor the same “do not offload” marker.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants