-
Notifications
You must be signed in to change notification settings - Fork 628
[JAX] Debugging inspect utility #2651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Debugging inspect utility #2651
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile OverviewGreptile SummaryThis PR introduces a new JAX “inspect” FFI/primitive wired from Python ( However, the change set also disables an existing cuBLAS shape/alignment validation for quantized GEMM ( Before merge, the inspect path should not introduce unconditional stdout logging, and the Python API should either implement the documented behavior (including using the Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User code
participant JAX as JAX runtime
participant InspectPy as transformer_engine/jax/inspect.py
participant FFI as jax.ffi
participant PyBind as transformer_engine_jax (pybind)
participant InspectCpp as transformer_engine/jax/csrc/extensions/amax.cpp
User->>InspectPy: inspect_array(x, name)
InspectPy->>InspectPy: _inspect(x) (custom_vjp)
InspectPy->>JAX: bind outer_primitive
JAX->>FFI: ffi_lowering("te_inspect_ffi", operand_output_aliases {0:0})
FFI->>PyBind: lookup registration te_inspect_ffi.execute
PyBind->>InspectCpp: call InspectHandler (InspectFFI)
InspectCpp-->>InspectCpp: NVTE_CHECK input/output alias
InspectCpp-->>JAX: return status (ffi_with_cuda_error_check)
JAX-->>InspectPy: returns x (aliased buffer)
InspectPy-->>User: returns x
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 4 comments
| NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); | ||
| NVTE_CHECK(output_buf->untyped_data() != nullptr, | ||
| "Output must be provided for inspect operation"); | ||
| NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), | ||
| "Input and output must point to the same buffer for inspect operation"); | ||
|
|
||
| printf("JTEST: Hello\n"); | ||
|
|
||
| return ffi_with_cuda_error_check(); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug printf in FFI
InspectFFI prints "JTEST: Hello" unconditionally. This will spam stdout/stderr on every execution (including under jax.jit where it may execute many times / across devices) and is not an acceptable side effect for a library primitive. Please remove the printf or gate it behind an explicit debug flag that defaults to off.
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: | ||
| """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. | ||
|
|
||
| Args: | ||
| x (jnp.ndarray): The JAX array to inspect. | ||
| name (str): The name of the array for identification in the output. | ||
| """ | ||
| # TODO: Handle the name of the tensor in the primitive and output files | ||
| return _inspect(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused name argument
inspect_array(x, name) documents printing the array name/shape/dtype/stats, but name is unused and the function currently just returns _inspect(x) (which itself is a no-op). This is misleading API surface and will confuse callers expecting output. Either implement passing/using name (and the actual inspection behavior) or drop the name parameter and update the docstring/export accordingly.
| def _inspect_fwd_rule( | ||
| x, | ||
| ): | ||
| """""" | ||
| ctx = () | ||
| x = InspectPrimitive.outer_primitive.bind(x) | ||
| return x, ctx | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing outer primitive assertion
_inspect_fwd_rule calls InspectPrimitive.outer_primitive.bind(x) without asserting outer_primitive is initialized. If registration fails (or this module is imported before primitives are set up), this will raise an attribute error at runtime. Other primitives in this repo typically guard with assert ... is not None before binding; please add the same guard here.
Additional Comments (1)
|
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
da437ca to
f2d1629
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
966b035 to
f56d869
Compare
for more information, see https://pre-commit.ci
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: