diff --git a/onnxruntime/core/optimizer/group_query_attention_fusion.cc b/onnxruntime/core/optimizer/group_query_attention_fusion.cc index f6bfd29315c58..feb2a6fd92bde 100644 --- a/onnxruntime/core/optimizer/group_query_attention_fusion.cc +++ b/onnxruntime/core/optimizer/group_query_attention_fusion.cc @@ -248,6 +248,48 @@ static bool CheckIfAnyOfRequiredGQANodesDoesNotExist(Node* rotary_node_1, Node* return rotary_node_1 == nullptr || rotary_node_2 == nullptr || q_node == nullptr || k_node == nullptr || v_node == nullptr; } +static bool NodeArgExists(const NodeArg* node_arg) { + return node_arg != nullptr && node_arg->Exists(); +} + +static bool TryGetRotaryEmbeddingArgs(Node& rotary_node, + NodeArg*& cos_cache_arg, + NodeArg*& sin_cache_arg, + NodeArg*& position_ids_arg) { + if (rotary_node.OpType() != "RotaryEmbedding") { + return false; + } + + auto& input_defs = rotary_node.MutableInputDefs(); + if (rotary_node.Domain() == kMSDomain) { + // com.microsoft.RotaryEmbedding inputs: + // input, position_ids, cos_cache, sin_cache + if (input_defs.size() < 4 || !NodeArgExists(input_defs[2]) || !NodeArgExists(input_defs[3])) { + return false; + } + cos_cache_arg = input_defs[2]; + sin_cache_arg = input_defs[3]; + return true; + } + + if (rotary_node.Domain() == kOnnxDomain) { + // ONNX RotaryEmbedding inputs: + // X, cos_cache, sin_cache, optional position_ids + // If position_ids is omitted, ONNX RotaryEmbedding uses 3D per-batch caches, which are + // incompatible with GroupQueryAttention's 2D rotary cache inputs. + if (input_defs.size() < 4 || !NodeArgExists(input_defs[1]) || !NodeArgExists(input_defs[2]) || + !NodeArgExists(input_defs[3])) { + return false; + } + cos_cache_arg = input_defs[1]; + sin_cache_arg = input_defs[2]; + position_ids_arg = input_defs[3]; + return true; + } + + return false; +} + static void FusePreGQANodes(Graph& graph, Node* q_node, Node* k_node, Node* v_node, Node* rotary_node_1, Node* rotary_node_2, Node* new_node, NodeArg& new_node_output_arg) { graph_utils::MoveAllNodeInputEdges(graph, *q_node, *new_node); @@ -318,6 +360,9 @@ Status GroupQueryAttentionFusion::ApplyImpl( NodeArg* cos_cache_arg = nullptr; NodeArg* sin_cache_arg = nullptr; + NodeArg* position_ids_arg = nullptr; + bool position_ids_arg_set = false; + bool position_ids_arg_mismatch = false; NodeArg* past_key_values_key_arg = node.MutableInputDefs()[3]; NodeArg* past_key_values_value_arg = node.MutableInputDefs()[4]; NodeArg* seqlens_k = node.MutableInputDefs()[5]; @@ -334,7 +379,13 @@ Status GroupQueryAttentionFusion::ApplyImpl( for (auto pre_gqa_node = node.InputNodesBegin(); pre_gqa_node != node.InputNodesEnd(); ++pre_gqa_node) { Node& rotary_or_v_node = *graph.GetNode(pre_gqa_node->Index()); - if (rotary_or_v_node.OpType() == "RotaryEmbedding") { + NodeArg* rotary_cos_cache_arg = nullptr; + NodeArg* rotary_sin_cache_arg = nullptr; + NodeArg* rotary_position_ids_arg = nullptr; + if (TryGetRotaryEmbeddingArgs(rotary_or_v_node, + rotary_cos_cache_arg, + rotary_sin_cache_arg, + rotary_position_ids_arg)) { if (!rotary_node_1) { rotary_node_1 = &rotary_or_v_node; } else { @@ -358,18 +409,27 @@ Status GroupQueryAttentionFusion::ApplyImpl( } if (cos_cache_arg == nullptr) { - cos_cache_arg = rotary_or_v_node.MutableInputDefs()[2]; + cos_cache_arg = rotary_cos_cache_arg; } if (sin_cache_arg == nullptr) { - sin_cache_arg = rotary_or_v_node.MutableInputDefs()[3]; + sin_cache_arg = rotary_sin_cache_arg; + } + + if (!position_ids_arg_set) { + position_ids_arg = rotary_position_ids_arg; + position_ids_arg_set = true; + } else if (position_ids_arg != rotary_position_ids_arg) { + position_ids_arg_mismatch = true; } } else if (rotary_or_v_node.OpType() == "MatMulNBits" || rotary_or_v_node.OpType() == "MatMul") { v_node = &rotary_or_v_node; } } - if (CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node)) { + if (position_ids_arg_mismatch || + CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node) || + cos_cache_arg == nullptr || sin_cache_arg == nullptr) { // Some of the required pre-GQA nodes required for fusion were not retrieved, // this can be expected if the model has extra nodes in between MatMuls and rotary embeddings. continue; @@ -493,7 +553,7 @@ Status GroupQueryAttentionFusion::ApplyImpl( std::string empty_name; auto& empty_node_arg = graph.GetOrCreateNodeArg(empty_name, nullptr); - const std::array gqa_input_defs{ + std::vector gqa_input_defs{ &matmul_or_nbits_output, &empty_node_arg, &empty_node_arg, @@ -503,10 +563,16 @@ Status GroupQueryAttentionFusion::ApplyImpl( total_seq_len, cos_cache_arg, sin_cache_arg}; + if (position_ids_arg != nullptr) { + gqa_input_defs.push_back(position_ids_arg); + } auto& gqa_input_args = node.MutableInputArgsCount(); gqa_input_args[7] = 1; gqa_input_args[8] = 1; + if (position_ids_arg != nullptr) { + gqa_input_args[9] = 1; + } // Switch GQA input defs from unfused into the fused form. auto& gqa_node_input_defs = node.MutableInputDefs(); diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index d5aaf0bb2d2ee..8c69e4b2c74c9 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -607,6 +607,127 @@ static void TestGQAFusion(const std::basic_string& file_path, int mat ASSERT_TRUE(op_to_count["com.microsoft.GroupQueryAttention"] == 1); } +static void BuildOnnxRotaryEmbeddingGQAFusionGraph(ModelTestBuilder& builder, bool include_position_ids) { + constexpr int64_t batch_size = 1; + constexpr int64_t sequence_length = 2; + constexpr int64_t input_hidden_size = 8; + constexpr int64_t num_heads = 2; + constexpr int64_t kv_num_heads = 1; + constexpr int64_t head_size = 16; + constexpr int64_t q_hidden_size = num_heads * head_size; + constexpr int64_t kv_hidden_size = kv_num_heads * head_size; + constexpr int64_t max_sequence_length = 8; + constexpr int64_t half_rotary_dim = head_size / 2; + + auto make_weight = [&builder](int64_t rows, int64_t cols, float value) { + return builder.MakeInitializer( + {rows, cols}, std::vector(static_cast(rows * cols), MLFloat16(value))); + }; + + NodeArg* input = builder.MakeInput({{batch_size, sequence_length, input_hidden_size}}); + NodeArg* q_weight = make_weight(input_hidden_size, q_hidden_size, 0.5f); + NodeArg* k_weight = make_weight(input_hidden_size, kv_hidden_size, 0.25f); + NodeArg* v_weight = make_weight(input_hidden_size, kv_hidden_size, 0.125f); + + NodeArg* q_matmul_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, q_hidden_size}); + NodeArg* k_matmul_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, kv_hidden_size}); + NodeArg* v_matmul_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, kv_hidden_size}); + builder.AddNode("MatMul", {input, q_weight}, {q_matmul_out}); + builder.AddNode("MatMul", {input, k_weight}, {k_matmul_out}); + builder.AddNode("MatMul", {input, v_weight}, {v_matmul_out}); + + const std::vector cache_shape = include_position_ids + ? std::vector{max_sequence_length, half_rotary_dim} + : std::vector{batch_size, sequence_length, half_rotary_dim}; + NodeArg* cos_cache = builder.MakeInput(cache_shape); + NodeArg* sin_cache = builder.MakeInput(cache_shape); + NodeArg* position_ids = include_position_ids + ? builder.MakeInput({{batch_size, sequence_length}}) + : nullptr; + + NodeArg* q_rotary_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, q_hidden_size}); + NodeArg* k_rotary_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, kv_hidden_size}); + + std::vector q_rotary_inputs{q_matmul_out, cos_cache, sin_cache}; + std::vector k_rotary_inputs{k_matmul_out, cos_cache, sin_cache}; + if (position_ids != nullptr) { + q_rotary_inputs.push_back(position_ids); + k_rotary_inputs.push_back(position_ids); + } + + Node& q_rotary = builder.AddNode("RotaryEmbedding", q_rotary_inputs, {q_rotary_out}, kOnnxDomain); + q_rotary.AddAttribute("num_heads", num_heads); + Node& k_rotary = builder.AddNode("RotaryEmbedding", k_rotary_inputs, {k_rotary_out}, kOnnxDomain); + k_rotary.AddAttribute("num_heads", kv_num_heads); + + NodeArg* past_key = + builder.MakeInput({{batch_size, kv_num_heads, max_sequence_length, head_size}}); + NodeArg* past_value = + builder.MakeInput({{batch_size, kv_num_heads, max_sequence_length, head_size}}); + NodeArg* seqlens_k = builder.MakeInput({{batch_size}}); + NodeArg* total_sequence_length = builder.MakeInput({{1}}); + NodeArg* gqa_output = + builder.MakeOutput(std::vector{batch_size, sequence_length, q_hidden_size}); + + Node& gqa = builder.AddNode("GroupQueryAttention", + {q_rotary_out, k_rotary_out, v_matmul_out, past_key, past_value, + seqlens_k, total_sequence_length}, + {gqa_output}, + kMSDomain); + gqa.AddAttribute("num_heads", num_heads); + gqa.AddAttribute("kv_num_heads", kv_num_heads); +} + +static Status CheckOnnxRotaryEmbeddingGQAFused(Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 0); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 1); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() != "GroupQueryAttention") { + continue; + } + + TEST_RETURN_IF_NOT(node.InputDefs().size() == 10); + TEST_RETURN_IF_NOT(node.InputDefs()[7] != nullptr && node.InputDefs()[7]->Exists()); + TEST_RETURN_IF_NOT(node.InputDefs()[8] != nullptr && node.InputDefs()[8]->Exists()); + TEST_RETURN_IF_NOT(node.InputDefs()[9] != nullptr && node.InputDefs()[9]->Exists()); + + const auto& attrs = node.GetAttributes(); + auto do_rotary_attr = attrs.find("do_rotary"); + TEST_RETURN_IF_NOT(do_rotary_attr != attrs.end()); + TEST_RETURN_IF_NOT(do_rotary_attr->second.i() == 1); + } + + return Status::OK(); +} + +static Status CheckOnnxRotaryEmbeddingGQANotFused(Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 2); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 3); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() != "GroupQueryAttention") { + continue; + } + + TEST_RETURN_IF_NOT(node.InputDefs().size() == 7); + const auto& attrs = node.GetAttributes(); + auto do_rotary_attr = attrs.find("do_rotary"); + TEST_RETURN_IF_NOT(do_rotary_attr == attrs.end() || do_rotary_attr->second.i() == 0); + } + + return Status::OK(); +} + static void TestSkipLayerNormFusion(const std::basic_string& file_path, int add_count, int ln_count, int skip_ln_count, int cast_count, logging::Logger* logger) { std::shared_ptr p_model; @@ -796,6 +917,28 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionFusionTest) { TestGQAFusion(MODEL_FOLDER "fusion/gqa_fusion_quantized_different_head_sizes.onnx", 1, 0, logger_.get()); } +TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_, + std::make_unique(), + TransformerLevel::Level2, 3, nullptr, + CheckOnnxRotaryEmbeddingGQAFused)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingNoPositionIdsTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, false); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_, + std::make_unique(), + TransformerLevel::Level2, 3, nullptr, + CheckOnnxRotaryEmbeddingGQANotFused)); +} + TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) { TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get()); TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get());