diff --git a/gigl/src/common/types/pb_wrappers/gbml_config.py b/gigl/src/common/types/pb_wrappers/gbml_config.py index 67e3726ee..f153a6805 100644 --- a/gigl/src/common/types/pb_wrappers/gbml_config.py +++ b/gigl/src/common/types/pb_wrappers/gbml_config.py @@ -7,6 +7,7 @@ from gigl.common import Uri, UriFactory from gigl.common.logger import Logger from gigl.common.utils.proto_utils import ProtoUtils +from gigl.src.common.types.graph_data import EdgeType, NodeType from gigl.src.common.types.pb_wrappers.dataset_metadata import DatasetMetadataPbWrapper from gigl.src.common.types.pb_wrappers.flattened_graph_metadata import ( FlattenedGraphMetadataPbWrapper, @@ -23,6 +24,7 @@ TrainedModelMetadataPbWrapper, ) from gigl.src.common.utils.file_loader import FileLoader +from gigl.src.data_preprocessor.lib.types import FeatureSchema from snapchat.research.gbml import ( dataset_metadata_pb2, flattened_graph_metadata_pb2, @@ -62,6 +64,16 @@ class GbmlConfigPbWrapper: SubgraphSamplingStrategyPbWrapper ] = field(default=None, init=False) + _node_type_to_feature_dim_map: dict[NodeType, int] = field( + default_factory=dict, init=False + ) + _node_type_to_feature_schema_map: dict[NodeType, FeatureSchema] = field( + default_factory=dict, init=False + ) + _edge_type_to_feature_dim_map: dict[EdgeType, int] = field( + default_factory=dict, init=False + ) + def __post_init__(self): # Populate the _preprocessed_metadata_pb_wrapper field self.__load_preprocessed_metadata_pb_wrapper( @@ -80,6 +92,44 @@ def __post_init__(self): self.__load_graph_metadata_pb_wrapper( graph_metadata_pb=self.gbml_config_pb.graph_metadata ) + # Derive typed-keyed feature maps by joining the just-populated + # _preprocessed_metadata_pb_wrapper (condensed-keyed) with + # _graph_metadata_pb_wrapper (condensed→typed lookup). + if hasattr(self, "_graph_metadata_pb_wrapper") and hasattr( + self, "_preprocessed_metadata_pb_wrapper" + ): + graph_metadata = self._graph_metadata_pb_wrapper + preprocessed_metadata = self._preprocessed_metadata_pb_wrapper + object.__setattr__( + self, + "_node_type_to_feature_dim_map", + { + graph_metadata.condensed_node_type_to_node_type_map[ + condensed_node_type + ]: feature_dim + for condensed_node_type, feature_dim in preprocessed_metadata.condensed_node_type_to_feature_dim_map.items() + }, + ) + object.__setattr__( + self, + "_node_type_to_feature_schema_map", + { + graph_metadata.condensed_node_type_to_node_type_map[ + condensed_node_type + ]: feature_schema + for condensed_node_type, feature_schema in preprocessed_metadata.condensed_node_type_to_feature_schema_map.items() + }, + ) + object.__setattr__( + self, + "_edge_type_to_feature_dim_map", + { + graph_metadata.condensed_edge_type_to_edge_type_map[ + condensed_edge_type + ]: feature_dim + for condensed_edge_type, feature_dim in preprocessed_metadata.condensed_edge_type_to_feature_dim_map.items() + }, + ) # Populate the _flattened_graph_metadata_pb_wrapper field if self.gbml_config_pb.shared_config.HasField("flattened_graph_metadata"): flattened_graph_metadata_pb = ( @@ -365,6 +415,30 @@ def preprocessed_metadata_pb_wrapper(self) -> PreprocessedMetadataPbWrapper: ) return self._preprocessed_metadata_pb_wrapper + @property + def node_type_to_feature_dim_map(self) -> dict[NodeType, int]: + """ + Returns: + dict[NodeType, int]: Mapping from NodeType to its input feature dimension. + """ + return self._node_type_to_feature_dim_map + + @property + def node_type_to_feature_schema_map(self) -> dict[NodeType, FeatureSchema]: + """ + Returns: + dict[NodeType, FeatureSchema]: Mapping from NodeType to its FeatureSchema. + """ + return self._node_type_to_feature_schema_map + + @property + def edge_type_to_feature_dim_map(self) -> dict[EdgeType, int]: + """ + Returns: + dict[EdgeType, int]: Mapping from EdgeType to its feature dimension. + """ + return self._edge_type_to_feature_dim_map + @property def trained_model_metadata_pb_wrapper(self) -> TrainedModelMetadataPbWrapper: """ diff --git a/pyproject.toml b/pyproject.toml index f6448874a..b22e1f0e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,7 +152,7 @@ lint = [ "ruff==0.15.10", "mdformat==0.7.22", "mdformat_tables==1.0.0", - "ty~=0.0.29", + "ty==0.0.31", "mypy-protobuf==3.3.0", # Used for protobuf stub generation (protoc-gen-mypy), not type checking ] diff --git a/tests/unit/src/common/utils/gbml_config_test.py b/tests/unit/src/common/utils/gbml_config_test.py index 4baf64627..1c1df4662 100644 --- a/tests/unit/src/common/utils/gbml_config_test.py +++ b/tests/unit/src/common/utils/gbml_config_test.py @@ -3,8 +3,16 @@ from gigl.common import LocalUri from gigl.common.utils.proto_utils import ProtoUtils +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.utils.file_loader import FileLoader from snapchat.research.gbml import gbml_config_pb2 +from tests.test_assets.graph_metadata_constants import ( + EXAMPLE_HETEROGENEOUS_CONDENSED_EDGE_TYPES, + EXAMPLE_HETEROGENEOUS_CONDENSED_NODE_TYPES, + EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB, + EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB_WRAPPER, + EXAMPLE_HETEROGENEOUS_PREPROCESSED_METADATA_PB, +) from tests.test_assets.test_case import TestCase @@ -43,3 +51,101 @@ def test_gbml_config_read_and_write_yaml(self): uri=self.target_yaml_uri, proto_cls=gbml_config_pb2.GbmlConfig ) self.assertEqual(obj, obj2) + + def test_typed_keyed_feature_maps_match_condensed_maps(self): + """`*_to_feature_*_map` properties are typed-key views of the condensed maps. + + Builds a GbmlConfigPbWrapper with both graph_metadata and + preprocessed_metadata populated (heterogeneous example fixture), then + verifies each typed-keyed property contains the same values as the + underlying condensed-keyed map after joining through the + graph_metadata's condensed→typed map. + """ + # Stage the heterogeneous preprocessed metadata on disk so + # __load_preprocessed_metadata_pb_wrapper can read it via the URI. + preprocessed_metadata_uri = LocalUri.join( + self.tmp_dir.name, + f"{self.gbml_config_test_run_id}_preprocessed_metadata.yaml", + ) + self.proto_utils.write_proto_to_yaml( + proto=EXAMPLE_HETEROGENEOUS_PREPROCESSED_METADATA_PB, + uri=preprocessed_metadata_uri, + ) + + gbml_config_pb = gbml_config_pb2.GbmlConfig( + graph_metadata=EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB, + ) + gbml_config_pb.shared_config.preprocessed_metadata_uri = ( + preprocessed_metadata_uri.uri + ) + + gbml_config_pb_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb) + + condensed_to_node_type = EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB_WRAPPER.condensed_node_type_to_node_type_map + condensed_to_edge_type = EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB_WRAPPER.condensed_edge_type_to_edge_type_map + expected_node_types = { + condensed_to_node_type[c] + for c in EXAMPLE_HETEROGENEOUS_CONDENSED_NODE_TYPES + } + expected_edge_types = { + condensed_to_edge_type[c] + for c in EXAMPLE_HETEROGENEOUS_CONDENSED_EDGE_TYPES + } + + # Keysets reflect every condensed type rekeyed via the type map. + self.assertEqual( + set(gbml_config_pb_wrapper.node_type_to_feature_dim_map.keys()), + expected_node_types, + ) + self.assertEqual( + set(gbml_config_pb_wrapper.node_type_to_feature_schema_map.keys()), + expected_node_types, + ) + self.assertEqual( + set(gbml_config_pb_wrapper.edge_type_to_feature_dim_map.keys()), + expected_edge_types, + ) + + # Values agree with the condensed-keyed source for every entry. + preprocessed = gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper + for ( + condensed_node_type, + expected_dim, + ) in preprocessed.condensed_node_type_to_feature_dim_map.items(): + node_type = condensed_to_node_type[condensed_node_type] + self.assertEqual( + gbml_config_pb_wrapper.node_type_to_feature_dim_map[node_type], + expected_dim, + ) + for ( + condensed_node_type, + expected_schema, + ) in preprocessed.condensed_node_type_to_feature_schema_map.items(): + node_type = condensed_to_node_type[condensed_node_type] + self.assertEqual( + gbml_config_pb_wrapper.node_type_to_feature_schema_map[node_type], + expected_schema, + ) + for ( + condensed_edge_type, + expected_dim, + ) in preprocessed.condensed_edge_type_to_feature_dim_map.items(): + edge_type = condensed_to_edge_type[condensed_edge_type] + self.assertEqual( + gbml_config_pb_wrapper.edge_type_to_feature_dim_map[edge_type], + expected_dim, + ) + + def test_typed_keyed_feature_maps_default_empty_when_metadata_missing(self): + """When neither graph_metadata nor preprocessed_metadata are populated + on the input GbmlConfig, the typed-keyed maps default to empty dicts. + + This matches the behavior of the underlying loaders, which silently + skip population when their inputs are absent. + """ + gbml_config_pb = gbml_config_pb2.GbmlConfig() + gbml_config_pb_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb) + + self.assertEqual(gbml_config_pb_wrapper.node_type_to_feature_dim_map, {}) + self.assertEqual(gbml_config_pb_wrapper.node_type_to_feature_schema_map, {}) + self.assertEqual(gbml_config_pb_wrapper.edge_type_to_feature_dim_map, {})