Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@ Tensor dense_to_jagged_forward(
// D is the embedding dimension
auto D = dense.size(-1);

// If total_L is not given then compute it
at::SymInt total_L_computed;
// If total_L is not given then compute it. We realize total_L to a
// concrete int64 here (via guard_int) instead of forwarding the raw
// SymInt to at::empty_symint. The aten empty.memory_format CUDA/HIP
// wrapper calls C10_AS_INTARRAYREF_SLOW on its size argument, which
// TORCH_CHECKs that no SymInt in the array is heap-allocated -- so a
// heap SymInt that arrives here (e.g. an unbacked SymInt produced
// inside a torch.compile region) would crash with
// "SymIntArrayRef expected to contain only concrete integers".
int64_t total_L_computed;
if (total_L.has_value()) {
total_L_computed = total_L.value();
total_L_computed = total_L.value().guard_int(__FILE__, __LINE__);
} else {
total_L_computed = (int64_t)offsets.back().max().item<int64_t>();
total_L_computed = offsets.back().max().item<int64_t>();
}
auto values = at::empty_symint({total_L_computed, D}, dense.options());
auto values = at::empty({total_L_computed, D}, dense.options());
auto output = at::empty_like(values);

CUDA_DEVICE_GUARD(dense);
Expand Down
15 changes: 8 additions & 7 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,17 @@ Tensor dense_to_jagged_forward(
// D is the embedding dimension
auto D = dense.size(-1);

// If total_L is not given then compute it
at::SymInt total_L_computed;
// Realize total_L to a concrete int64 (see CUDA variant for the full
// explanation -- forwarding a heap SymInt to empty/zeros_symint crashes
// in the empty.memory_format wrapper's C10_AS_INTARRAYREF_SLOW).
int64_t total_L_computed;
if (total_L.has_value()) {
total_L_computed = total_L.value();
total_L_computed = total_L.value().guard_int(__FILE__, __LINE__);
} else {
total_L_computed =
static_cast<int64_t>(offsets.back().max().item<int64_t>());
total_L_computed = offsets.back().max().item<int64_t>();
}
auto values = at::empty_symint({total_L_computed, D}, dense.options());
auto output = at::zeros_symint({total_L_computed, D}, dense.options());
auto values = at::empty({total_L_computed, D}, dense.options());
auto output = at::zeros({total_L_computed, D}, dense.options());

FBGEMM_DISPATCH_ALL_TYPES(values.scalar_type(), "jagged_scalars", [&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
Expand Down
73 changes: 73 additions & 0 deletions fbgemm_gpu/test/jagged/dense_to_jagged_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,79 @@ def dense_to_jagged_noL(
# verify forward
assert dense.size() == dense2.size()

@optests.dontGenerateOpCheckTests("regression test, not an op-shape check")
@unittest.skipIf(*gpu_unavailable)
def test_dense_to_jagged_heap_symint_total_L(self) -> None:
"""Regression: dense_to_jagged_forward crashes on AMD/HIP when the
total_L argument is a heap-allocated SymInt rather than an inline
concrete int.

Production failure (Stories LSR on MI350X, MAST f1096341099): the
outer forward is wrapped by torch.compile; downstream a Python int
num_events was symbolicalized into a heap SymInt and forwarded to
torch.ops.fbgemm.dense_to_jagged. dense_to_jagged_forward.cu then
calls
at::empty_symint({total_L_computed, D}, ...)
with the heap SymInt. On the hipified empty.memory_format dispatcher
the shape array goes through asIntArrayRefSlow which raises
"SymIntArrayRef expected to contain only concrete integers"
or, when the SymNode pointer slips through unchecked, lands in
empty_generic with the pointer reinterpreted as int64, producing
"Trying to create tensor with negative dimension <huge negative>".

This test bypasses dynamo and constructs the heap SymInt directly
via ShapeEnv.create_unbacked_symint(), then calls the real op with
a real CUDA/HIP tensor. That reproduces the exact kernel-level
condition without depending on which torch.compile backend / version
chooses to preserve vs. realize the SymInt.

Fix: dense_to_jagged_forward must realize the total_L SymInt to a
concrete int64 (via .guard_int(__FILE__, __LINE__)) before
constructing the empty output, instead of forwarding the raw SymInt
to empty_symint.
"""
from torch.fx.experimental.symbolic_shapes import ShapeEnv

device = torch.accelerator.current_accelerator()

B = 4
D = 8
max_L = 16
total_L_concrete = 17

lengths = torch.tensor([3, 5, 2, 7], dtype=torch.long, device=device)
offsets = torch.zeros(B + 1, dtype=torch.long, device=device)
offsets[1:] = lengths.cumsum(0)
self.assertEqual(int(offsets[-1].item()), total_L_concrete)

dense = torch.randn(B, max_L, D, dtype=torch.float32, device=device)

# Construct an unbacked, heap-allocated SymInt with no hint. The
# IValue(c10::SymInt) constructor only preserves heap-ness when the
# SymNode's maybe_as_int() returns None -- i.e. when the SymInt is
# truly unbacked (no hint, no constant simplification). A SymInt
# with a hint is collapsed to Tag::Int (inline) by IValue and would
# not exercise the kernel-level crash path.
shape_env = ShapeEnv()
total_L_sym = shape_env.create_unbacked_symint()

# Pre-fix kernel forwards this heap SymInt to at::empty_symint,
# which crashes in the empty.memory_format HIP wrapper at
# "SymIntArrayRef expected to contain only concrete integers".
# Post-fix the kernel realizes total_L via guard_int(__FILE__, __LINE__)
# before constructing the empty output. guard_int on a truly
# unbacked SymNode (no hint, no runtime guard) raises a clean
# GuardOnDataDependentSymNode -- which is the correct user-visible
# behavior: "the kernel cannot allocate without a concrete size."
# That clean error is what we assert here. The buggy pre-fix path
# raised the low-level SymIntArrayRef error from inside the empty
# wrapper, which fails this regex match.
with self.assertRaisesRegex(
RuntimeError,
r"Could not extract specialized integer from data-dependent expression",
):
torch.ops.fbgemm.dense_to_jagged(dense, [offsets], total_L_sym)


if __name__ == "__main__":
unittest.main()
Loading