Skip to content
Open
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler);

// Inspect
XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler);

// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);

Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/jax/csrc/extensions/amax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
************************************************************************/
#include <cuda_runtime.h>

#include <iostream>

#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
Expand Down
98 changes: 98 additions & 0 deletions transformer_engine/jax/csrc/extensions/inspect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>

#include <fstream>
#include <iostream>

#include "../extensions.h"
#include "xla/ffi/api/c_api.h"

namespace transformer_engine {
namespace jax {

Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf,
Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf,
Result_Type output_buf) {
NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
NVTE_CHECK(output_buf->untyped_data() != nullptr,
"Output must be provided for inspect operation");
NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(),
"Input and output must point to the same buffer for inspect operation");

std::vector<uint8_t> input_data(input_buf.size_bytes());
cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(),
cudaMemcpyDeviceToHost, stream);

float min_val{}, max_val{}, mean_val{}, std_val{};
cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream);
cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream);
cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost,
stream);
cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream);

cudaStreamSynchronize(stream);

int device;
cudaGetDevice(&device);

// Write the tensor data to a file as a binary blob
std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hardcoded filename my_tensor_gpu prevents distinguishing between different tensors. The name parameter from Python is not passed through to C++, making all dumps use the same base filename. This will overwrite data when inspecting multiple tensors.

std::ofstream file(filename, std::ios::binary);
if (file.is_open()) {
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
}
Comment on lines +44 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

file write failures are silently ignored. If file.is_open() fails (e.g., permission denied, disk full), the function continues without indication. Add error handling or logging.


// Write out a metadata file
std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json";
std::ofstream meta_file(meta_filename);
if (meta_file.is_open()) {
meta_file << "{";
meta_file << "\"shape\": [";
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
meta_file << input_buf.dimensions()[i];
if (i < input_buf.dimensions().size() - 1) {
meta_file << ", ";
}
}
meta_file << "], ";
meta_file << "\"dtype\": " << static_cast<int>(input_buf.element_type());
meta_file << ", \"min\": " << min_val;
meta_file << ", \"max\": " << max_val;
meta_file << ", \"mean\": " << mean_val;
meta_file << ", \"std\": " << std_val;
meta_file << "}";
meta_file.close();
}

// Log the tensor metadata to the console
printf("Tensor data written to %s (shape: [", filename.c_str());
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
if (i < input_buf.dimensions().size() - 1) {
printf(", ");
}
}
printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);
Comment on lines +72 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unconditional printf will spam output every time this primitive executes (including in jitted code on every device). For a debugging utility, consider gating behind an environment variable or adding a way to disable verbose output.


return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // min
.Arg<Buffer_Type>() // max
.Arg<Buffer_Type>() // mean
.Arg<Buffer_Type>() // std
.Ret<Buffer_Type>() // output
);

} // namespace jax
} // namespace transformer_engine
3 changes: 3 additions & 0 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ pybind11::dict Registrations() {
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));

dict["te_inspect_ffi"] =
pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler));

return dict;
}

Expand Down
14 changes: 14 additions & 0 deletions transformer_engine/jax/debug/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.
This API is experimental and may change or be removed without deprecation in future releases.
"""

from .inspect import inspect_array, load_array_dump

__all__ = [
"inspect_array",
"load_array_dump",
]
170 changes: 170 additions & 0 deletions transformer_engine/jax/debug/experimental/inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Experimental JAX array inspection utilities."""

from functools import partial

import jax
import jax.numpy as jnp
from jax import ffi

from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive

__all__ = ["inspect_array", "load_array_dump"]


class InspectPrimitive(BasePrimitive):
"""
No-op used for inspect array values.
"""

name = "te_inspect_ffi"
multiple_results = False
impl_static_args = ()
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(
x_aval,
x_min_aval,
x_max_aval,
x_mean_aval,
x_std_aval,
):
"""
inspect abstract
"""
assert (
x_min_aval.shape == () and x_min_aval.dtype == jnp.float32
), "x_min must be a scalar with dtype float32"
assert (
x_max_aval.shape == () and x_max_aval.dtype == jnp.float32
), "x_max must be a scalar with dtype float32"
assert (
x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32
), "x_mean must be a scalar with dtype float32"
assert (
x_std_aval.shape == () and x_std_aval.dtype == jnp.float32
), "x_std must be a scalar with dtype float32"
return x_aval

@staticmethod
def lowering(
ctx,
x,
x_min,
x_max,
x_mean,
x_std,
):
"""
inspect lowering rules
"""

return ffi.ffi_lowering(
InspectPrimitive.name,
operand_output_aliases={0: 0}, # donate input buffer to output buffer
)(
ctx,
x,
x_min,
x_max,
x_mean,
x_std,
)

@staticmethod
def impl(
x,
x_min,
x_max,
x_mean,
x_std,
):
"""
inspect implementation
"""
assert InspectPrimitive.inner_primitive is not None
(x) = InspectPrimitive.inner_primitive.bind(
x,
x_min,
x_max,
x_mean,
x_std,
)
return x


register_primitive(InspectPrimitive)


def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
return InspectPrimitive.outer_primitive.bind(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing assert InspectPrimitive.outer_primitive is not None before bind. Other primitives in this codebase guard this to prevent AttributeError if registration fails (see activation.py:351, amax.py:381, etc.).

x,
jnp.min(x).astype(jnp.float32),
jnp.max(x).astype(jnp.float32),
jnp.mean(x.astype(jnp.float32)),
jnp.std(x.astype(jnp.float32)),
)


@partial(jax.custom_vjp, nondiff_argnums=())
def _inspect(
x,
):
""" """
output, _ = _inspect_fwd_rule(
x,
)
return output


def _inspect_fwd_rule(
x,
):
""""""
ctx = ()
x = _inspect_array_inner(x)
return x, ctx


Comment on lines 124 to 132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing outer primitive assertion

_inspect_fwd_rule calls InspectPrimitive.outer_primitive.bind(x) without asserting outer_primitive is initialized. If registration fails (or this module is imported before primitives are set up), this will raise an attribute error at runtime. Other primitives in this repo typically guard with assert ... is not None before binding; please add the same guard here.

def _inspect_bwd_rule(
ctx,
grad,
):
""""""
del ctx
return (grad,)


_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)


def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.

Args:
x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
# TODO: Handle the name of the tensor in the primitive and output files
return _inspect(x)


def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray:
"""Utility function to load a JAX array from a dumped binary file.

Args:
filename (str): The path to the binary file containing the array data.
shape (tuple): The shape of the array to be loaded.
dtype (jnp.dtype): The data type of the array to be loaded.

Returns:
jnp.ndarray: The loaded JAX array.
"""
with open(filename, "rb") as f:
data = f.read()
array = jnp.frombuffer(data, dtype=dtype).reshape(shape)
return array
Loading