Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft February 4, 2026 22:51
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Greptile Overview

Greptile Summary

This PR introduces a new JAX “inspect” FFI/primitive wired from Python (transformer_engine/jax/inspect.py) through pybind registrations (transformer_engine/jax/csrc/extensions/pybind.cpp) to a new XLA FFI handler (InspectHandler) declared in transformer_engine/jax/csrc/extensions.h and implemented in transformer_engine/jax/csrc/extensions/amax.cpp.

However, the change set also disables an existing cuBLAS shape/alignment validation for quantized GEMM (assert_cublas_requirements in transformer_engine/jax/cpp_extensions/gemm.py), which can unsupported shapes to reach the cuBLAS custom call.

Before merge, the inspect path should not introduce unconditional stdout logging, and the Python API should either implement the documented behavior (including using the name argument) or adjust the API/docstrings to match the actual no-op behavior.

Confidence Score: 2/5

  • This PR is not safe to merge as-is due to a disabled GEMM validation and noisy inspect side effects.
  • The PR comments out a required cuBLAS quantized GEMM alignment assertion (can cause runtime failures) and introduces an inspect FFI handler that unconditionally prints to stdout; the new Python API also exposes an unused name parameter and lacks a guard for primitive initialization, indicating incomplete/debug-only code paths.
  • transformer_engine/jax/cpp_extensions/gemm.py, transformer_engine/jax/csrc/extensions/amax.cpp, transformer_engine/jax/inspect.py

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Comments out cuBLAS quantized contracting-dimension alignment assert, allowing invalid shapes to reach the cuBLAS custom call.
transformer_engine/jax/csrc/extensions.h Adds InspectHandler symbol declaration for new inspect FFI handler.
transformer_engine/jax/csrc/extensions/amax.cpp Adds InspectFFI handler but includes unconditional printf, creating noisy side effects in normal execution.
transformer_engine/jax/csrc/extensions/pybind.cpp Registers new te_inspect_ffi execute capsule mapping to InspectHandler.
transformer_engine/jax/inspect.py Introduces JAX inspect primitive/custom_vjp wrapper, but name arg is unused and forward rule lacks guard that outer_primitive is initialized.

Sequence Diagram

sequenceDiagram
  participant User as User code
  participant JAX as JAX runtime
  participant InspectPy as transformer_engine/jax/inspect.py
  participant FFI as jax.ffi
  participant PyBind as transformer_engine_jax (pybind)
  participant InspectCpp as transformer_engine/jax/csrc/extensions/amax.cpp

  User->>InspectPy: inspect_array(x, name)
  InspectPy->>InspectPy: _inspect(x) (custom_vjp)
  InspectPy->>JAX: bind outer_primitive
  JAX->>FFI: ffi_lowering("te_inspect_ffi", operand_output_aliases {0:0})
  FFI->>PyBind: lookup registration te_inspect_ffi.execute
  PyBind->>InspectCpp: call InspectHandler (InspectFFI)
  InspectCpp-->>InspectCpp: NVTE_CHECK input/output alias
  InspectCpp-->>JAX: return status (ffi_with_cuda_error_check)
  JAX-->>InspectPy: returns x (aliased buffer)
  InspectPy-->>User: returns x
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 100 to 110
NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
NVTE_CHECK(output_buf->untyped_data() != nullptr,
"Output must be provided for inspect operation");
NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(),
"Input and output must point to the same buffer for inspect operation");

printf("JTEST: Hello\n");

return ffi_with_cuda_error_check();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug printf in FFI

InspectFFI prints "JTEST: Hello" unconditionally. This will spam stdout/stderr on every execution (including under jax.jit where it may execute many times / across devices) and is not an acceptable side effect for a library primitive. Please remove the printf or gate it behind an explicit debug flag that defaults to off.

Comment on lines +103 to +111
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.

Args:
x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
# TODO: Handle the name of the tensor in the primitive and output files
return _inspect(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused name argument

inspect_array(x, name) documents printing the array name/shape/dtype/stats, but name is unused and the function currently just returns _inspect(x) (which itself is a no-op). This is misleading API surface and will confuse callers expecting output. Either implement passing/using name (and the actual inspection behavior) or drop the name parameter and update the docstring/export accordingly.

Comment on lines +82 to +91
def _inspect_fwd_rule(
x,
):
""""""
ctx = ()
x = InspectPrimitive.outer_primitive.bind(x)
return x, ctx


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing outer primitive assertion

_inspect_fwd_rule calls InspectPrimitive.outer_primitive.bind(x) without asserting outer_primitive is initialized. If registration fails (or this module is imported before primitives are set up), this will raise an attribute error at runtime. Other primitives in this repo typically guard with assert ... is not None before binding; please add the same guard here.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Additional Comments (1)

transformer_engine/jax/cpp_extensions/gemm.py
Disabled cuBLAS alignment check

assert_cublas_requirements no longer enforces contracting_size % alignment == 0 for quantized GEMM (the assert is commented out). This will allow invalid shapes through to the cuBLAS custom call and can trigger runtime failures or incorrect behavior when using FP8/NVFP4 inputs. Please restore the check or replace it with an equivalent validation (and only relax it if the backend truly supports unaligned contracting sizes).

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant