Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,11 @@ def hook_fn(
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs):
torch.autograd.backward(
tuple(o for o in outputs if o.requires_grad),
tuple(o for o in outputs if o is not None and o.requires_grad),
grad_tensors=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
torch.empty_like(o)
for o in outputs
if o is not None and o.requires_grad
),
)
grad_inputs = tuple(input.grad for input in inputs)
Expand Down Expand Up @@ -616,27 +618,32 @@ def hook_fn(
# Note for _reuse_graph_input_output_buffers: grad output is only used
# within backward, so we can reuse the same static buffers every time.
static_grad_outputs_keys = tuple(
(o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad
(o.shape, o.dtype, o.layout)
for o in static_outputs
if o is not None and o.requires_grad
)
if static_grad_outputs_keys in static_grad_outputs_dict:
static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys]
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs
else:
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
tuple(
o for o in static_outputs if o is not None and o.requires_grad
),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
Expand Down Expand Up @@ -719,15 +726,16 @@ def hook_fn(
):
# For now, assumes all static_outputs require grad
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
torch.empty_like(o) if o is not None and o.requires_grad else None
for o in static_outputs
)
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
torch.autograd.backward(
tuple(o for o in static_outputs if o.requires_grad),
tuple(o for o in static_outputs if o is not None and o.requires_grad),
grad_tensors=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=retain_graph_in_backward,
)
Expand Down Expand Up @@ -794,7 +802,7 @@ def forward(ctx, skip_fp8_weight_update, *inputs):
# Replay forward graph
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
return tuple(o.detach() if o is not None else o for o in static_outputs)

@staticmethod
@torch.autograd.function.once_differentiable
Expand Down