Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions gigl/src/common/types/pb_wrappers/gbml_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.src.common.types.pb_wrappers.dataset_metadata import DatasetMetadataPbWrapper
from gigl.src.common.types.pb_wrappers.flattened_graph_metadata import (
FlattenedGraphMetadataPbWrapper,
Expand All @@ -23,6 +24,7 @@
TrainedModelMetadataPbWrapper,
)
from gigl.src.common.utils.file_loader import FileLoader
from gigl.src.data_preprocessor.lib.types import FeatureSchema
from snapchat.research.gbml import (
dataset_metadata_pb2,
flattened_graph_metadata_pb2,
Expand Down Expand Up @@ -62,6 +64,16 @@ class GbmlConfigPbWrapper:
SubgraphSamplingStrategyPbWrapper
] = field(default=None, init=False)

_node_type_to_feature_dim_map: dict[NodeType, int] = field(
default_factory=dict, init=False
)
_node_type_to_feature_schema_map: dict[NodeType, FeatureSchema] = field(
default_factory=dict, init=False
)
_edge_type_to_feature_dim_map: dict[EdgeType, int] = field(
default_factory=dict, init=False
)

def __post_init__(self):
# Populate the _preprocessed_metadata_pb_wrapper field
self.__load_preprocessed_metadata_pb_wrapper(
Expand All @@ -80,6 +92,44 @@ def __post_init__(self):
self.__load_graph_metadata_pb_wrapper(
graph_metadata_pb=self.gbml_config_pb.graph_metadata
)
# Derive typed-keyed feature maps by joining the just-populated
# _preprocessed_metadata_pb_wrapper (condensed-keyed) with
# _graph_metadata_pb_wrapper (condensed→typed lookup).
if hasattr(self, "_graph_metadata_pb_wrapper") and hasattr(
self, "_preprocessed_metadata_pb_wrapper"
):
graph_metadata = self._graph_metadata_pb_wrapper
preprocessed_metadata = self._preprocessed_metadata_pb_wrapper
object.__setattr__(
self,
"_node_type_to_feature_dim_map",
{
graph_metadata.condensed_node_type_to_node_type_map[
condensed_node_type
]: feature_dim
for condensed_node_type, feature_dim in preprocessed_metadata.condensed_node_type_to_feature_dim_map.items()
},
)
object.__setattr__(
self,
"_node_type_to_feature_schema_map",
{
graph_metadata.condensed_node_type_to_node_type_map[
condensed_node_type
]: feature_schema
for condensed_node_type, feature_schema in preprocessed_metadata.condensed_node_type_to_feature_schema_map.items()
},
)
object.__setattr__(
self,
"_edge_type_to_feature_dim_map",
{
graph_metadata.condensed_edge_type_to_edge_type_map[
condensed_edge_type
]: feature_dim
for condensed_edge_type, feature_dim in preprocessed_metadata.condensed_edge_type_to_feature_dim_map.items()
},
)
# Populate the _flattened_graph_metadata_pb_wrapper field
if self.gbml_config_pb.shared_config.HasField("flattened_graph_metadata"):
flattened_graph_metadata_pb = (
Expand Down Expand Up @@ -365,6 +415,30 @@ def preprocessed_metadata_pb_wrapper(self) -> PreprocessedMetadataPbWrapper:
)
return self._preprocessed_metadata_pb_wrapper

@property
def node_type_to_feature_dim_map(self) -> dict[NodeType, int]:
Comment thread
svij-sc marked this conversation as resolved.
"""
Returns:
dict[NodeType, int]: Mapping from NodeType to its input feature dimension.
"""
return self._node_type_to_feature_dim_map

@property
def node_type_to_feature_schema_map(self) -> dict[NodeType, FeatureSchema]:
"""
Returns:
dict[NodeType, FeatureSchema]: Mapping from NodeType to its FeatureSchema.
"""
return self._node_type_to_feature_schema_map

@property
def edge_type_to_feature_dim_map(self) -> dict[EdgeType, int]:
"""
Returns:
dict[EdgeType, int]: Mapping from EdgeType to its feature dimension.
"""
return self._edge_type_to_feature_dim_map

@property
def trained_model_metadata_pb_wrapper(self) -> TrainedModelMetadataPbWrapper:
"""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ lint = [
"ruff==0.15.10",
"mdformat==0.7.22",
"mdformat_tables==1.0.0",
"ty~=0.0.29",
"ty==0.0.31",
Comment thread
svij-sc marked this conversation as resolved.
"mypy-protobuf==3.3.0", # Used for protobuf stub generation (protoc-gen-mypy), not type checking
]

Expand Down
106 changes: 106 additions & 0 deletions tests/unit/src/common/utils/gbml_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@

from gigl.common import LocalUri
from gigl.common.utils.proto_utils import ProtoUtils
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.file_loader import FileLoader
from snapchat.research.gbml import gbml_config_pb2
from tests.test_assets.graph_metadata_constants import (
EXAMPLE_HETEROGENEOUS_CONDENSED_EDGE_TYPES,
EXAMPLE_HETEROGENEOUS_CONDENSED_NODE_TYPES,
EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB,
EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB_WRAPPER,
EXAMPLE_HETEROGENEOUS_PREPROCESSED_METADATA_PB,
)
from tests.test_assets.test_case import TestCase


Expand Down Expand Up @@ -43,3 +51,101 @@ def test_gbml_config_read_and_write_yaml(self):
uri=self.target_yaml_uri, proto_cls=gbml_config_pb2.GbmlConfig
)
self.assertEqual(obj, obj2)

def test_typed_keyed_feature_maps_match_condensed_maps(self):
"""`*_to_feature_*_map` properties are typed-key views of the condensed maps.

Builds a GbmlConfigPbWrapper with both graph_metadata and
preprocessed_metadata populated (heterogeneous example fixture), then
verifies each typed-keyed property contains the same values as the
underlying condensed-keyed map after joining through the
graph_metadata's condensed→typed map.
"""
# Stage the heterogeneous preprocessed metadata on disk so
# __load_preprocessed_metadata_pb_wrapper can read it via the URI.
preprocessed_metadata_uri = LocalUri.join(
self.tmp_dir.name,
f"{self.gbml_config_test_run_id}_preprocessed_metadata.yaml",
)
self.proto_utils.write_proto_to_yaml(
proto=EXAMPLE_HETEROGENEOUS_PREPROCESSED_METADATA_PB,
uri=preprocessed_metadata_uri,
)

gbml_config_pb = gbml_config_pb2.GbmlConfig(
graph_metadata=EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB,
)
gbml_config_pb.shared_config.preprocessed_metadata_uri = (
preprocessed_metadata_uri.uri
)

gbml_config_pb_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)

condensed_to_node_type = EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB_WRAPPER.condensed_node_type_to_node_type_map
condensed_to_edge_type = EXAMPLE_HETEROGENEOUS_GRAPH_METADATA_PB_WRAPPER.condensed_edge_type_to_edge_type_map
expected_node_types = {
condensed_to_node_type[c]
for c in EXAMPLE_HETEROGENEOUS_CONDENSED_NODE_TYPES
}
expected_edge_types = {
condensed_to_edge_type[c]
for c in EXAMPLE_HETEROGENEOUS_CONDENSED_EDGE_TYPES
}

# Keysets reflect every condensed type rekeyed via the type map.
self.assertEqual(
set(gbml_config_pb_wrapper.node_type_to_feature_dim_map.keys()),
expected_node_types,
)
self.assertEqual(
set(gbml_config_pb_wrapper.node_type_to_feature_schema_map.keys()),
expected_node_types,
)
self.assertEqual(
set(gbml_config_pb_wrapper.edge_type_to_feature_dim_map.keys()),
expected_edge_types,
)

# Values agree with the condensed-keyed source for every entry.
preprocessed = gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper
for (
condensed_node_type,
expected_dim,
) in preprocessed.condensed_node_type_to_feature_dim_map.items():
node_type = condensed_to_node_type[condensed_node_type]
self.assertEqual(
gbml_config_pb_wrapper.node_type_to_feature_dim_map[node_type],
expected_dim,
)
for (
condensed_node_type,
expected_schema,
) in preprocessed.condensed_node_type_to_feature_schema_map.items():
node_type = condensed_to_node_type[condensed_node_type]
self.assertEqual(
gbml_config_pb_wrapper.node_type_to_feature_schema_map[node_type],
expected_schema,
)
for (
condensed_edge_type,
expected_dim,
) in preprocessed.condensed_edge_type_to_feature_dim_map.items():
edge_type = condensed_to_edge_type[condensed_edge_type]
self.assertEqual(
gbml_config_pb_wrapper.edge_type_to_feature_dim_map[edge_type],
expected_dim,
)

def test_typed_keyed_feature_maps_default_empty_when_metadata_missing(self):
"""When neither graph_metadata nor preprocessed_metadata are populated
on the input GbmlConfig, the typed-keyed maps default to empty dicts.

This matches the behavior of the underlying loaders, which silently
skip population when their inputs are absent.
"""
gbml_config_pb = gbml_config_pb2.GbmlConfig()
gbml_config_pb_wrapper = GbmlConfigPbWrapper(gbml_config_pb=gbml_config_pb)

self.assertEqual(gbml_config_pb_wrapper.node_type_to_feature_dim_map, {})
self.assertEqual(gbml_config_pb_wrapper.node_type_to_feature_schema_map, {})
self.assertEqual(gbml_config_pb_wrapper.edge_type_to_feature_dim_map, {})