[bugfix] fix qwen3_vl#73
Conversation
There was a problem hiding this comment.
Code Review
This pull request modifies the transformer block to support keyword arguments during activation checkpointing by wrapping the forward function and converting kwargs into positional arguments. This ensures that tensor keyword arguments are correctly tracked by the autograd graph. Feedback suggests that other tensor inputs currently captured via closure, such as attention_bias and packed_seq_params, should also be passed as explicit arguments to ensure robust autograd tracking during recomputation.
| def wrapped_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb, padding_mask, | ||
| *extra_args): | ||
| extra_kwargs = dict(zip(extra_kwargs_keys, extra_args)) | ||
| return forward_func( | ||
| hidden_states, | ||
| attention_mask, | ||
| context, | ||
| context_mask, | ||
| rotary_pos_emb, | ||
| padding_mask, | ||
| **extra_kwargs, | ||
| ) |
There was a problem hiding this comment.
The wrapped_forward function correctly converts positional arguments back to keyword arguments to bypass the torch.utils.checkpoint limitation. However, note that attention_bias and packed_seq_params (from the outer _checkpointed_forward scope) are still captured via closure inside custom_forward. While this might be acceptable if they don't require gradients, it is generally safer to pass all Tensor inputs as explicit arguments to the checkpointed function to ensure they are correctly tracked by the autograd engine during activation recomputation.
No description provided.