diff --git a/fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu index 06f1e652cc..6d95d8c04c 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu @@ -19,14 +19,21 @@ Tensor dense_to_jagged_forward( // D is the embedding dimension auto D = dense.size(-1); - // If total_L is not given then compute it - at::SymInt total_L_computed; + // If total_L is not given then compute it. We realize total_L to a + // concrete int64 here (via guard_int) instead of forwarding the raw + // SymInt to at::empty_symint. The aten empty.memory_format CUDA/HIP + // wrapper calls C10_AS_INTARRAYREF_SLOW on its size argument, which + // TORCH_CHECKs that no SymInt in the array is heap-allocated -- so a + // heap SymInt that arrives here (e.g. an unbacked SymInt produced + // inside a torch.compile region) would crash with + // "SymIntArrayRef expected to contain only concrete integers". + int64_t total_L_computed; if (total_L.has_value()) { - total_L_computed = total_L.value(); + total_L_computed = total_L.value().guard_int(__FILE__, __LINE__); } else { - total_L_computed = (int64_t)offsets.back().max().item(); + total_L_computed = offsets.back().max().item(); } - auto values = at::empty_symint({total_L_computed, D}, dense.options()); + auto values = at::empty({total_L_computed, D}, dense.options()); auto output = at::empty_like(values); CUDA_DEVICE_GUARD(dense); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp index fb7b9d3360..a904ae45a0 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp @@ -467,16 +467,17 @@ Tensor dense_to_jagged_forward( // D is the embedding dimension auto D = dense.size(-1); - // If total_L is not given then compute it - at::SymInt total_L_computed; + // Realize total_L to a concrete int64 (see CUDA variant for the full + // explanation -- forwarding a heap SymInt to empty/zeros_symint crashes + // in the empty.memory_format wrapper's C10_AS_INTARRAYREF_SLOW). + int64_t total_L_computed; if (total_L.has_value()) { - total_L_computed = total_L.value(); + total_L_computed = total_L.value().guard_int(__FILE__, __LINE__); } else { - total_L_computed = - static_cast(offsets.back().max().item()); + total_L_computed = offsets.back().max().item(); } - auto values = at::empty_symint({total_L_computed, D}, dense.options()); - auto output = at::zeros_symint({total_L_computed, D}, dense.options()); + auto values = at::empty({total_L_computed, D}, dense.options()); + auto output = at::zeros({total_L_computed, D}, dense.options()); FBGEMM_DISPATCH_ALL_TYPES(values.scalar_type(), "jagged_scalars", [&] { jagged_dense_elementwise_jagged_output_( diff --git a/fbgemm_gpu/test/jagged/dense_to_jagged_test.py b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py index 5048d71112..0b5e1e3c2b 100644 --- a/fbgemm_gpu/test/jagged/dense_to_jagged_test.py +++ b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py @@ -368,6 +368,79 @@ def dense_to_jagged_noL( # verify forward assert dense.size() == dense2.size() + @optests.dontGenerateOpCheckTests("regression test, not an op-shape check") + @unittest.skipIf(*gpu_unavailable) + def test_dense_to_jagged_heap_symint_total_L(self) -> None: + """Regression: dense_to_jagged_forward crashes on AMD/HIP when the + total_L argument is a heap-allocated SymInt rather than an inline + concrete int. + + Production failure (Stories LSR on MI350X, MAST f1096341099): the + outer forward is wrapped by torch.compile; downstream a Python int + num_events was symbolicalized into a heap SymInt and forwarded to + torch.ops.fbgemm.dense_to_jagged. dense_to_jagged_forward.cu then + calls + at::empty_symint({total_L_computed, D}, ...) + with the heap SymInt. On the hipified empty.memory_format dispatcher + the shape array goes through asIntArrayRefSlow which raises + "SymIntArrayRef expected to contain only concrete integers" + or, when the SymNode pointer slips through unchecked, lands in + empty_generic with the pointer reinterpreted as int64, producing + "Trying to create tensor with negative dimension ". + + This test bypasses dynamo and constructs the heap SymInt directly + via ShapeEnv.create_unbacked_symint(), then calls the real op with + a real CUDA/HIP tensor. That reproduces the exact kernel-level + condition without depending on which torch.compile backend / version + chooses to preserve vs. realize the SymInt. + + Fix: dense_to_jagged_forward must realize the total_L SymInt to a + concrete int64 (via .guard_int(__FILE__, __LINE__)) before + constructing the empty output, instead of forwarding the raw SymInt + to empty_symint. + """ + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + device = torch.accelerator.current_accelerator() + + B = 4 + D = 8 + max_L = 16 + total_L_concrete = 17 + + lengths = torch.tensor([3, 5, 2, 7], dtype=torch.long, device=device) + offsets = torch.zeros(B + 1, dtype=torch.long, device=device) + offsets[1:] = lengths.cumsum(0) + self.assertEqual(int(offsets[-1].item()), total_L_concrete) + + dense = torch.randn(B, max_L, D, dtype=torch.float32, device=device) + + # Construct an unbacked, heap-allocated SymInt with no hint. The + # IValue(c10::SymInt) constructor only preserves heap-ness when the + # SymNode's maybe_as_int() returns None -- i.e. when the SymInt is + # truly unbacked (no hint, no constant simplification). A SymInt + # with a hint is collapsed to Tag::Int (inline) by IValue and would + # not exercise the kernel-level crash path. + shape_env = ShapeEnv() + total_L_sym = shape_env.create_unbacked_symint() + + # Pre-fix kernel forwards this heap SymInt to at::empty_symint, + # which crashes in the empty.memory_format HIP wrapper at + # "SymIntArrayRef expected to contain only concrete integers". + # Post-fix the kernel realizes total_L via guard_int(__FILE__, __LINE__) + # before constructing the empty output. guard_int on a truly + # unbacked SymNode (no hint, no runtime guard) raises a clean + # GuardOnDataDependentSymNode -- which is the correct user-visible + # behavior: "the kernel cannot allocate without a concrete size." + # That clean error is what we assert here. The buggy pre-fix path + # raised the low-level SymIntArrayRef error from inside the empty + # wrapper, which fails this regex match. + with self.assertRaisesRegex( + RuntimeError, + r"Could not extract specialized integer from data-dependent expression", + ): + torch.ops.fbgemm.dense_to_jagged(dense, [offsets], total_L_sym) + if __name__ == "__main__": unittest.main()