-
Notifications
You must be signed in to change notification settings - Fork 631
[JAX] Debugging inspect utility #2651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f693220
f2d1629
f56d869
37a7dd5
c3fe902
cf7be54
cdf53f5
39a2194
5b8587d
378b4ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <fstream> | ||
| #include <iostream> | ||
|
|
||
| #include "../extensions.h" | ||
| #include "xla/ffi/api/c_api.h" | ||
|
|
||
| namespace transformer_engine { | ||
| namespace jax { | ||
|
|
||
| Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf, | ||
| Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf, | ||
| Result_Type output_buf) { | ||
| NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); | ||
| NVTE_CHECK(output_buf->untyped_data() != nullptr, | ||
| "Output must be provided for inspect operation"); | ||
| NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), | ||
| "Input and output must point to the same buffer for inspect operation"); | ||
|
|
||
| std::vector<uint8_t> input_data(input_buf.size_bytes()); | ||
| cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), | ||
| cudaMemcpyDeviceToHost, stream); | ||
|
|
||
| float min_val{}, max_val{}, mean_val{}, std_val{}; | ||
| cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); | ||
| cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); | ||
| cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, | ||
| stream); | ||
| cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream); | ||
|
|
||
| cudaStreamSynchronize(stream); | ||
|
|
||
| int device; | ||
| cudaGetDevice(&device); | ||
|
|
||
| // Write the tensor data to a file as a binary blob | ||
| std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; | ||
| std::ofstream file(filename, std::ios::binary); | ||
| if (file.is_open()) { | ||
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | ||
| file.close(); | ||
| } | ||
|
Comment on lines
+44
to
+48
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. file write failures are silently ignored. If |
||
|
|
||
| // Write out a metadata file | ||
| std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json"; | ||
| std::ofstream meta_file(meta_filename); | ||
| if (meta_file.is_open()) { | ||
| meta_file << "{"; | ||
| meta_file << "\"shape\": ["; | ||
| for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { | ||
| meta_file << input_buf.dimensions()[i]; | ||
| if (i < input_buf.dimensions().size() - 1) { | ||
| meta_file << ", "; | ||
| } | ||
| } | ||
| meta_file << "], "; | ||
| meta_file << "\"dtype\": " << static_cast<int>(input_buf.element_type()); | ||
| meta_file << ", \"min\": " << min_val; | ||
| meta_file << ", \"max\": " << max_val; | ||
| meta_file << ", \"mean\": " << mean_val; | ||
| meta_file << ", \"std\": " << std_val; | ||
| meta_file << "}"; | ||
| meta_file.close(); | ||
| } | ||
|
|
||
| // Log the tensor metadata to the console | ||
| printf("Tensor data written to %s (shape: [", filename.c_str()); | ||
| for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { | ||
| printf("%zu", static_cast<size_t>(input_buf.dimensions()[i])); | ||
| if (i < input_buf.dimensions().size() - 1) { | ||
| printf(", "); | ||
| } | ||
| } | ||
| printf("], dtype: %d", static_cast<int>(input_buf.element_type())); | ||
| printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); | ||
|
Comment on lines
+72
to
+81
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unconditional printf will spam output every time this primitive executes (including in jitted code on every device). For a debugging utility, consider gating behind an environment variable or adding a way to disable verbose output. |
||
|
|
||
| return ffi_with_cuda_error_check(); | ||
| } | ||
|
|
||
| XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, | ||
| FFI::Bind() | ||
| .Ctx<FFI_Stream_Type>() // stream | ||
| .Arg<Buffer_Type>() // input | ||
| .Arg<Buffer_Type>() // min | ||
| .Arg<Buffer_Type>() // max | ||
| .Arg<Buffer_Type>() // mean | ||
| .Arg<Buffer_Type>() // std | ||
| .Ret<Buffer_Type>() // output | ||
| ); | ||
|
|
||
| } // namespace jax | ||
| } // namespace transformer_engine | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| """EXPERIMENTAL debugging utilities for Transformer Engine JAX. | ||
| This API is experimental and may change or be removed without deprecation in future releases. | ||
| """ | ||
|
|
||
| from .inspect import inspect_array, load_array_dump | ||
|
|
||
| __all__ = [ | ||
| "inspect_array", | ||
| "load_array_dump", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,170 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
| """Experimental JAX array inspection utilities.""" | ||
|
|
||
| from functools import partial | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax import ffi | ||
|
|
||
| from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive | ||
|
|
||
| __all__ = ["inspect_array", "load_array_dump"] | ||
|
|
||
|
|
||
| class InspectPrimitive(BasePrimitive): | ||
| """ | ||
| No-op used for inspect array values. | ||
| """ | ||
|
|
||
| name = "te_inspect_ffi" | ||
| multiple_results = False | ||
| impl_static_args = () | ||
| inner_primitive = None | ||
| outer_primitive = None | ||
|
|
||
| @staticmethod | ||
| def abstract( | ||
| x_aval, | ||
| x_min_aval, | ||
| x_max_aval, | ||
| x_mean_aval, | ||
| x_std_aval, | ||
| ): | ||
| """ | ||
| inspect abstract | ||
| """ | ||
| assert ( | ||
| x_min_aval.shape == () and x_min_aval.dtype == jnp.float32 | ||
| ), "x_min must be a scalar with dtype float32" | ||
| assert ( | ||
| x_max_aval.shape == () and x_max_aval.dtype == jnp.float32 | ||
| ), "x_max must be a scalar with dtype float32" | ||
| assert ( | ||
| x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32 | ||
| ), "x_mean must be a scalar with dtype float32" | ||
| assert ( | ||
| x_std_aval.shape == () and x_std_aval.dtype == jnp.float32 | ||
| ), "x_std must be a scalar with dtype float32" | ||
| return x_aval | ||
|
|
||
| @staticmethod | ||
| def lowering( | ||
| ctx, | ||
| x, | ||
| x_min, | ||
| x_max, | ||
| x_mean, | ||
| x_std, | ||
| ): | ||
| """ | ||
| inspect lowering rules | ||
| """ | ||
|
|
||
| return ffi.ffi_lowering( | ||
| InspectPrimitive.name, | ||
| operand_output_aliases={0: 0}, # donate input buffer to output buffer | ||
| )( | ||
| ctx, | ||
| x, | ||
| x_min, | ||
| x_max, | ||
| x_mean, | ||
| x_std, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def impl( | ||
| x, | ||
| x_min, | ||
| x_max, | ||
| x_mean, | ||
| x_std, | ||
| ): | ||
| """ | ||
| inspect implementation | ||
| """ | ||
| assert InspectPrimitive.inner_primitive is not None | ||
| (x) = InspectPrimitive.inner_primitive.bind( | ||
| x, | ||
| x_min, | ||
| x_max, | ||
| x_mean, | ||
| x_std, | ||
| ) | ||
| return x | ||
|
|
||
|
|
||
| register_primitive(InspectPrimitive) | ||
|
|
||
|
|
||
| def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray: | ||
| return InspectPrimitive.outer_primitive.bind( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing |
||
| x, | ||
| jnp.min(x).astype(jnp.float32), | ||
| jnp.max(x).astype(jnp.float32), | ||
| jnp.mean(x.astype(jnp.float32)), | ||
| jnp.std(x.astype(jnp.float32)), | ||
| ) | ||
|
|
||
|
|
||
| @partial(jax.custom_vjp, nondiff_argnums=()) | ||
| def _inspect( | ||
| x, | ||
| ): | ||
| """ """ | ||
| output, _ = _inspect_fwd_rule( | ||
| x, | ||
| ) | ||
| return output | ||
|
|
||
|
|
||
| def _inspect_fwd_rule( | ||
| x, | ||
| ): | ||
| """""" | ||
| ctx = () | ||
| x = _inspect_array_inner(x) | ||
| return x, ctx | ||
|
|
||
|
|
||
|
Comment on lines
124
to
132
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing outer primitive assertion
|
||
| def _inspect_bwd_rule( | ||
| ctx, | ||
| grad, | ||
| ): | ||
| """""" | ||
| del ctx | ||
| return (grad,) | ||
|
|
||
|
|
||
| _inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) | ||
|
|
||
|
|
||
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: | ||
| """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. | ||
|
|
||
| Args: | ||
| x (jnp.ndarray): The JAX array to inspect. | ||
| name (str): The name of the array for identification in the output. | ||
| """ | ||
| # TODO: Handle the name of the tensor in the primitive and output files | ||
| return _inspect(x) | ||
|
|
||
|
|
||
| def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray: | ||
| """Utility function to load a JAX array from a dumped binary file. | ||
|
|
||
| Args: | ||
| filename (str): The path to the binary file containing the array data. | ||
| shape (tuple): The shape of the array to be loaded. | ||
| dtype (jnp.dtype): The data type of the array to be loaded. | ||
|
|
||
| Returns: | ||
| jnp.ndarray: The loaded JAX array. | ||
| """ | ||
| with open(filename, "rb") as f: | ||
| data = f.read() | ||
| array = jnp.frombuffer(data, dtype=dtype).reshape(shape) | ||
| return array | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hardcoded filename
my_tensor_gpuprevents distinguishing between different tensors. Thenameparameter from Python is not passed through to C++, making all dumps use the same base filename. This will overwrite data when inspecting multiple tensors.