Skip to content

[BUG] Unhelpful error when cutedsl runtime shared libs aren't found #3329

Description

@slayton58

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug

If cutedsl cannot auto-resolve the location of the cutedsl runtime .so (or potentially auto-resolves to a version that is not compatible with the python package), and CUTE_DSL_LIBS is not set (or set inappropriately), code fails -- this part is fine, it should fail, but the error message ends up as something like:

RuntimeError: Unknown function cutlass_add_one_FakeTensorFloat322561_FakeTensorFloat322561_FakeStream

Which doesn't indicate anything about the actual problem. Now, prior to this, you should see:

JIT session error: Symbols not found: [ cuda_dialect_unload_library_once, cuda_dialect_init_library_once, cuda_dialect_get_error_name, _cuKernelGetAttribute, _cudaDeviceGetAttribute, _cudaFuncSetAttribute, _cudaGetDevice, _cudaKernelSetAttributeForDevice, _cudaLaunchKernelEx, _cudaLibraryGetKernel, _cudaLibraryLoadData, _cudaSetDevice ]

But this doesn't seem to be a hard-error, instead it continues and fails elsewhere. This made a debug of the above RuntimeError in production workflows significantly harder by leading us in a very (very) wrong direction.

Steps/Code to reproduce bug
Run the following script with CUDA_DSL_LIBS set to something silly.

import traceback

import cutlass
import cutlass.cute as cute
from cutlass import Float32

print("cutlass version:", getattr(cutlass, "__version__", "?"))


@cute.kernel
def _add_one_kernel(src: cute.Tensor, dst: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    if tidx < src.shape[0]:
        dst[tidx] = src[tidx] + Float32(1.0)


@cute.jit
def add_one(src: cute.Tensor, dst: cute.Tensor, stream):
    _add_one_kernel(src, dst).launch(grid=(1, 1, 1), block=(256, 1, 1), stream=stream)


def main():
    src = cute.runtime.make_fake_tensor(Float32, (256,), stride=(1,), assumed_align=4)
    dst = cute.runtime.make_fake_tensor(Float32, (256,), stride=(1,), assumed_align=4)
    stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)
    try:
        compiled = cute.compile(add_one, src, dst, stream, options="--enable-tvm-ffi")
        print("COMPILE OK ->", compiled)
    except Exception as e:
        print("=== RAISED", type(e).__name__, "===")
        print(str(e)[:600])
        traceback.print_exc()


if __name__ == "__main__":
    main()

Expected behavior
An immediate hard failure indicating that necessary libraries were not found.

Environment details (please complete the following information):

  • Bare metal, version 4.5.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions