diff --git a/src/native/cambricon/ops/add/add.h b/src/native/cambricon/ops/add/add.h new file mode 100644 index 000000000..38a165023 --- /dev/null +++ b/src/native/cambricon/ops/add/add.h @@ -0,0 +1,66 @@ +#ifndef INFINI_OPS_CAMBRICON_ADD_H_ +#define INFINI_OPS_CAMBRICON_ADD_H_ + +#include "base/add.h" +#include "native/cambricon/common.h" +#include "native/cambricon/data_type_.h" + +namespace infini::ops { + +template +void AddUnion(void* workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void* out, const void* input, + const void* other, const size_t* out_shape, + const ptrdiff_t* out_strides, const size_t* input_shape, + const ptrdiff_t* input_strides, const size_t* other_shape, + const ptrdiff_t* other_strides, size_t output_size, int ndim, + bool fast_path, bool out_contiguous); + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add{input, other, out} { + cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster, + &cluster_count); + cnrtMalloc(&default_workspace_, workspace_size_in_bytes()); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto queue = static_cast(stream_ ? stream_ : 0); + auto workspace{workspace_ ? workspace_ : default_workspace_}; + + bool fast_path = is_input_contiguous_ && is_other_contiguous_ && + is_out_contiguous_ && input_shape_ == out_shape_ && + other_shape_ == out_shape_; + + DispatchFunc>( + {static_cast(out_type_)}, + [&](auto tag) { + using T = TypeMapType(tag)>; + AddUnion(workspace, core_per_cluster, cluster_count, queue, + out.data(), input.data(), other.data(), out_shape_.data(), + out_strides_.data(), input_shape_.data(), + input_strides_.data(), other_shape_.data(), + other_strides_.data(), output_size_, ndim_, fast_path, + is_out_contiguous_); + }, + "CambriconAdd::operator() - output dispatch"); + } + + ~Operator() { cnrtFree(default_workspace_); } + + std::size_t workspace_size_in_bytes() const override { + return ndim_ * (3 * sizeof(size_t) + 3 * sizeof(ptrdiff_t)); + } + + void* default_workspace_{nullptr}; + int core_per_cluster = 0; + int cluster_count = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cambricon/ops/add/kernel.mlu b/src/native/cambricon/ops/add/kernel.mlu new file mode 100644 index 000000000..f76aecd95 --- /dev/null +++ b/src/native/cambricon/ops/add/kernel.mlu @@ -0,0 +1,205 @@ +#include "add.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +namespace infini::ops { + +template +__mlu_device__ void BangAdd(const T* src1, const T* src2, T* dst, size_t n) { + if constexpr (std::is_same_v) { + __bang_add(reinterpret_cast(dst), + reinterpret_cast(src1), + reinterpret_cast(src2), n); + } else { + __bang_add(dst, src1, src2, n); + } +} + +template +__mlu_global__ void AddKernel( + const T* input, const T* other, T* output, const size_t* out_shape, + const ptrdiff_t* out_strides, const size_t* input_shape, + const ptrdiff_t* input_strides, const size_t* other_shape, + const ptrdiff_t* other_strides, size_t output_size, int ndim, + bool fast_path, bool out_contiguous) { + size_t elements_per_task = (output_size + taskDim - 1) / taskDim; + size_t start = taskId * elements_per_task; + size_t end = start + elements_per_task; + if (end > output_size) end = output_size; + size_t num_elements = end > start ? end - start : 0; + if (num_elements == 0) return; + + size_t nram_usable = NRAM_MAX_SIZE - 256; + size_t block_size = nram_usable / (3 * sizeof(T)); + block_size = (block_size / 64) * 64; // Align to 64 elements + if (block_size == 0) block_size = 64; + + T* input_buf = reinterpret_cast(nram_buffer); + T* other_buf = input_buf + block_size; + T* output_buf = other_buf + block_size; + + size_t processed = 0; + + if (fast_path) { + // Fast path: all tensors contiguous with matching shapes (no broadcast). + while (processed < num_elements) { + size_t curr = block_size; + if (curr > num_elements - processed) curr = num_elements - processed; + + __memcpy(input_buf, input + start + processed, curr * sizeof(T), + GDRAM2NRAM); + __memcpy(other_buf, other + start + processed, curr * sizeof(T), + GDRAM2NRAM); + BangAdd(input_buf, other_buf, output_buf, curr); + __memcpy(output + start + processed, output_buf, curr * sizeof(T), + NRAM2GDRAM); + + processed += curr; + } + return; + } + + // General path: handle non-contiguous tensors and broadcasting. + while (processed < num_elements) { + size_t curr = block_size; + if (curr > num_elements - processed) curr = num_elements - processed; + + for (size_t i = 0; i < curr; ++i) { + size_t flat_idx = start + processed + i; + + // Compute `input` offset. + { + size_t tmp = flat_idx; + ptrdiff_t offset = 0; + for (int d = ndim - 1; d >= 0; --d) { + size_t coord = tmp % out_shape[d]; + tmp /= out_shape[d]; + size_t c = coord < input_shape[d] ? coord : 0; + offset += static_cast(c) * input_strides[d]; + } + input_buf[i] = input[offset]; + } + + // Compute `other` offset. + { + size_t tmp = flat_idx; + ptrdiff_t offset = 0; + for (int d = ndim - 1; d >= 0; --d) { + size_t coord = tmp % out_shape[d]; + tmp /= out_shape[d]; + size_t c = coord < other_shape[d] ? coord : 0; + offset += static_cast(c) * other_strides[d]; + } + other_buf[i] = other[offset]; + } + } + + BangAdd(input_buf, other_buf, output_buf, curr); + + if (out_contiguous) { + __memcpy(output + start + processed, output_buf, curr * sizeof(T), + NRAM2GDRAM); + } else { + for (size_t i = 0; i < curr; ++i) { + size_t flat_idx = start + processed + i; + size_t tmp = flat_idx; + ptrdiff_t offset = 0; + for (int d = ndim - 1; d >= 0; --d) { + size_t coord = tmp % out_shape[d]; + offset += static_cast(coord) * out_strides[d]; + tmp /= out_shape[d]; + } + output[offset] = output_buf[i]; + } + } + + processed += curr; + } +} + +template +void AddUnion(void* workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void* out, const void* input, + const void* other, const size_t* out_shape, + const ptrdiff_t* out_strides, const size_t* input_shape, + const ptrdiff_t* input_strides, const size_t* other_shape, + const ptrdiff_t* other_strides, size_t output_size, int ndim, + bool fast_path, bool out_contiguous) { + cnrtDim3_t kernel_dim; + cnrtFunctionType_t kernel_type; + + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + kernel_type = cnrtFuncTypeUnion1; + + auto out_ = reinterpret_cast(out); + auto input_ = reinterpret_cast(input); + auto other_ = reinterpret_cast(other); + + char* tmp = reinterpret_cast(workspace); + size_t* mlu_out_shape = reinterpret_cast(tmp); + size_t* mlu_input_shape = mlu_out_shape + ndim; + size_t* mlu_other_shape = mlu_input_shape + ndim; + ptrdiff_t* mlu_out_strides = + reinterpret_cast(mlu_other_shape + ndim); + ptrdiff_t* mlu_input_strides = mlu_out_strides + ndim; + ptrdiff_t* mlu_other_strides = mlu_input_strides + ndim; + + CNRT_CHECK(cnrtMemcpyAsync(mlu_out_shape, const_cast(out_shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_input_shape, const_cast(input_shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_other_shape, const_cast(other_shape), + ndim * sizeof(size_t), queue, + cnrtMemcpyHostToDev)); + CNRT_CHECK( + cnrtMemcpyAsync(mlu_out_strides, const_cast(out_strides), + ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); + CNRT_CHECK( + cnrtMemcpyAsync(mlu_input_strides, const_cast(input_strides), + ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); + CNRT_CHECK( + cnrtMemcpyAsync(mlu_other_strides, const_cast(other_strides), + ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); + + AddKernel<<>>( + input_, other_, out_, mlu_out_shape, mlu_out_strides, mlu_input_shape, + mlu_input_strides, mlu_other_shape, mlu_other_strides, output_size, ndim, + fast_path, out_contiguous); + + cnrtQueueSync(queue); +} + +template void AddUnion<__half>(void*, int, int, cnrtQueue_t, void*, const void*, + const void*, const size_t*, const ptrdiff_t*, + const size_t*, const ptrdiff_t*, const size_t*, + const ptrdiff_t*, size_t, int, bool, bool); + +template void AddUnion<__bang_bfloat16>(void*, int, int, cnrtQueue_t, void*, + const void*, const void*, const size_t*, + const ptrdiff_t*, const size_t*, + const ptrdiff_t*, const size_t*, + const ptrdiff_t*, size_t, int, bool, + bool); + +template void AddUnion(void*, int, int, cnrtQueue_t, void*, const void*, + const void*, const size_t*, const ptrdiff_t*, + const size_t*, const ptrdiff_t*, const size_t*, + const ptrdiff_t*, size_t, int, bool, bool); + +template void AddUnion(void*, int, int, cnrtQueue_t, void*, + const void*, const void*, const size_t*, + const ptrdiff_t*, const size_t*, + const ptrdiff_t*, const size_t*, + const ptrdiff_t*, size_t, int, bool, bool); + +template void AddUnion(void*, int, int, cnrtQueue_t, void*, + const void*, const void*, const size_t*, + const ptrdiff_t*, const size_t*, + const ptrdiff_t*, const size_t*, + const ptrdiff_t*, size_t, int, bool, bool); + +} // namespace infini::ops diff --git a/tests/test_add.py b/tests/test_add.py index e2266c30d..a32b37d62 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -60,6 +60,11 @@ def test_add( "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." ) + if device == "mlu" and (dtype in _UINT_DTYPES or dtype == torch.int16): + pytest.skip( + "The `torch.mlu` test cloning path does not support `int16`, `uint16`, `uint32`, or `uint64`." + ) + if implementation_index == 1 and dtype in _UINT_DTYPES: pytest.skip("ATen `add` does not support unsigned integer types")