Support graph surrogate model execution.#193
Open
YohannDudouit wants to merge 2 commits intoyohann/graph-executefrom
Open
Support graph surrogate model execution.#193YohannDudouit wants to merge 2 commits intoyohann/graph-executefrom
YohannDudouit wants to merge 2 commits intoyohann/graph-executefrom
Conversation
- Add tests for homogeneous and heterogenous graph surrogate models.
There was a problem hiding this comment.
Cpp-linter Review
Used clang-format v18.1.8
Click here for the full clang-format patch
diff --git a/src/AMSlib/ml/surrogate.hpp b/src/AMSlib/ml/surrogate.hpp
index c9cf07a..e9cb326 100644
--- a/src/AMSlib/ml/surrogate.hpp
+++ b/src/AMSlib/ml/surrogate.hpp
@@ -29 +29,2 @@
-namespace ams {
+namespace ams
+{
@@ -37 +38 @@ bool tryGraphSurrogate(AMSWorkflow*,
-}
+} // namespace ams
diff --git a/src/AMSlib/wf/interface.cpp b/src/AMSlib/wf/interface.cpp
index 506fdc0..3113952 100644
--- a/src/AMSlib/wf/interface.cpp
+++ b/src/AMSlib/wf/interface.cpp
@@ -304 +304,2 @@ amsEdgeStoresToTorchDict(
- out.insert(ams::edgeTypeToString(edge_type), amsTensorMapToTorchDict(store));
+ out.insert(ams::edgeTypeToString(edge_type),
+ amsTensorMapToTorchDict(store));
@@ -371 +372,2 @@ void callApplication(ams::HeterogeneousGraphDomainFn CallBack,
-namespace ams {
+namespace ams
+{
diff --git a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
index b23b55d..24249d9 100644
--- a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
+++ b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
@@ -10,2 +10,4 @@ using namespace ams;
-static const char* HOMOGENEOUS_GRAPH_MODEL_PATH = "../models/homogeneous_graph.pt";
-static const char* HETEROGENEOUS_GRAPH_MODEL_PATH = "../models/heterogeneous_graph.pt";
+static const char* HOMOGENEOUS_GRAPH_MODEL_PATH =
+ "../models/homogeneous_graph.pt";
+static const char* HETEROGENEOUS_GRAPH_MODEL_PATH =
+ "../models/heterogeneous_graph.pt";
@@ -19,2 +21,4 @@ CATCH_TEST_CASE("AMSExecute homogeneous graph surrogate execution",
- auto model = AMSRegisterAbstractModel("test_homo_surrogate", 0.5,
- HOMOGENEOUS_GRAPH_MODEL_PATH, false);
+ auto model = AMSRegisterAbstractModel("test_homo_surrogate",
+ 0.5,
+ HOMOGENEOUS_GRAPH_MODEL_PATH,
+ false);
@@ -90,2 +94,4 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph surrogate execution",
- auto model = AMSRegisterAbstractModel("test_hetero_surrogate", 0.5,
- HETEROGENEOUS_GRAPH_MODEL_PATH, false);
+ auto model = AMSRegisterAbstractModel("test_hetero_surrogate",
+ 0.5,
+ HETEROGENEOUS_GRAPH_MODEL_PATH,
+ false);
@@ -140,26 +146,26 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph surrogate execution",
- HeterogeneousGraphDomainFn callback =
- [&](const AMSHeterogeneousGraph& g, SmallVector<AMSTensor>& outputs) {
- callback_invoked = true;
-
- // Verify graph structure
- CATCH_REQUIRE(g.containsNodeStore("node"));
- const auto* node_store = g.findNodeStore("node");
- CATCH_REQUIRE(node_store != nullptr);
- CATCH_REQUIRE(containsTensor(*node_store, "x"));
-
- // Create output tensor
- AMSTensor::IntDimType out_shape[] = {10, 8};
- AMSTensor::IntDimType out_strides[] = {8, 1};
- auto out_tensor = AMSTensor::create<float>(
- ams::ArrayRef<AMSTensor::IntDimType>(out_shape, 2),
- ams::ArrayRef<AMSTensor::IntDimType>(out_strides, 2),
- AMSResourceType::AMS_HOST);
-
- float* out_data = out_tensor.data<float>();
- for (int i = 0; i < 80; ++i) {
- out_data[i] = static_cast<float>(i);
- }
-
- outputs.clear();
- outputs.push_back(std::move(out_tensor));
- };
+ HeterogeneousGraphDomainFn callback = [&](const AMSHeterogeneousGraph& g,
+ SmallVector<AMSTensor>& outputs) {
+ callback_invoked = true;
+
+ // Verify graph structure
+ CATCH_REQUIRE(g.containsNodeStore("node"));
+ const auto* node_store = g.findNodeStore("node");
+ CATCH_REQUIRE(node_store != nullptr);
+ CATCH_REQUIRE(containsTensor(*node_store, "x"));
+
+ // Create output tensor
+ AMSTensor::IntDimType out_shape[] = {10, 8};
+ AMSTensor::IntDimType out_strides[] = {8, 1};
+ auto out_tensor = AMSTensor::create<float>(
+ ams::ArrayRef<AMSTensor::IntDimType>(out_shape, 2),
+ ams::ArrayRef<AMSTensor::IntDimType>(out_strides, 2),
+ AMSResourceType::AMS_HOST);
+
+ float* out_data = out_tensor.data<float>();
+ for (int i = 0; i < 80; ++i) {
+ out_data[i] = static_cast<float>(i);
+ }
+
+ outputs.clear();
+ outputs.push_back(std::move(out_tensor));
+ };
@@ -184,2 +190 @@ CATCH_TEST_CASE("Graph surrogate with no model triggers fallback",
- auto model = AMSRegisterAbstractModel(
- "test_no_model", 0.5, "", false);
+ auto model = AMSRegisterAbstractModel("test_no_model", 0.5, "", false);
@@ -192,4 +197,4 @@ CATCH_TEST_CASE("Graph surrogate with no model triggers fallback",
- auto features = AMSTensor::create<float>(
- ams::ArrayRef<AMSTensor::IntDimType>(shape, 2),
- ams::ArrayRef<AMSTensor::IntDimType>(strides, 2),
- AMSResourceType::AMS_HOST);
+ auto features =
+ AMSTensor::create<float>(ams::ArrayRef<AMSTensor::IntDimType>(shape, 2),
+ ams::ArrayRef<AMSTensor::IntDimType>(strides, 2),
+ AMSResourceType::AMS_HOST);
Have any feedback or feature suggestions? Share it here.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
lpottier
reviewed
Apr 29, 2026
Member
lpottier
left a comment
There was a problem hiding this comment.
It looks like a good start. Can you rebase on develop to fix the tests?
| namespace ams | ||
| { | ||
| class AMSWorkflow; | ||
| bool tryGraphSurrogate(AMSWorkflow*, |
Member
There was a problem hiding this comment.
Ideally I would like to avoid AMSWorkflow spilling into the surrogate part of the code. Can we move these to workflow.hpp?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request adds initial support for graph surrogate model execution in AMS.
More specifically, it introduces:
This PR focuses on establishing the core execution path for graph-based surrogates. The current implementation is intentionally minimal and is tested using simple graph surrogate models rather than full GNN/MGN applications.