Skip to content

Conversation

@wj-laskowski
Copy link
Contributor

@wj-laskowski wj-laskowski commented Jan 16, 2026

Proposed changes

Added bias bnorm clamp operation for WMMA conv fwd large tensor (FP16/BF16 data type and NHWGCxGKYXC layout).

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

Following operations are added for FP16/BF16 data type and NHWGCxGKYXC layout.
- grouped_conv2d_fwd_bias_bnorm_clamp
- grouped_conv3d_fwd_bias_bnorm_clamp
@wj-laskowski wj-laskowski force-pushed the streamhpc/grouped-conv-fwd-wmma-large-tensor-bias_bnorm_clamp branch from 7b0341d to b5c541f Compare January 19, 2026 09:35
@wj-laskowski wj-laskowski changed the title Streamhpc/grouped conv fwd wmma large tensor bias bnorm clamp WMMA grouped conv fwd large tensor bias bnorm clamp Jan 19, 2026
@wj-laskowski wj-laskowski marked this pull request as ready for review January 19, 2026 12:58

gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args;
auto* gemm_args = &gemm_desc_kernel_args_.At(valid_gemms_count_);
new(gemm_args) GemmArgs{p_as_grid,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice use of operator new here. The basic idea is good but this is difficult to maintain, as operator new with placement-args comes with a caveat: you have to manually call the dtor.
In our case (I think that) the implicitly defaulted dtor doesn't do anything useful. -I haven't checked that- But the implicitly deleted copy assignment operator that made you use this approach is concerning.
My point is basically that: if a future dev changes something that requires the GemmArg dtor to be called, they could very easily oversee this.
My proposal is to define an Emplace function in the Array struct to handle this case :

    template<typename... Args>
    auto Emplace(ck::index_t i, Args&&... args) -> std::enable_if_t<std::is_nothrow_constructible_v<TData, Args...>>
    {
        static_assert(i >= 0 && i < NSize);
        mData[i].~TData();
        new(mData + i) TData(std::forward<Args>(args)...);
    }

On another note... I see this way of initializing structs quite often and I can boldly say that I really don't like it. In C++17 we can already use designated initializers given that they are all included. This makes code more explicit, more robust and safer. So I'd change this section of the code to :

                gemm_desc_kernel_args_.Emplace(valid_gemms_count_, GemmArgs{.a_ptrs_ = p_as_grid,
                                                                                    .b_ptrs_ = p_bs_grid,
                                                                                    .ds_ptrs_ = p_ds_grid,
                                                                                    .e_ptr_ = p_e_grid,
                                                                                    .a_element_op_ = a_element_op_,
                                                                                    .b_element_op_ = b_element_op_,
                                                                                    .cde_element_op_ = cde_element_op_,
                                                                                    .M_ = gemm_m,
                                                                                    .N_ = gemm_n,
                                                                                    .a_grid_desc_ = a_grid_desc,
                                                                                    .b_grid_desc_ = b_grid_desc,
                                                                                    .ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
                                                                                        ds_desc_mblock_mperblock_nblock_nperblock,
                                                                                    .e_grid_desc_mblock_mperblock_nblock_nperblock_ =
                                                                                        e_desc_mblock_mperblock_nblock_nperblock,
                                                                                    .BlockStart_ = BlockStart,
                                                                                    .BlockEnd_ = BlockEnd});

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants