Skip to content

Add sigmoid GLU#2656

Open
singleheart wants to merge 7 commits intoNVIDIA:mainfrom
singleheart:feature/add-sigmoid-glu
Open

Add sigmoid GLU#2656
singleheart wants to merge 7 commits intoNVIDIA:mainfrom
singleheart:feature/add-sigmoid-glu

Conversation

@singleheart
Copy link

Description

Add the original GLU (Gated Linear Unit) activation function as described in
Dauphin et al. (2017) and referenced in
Shazeer (2020), "GLU Variants Improve Transformer".

GLU is defined as:

$$\text{GLU}(a, b) = \sigma(a) \odot b$$

where $\sigma$ is the sigmoid function and the input is split into two halves $a$ and $b$ along the last dimension.

Transformer Engine already supports several GLU variants (GEGLU, ReGLU, SReGLU, SwiGLU, etc.)
but was missing the original sigmoid-gated GLU. This PR fills that gap so that users can
simply pass activation="glu" to LayerNormMLP or TransformerLayer.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • transformer_engine/common/activation/glu.cu (new file): CUDA kernels nvte_glu and nvte_dglu using existing sigmoid/dsigmoid primitives from math.h and the gated_act_fn/dgated_act_fn templates.
  • transformer_engine/common/include/transformer_engine/activation.h: Added GLU to NVTE_Activation_Type enum; declared nvte_glu and nvte_dglu with doxygen documentation.
  • transformer_engine/common/CMakeLists.txt: Registered activation/glu.cu in both arch_specific_sources and fast_math build lists.
  • transformer_engine/pytorch/csrc/extensions/activation.cpp: Added glu() and dglu() C++ wrapper functions.
  • transformer_engine/pytorch/csrc/extensions.h: Declared glu and dglu.
  • transformer_engine/pytorch/csrc/extensions/pybind.cpp: Exposed tex.glu and tex.dglu to Python.
  • transformer_engine/pytorch/module/layernorm_mlp.py: Added "glu" to _get_act_func_supported_list (all 3 recipe branches), FC1 output-doubling condition, ONNX export activation_map, and docstring.
  • transformer_engine/pytorch/ops/basic/activation.py: Added GLU operation class with forward (tex.glu) and backward (tex.dglu).
  • transformer_engine/pytorch/ops/basic/__init__.py: Exported GLU.
  • transformer_engine/pytorch/transformer.py: Updated TransformerLayer docstring to list 'glu' as a supported activation.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

adds original GLU (Gated Linear Unit) activation with sigmoid gating to match the paper definition

Key changes

  • CUDA kernels in activation/glu.cu use existing sigmoid primitives with gated activation templates
  • Python API exposed through PyBind, C++ wrappers, and ops layer following existing patterns
  • Integration with LayerNormMLP and TransformerLayer modules
  • Test coverage added with PyTorch reference implementation

Issues found

  • Critical: test will fail with KeyError - make_op dictionary at line 1655 is missing the glu entry mapping to te_ops.GLU

Confidence Score: 2/5

  • PR cannot merge - test will fail at runtime with missing dictionary key
  • implementation is well-structured and follows existing patterns correctly, but test has a critical bug that will cause immediate failure when running the glu test case
  • tests/pytorch/test_fusible_ops.py requires fix to make_op dictionary before tests can pass

Important Files Changed

Filename Overview
tests/pytorch/test_fusible_ops.py test reference code added, but missing glu in make_op dict causing test failure
transformer_engine/common/activation/glu.cu correctly implemented GLU forward/backward using existing sigmoid templates
transformer_engine/pytorch/ops/basic/activation.py added GLU operation class with correct forward/backward implementations

Sequence Diagram

sequenceDiagram
    participant User
    participant Python as Python Layer<br/>(layernorm_mlp.py)
    participant OpsAPI as Ops API<br/>(ops/basic/activation.py)
    participant PyBind as PyBind<br/>(pybind.cpp)
    participant CPP as C++ Wrapper<br/>(activation.cpp)
    participant CUDA as CUDA Kernel<br/>(glu.cu)
    
    User->>Python: LayerNormMLP(activation="glu")
    Python->>Python: Check activation in supported list
    Python->>Python: Set fc1_output_features = 2 * hidden_size
    
    Note over User,CUDA: Forward Pass
    User->>OpsAPI: GLU.forward(input)
    OpsAPI->>PyBind: tex.glu(input, quantizer)
    PyBind->>CPP: glu(input, quantizer)
    CPP->>CPP: Create output tensor (shape_divisor=2)
    CPP->>CUDA: nvte_glu(input, output, stream)
    CUDA->>CUDA: gated_act_fn<sigmoid>(input, output)
    CUDA->>CUDA: Split input into a, b
    CUDA->>CUDA: Compute sigmoid(a) * b
    CUDA-->>CPP: output
    CPP-->>PyBind: output tensor
    PyBind-->>OpsAPI: output tensor
    OpsAPI-->>User: result
    
    Note over User,CUDA: Backward Pass
    User->>OpsAPI: GLU.backward(grad_output)
    OpsAPI->>PyBind: tex.dglu(grad, input, quantizer)
    PyBind->>CPP: dglu(grad, input, quantizer)
    CPP->>CUDA: nvte_dglu(grad, input, output, stream)
    CUDA->>CUDA: dgated_act_fn<sigmoid, dsigmoid>(grad, input, output)
    CUDA->>CUDA: Compute gradient w.r.t. a and b
    CUDA-->>CPP: grad_input
    CPP-->>PyBind: grad_input tensor
    PyBind-->>OpsAPI: grad_input tensor
    OpsAPI-->>User: gradient
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (3)

tests/pytorch/test_sanity.py
"glu" not added to test list - new activation won't be tested

all_activations = [
    "gelu",
    "geglu",
    "glu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
    "clamped_swiglu",
]

tests/pytorch/test_numerics.py
"glu" missing from test list

all_activations = [
    "gelu",
    "geglu",
    "glu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
]

tests/pytorch/test_fusible_ops.py
"glu" missing from test parameters - add glu to tuple and handle in test logic below (around line 1631)

singleheart and others added 2 commits February 6, 2026 20:15
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
for more information, see https://pre-commit.ci

Signed-off-by: Kim, Jin <jinn.kim@sk.com>
@singleheart singleheart force-pushed the feature/add-sigmoid-glu branch from 8adb19a to f7c6385 Compare February 6, 2026 11:15
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

timmoon10 and others added 2 commits February 6, 2026 18:40
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10

This comment was marked as outdated.

timmoon10
timmoon10 previously approved these changes Feb 6, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

If we're adding GLU, it would also be natural to have sigmoid. Not critical though.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (3)

transformer_engine/jax/csrc/extensions/activation.cpp
Missing GLU dispatch

NVTE_Activation_Type::GLU was added to the shared activation enum, but the JAX FFI dispatch here doesn’t handle it. Calling act_lu/dact_lu with GLU will hit the default branch and raise NVTE_ERROR("Unsupported ActivationEnum").

Fix: add explicit case NVTE_Activation_Type::GLU: branches that call nvte_glu(...) and nvte_dglu(...) (the backward path is in the DActLuDBiasQuantizeFFI switch later in this file as well).


transformer_engine/jax/csrc/extensions/pybind.cpp
GLU not exported to Python

The pybind enum export for NVTE_Activation_Type doesn’t include GLU, so JAX Python can’t reference NVTE_Activation_Type.GLU even though it exists in the C++ header. This will block using GLU from Python-side activation selection.

Fix: add .value("GLU", NVTE_Activation_Type::GLU) to this enum binding.


transformer_engine/jax/cpp_extensions/activation.py
ActivationEnum missing GLU

ActivationEnum doesn’t have an entry mapping any activation tuple to NVTE_Activation_Type.GLU, so GLU can’t be selected through the existing string/tuple activation routing in JAX.

Fix: add the appropriate key (consistent with the existing conventions, e.g. how GEGLU/REGLU/SWIGLU are represented) mapping to NVTE_Activation_Type.GLU.

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10

This comment was marked as outdated.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (1)

tests/pytorch/test_fusible_ops.py
glu is missing from the make_op dictionary, causing a KeyError when the test runs

        make_op = dict(
            gelu=te_ops.GELU,
            geglu=te_ops.GEGLU,
            glu=te_ops.GLU,
            qgelu=te_ops.QGELU,
            qgeglu=te_ops.QGEGLU,
            relu=te_ops.ReLU,
            reglu=te_ops.ReGLU,
            srelu=te_ops.SReLU,
            sreglu=te_ops.SReGLU,
            silu=te_ops.SiLU,
            swiglu=te_ops.SwiGLU,
        )[activation]

Comment on lines +1632 to +1633
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sigmoid is not an option in the test, is this a leftover code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not that important, but I find it awkward to have GeLU/GeGLU, ReLU/ReGLU, SiLU/SwiGLU, and then just GLU.

@ptrendx
Copy link
Member

ptrendx commented Feb 7, 2026

Could we also add the test for regular LayerNormMLP module in the test_numerics.py?
Also @jberchtold-nvidia FYI - we will also want to expose it in JAX.

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 7, 2026

Additional Comments (2)

transformer_engine/jax/csrc/extensions/pybind.cpp
Missing GLU enum export

NVTE_Activation_Type gained a new GLU value in transformer_engine/common/include/transformer_engine/activation.h, but the JAX pybind enum here doesn’t export it. This will break JAX-side code that expects the enum values to match the core header (e.g., passing NVTE_Activation_Type.GLU from Python will fail / be unavailable).

Also update the JAX activation dispatch switch to handle NVTE_Activation_Type::GLU (see transformer_engine/jax/csrc/extensions/activation.cpp:105-143).


transformer_engine/jax/csrc/extensions/activation.cpp
GLU not handled in switch

The activation dispatch switch (act_type) doesn’t include a case NVTE_Activation_Type::GLU, even though GLU was added to the shared NVTE_Activation_Type enum. If JAX passes the new enum value, this currently falls into default: and raises NVTE_ERROR("Unsupported ActivationEnum").

Add a GLU case that calls nvte_glu(...), and ensure the backward/quantized paths (other switches later in this file) are also updated consistently.

@timmoon10
Copy link
Collaborator

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 7, 2026

Additional Comments (1)

tests/pytorch/test_fusible_ops.py
missing glu in test's make_op dictionary - test will fail with KeyError when running the glu case

        make_op = dict(
            gelu=te_ops.GELU,
            geglu=te_ops.GEGLU,
            glu=te_ops.GLU,
            qgelu=te_ops.QGELU,
            qgeglu=te_ops.QGEGLU,
            relu=te_ops.ReLU,
            reglu=te_ops.ReGLU,
            srelu=te_ops.SReLU,
            sreglu=te_ops.SReGLU,
            silu=te_ops.SiLU,
            swiglu=te_ops.SwiGLU,
        )[activation]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants