Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,19 @@ def test_make_extra_output(

@pytest.mark.parametrize(
"activation",
("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"),
(
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"glu",
"srelu",
"sreglu",
"silu",
"swiglu",
),
)
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
Expand All @@ -1577,7 +1589,7 @@ def test_activation(

# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"):
if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"):
in_shape[-1] *= 2

# Skip invalid configurations
Expand Down Expand Up @@ -1617,6 +1629,13 @@ def test_activation(
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
Comment on lines +1632 to +1633
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sigmoid is not an option in the test, is this a leftover code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not that important, but I find it awkward to have GeLU/GeGLU, ReLU/ReGLU, SiLU/SwiGLU, and then just GLU.

elif activation == "glu":
x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2)
x = x.flip(-2) # PyTorch GLU swaps gate and linear unit
x = x.reshape(in_shape)
y_ref = torch.nn.functional.glu(x)
elif activation == "srelu":
y_ref = torch.nn.functional.relu(x_ref) ** 2
elif activation == "sreglu":
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ list(APPEND transformer_engine_cuda_sources

list(APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu
cast/cast.cu
Expand Down Expand Up @@ -352,6 +353,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu)
endif()
Expand Down
24 changes: 24 additions & 0 deletions transformer_engine/common/activation/glu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../util/math.h"
#include "./activation_template.h"

void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_glu);
using namespace transformer_engine;
Empty e = {};
gated_act_fn<fp32, Empty, sigmoid<fp32, fp32>>(input, output, e, stream);
}

void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dglu);
using namespace transformer_engine;
Empty e = {};
dgated_act_fn<fp32, Empty, sigmoid<fp32, fp32>, dsigmoid<fp32, fp32>>(grad, input, output, e,
stream);
}
27 changes: 27 additions & 0 deletions transformer_engine/common/include/transformer_engine/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ extern "C" {
enum class NVTE_Activation_Type {
GELU,
GEGLU,
GLU,
SILU,
SWIGLU,
RELU,
Expand Down Expand Up @@ -262,6 +263,32 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);

/*! \brief Computes the GLU (Gated Linear Unit) activation of the input.
* GLU(a,b) = sigmoid(a) * b
* See "Language Modeling with Gated Convolutional Networks" (arXiv:1612.08083)
* and "GLU Variants Improve Transformer" (arXiv:2002.05202).
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes sigmoid(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Computes the GLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);

/*! \brief Computes the gated GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st
* Activations
**************************************************************************************************/

/* GLU (sigmoid gate) */
py::object glu(const at::Tensor &input, py::handle quantizer);

py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

/* GELU and variants*/
py::object gelu(const at::Tensor &input, py::handle quantizer);

Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,14 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua
return dactivation_helper<nvte_dgelu, nullptr>(grad, input, quantizer);
}

py::object glu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_glu, nullptr>(input, quantizer, 2);
}

py::object dglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dglu, nullptr>(grad, input, quantizer);
}

py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu, nullptr>(input, quantizer, 2);
}
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
/* GLU (sigmoid gate) */
m.def("glu", transformer_engine::pytorch::glu, "GLU activation", py::arg("input"),
py::arg("quantizer"));
/* GELU and variants*/
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
Expand All @@ -158,6 +161,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu,
"SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"),
py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* Backward of GLU */
m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of GELU and variants */
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
Expand Down
16 changes: 14 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return {
"gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
"glu": (tex.glu, tex.dglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
Expand All @@ -114,6 +115,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu),
"geglu": (tex.geglu, tex.dgeglu, None),
"glu": (tex.glu, tex.dglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
Expand All @@ -136,6 +138,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
return {
"gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
"glu": (tex.glu, tex.dglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"relu": (tex.relu, tex.drelu, None),
Expand Down Expand Up @@ -1665,7 +1668,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
Options: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : dict, default = None
Additional parameters for the activation function.
Expand Down Expand Up @@ -1884,7 +1887,15 @@ def __init__(
self.layer_norm_bias = None

# FC1 init
if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]:
if self.activation in [
"geglu",
"glu",
"qgeglu",
"reglu",
"sreglu",
"swiglu",
"clamped_swiglu",
]:
fc1_output_features = 2 * self.size_per_partition
else:
fc1_output_features = self.size_per_partition
Expand Down Expand Up @@ -2308,6 +2319,7 @@ def _clamped_swiglu(x, limit, alpha):
activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"glu": lambda x: torch.sigmoid(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
"qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh")
* x.chunk(2, -1)[1],
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .activation import (
GELU,
GEGLU,
GLU,
QGELU,
QGEGLU,
ReLU,
Expand Down
33 changes: 33 additions & 0 deletions transformer_engine/pytorch/ops/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
__all__ = [
"GELU",
"GEGLU",
"GLU",
"QGELU",
"QGEGLU",
"ReLU",
Expand Down Expand Up @@ -164,6 +165,38 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dgelu(*args, **kwargs)


class GLU(_ActivationOperation):
r"""Gated Linear Unit

The input tensor is split into chunks :math:`a` and :math:`b`
along the last dimension and the following is computed:

.. math::

\text{GLU}(a,b) = \sigma(a) * b

where :math:`\sigma` is the sigmoid function.

.. warning::

Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:`a` and
:math:`b`. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.

See `Language Modeling with Gated Convolutional Networks<https://arxiv.org/abs/1612.08083>`__
and `GLU Variants Improve Transformer<https://arxiv.org/abs/2002.05202>`__.

"""

def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.glu(*args, **kwargs)

def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dglu(*args, **kwargs)


class GEGLU(_ActivationOperation):
r"""Gaussian Error Gated Linear Unit

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class TransformerLayer(torch.nn.Module):
if set to ``False``, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
Options are: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : Optional[dict], default = None
Additional parameters for the activation function.
Expand Down
Loading