From 2f4c6b814625bd9303ac24642e126f8574fe46f8 Mon Sep 17 00:00:00 2001 From: Lifu Zhang Date: Fri, 6 Feb 2026 10:33:10 -0800 Subject: [PATCH 1/4] Fix on TE to support Mcore Vision Encoder CUDA Graph Signed-off-by: Lifu Zhang --- transformer_engine/pytorch/graph.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f587ca9946..80b5f07a2f 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -452,9 +452,9 @@ def hook_fn( inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): torch.autograd.backward( - tuple(o for o in outputs if o.requires_grad), + tuple(o for o in outputs if o is not None and o.requires_grad), grad_tensors=tuple( - torch.empty_like(o) for o in outputs if o.requires_grad + torch.empty_like(o) for o in outputs if o is not None and o.requires_grad ), ) grad_inputs = tuple(input.grad for input in inputs) @@ -616,19 +616,19 @@ def hook_fn( # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. static_grad_outputs_keys = tuple( - (o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad + (o.shape, o.dtype, o.layout) for o in static_outputs if o is not None and o.requires_grad ) if static_grad_outputs_keys in static_grad_outputs_dict: static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -636,7 +636,7 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple(o for o in static_outputs if o is not None and o.requires_grad), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -719,7 +719,7 @@ def hook_fn( ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -727,7 +727,7 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple(o for o in static_outputs if o is not None and o.requires_grad), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -794,7 +794,7 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Replay forward graph fwd_graph.replay() assert isinstance(static_outputs, tuple) - return tuple(o.detach() for o in static_outputs) + return tuple(o.detach() if o is not None else o for o in static_outputs) @staticmethod @torch.autograd.function.once_differentiable From 503a64458641f1a0c033d8de77e3bbf8f628f8d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:39:18 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 80b5f07a2f..cc7aade361 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -454,7 +454,9 @@ def hook_fn( torch.autograd.backward( tuple(o for o in outputs if o is not None and o.requires_grad), grad_tensors=tuple( - torch.empty_like(o) for o in outputs if o is not None and o.requires_grad + torch.empty_like(o) + for o in outputs + if o is not None and o.requires_grad ), ) grad_inputs = tuple(input.grad for input in inputs) @@ -616,7 +618,9 @@ def hook_fn( # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. static_grad_outputs_keys = tuple( - (o.shape, o.dtype, o.layout) for o in static_outputs if o is not None and o.requires_grad + (o.shape, o.dtype, o.layout) + for o in static_outputs + if o is not None and o.requires_grad ) if static_grad_outputs_keys in static_grad_outputs_dict: static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] @@ -628,7 +632,8 @@ def hook_fn( static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs else: static_grad_outputs = tuple( - torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -636,7 +641,9 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o is not None and o.requires_grad), + tuple( + o for o in static_outputs if o is not None and o.requires_grad + ), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -719,7 +726,8 @@ def hook_fn( ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( - torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) From 1f1efff062ffefc7523666d3c26d0d9dbecde638 Mon Sep 17 00:00:00 2001 From: Lifu Zhang Date: Tue, 10 Feb 2026 16:20:53 -0800 Subject: [PATCH 3/4] refactoring code Signed-off-by: Lifu Zhang --- transformer_engine/pytorch/graph.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index cc7aade361..3d6b217d0b 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -451,13 +451,10 @@ def hook_fn( if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): + outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad) torch.autograd.backward( - tuple(o for o in outputs if o is not None and o.requires_grad), - grad_tensors=tuple( - torch.empty_like(o) - for o in outputs - if o is not None and o.requires_grad - ), + outputs_requiring_grad, + grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), ) grad_inputs = tuple(input.grad for input in inputs) From 6dd08b516d53c4fd7fee00a50b93cfd613fe3b64 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Feb 2026 00:25:54 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 3d6b217d0b..2b30e3f02a 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -451,7 +451,9 @@ def hook_fn( if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): - outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad) + outputs_requiring_grad = tuple( + o for o in outputs if o is not None and o.requires_grad + ) torch.autograd.backward( outputs_requiring_grad, grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad),