Enabled Weighted Sampling#635
Conversation
kmontemayor2-sc
left a comment
There was a problem hiding this comment.
Thanks Matt! Me and the robots did a first pass, it's possible they're imagining some of the issues here but I figured I'd flag :)
|
/unit_test |
GiGL Automation@ 23:03:34UTC : 🔄 @ 24:11:45UTC : ❌ Workflow failed. |
GiGL Automation@ 23:03:35UTC : 🔄 @ 23:05:34UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 23:03:35UTC : 🔄 @ 23:13:48UTC : ✅ Workflow completed successfully. |
|
/unit_test |
GiGL Automation@ 06:09:08UTC : 🔄 @ 06:19:33UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 06:09:09UTC : 🔄 @ 06:13:19UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 06:09:09UTC : 🔄 @ 07:13:24UTC : ✅ Workflow completed successfully. |
|
/unit_test |
GiGL Automation@ 22:16:50UTC : 🔄 @ 23:36:03UTC : ❌ Workflow failed. |
GiGL Automation@ 22:16:51UTC : 🔄 @ 22:26:16UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 22:16:51UTC : 🔄 @ 22:18:35UTC : ✅ Workflow completed successfully. |
Summary
Adds native weighted edge sampling to GiGL's distributed training pipeline via GLT's
CPUWeightedSampler. When enabled, neighbors are sampled proportionally to edge weights rather than uniformly.New API
DistPartitioner.register_edge_weights(edge_weights)— registers a 1D per-edge weight tensor (homogeneous ordict[EdgeType, Tensor]for heterogeneous) before callingpartition_edge_index_and_edge_features(). Weights are partitioned alongside edge features in the same pass (co-partitioned, mirroring the node features + labels pattern).load_torch_tensors_from_tf_record(weight_edge_feat_name=...)— accepts the name of an existing edge feature column to extract as sampling weights during TFRecord loading. The column is sliced out of the feature tensor and stored inLoadedGraphTensors.edge_weights; it is never duplicated in memory.build_dataset(weight_edge_feat_name=...)— threadsweight_edge_feat_namethrough to TFRecord loading and then callsregister_edge_weights()with the extracted weights.DistNeighborLoader(with_weight=True)/DistABLPLoader(with_weight=True)— enables weighted sampling. Defaults toFalse; must be set explicitly.BaseDistLoader.validate_with_weight()— shared validation: raisesValueErrorifwith_weight=Truebut no weights are registered in the dataset; raisesNotImplementedErrorif used withPPRSamplerOptions(weight-proportional PPR residual propagation is deferred to a future PR).Implementation notes
LoadedGraphTensors.edge_weights— new field carrying extracted weights from TFRecord loading through toregister_edge_weights().GraphPartitionData.weights(field already existed) carries the partitioned weight tensor toDistDataset._initialize_graph(), which forwards it to GLT'sinit_graph(edge_weights=...).DistDataset.has_edge_weightsproperty reflects whether weights were registered at construction time.SamplingConfig.with_weightis now threaded through from the loader rather than hardcoded toFalse.DistServer.get_edge_weights_registered()andRemoteDistDataset.fetch_edge_weights_registered()propagatehas_edge_weightsacross the RPC boundary so compute nodes can validatewith_weightagainst the remote dataset.Tests
tests/unit/distributed/distributed_weighted_sampling_test.py(8 new tests):GraphPartitionData.edge_ids == FeaturePartitionData.ids), and heterogeneous partial weights (one edge type weighted, another not).tests/unit/common/data/dataloaders_test.py(1 new test):test_load_edge_weights_from_tf_record— verifies thatload_torch_tensors_from_tf_recordcorrectly extracts a named column intoedge_weights, removes it fromedge_features, and returns the right shapes and values.