Skip to content

Support graph surrogate model execution.#193

Open
YohannDudouit wants to merge 2 commits intoyohann/graph-executefrom
yohann/graph-surrogate-execute
Open

Support graph surrogate model execution.#193
YohannDudouit wants to merge 2 commits intoyohann/graph-executefrom
yohann/graph-surrogate-execute

Conversation

@YohannDudouit
Copy link
Copy Markdown
Collaborator

This pull request adds initial support for graph surrogate model execution in AMS.

More specifically, it introduces:

  • graph-native execution paths for homogeneous and heterogeneous graph inputs,
  • the associated wrapper and callback plumbing,
  • and simple dummy graph surrogate models used to validate the end-to-end infrastructure.

This PR focuses on establishing the core execution path for graph-based surrogates. The current implementation is intentionally minimal and is tested using simple graph surrogate models rather than full GNN/MGN applications.

- Add tests for homogeneous and heterogenous graph surrogate models.
@YohannDudouit YohannDudouit self-assigned this Apr 24, 2026
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cpp-linter Review

Used clang-format v18.1.8

Click here for the full clang-format patch
diff --git a/src/AMSlib/ml/surrogate.hpp b/src/AMSlib/ml/surrogate.hpp
index c9cf07a..e9cb326 100644
--- a/src/AMSlib/ml/surrogate.hpp
+++ b/src/AMSlib/ml/surrogate.hpp
@@ -29 +29,2 @@
-namespace ams {
+namespace ams
+{
@@ -37 +38 @@ bool tryGraphSurrogate(AMSWorkflow*,
-}
+}  // namespace ams
diff --git a/src/AMSlib/wf/interface.cpp b/src/AMSlib/wf/interface.cpp
index 506fdc0..3113952 100644
--- a/src/AMSlib/wf/interface.cpp
+++ b/src/AMSlib/wf/interface.cpp
@@ -304 +304,2 @@ amsEdgeStoresToTorchDict(
-    out.insert(ams::edgeTypeToString(edge_type), amsTensorMapToTorchDict(store));
+    out.insert(ams::edgeTypeToString(edge_type),
+               amsTensorMapToTorchDict(store));
@@ -371 +372,2 @@ void callApplication(ams::HeterogeneousGraphDomainFn CallBack,
-namespace ams {
+namespace ams
+{
diff --git a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
index b23b55d..24249d9 100644
--- a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
+++ b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
@@ -10,2 +10,4 @@ using namespace ams;
-static const char* HOMOGENEOUS_GRAPH_MODEL_PATH = "../models/homogeneous_graph.pt";
-static const char* HETEROGENEOUS_GRAPH_MODEL_PATH = "../models/heterogeneous_graph.pt";
+static const char* HOMOGENEOUS_GRAPH_MODEL_PATH =
+    "../models/homogeneous_graph.pt";
+static const char* HETEROGENEOUS_GRAPH_MODEL_PATH =
+    "../models/heterogeneous_graph.pt";
@@ -19,2 +21,4 @@ CATCH_TEST_CASE("AMSExecute homogeneous graph surrogate execution",
-  auto model = AMSRegisterAbstractModel("test_homo_surrogate", 0.5,
-                                        HOMOGENEOUS_GRAPH_MODEL_PATH, false);
+  auto model = AMSRegisterAbstractModel("test_homo_surrogate",
+                                        0.5,
+                                        HOMOGENEOUS_GRAPH_MODEL_PATH,
+                                        false);
@@ -90,2 +94,4 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph surrogate execution",
-  auto model = AMSRegisterAbstractModel("test_hetero_surrogate", 0.5,
-                                        HETEROGENEOUS_GRAPH_MODEL_PATH, false);
+  auto model = AMSRegisterAbstractModel("test_hetero_surrogate",
+                                        0.5,
+                                        HETEROGENEOUS_GRAPH_MODEL_PATH,
+                                        false);
@@ -140,26 +146,26 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph surrogate execution",
-  HeterogeneousGraphDomainFn callback =
-      [&](const AMSHeterogeneousGraph& g, SmallVector<AMSTensor>& outputs) {
-        callback_invoked = true;
-
-        // Verify graph structure
-        CATCH_REQUIRE(g.containsNodeStore("node"));
-        const auto* node_store = g.findNodeStore("node");
-        CATCH_REQUIRE(node_store != nullptr);
-        CATCH_REQUIRE(containsTensor(*node_store, "x"));
-
-        // Create output tensor
-        AMSTensor::IntDimType out_shape[] = {10, 8};
-        AMSTensor::IntDimType out_strides[] = {8, 1};
-        auto out_tensor = AMSTensor::create<float>(
-            ams::ArrayRef<AMSTensor::IntDimType>(out_shape, 2),
-            ams::ArrayRef<AMSTensor::IntDimType>(out_strides, 2),
-            AMSResourceType::AMS_HOST);
-
-        float* out_data = out_tensor.data<float>();
-        for (int i = 0; i < 80; ++i) {
-          out_data[i] = static_cast<float>(i);
-        }
-
-        outputs.clear();
-        outputs.push_back(std::move(out_tensor));
-      };
+  HeterogeneousGraphDomainFn callback = [&](const AMSHeterogeneousGraph& g,
+                                            SmallVector<AMSTensor>& outputs) {
+    callback_invoked = true;
+
+    // Verify graph structure
+    CATCH_REQUIRE(g.containsNodeStore("node"));
+    const auto* node_store = g.findNodeStore("node");
+    CATCH_REQUIRE(node_store != nullptr);
+    CATCH_REQUIRE(containsTensor(*node_store, "x"));
+
+    // Create output tensor
+    AMSTensor::IntDimType out_shape[] = {10, 8};
+    AMSTensor::IntDimType out_strides[] = {8, 1};
+    auto out_tensor = AMSTensor::create<float>(
+        ams::ArrayRef<AMSTensor::IntDimType>(out_shape, 2),
+        ams::ArrayRef<AMSTensor::IntDimType>(out_strides, 2),
+        AMSResourceType::AMS_HOST);
+
+    float* out_data = out_tensor.data<float>();
+    for (int i = 0; i < 80; ++i) {
+      out_data[i] = static_cast<float>(i);
+    }
+
+    outputs.clear();
+    outputs.push_back(std::move(out_tensor));
+  };
@@ -184,2 +190 @@ CATCH_TEST_CASE("Graph surrogate with no model triggers fallback",
-  auto model = AMSRegisterAbstractModel(
-      "test_no_model", 0.5, "", false);
+  auto model = AMSRegisterAbstractModel("test_no_model", 0.5, "", false);
@@ -192,4 +197,4 @@ CATCH_TEST_CASE("Graph surrogate with no model triggers fallback",
-  auto features = AMSTensor::create<float>(
-      ams::ArrayRef<AMSTensor::IntDimType>(shape, 2),
-      ams::ArrayRef<AMSTensor::IntDimType>(strides, 2),
-      AMSResourceType::AMS_HOST);
+  auto features =
+      AMSTensor::create<float>(ams::ArrayRef<AMSTensor::IntDimType>(shape, 2),
+                               ams::ArrayRef<AMSTensor::IntDimType>(strides, 2),
+                               AMSResourceType::AMS_HOST);

Have any feedback or feature suggestions? Share it here.

Comment thread src/AMSlib/ml/surrogate.hpp Outdated
Comment thread src/AMSlib/ml/surrogate.hpp Outdated
Comment thread src/AMSlib/wf/interface.cpp Outdated
Comment thread src/AMSlib/wf/interface.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Member

@lpottier lpottier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like a good start. Can you rebase on develop to fix the tests?

namespace ams
{
class AMSWorkflow;
bool tryGraphSurrogate(AMSWorkflow*,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I would like to avoid AMSWorkflow spilling into the surrogate part of the code. Can we move these to workflow.hpp?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants