Validate EmbedLayerNormalization indices and use 64-bit offsets in CUDA BERT LayerNorm/SkipLayerNorm#29257
Open
titaiwangms wants to merge 1 commit into
Conversation
917db75 to
eaa86f8
Compare
Contributor
Author
…DA BERT LayerNorm/SkipLayerNorm Two related changes to the CUDA BERT normalization path: (1) EmbedLayerNormalization input validation. The CUDA kernel validates the word/position/segment ids against their embedding-table row counts on the device and returns a clear error instead of indexing past the tables (no silent clamp), mirroring the CPU implementation. CheckInputs requires the position_embedding table to have at least sequence_length rows when position_ids is not provided. The host error-flag readback is skipped while a CUDA graph is being captured so graph capture remains supported. Read-side offset arithmetic uses int64. (2) 64-bit write-element offsets. The global write-element offsets in the CUDA BERT LayerNorm/SkipLayerNorm/EmbedLayerNorm write path are widened to 64-bit so tensors where batch*seq*hidden exceeds 2^31 index correctly; the gamma/beta indices (bounded by hidden_size) stay 32-bit. Comments document why each widened site must remain 64-bit. No behavior change at normal sizes; existing LayerNorm/SkipLayerNorm numeric tests are unchanged. Adds EmbedLayerNorm expect-failure tests for the validated cases. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
eaa86f8 to
ad564ff
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This change has two coherent parts in the CUDA BERT embedding/layer-norm path:
EmbedLayerNormalization index validation (CUDA). The CUDA kernel now validates the word, position, and segment ids against their embedding-table row counts before using them to index the tables, mirroring the existing CPU path (device-side, early return, no silent clamp).
CheckInputsadditionally requiresposition_embeddingto have at leastsequence_lengthrows whenposition_idsis not supplied. The device error-flag host readback is skipped while a CUDA graph is being captured (cudaStreamIsCapturing) so CUDA-graph capture remains supported.64-bit write-element offsets in CUDA BERT LayerNorm / SkipLayerNorm / EmbedLayerNorm. The global output write-element offsets in the
LayerNorm,SimplifiedLayerNorm,LayerNormSmall, andSimplifiedLayerNormSmalldevice helpers (and their EmbedLayerNorm / SkipLayerNorm call sites) are widened to 64-bit so that large tensors (batch * seq * hidden > 2^31) index the output correctly. Thegamma/betaindices remain 32-bit (bounded by the hidden dimension).Changes
input_ids,position_ids, andsegment_idsvalues against the corresponding embedding-table row counts in the CUDA kernel.position_embeddingrows >=sequence_lengthinCheckInputswhenposition_idsis absent.Motivation
Improves input validation and error diagnostics for malformed
EmbedLayerNormalizationinputs (aligning CUDA with CPU), and ensures correct indexing for large tensors. No behavior change for valid, in-range inputs.Co-authored-by: Copilot 223556219+Copilot@users.noreply.github.com