diff --git a/gigl-core/core/sampling/ppr_forward_push.cpp b/gigl-core/core/sampling/ppr_forward_push.cpp index 9a2a17f03..c71d81ada 100644 --- a/gigl-core/core/sampling/ppr_forward_push.cpp +++ b/gigl-core/core/sampling/ppr_forward_push.cpp @@ -147,20 +147,33 @@ std::optional> PPRForwardPush::drainQ } void PPRForwardPush::pushResiduals( - const std::unordered_map>& fetchedByEtypeId) { - // Step 1: Unpack the input map into a C++ map keyed by packKey(nodeId, edgeTypeId) + const std::unordered_map>& + fetchedByEtypeId) { + // Step 1: Unpack the input map into C++ maps keyed by packKey(nodeId, edgeTypeId) // for fast lookup during the residual-push loop below. + // fetchedWeights is populated only in weighted mode (_hasWeights becomes true on + // the first call that includes a non-empty flat_weights tensor). std::unordered_map> fetched; + std::unordered_map> fetchedWeights; + for (const auto& [edgeTypeId, neighborTensors] : fetchedByEtypeId) { const auto& nodeIdsTensor = std::get<0>(neighborTensors); const auto& flatNeighborIdsTensor = std::get<1>(neighborTensors); const auto& countsTensor = std::get<2>(neighborTensors); + const auto& flatWeightsTensor = std::get<3>(neighborTensors); + + bool etypeHasWeights = flatWeightsTensor.numel() > 0; + if (etypeHasWeights) { + _hasWeights = true; + } // accessor() gives a bounds-checked, typed 1-D view into // each tensor's data — equivalent to iterating over a NumPy array. auto nodeIdsAccessor = nodeIdsTensor.accessor(); auto flatNeighborIdsAccessor = flatNeighborIdsTensor.accessor(); auto countsAccessor = countsTensor.accessor(); + // Raw pointer for weights avoids a conditional accessor construction. + const double* flatWeightsPtr = etypeHasWeights ? flatWeightsTensor.data_ptr() : nullptr; // Walk the flat neighbor list, slicing out each node's neighbors using // the running offset into the concatenated flat buffer. @@ -168,20 +181,54 @@ void PPRForwardPush::pushResiduals( for (int64_t nodeIdx = 0; nodeIdx < nodeIdsTensor.size(0); ++nodeIdx) { auto nodeId = static_cast(nodeIdsAccessor[nodeIdx]); int64_t count = countsAccessor[nodeIdx]; + uint64_t key = packKey(nodeId, edgeTypeId); + std::vector neighborIds(count); for (int64_t neighborIdx = 0; neighborIdx < count; ++neighborIdx) { neighborIds[neighborIdx] = static_cast(flatNeighborIdsAccessor[offset + neighborIdx]); } - fetched[packKey(nodeId, edgeTypeId)] = std::move(neighborIds); + fetched[key] = std::move(neighborIds); + + if (flatWeightsPtr != nullptr) { + std::vector neighborWeights(count); + for (int64_t neighborIdx = 0; neighborIdx < count; ++neighborIdx) { + neighborWeights[neighborIdx] = flatWeightsPtr[offset + neighborIdx]; + } + fetchedWeights[key] = std::move(neighborWeights); + } + offset += count; } } + // Promote neighbor and weight lists for a newly re-queued node into the persistent cache. + // Called from both uniform and weighted paths — the _hasWeights guard inside handles + // whether _weightCache is populated. + auto promoteToCache = [&](int32_t neighborNodeId, int32_t dstNodeTypeId) { + for (int32_t neighborEdgeTypeId : _nodeTypeToEdgeTypeIds[dstNodeTypeId]) { + uint64_t packedKey = packKey(neighborNodeId, neighborEdgeTypeId); + if (_neighborCache.find(packedKey) == _neighborCache.end()) { + auto fetchedNeighborEntry = fetched.find(packedKey); + if (fetchedNeighborEntry != fetched.end()) { + _neighborCache[packedKey] = fetchedNeighborEntry->second; + if (_hasWeights) { + auto fetchedWeightsNeighborEntry = fetchedWeights.find(packedKey); + if (fetchedWeightsNeighborEntry != fetchedWeights.end()) { + _weightCache[packedKey] = fetchedWeightsNeighborEntry->second; + } + } + } + } + } + }; + // Step 2: For every node that was in the queue (captured in _queuedNodes // by drainQueue()), apply one PPR push step: // a. Absorb residual into the PPR score. - // b. Distribute (1-alpha) * residual equally to each neighbor. - // c. Enqueue any neighbor whose residual now exceeds the requeue threshold. + // b. Compute the normalisation factor (neighbor count or total weight). + // c. Distribute (1-alpha) * residual to each neighbor: uniformly when + // _hasWeights is false; proportionally to edge weight when true. + // d. Enqueue any neighbor whose residual now exceeds the requeue threshold. for (int32_t seedIdx = 0; seedIdx < _batchSize; ++seedIdx) { for (int32_t nodeTypeId = 0; nodeTypeId < _numNodeTypes; ++nodeTypeId) { auto& srcNodeTypeState = _state[seedIdx][nodeTypeId]; @@ -197,89 +244,167 @@ void PPRForwardPush::pushResiduals( srcNodeTypeState.pprScores[sourceNodeId] += sourceResidual; srcNodeTypeState.residuals[sourceNodeId] = 0.0; - // b. Count total fetched/cached neighbors across all edge types for - // this source node. We normalise by the number of neighbors we - // actually retrieved, not the true degree, so residual is fully - // distributed among known neighbors rather than leaking to unfetched - // ones (which matters when num_neighbors_per_hop < true_degree). - int32_t totalFetched = 0; - for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { - auto fetchedEntry = fetched.find(packKey(sourceNodeId, edgeTypeId)); - if (fetchedEntry != fetched.end()) { - totalFetched += static_cast(fetchedEntry->second.size()); - } else { - auto cachedEntry = _neighborCache.find(packKey(sourceNodeId, edgeTypeId)); - if (cachedEntry != _neighborCache.end()) { - totalFetched += static_cast(cachedEntry->second.size()); + if (!_hasWeights) { + // --- Uniform path --- + // b. Count total fetched/cached neighbors across all edge types for + // this source node. We normalise by the number of neighbors we + // actually retrieved, not the true degree, so residual is fully + // distributed among known neighbors rather than leaking to unfetched + // ones (which matters when num_neighbors_per_hop < true_degree). + int32_t totalFetched = 0; + for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { + auto fetchedEntry = fetched.find(packKey(sourceNodeId, edgeTypeId)); + if (fetchedEntry != fetched.end()) { + totalFetched += static_cast(fetchedEntry->second.size()); + } else { + auto cachedEntry = _neighborCache.find(packKey(sourceNodeId, edgeTypeId)); + if (cachedEntry != _neighborCache.end()) { + totalFetched += static_cast(cachedEntry->second.size()); + } } } - } - // Two cases reach here: - // 1. True sink node (no outgoing edges): absorbing the full residual is correct. - // 2. Budget exhausted, no cache entry: the (1-α)·r that should flow to - // neighbors has nowhere to go, so it gets absorbed into src's score instead. - // This overstates src and understates its neighbors. This is expected - // behavior when max_fetch_iterations is set, which intentionally trades - // theoretical PPR correctness for better throughput. - if (totalFetched == 0) { - continue; - } + // Two cases reach here: + // 1. True sink node (no outgoing edges): absorbing the full residual is correct. + // 2. Budget exhausted, no cache entry: the (1-α)·r that should flow to + // neighbors has nowhere to go, so it gets absorbed into src's score instead. + // This overstates src and understates its neighbors. This is expected + // behavior when max_fetch_iterations is set, which intentionally trades + // theoretical PPR correctness for better throughput. + if (totalFetched == 0) { + continue; + } - double residualPerNeighbor = (1.0 - _alpha) * sourceResidual / static_cast(totalFetched); + double residualPerNeighbor = (1.0 - _alpha) * sourceResidual / static_cast(totalFetched); + + for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { + // Invariant: fetched and _neighborCache are mutually exclusive for + // any given (node, etype) key within one iteration. drainQueue() + // only requests a fetch for nodes absent from _neighborCache, so a + // key is in at most one of the two. + // + // Neighbor list for this (src, edgeTypeId) pair, borrowed from whichever + // map holds it. reference_wrapper is used because std::optional cannot + // hold a reference directly, and we want to avoid copying the vector — + // the data already exists in fetched or _neighborCache and both outlive + // this loop body. Access via neighborList->get(). + std::optional>> neighborList; + auto fetchedEntry = fetched.find(packKey(sourceNodeId, edgeTypeId)); + if (fetchedEntry != fetched.end()) { + neighborList = std::cref(fetchedEntry->second); + } else { + auto cachedEntry = _neighborCache.find(packKey(sourceNodeId, edgeTypeId)); + if (cachedEntry != _neighborCache.end()) { + neighborList = std::cref(cachedEntry->second); + } + } + if (!neighborList || neighborList->get().empty()) { + continue; + } - for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { - // Invariant: fetched and _neighborCache are mutually exclusive for - // any given (node, etype) key within one iteration. drainQueue() - // only requests a fetch for nodes absent from _neighborCache, so a - // key is in at most one of the two. - // - // Neighbor list for this (src, edgeTypeId) pair, borrowed from whichever - // map holds it. reference_wrapper is used because std::optional cannot - // hold a reference directly, and we want to avoid copying the vector — - // the data already exists in fetched or _neighborCache and both outlive - // this loop body. Access via neighborList->get(). - std::optional>> neighborList; - auto fetchedEntry = fetched.find(packKey(sourceNodeId, edgeTypeId)); - if (fetchedEntry != fetched.end()) { - neighborList = std::cref(fetchedEntry->second); - } else { - auto cachedEntry = _neighborCache.find(packKey(sourceNodeId, edgeTypeId)); - if (cachedEntry != _neighborCache.end()) { - neighborList = std::cref(cachedEntry->second); + int32_t dstNodeTypeId = _edgeTypeToDstNtypeId[edgeTypeId]; + + // c. Accumulate residual for each neighbor and re-enqueue if threshold + // exceeded. + auto& dstNodeTypeState = _state[seedIdx][dstNodeTypeId]; + for (int32_t neighborNodeId : neighborList->get()) { + dstNodeTypeState.residuals[neighborNodeId] += residualPerNeighbor; + + double threshold = _requeueThresholdFactor * + static_cast(getTotalDegree(neighborNodeId, dstNodeTypeId)); + + if (dstNodeTypeState.queue.find(neighborNodeId) == dstNodeTypeState.queue.end() && + dstNodeTypeState.residuals[neighborNodeId] >= threshold) { + dstNodeTypeState.queue.insert(neighborNodeId); + ++_numNodesInQueue; + promoteToCache(neighborNodeId, dstNodeTypeId); + } + } + } + } else { + // --- Weighted path --- + // b. Sum total weight of fetched/cached neighbors across all edge types. + // We normalise by total fetched weight rather than by true out-weight so + // that the residual is fully distributed among known neighbors, consistent + // with how the uniform path handles truncated neighbor lists. + double totalFetchedWeight = 0.0; + for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { + uint64_t key = packKey(sourceNodeId, edgeTypeId); + auto fetchedWeightsEntry = fetchedWeights.find(key); + if (fetchedWeightsEntry != fetchedWeights.end()) { + for (double w : fetchedWeightsEntry->second) { + totalFetchedWeight += w; + } + } else { + auto cachedWeightsEntry = _weightCache.find(key); + if (cachedWeightsEntry != _weightCache.end()) { + for (double w : cachedWeightsEntry->second) { + totalFetchedWeight += w; + } + } } } - if (!neighborList || neighborList->get().empty()) { + // Sink node or all-zero-weight edges: absorb residual, nothing to distribute. + if (totalFetchedWeight == 0.0) { continue; } - int32_t dstNodeTypeId = _edgeTypeToDstNtypeId[edgeTypeId]; - - // c. Accumulate residual for each neighbor and re-enqueue if threshold - // exceeded. - auto& dstNodeTypeState = _state[seedIdx][dstNodeTypeId]; - for (int32_t neighborNodeId : neighborList->get()) { - dstNodeTypeState.residuals[neighborNodeId] += residualPerNeighbor; - - double threshold = _requeueThresholdFactor * - static_cast(getTotalDegree(neighborNodeId, dstNodeTypeId)); - - if (dstNodeTypeState.queue.find(neighborNodeId) == dstNodeTypeState.queue.end() && - dstNodeTypeState.residuals[neighborNodeId] >= threshold) { - dstNodeTypeState.queue.insert(neighborNodeId); - ++_numNodesInQueue; - - // Promote neighbor lists to the persistent cache: this node will - // be processed next iteration, so caching avoids a re-fetch. - for (int32_t neighborEdgeTypeId : _nodeTypeToEdgeTypeIds[dstNodeTypeId]) { - uint64_t packedKey = packKey(neighborNodeId, neighborEdgeTypeId); - if (_neighborCache.find(packedKey) == _neighborCache.end()) { - auto fetchedNeighborEntry = fetched.find(packedKey); - if (fetchedNeighborEntry != fetched.end()) { - _neighborCache[packedKey] = fetchedNeighborEntry->second; - } + double baseResidual = (1.0 - _alpha) * sourceResidual; + + for (int32_t edgeTypeId : _nodeTypeToEdgeTypeIds[nodeTypeId]) { + uint64_t key = packKey(sourceNodeId, edgeTypeId); + + std::optional>> neighborList; + std::optional>> weightList; + + auto fetchedEntry = fetched.find(key); + if (fetchedEntry != fetched.end()) { + neighborList = std::cref(fetchedEntry->second); + auto fetchedWeightsEntry = fetchedWeights.find(key); + if (fetchedWeightsEntry != fetchedWeights.end()) { + weightList = std::cref(fetchedWeightsEntry->second); + } + } else { + auto cachedEntry = _neighborCache.find(key); + if (cachedEntry != _neighborCache.end()) { + neighborList = std::cref(cachedEntry->second); + auto cachedWeightsEntry = _weightCache.find(key); + if (cachedWeightsEntry != _weightCache.end()) { + weightList = std::cref(cachedWeightsEntry->second); } } } + if (!neighborList || neighborList->get().empty()) { + continue; + } + + int32_t dstNodeTypeId = _edgeTypeToDstNtypeId[edgeTypeId]; + auto& dstNodeTypeState = _state[seedIdx][dstNodeTypeId]; + + // c. Accumulate weight-proportional residual for each neighbor. + // weightList is always populated alongside neighborList in weighted mode: + // fetched[key] and fetchedWeights[key] are set together in Step 1, + // and _neighborCache[key] and _weightCache[key] are promoted together. + TORCH_INTERNAL_ASSERT(weightList.has_value(), + "weightList must be populated alongside neighborList in weighted mode"); + const auto& neighbors = neighborList->get(); + const auto& weights = weightList->get(); + TORCH_INTERNAL_ASSERT(weights.size() == neighbors.size(), + "weightList and neighborList must have the same size"); + for (int32_t i = 0; i < static_cast(neighbors.size()); ++i) { + int32_t neighborNodeId = neighbors[i]; + dstNodeTypeState.residuals[neighborNodeId] += + baseResidual * weights[i] / totalFetchedWeight; + + double threshold = _requeueThresholdFactor * + static_cast(getTotalDegree(neighborNodeId, dstNodeTypeId)); + + if (dstNodeTypeState.queue.find(neighborNodeId) == dstNodeTypeState.queue.end() && + dstNodeTypeState.residuals[neighborNodeId] >= threshold) { + dstNodeTypeState.queue.insert(neighborNodeId); + ++_numNodesInQueue; + promoteToCache(neighborNodeId, dstNodeTypeId); + } + } } } } diff --git a/gigl-core/core/sampling/ppr_forward_push.h b/gigl-core/core/sampling/ppr_forward_push.h index 1c1eef670..2ba811e1a 100644 --- a/gigl-core/core/sampling/ppr_forward_push.h +++ b/gigl-core/core/sampling/ppr_forward_push.h @@ -52,9 +52,12 @@ class PPRForwardPush { std::optional> drainQueue(); // Push residuals given fetched neighbor data. - // fetchedByEtypeId: {etype_id: (node_ids[N], flat_nbrs[sum(counts)], counts[N])} + // fetchedByEtypeId: {etype_id: (node_ids[N], flat_nbrs[sum(counts)], counts[N], flat_weights[sum(counts)])} + // flat_weights is empty (numel()==0) for uniform-residual mode; non-empty for + // weight-proportional mode. _hasWeights is latched true on the first call with a + // non-empty flat_weights and never reset within one PPRForwardPush lifetime. void pushResiduals(const std::unordered_map< - int32_t, std::tuple>& + int32_t, std::tuple>& fetchedByEtypeId); // Return top-k PPR nodes per seed per node type. @@ -103,6 +106,14 @@ class PPRForwardPush { // impractical (contrast with _state above). Populated incrementally; avoids re-fetching. std::unordered_map> _neighborCache; + // True once any pushResiduals call receives a non-empty flat_weights tensor. + // Latched true for the object lifetime; never reset. + bool _hasWeights{false}; + + // Per-edge weights parallel to _neighborCache: _weightCache[packKey(node, etype)][i] + // is the weight of the i-th cached neighbor. Only populated in weighted mode. + std::unordered_map> _weightCache; + }; } // namespace gigl diff --git a/gigl-core/core/sampling/python_ppr_forward_push.cpp b/gigl-core/core/sampling/python_ppr_forward_push.cpp index 22981a48a..4d3674a2f 100644 --- a/gigl-core/core/sampling/python_ppr_forward_push.cpp +++ b/gigl-core/core/sampling/python_ppr_forward_push.cpp @@ -20,15 +20,21 @@ namespace gigl { // pushResiduals: a wrapper is needed solely to release the GIL during the C++ push. // pybind11/stl.h handles all type conversions automatically; the other methods use // direct member function pointers for the same reason. +// +// Each tuple value is (node_ids, flat_nbrs, counts, flat_weights). flat_weights is +// an empty tensor in uniform-residual mode and a non-empty float64 tensor in +// weight-proportional mode. static void pushResidualsWrapper(PPRForwardPush& state, const py::dict& fetchedByEtypeId) { - std::unordered_map> neighborTensorsByEtypeId; + std::unordered_map> + neighborTensorsByEtypeId; // Dict iteration touches Python objects — GIL must be held here. for (auto item : fetchedByEtypeId) { auto edgeTypeId = item.first.cast(); auto neighborTensors = item.second.cast(); neighborTensorsByEtypeId[edgeTypeId] = {neighborTensors[0].cast(), neighborTensors[1].cast(), - neighborTensors[2].cast()}; + neighborTensors[2].cast(), + neighborTensors[3].cast()}; } // C++ push only uses tensor accessor/data_ptr APIs — GIL-safe to release. // Releasing here lets the asyncio event loop process RPC completion callbacks diff --git a/gigl-core/src/gigl_core/ppr_forward_push.pyi b/gigl-core/src/gigl_core/ppr_forward_push.pyi index 0c1ea79af..31b052dd0 100644 --- a/gigl-core/src/gigl_core/ppr_forward_push.pyi +++ b/gigl-core/src/gigl_core/ppr_forward_push.pyi @@ -14,7 +14,9 @@ class PPRForwardPush: def drain_queue(self) -> dict[int, torch.Tensor] | None: ... def push_residuals( self, - fetched_by_etype_id: dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + fetched_by_etype_id: dict[ + int, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + ], ) -> None: ... def extract_top_k( self, max_ppr_nodes: int diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index d993d83ca..4bacf8478 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -348,7 +348,6 @@ def validate_for_weighted_sampling( Raises: ValueError: If ``with_weight=True`` but no edge weights are registered. - NotImplementedError: If ``with_weight=True`` and a PPR sampler is requested. """ if not with_weight: return @@ -362,12 +361,6 @@ def validate_for_weighted_sampling( "with_weight=True requires edge weights to be registered in the dataset. " "Pass weight_edge_feat_name to build_dataset() to register edge weights." ) - # TODO(mkolodner-sc): Implement weight-proportional residual propagation for PPR. - if with_weight and isinstance(sampler_options, PPRSamplerOptions): - raise NotImplementedError( - "Weighted sampling is not yet supported with PPRSamplerOptions. " - "Weight-proportional residual propagation for PPR is planned but not implemented." - ) @staticmethod def create_sampling_config( diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 402e381c1..8e61170aa 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -48,6 +48,14 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler): scores are approximated here using the Forward Push algorithm (Andersen et al., 2006). + **Weighted PPR**: when ``edge_weights`` is provided, neighbor fetching during + traversal always uses uniform sampling (all neighbors are fetched without + weight-biased selection). The weighting is applied exclusively in how the PPR + residual is spread: each neighbor receives residual proportional to its edge + weight rather than an equal share. This is the correct formulation — using + weighted sampling during traversal would double-count high-weight edges (once + by over-representing them in the fetch and again by giving them more residual). + This sampler supports both homogeneous and heterogeneous graphs. For heterogeneous graphs, the PPR algorithm traverses across all edge types, switching edge types based on the current node type and the configured edge direction. @@ -90,6 +98,9 @@ def __init__( total_degree_dtype: torch.dtype = torch.int32, degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]], max_fetch_iterations: Optional[int] = None, + edge_weights: Optional[ + Union[torch.Tensor, dict[EdgeType, torch.Tensor]] + ] = None, **kwargs, ): super().__init__(*args, **kwargs) @@ -98,6 +109,7 @@ def __init__( self._requeue_threshold_factor = alpha * eps self._num_neighbors_per_hop = num_neighbors_per_hop self._max_fetch_iterations = max_fetch_iterations + self._edge_weights = edge_weights # Build mapping from node type to edge types that can be traversed from that node type. self._node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict( @@ -251,7 +263,7 @@ async def _batch_fetch_neighbors( self, nodes_by_etype_id: dict[int, torch.Tensor], device: torch.device, - ) -> dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + ) -> dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """Batch fetch neighbors for nodes grouped by integer edge type ID. Issues one ``_sample_one_hop`` call per edge type (not per node), so all @@ -265,11 +277,13 @@ async def _batch_fetch_neighbors( device: Torch device for intermediate tensor creation. Returns: - Dict mapping etype_id to ``(node_ids, flat_neighbors, counts)`` as - int64 tensors, ready to pass directly to ``push_residuals``. + Dict mapping etype_id to ``(node_ids, flat_neighbors, counts, flat_weights)`` + as tensors, ready to pass directly to ``push_residuals``. ``flat_neighbors`` is the flat concatenation of all neighbor lists for that edge type; ``counts[i]`` is the neighbor count for - ``node_ids[i]``. + ``node_ids[i]``. ``flat_weights`` is a float64 tensor of the same + shape as ``flat_neighbors`` in weighted mode, or an empty tensor in + uniform mode. Example:: @@ -279,8 +293,8 @@ async def _batch_fetch_neighbors( } # Might return (neighbor lists depend on graph structure): { - 2: (tensor([0, 3]), tensor([5, 9, 2, 1]), tensor([3, 1])), - 5: (tensor([7]), tensor([0, 3]), tensor([2])), + 2: (tensor([0, 3]), tensor([5, 9, 2, 1]), tensor([3, 1]), tensor([])), + 5: (tensor([7]), tensor([0, 3]), tensor([2]), tensor([])), } """ # Fire all per-edge-type RPC calls concurrently. Each _sample_one_hop @@ -299,10 +313,36 @@ async def _batch_fetch_neighbors( ) ) outputs: list[NeighborOutput] = await asyncio.gather(*sample_tasks) - return { - eid: (nodes_by_etype_id[eid], output.nbr, output.nbr_num) - for eid, output in zip(eids, outputs) - } + + _empty_weights = torch.empty(0, dtype=torch.float64) + + result: dict[ + int, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + ] = {} + for eid, output in zip(eids, outputs): + if self._edge_weights is not None: + assert output.edge is not None, ( + "output.edge must be set when edge_weights is provided; " + "ensure with_edge=True in SamplingConfig (hardcoded in create_sampling_config)." + ) + if self._is_homogeneous: + assert isinstance(self._edge_weights, torch.Tensor) + flat_weights = self._edge_weights[output.edge].to(torch.float64) + else: + assert isinstance(self._edge_weights, dict) + etype = self._etype_id_to_etype[eid] + flat_weights = self._edge_weights[etype][output.edge].to( + torch.float64 + ) + else: + flat_weights = _empty_weights + result[eid] = ( + nodes_by_etype_id[eid], + output.nbr, + output.nbr_num, + flat_weights, + ) + return result async def _compute_ppr_scores( self, diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 3a51715e2..13c0ff751 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -103,6 +103,7 @@ def _sampling_worker_loop( sampler_options=sampler_options, degree_tensors=degree_tensors, current_device=current_device, + edge_weights=data.edge_weights, ) dist_sampler.start_loop() diff --git a/gigl/distributed/graph_store/shared_dist_sampling_producer.py b/gigl/distributed/graph_store/shared_dist_sampling_producer.py index 0f7461196..a8a6d6c50 100644 --- a/gigl/distributed/graph_store/shared_dist_sampling_producer.py +++ b/gigl/distributed/graph_store/shared_dist_sampling_producer.py @@ -690,6 +690,7 @@ def _handle_command(command: SharedMpCommand, payload: CommandPayload) -> bool: sampler_options=sampler_options, degree_tensors=degree_tensors, current_device=current_device, + edge_weights=data.edge_weights, ) sampler.start_loop() with state_lock: diff --git a/gigl/distributed/utils/dist_sampler.py b/gigl/distributed/utils/dist_sampler.py index 0333f4138..f218bb8eb 100644 --- a/gigl/distributed/utils/dist_sampler.py +++ b/gigl/distributed/utils/dist_sampler.py @@ -37,6 +37,7 @@ def create_dist_sampler( sampler_options: SamplerOptions, degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], current_device: torch.device, + edge_weights: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] = None, ) -> SamplerRuntime: """Create a GiGL sampler runtime for one channel on one worker. @@ -49,6 +50,10 @@ def create_dist_sampler( degree_tensors: Pre-computed degree tensors required by PPR sampling. Must not be ``None`` when ``sampler_options`` is :class:`PPRSamplerOptions`. current_device: The device on which sampling will run. + edge_weights: Per-edge weight tensors for this rank's partition. Required when + ``sampler_options`` is :class:`PPRSamplerOptions` and + ``sampling_config.with_weight`` is ``True``. Ignored for k-hop sampling + (GLT handles weight-biased sampling internally via ``with_weight``). Returns: A configured sampler runtime, either :class:`DistNeighborSampler` or @@ -77,8 +82,12 @@ def create_dist_sampler( ) elif isinstance(sampler_options, PPRSamplerOptions): assert degree_tensors is not None + # PPR traversal must always use uniform neighbor sampling: biased selection + # would double-count high-weight edges (once in the fetch, once in residual + # distribution). Weight influence is captured entirely in the residual math. + # Pass edge_weights only when with_weight=True; None disables weighted residuals. sampler = DistPPRNeighborSampler( - **shared_sampler_kwargs, + **{**shared_sampler_kwargs, "with_weight": False}, alpha=sampler_options.alpha, eps=sampler_options.eps, max_ppr_nodes=sampler_options.max_ppr_nodes, @@ -86,6 +95,7 @@ def create_dist_sampler( num_neighbors_per_hop=sampler_options.num_neighbors_per_hop, total_degree_dtype=sampler_options.total_degree_dtype, degree_tensors=degree_tensors, + edge_weights=edge_weights if sampling_config.with_weight else None, ) else: raise NotImplementedError( diff --git a/tests/test_assets/distributed/bipartite_weight_graph.py b/tests/test_assets/distributed/bipartite_weight_graph.py new file mode 100644 index 000000000..acf37ba85 --- /dev/null +++ b/tests/test_assets/distributed/bipartite_weight_graph.py @@ -0,0 +1,188 @@ +"""Shared bipartite graph builders for weighted-sampling tests. + +Used by both distributed_weighted_sampling_test and +distributed_ppr_weighted_sampling_test. Each builder returns a single-rank +PartitionOutput that encodes node type as a feature (hub/user=2.0, good=1.0, +bad=0.0) so tests can assert that weight=0 edges never appear in any sampled +subgraph. +""" + +import torch + +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import ( + FeaturePartitionData, + GraphPartitionData, + PartitionOutput, +) + +USER = NodeType("user") +ITEM = NodeType("item") +USER_TO_ITEM = EdgeType(USER, Relation("to"), ITEM) +ITEM_TO_USER = EdgeType(ITEM, Relation("to"), USER) + + +def build_homogeneous_bipartite_weight_graph() -> tuple[PartitionOutput, int]: + """Build a homogeneous graph with hub, good, and bad nodes. + + Graph structure: + - 10 hub nodes (0..9): used as seed nodes; feature value = 2.0 + - 50 good nodes (10..59): reachable from hubs via weight=1 edges; feature = 1.0 + - 40 bad nodes (60..99): reachable from hubs via weight=0 edges; feature = 0.0 + - Each good node also has 5 outgoing weight=1 edges to nearby good nodes + (ring topology, for 2nd-hop sampling). + + With weighted sampling only good nodes should ever appear as sampled + neighbors — weight=0 edges to bad nodes must never be traversed. + + Returns: + (partition_output, n_hub) + """ + n_hub = 10 + n_good = 50 + n_bad = 40 + n = n_hub + n_good + n_bad # 100 + + hub_ids = torch.arange(n_hub) + good_ids = torch.arange(n_hub, n_hub + n_good) + bad_ids = torch.arange(n_hub + n_good, n) + + # Hub → Good: weight=1 + hub_good_src = hub_ids.repeat_interleave(n_good) + hub_good_dst = good_ids.repeat(n_hub) + hub_good_w = torch.ones(n_hub * n_good) + + # Hub → Bad: weight=0 + hub_bad_src = hub_ids.repeat_interleave(n_bad) + hub_bad_dst = bad_ids.repeat(n_hub) + hub_bad_w = torch.zeros(n_hub * n_bad) + + # Good → Good: ring with 5 outgoing edges per node, weight=1 (2nd-hop targets) + connections_per_good = 5 + good_src = good_ids.repeat_interleave(connections_per_good) + # Row i of [connections_per_good, n_good].T gives neighbors of good_ids[i] + good_dst = torch.stack( + [torch.roll(good_ids, -j) for j in range(1, connections_per_good + 1)] + ).T.reshape(-1) + good_w = torch.ones(n_good * connections_per_good) + + edge_src = torch.cat([hub_good_src, hub_bad_src, good_src]) + edge_dst = torch.cat([hub_good_dst, hub_bad_dst, good_dst]) + weights = torch.cat([hub_good_w, hub_bad_w, good_w]) + edge_index = torch.stack([edge_src, edge_dst]) + n_edges = edge_src.shape[0] + + # Feature encodes node type: hub=2.0, good=1.0, bad=0.0 + node_feats = torch.cat( + [ + torch.full((n_hub, 1), 2.0), + torch.full((n_good, 1), 1.0), + torch.full((n_bad, 1), 0.0), + ] + ) + + partition_output = PartitionOutput( + node_partition_book=torch.zeros(n), + edge_partition_book=torch.zeros(n_edges), + partitioned_edge_index=GraphPartitionData( + edge_index=edge_index, + edge_ids=None, + weights=weights, + ), + partitioned_node_features=FeaturePartitionData( + feats=node_feats, + ids=torch.arange(n), + ), + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + return partition_output, n_hub + + +def build_heterogeneous_bipartite_weight_graph() -> tuple[PartitionOutput, int]: + """Build a heterogeneous (user/item) graph with good and bad item nodes. + + Graph structure: + - 10 user nodes (0..9): seed nodes; user feature = 2.0 + - 60 item nodes total: + - Items 0..39: good, reachable from users via weight=1 edges; feature = 1.0 + - Items 40..59: bad, reachable from users via weight=0 edges; feature = 0.0 + - Good items also have weight=1 edges back to all users (for 2nd-hop). + + With weighted sampling only good item nodes should ever appear as sampled + item neighbors. + + Returns: + (partition_output, n_user) + """ + n_user = 10 + n_good_item = 40 + n_bad_item = 20 + n_item = n_good_item + n_bad_item # 60 + + user_ids = torch.arange(n_user) + good_item_ids = torch.arange(n_good_item) + bad_item_ids = torch.arange(n_good_item, n_item) + + # User → Good Item: weight=1 + u2gi_src = user_ids.repeat_interleave(n_good_item) + u2gi_dst = good_item_ids.repeat(n_user) + u2gi_w = torch.ones(n_user * n_good_item) + + # User → Bad Item: weight=0 + u2bi_src = user_ids.repeat_interleave(n_bad_item) + u2bi_dst = bad_item_ids.repeat(n_user) + u2bi_w = torch.zeros(n_user * n_bad_item) + + # Good Item → User: weight=1 (2nd-hop back to users) + gi2u_src = good_item_ids.repeat_interleave(n_user) + gi2u_dst = user_ids.repeat(n_good_item) + gi2u_w = torch.ones(n_good_item * n_user) + + u2i_src = torch.cat([u2gi_src, u2bi_src]) + u2i_dst = torch.cat([u2gi_dst, u2bi_dst]) + u2i_w = torch.cat([u2gi_w, u2bi_w]) + n_u2i_edges = u2i_src.shape[0] + + user_feats = torch.full((n_user, 1), 2.0) + # Item feature encodes type: good=1.0, bad=0.0 + item_feats = torch.cat( + [ + torch.full((n_good_item, 1), 1.0), + torch.full((n_bad_item, 1), 0.0), + ] + ) + + partition_output = PartitionOutput( + node_partition_book={ + USER: torch.zeros(n_user), + ITEM: torch.zeros(n_item), + }, + edge_partition_book={ + USER_TO_ITEM: torch.zeros(n_u2i_edges), + ITEM_TO_USER: torch.zeros(gi2u_src.shape[0]), + }, + partitioned_edge_index={ + USER_TO_ITEM: GraphPartitionData( + edge_index=torch.stack([u2i_src, u2i_dst]), + edge_ids=None, + weights=u2i_w, + ), + ITEM_TO_USER: GraphPartitionData( + edge_index=torch.stack([gi2u_src, gi2u_dst]), + edge_ids=None, + weights=gi2u_w, + ), + }, + partitioned_node_features={ + USER: FeaturePartitionData(feats=user_feats, ids=torch.arange(n_user)), + ITEM: FeaturePartitionData(feats=item_feats, ids=torch.arange(n_item)), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + return partition_output, n_user diff --git a/tests/unit/distributed/distributed_ppr_weighted_sampling_test.py b/tests/unit/distributed/distributed_ppr_weighted_sampling_test.py new file mode 100644 index 000000000..8f666ccb7 --- /dev/null +++ b/tests/unit/distributed/distributed_ppr_weighted_sampling_test.py @@ -0,0 +1,190 @@ +"""End-to-end correctness tests for PPR weighted sampling. + +Verifies that DistNeighborLoader with PPRSamplerOptions and with_weight=True +never traverses weight=0 edges. The test graph encodes node type in features +(hub=2.0, good=1.0, bad=0.0); any bad node in a sampled subgraph indicates +that a weight=0 edge contributed PPR residual — a test failure. + +With weight-proportional residual distribution, a weight=0 edge contributes +zero to totalFetchedWeight and receives zero residual per push step. Bad +nodes therefore accumulate a PPR score of exactly 0 and are excluded from +every top-k result. +""" + +import torch +import torch.multiprocessing as mp +from absl.testing import absltest +from graphlearn_torch.distributed import shutdown_rpc +from torch_geometric.data import Data, HeteroData + +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.sampler_options import PPRSamplerOptions +from tests.test_assets.distributed.bipartite_weight_graph import ( + ITEM, + USER, + build_heterogeneous_bipartite_weight_graph, + build_homogeneous_bipartite_weight_graph, +) +from tests.test_assets.distributed.utils import create_test_process_group +from tests.test_assets.test_case import TestCase + +# PPR parameters used across all tests. +_PPR_ALPHA = 0.5 +_PPR_EPS = 1e-4 +_PPR_MAX_NODES = 60 +_PPR_NUM_NBRS = 200 + + +# --------------------------------------------------------------------------- +# Subprocess functions +# --------------------------------------------------------------------------- + + +def _run_ppr_weighted_correctness_homogeneous( + _: int, + dataset: DistDataset, + n_hub: int, +) -> None: + """Subprocess: verifies weight=0 edges never contribute PPR residual (homogeneous). + + Seeds are hub nodes only. Node features encode type: hub=2.0, good=1.0, bad=0.0. + Any batch containing a bad node (feature==0.0) means a weight=0 edge contributed + PPR residual — a test failure. + """ + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + input_nodes=torch.arange(n_hub), + num_neighbors=[], + sampler_options=PPRSamplerOptions( + alpha=_PPR_ALPHA, + eps=_PPR_EPS, + max_ppr_nodes=_PPR_MAX_NODES, + num_neighbors_per_hop=_PPR_NUM_NBRS, + ), + with_weight=True, + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + count = 0 + for datum in loader: + assert isinstance(datum, Data), f"Expected Data, got {type(datum)}" + assert datum.x is not None, "Node features missing from PPR batch" + bad_mask = datum.x[:, 0] == 0.0 + assert not bad_mask.any(), ( + f"weight=0 edge contributed PPR residual: bad node(s) found. " + f"Features of bad nodes: {datum.x[bad_mask].squeeze().tolist()}" + ) + count += 1 + assert count == n_hub, f"Expected {n_hub} batches, got {count}" + shutdown_rpc() + + +def _run_ppr_weighted_correctness_heterogeneous( + _: int, + dataset: DistDataset, + n_user: int, +) -> None: + """Subprocess: verifies weight=0 edges never contribute PPR residual (heterogeneous). + + Seeds are user nodes. Item features encode type: good=1.0, bad=0.0. + Any batch containing a bad item node means a weight=0 edge contributed PPR residual. + """ + create_test_process_group() + node_ids = dataset.node_ids + assert not isinstance(node_ids, torch.Tensor) and node_ids is not None, ( + "Expected heterogeneous dataset with dict node_ids" + ) + loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(USER, node_ids[USER]), + num_neighbors=[], + sampler_options=PPRSamplerOptions( + alpha=_PPR_ALPHA, + eps=_PPR_EPS, + max_ppr_nodes=_PPR_MAX_NODES, + num_neighbors_per_hop=_PPR_NUM_NBRS, + ), + with_weight=True, + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + count = 0 + for datum in loader: + assert isinstance(datum, HeteroData), f"Expected HeteroData, got {type(datum)}" + if ITEM in datum.node_types: + item_x = datum[ITEM].x + assert item_x is not None, "Item features missing from PPR batch" + bad_mask = item_x[:, 0] == 0.0 + assert not bad_mask.any(), ( + f"weight=0 edge contributed PPR residual: bad item(s) found. " + f"Features of bad items: {item_x[bad_mask].squeeze().tolist()}" + ) + count += 1 + assert count == n_user, f"Expected {n_user} batches, got {count}" + shutdown_rpc() + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- + + +class PPRWeightedSamplingTest(TestCase): + """End-to-end correctness tests for PPR sampling with weight_proportional_residuals. + + Each test builds a bipartite graph with "good" neighbors (weight=1) and "bad" + neighbors (weight=0) reachable from seed nodes. With weight-proportional PPR, + bad nodes must never appear in any sampled subgraph because weight=0 edges + contribute zero residual per push step. + """ + + def tearDown(self) -> None: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + super().tearDown() + + def test_ppr_weighted_never_traverses_zero_weight_edges_homogeneous(self) -> None: + """Homogeneous: weight=0 edges to bad nodes never contribute PPR residual. + + Graph: 10 hub seeds, each connected to 50 good nodes (weight=1) and 40 bad + nodes (weight=0). Good nodes have 5 further weight=1 edges for deeper walks. + PPR max_ppr_nodes=60 is larger than the number of good neighbors, so the + sampler must actively filter: correct weighting excludes bad nodes entirely. + """ + partition_output, n_hub = build_homogeneous_bipartite_weight_graph() + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + self.assertTrue(dataset.has_edge_weights) + + mp.spawn( + fn=_run_ppr_weighted_correctness_homogeneous, + args=(dataset, n_hub), + nprocs=1, + ) + + def test_ppr_weighted_never_traverses_zero_weight_edges_heterogeneous(self) -> None: + """Heterogeneous: weight=0 user→item edges to bad items never contribute PPR residual. + + Graph: 10 user seeds, each connected to 40 good items (weight=1) and 20 bad + items (weight=0). Good items connect back to all users via weight=1 (2nd-hop). + PPR max_ppr_nodes=60 is larger than n_good, so correct weighting is required + to exclude bad items. + """ + partition_output, n_user = build_heterogeneous_bipartite_weight_graph() + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + self.assertTrue(dataset.has_edge_weights) + + mp.spawn( + fn=_run_ppr_weighted_correctness_heterogeneous, + args=(dataset, n_user), + nprocs=1, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/distributed/distributed_weighted_sampling_test.py b/tests/unit/distributed/distributed_weighted_sampling_test.py index 659c23f5e..eff598a80 100644 --- a/tests/unit/distributed/distributed_weighted_sampling_test.py +++ b/tests/unit/distributed/distributed_weighted_sampling_test.py @@ -22,12 +22,19 @@ from gigl.distributed.dist_range_partitioner import DistRangePartitioner from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.utils.networking import get_free_port -from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation from gigl.types.graph import ( FeaturePartitionData, GraphPartitionData, PartitionOutput, ) +from tests.test_assets.distributed.bipartite_weight_graph import ( + ITEM, + ITEM_TO_USER, + USER, + USER_TO_ITEM, + build_heterogeneous_bipartite_weight_graph, + build_homogeneous_bipartite_weight_graph, +) from tests.test_assets.distributed.constants import ( MOCKED_NUM_PARTITIONS, MOCKED_U2U_EDGE_INDEX_ON_RANK_ONE, @@ -43,189 +50,13 @@ from tests.test_assets.distributed.utils import create_test_process_group from tests.test_assets.test_case import TestCase -_USER = NodeType("user") -_ITEM = NodeType("item") -_USER_TO_ITEM = EdgeType(_USER, Relation("to"), _ITEM) -_ITEM_TO_USER = EdgeType(_ITEM, Relation("to"), _USER) - - # --------------------------------------------------------------------------- # Graph builders # --------------------------------------------------------------------------- -def _build_homogeneous_bipartite_weight_graph() -> tuple[ - PartitionOutput, int, int, int -]: - """Build a homogeneous graph with hub, good, and bad nodes. - - Graph structure: - - 10 hub nodes (0..9): used as seed nodes; feature value = 2.0 - - 50 good nodes (10..59): reachable from hubs via weight=1 edges; feature = 1.0 - - 40 bad nodes (60..99): reachable from hubs via weight=0 edges; feature = 0.0 - - Each good node also has 5 outgoing weight=1 edges to nearby good nodes - (ring topology, for 2nd-hop sampling). - - With weighted sampling only good nodes should ever appear as sampled - neighbors — weight=0 edges to bad nodes must never be traversed. - - Returns: - (partition_output, n_hub, n_good, n_bad) - """ - n_hub = 10 - n_good = 50 - n_bad = 40 - n = n_hub + n_good + n_bad # 100 - - hub_ids = torch.arange(n_hub) - good_ids = torch.arange(n_hub, n_hub + n_good) - bad_ids = torch.arange(n_hub + n_good, n) - - # Hub → Good: weight=1 - hub_good_src = hub_ids.repeat_interleave(n_good) - hub_good_dst = good_ids.repeat(n_hub) - hub_good_w = torch.ones(n_hub * n_good) - - # Hub → Bad: weight=0 - hub_bad_src = hub_ids.repeat_interleave(n_bad) - hub_bad_dst = bad_ids.repeat(n_hub) - hub_bad_w = torch.zeros(n_hub * n_bad) - - # Good → Good: ring with 5 outgoing edges per node, weight=1 (2nd-hop targets) - connections_per_good = 5 - good_src = good_ids.repeat_interleave(connections_per_good) - # Row i of [connections_per_good, n_good].T gives neighbors of good_ids[i] - good_dst = torch.stack( - [torch.roll(good_ids, -j) for j in range(1, connections_per_good + 1)] - ).T.reshape(-1) - good_w = torch.ones(n_good * connections_per_good) - - edge_src = torch.cat([hub_good_src, hub_bad_src, good_src]) - edge_dst = torch.cat([hub_good_dst, hub_bad_dst, good_dst]) - weights = torch.cat([hub_good_w, hub_bad_w, good_w]) - edge_index = torch.stack([edge_src, edge_dst]) - n_edges = edge_src.shape[0] - - # Feature encodes node type: hub=2.0, good=1.0, bad=0.0 - node_feats = torch.cat( - [ - torch.full((n_hub, 1), 2.0), - torch.full((n_good, 1), 1.0), - torch.full((n_bad, 1), 0.0), - ] - ) - - partition_output = PartitionOutput( - node_partition_book=torch.zeros(n), - edge_partition_book=torch.zeros(n_edges), - partitioned_edge_index=GraphPartitionData( - edge_index=edge_index, - edge_ids=None, - weights=weights, - ), - partitioned_node_features=FeaturePartitionData( - feats=node_feats, - ids=torch.arange(n), - ), - partitioned_edge_features=None, - partitioned_positive_labels=None, - partitioned_negative_labels=None, - partitioned_node_labels=None, - ) - return partition_output, n_hub, n_good, n_bad - - -def _build_heterogeneous_bipartite_weight_graph() -> tuple[ - PartitionOutput, int, int, int -]: - """Build a heterogeneous (user/item) graph with good and bad item nodes. - - Graph structure: - - 10 user nodes (0..9): seed nodes; user feature = 2.0 - - 60 item nodes total: - - Items 0..39: good, reachable from users via weight=1 edges; feature = 1.0 - - Items 40..59: bad, reachable from users via weight=0 edges; feature = 0.0 - - Good items also have weight=1 edges back to all users (for 2nd-hop). - - With weighted sampling only good item nodes should ever appear as sampled - item neighbors. - - Returns: - (partition_output, n_user, n_good_item, n_bad_item) - """ - n_user = 10 - n_good_item = 40 - n_bad_item = 20 - n_item = n_good_item + n_bad_item # 60 - - user_ids = torch.arange(n_user) - good_item_ids = torch.arange(n_good_item) - bad_item_ids = torch.arange(n_good_item, n_item) - - # User → Good Item: weight=1 - u2gi_src = user_ids.repeat_interleave(n_good_item) - u2gi_dst = good_item_ids.repeat(n_user) - u2gi_w = torch.ones(n_user * n_good_item) - - # User → Bad Item: weight=0 - u2bi_src = user_ids.repeat_interleave(n_bad_item) - u2bi_dst = bad_item_ids.repeat(n_user) - u2bi_w = torch.zeros(n_user * n_bad_item) - - # Good Item → User: weight=1 (2nd-hop back to users) - gi2u_src = good_item_ids.repeat_interleave(n_user) - gi2u_dst = user_ids.repeat(n_good_item) - gi2u_w = torch.ones(n_good_item * n_user) - - u2i_src = torch.cat([u2gi_src, u2bi_src]) - u2i_dst = torch.cat([u2gi_dst, u2bi_dst]) - u2i_w = torch.cat([u2gi_w, u2bi_w]) - n_u2i_edges = u2i_src.shape[0] - - user_feats = torch.full((n_user, 1), 2.0) - # Item feature encodes type: good=1.0, bad=0.0 - item_feats = torch.cat( - [ - torch.full((n_good_item, 1), 1.0), - torch.full((n_bad_item, 1), 0.0), - ] - ) - - partition_output = PartitionOutput( - node_partition_book={ - _USER: torch.zeros(n_user), - _ITEM: torch.zeros(n_item), - }, - edge_partition_book={ - _USER_TO_ITEM: torch.zeros(n_u2i_edges), - _ITEM_TO_USER: torch.zeros(gi2u_src.shape[0]), - }, - partitioned_edge_index={ - _USER_TO_ITEM: GraphPartitionData( - edge_index=torch.stack([u2i_src, u2i_dst]), - edge_ids=None, - weights=u2i_w, - ), - _ITEM_TO_USER: GraphPartitionData( - edge_index=torch.stack([gi2u_src, gi2u_dst]), - edge_ids=None, - weights=gi2u_w, - ), - }, - partitioned_node_features={ - _USER: FeaturePartitionData(feats=user_feats, ids=torch.arange(n_user)), - _ITEM: FeaturePartitionData(feats=item_feats, ids=torch.arange(n_item)), - }, - partitioned_edge_features=None, - partitioned_positive_labels=None, - partitioned_negative_labels=None, - partitioned_node_labels=None, - ) - return partition_output, n_user, n_good_item, n_bad_item - - def _build_heterogeneous_bipartite_partial_weight_graph() -> tuple[ - PartitionOutput, int, int, int + PartitionOutput, int ]: """Same graph as _build_heterogeneous_bipartite_weight_graph but ITEM_TO_USER is unweighted. @@ -268,35 +99,35 @@ def _build_heterogeneous_bipartite_partial_weight_graph() -> tuple[ partition_output = PartitionOutput( node_partition_book={ - _USER: torch.zeros(n_user), - _ITEM: torch.zeros(n_item), + USER: torch.zeros(n_user), + ITEM: torch.zeros(n_item), }, edge_partition_book={ - _USER_TO_ITEM: torch.zeros(n_u2i_edges), - _ITEM_TO_USER: torch.zeros(gi2u_src.shape[0]), + USER_TO_ITEM: torch.zeros(n_u2i_edges), + ITEM_TO_USER: torch.zeros(gi2u_src.shape[0]), }, partitioned_edge_index={ - _USER_TO_ITEM: GraphPartitionData( + USER_TO_ITEM: GraphPartitionData( edge_index=torch.stack([u2i_src, u2i_dst]), edge_ids=None, weights=u2i_w, ), - _ITEM_TO_USER: GraphPartitionData( + ITEM_TO_USER: GraphPartitionData( edge_index=torch.stack([gi2u_src, gi2u_dst]), edge_ids=None, weights=None, # unweighted — samples uniformly ), }, partitioned_node_features={ - _USER: FeaturePartitionData(feats=user_feats, ids=torch.arange(n_user)), - _ITEM: FeaturePartitionData(feats=item_feats, ids=torch.arange(n_item)), + USER: FeaturePartitionData(feats=user_feats, ids=torch.arange(n_user)), + ITEM: FeaturePartitionData(feats=item_feats, ids=torch.arange(n_item)), }, partitioned_edge_features=None, partitioned_positive_labels=None, partitioned_negative_labels=None, partitioned_node_labels=None, ) - return partition_output, n_user, n_good_item, n_bad_item + return partition_output, n_user # --------------------------------------------------------------------------- @@ -355,7 +186,7 @@ def _run_weighted_sampling_correctness_heterogeneous( ) loader = DistNeighborLoader( dataset=dataset, - input_nodes=(_USER, node_ids[_USER]), + input_nodes=(USER, node_ids[USER]), num_neighbors=[10, 5], with_weight=True, pin_memory_device=torch.device("cpu"), @@ -363,8 +194,8 @@ def _run_weighted_sampling_correctness_heterogeneous( count = 0 for datum in loader: assert isinstance(datum, HeteroData), f"Expected HeteroData, got {type(datum)}" - if _ITEM in datum.node_types: - item_x = datum[_ITEM].x + if ITEM in datum.node_types: + item_x = datum[ITEM].x assert item_x is not None, "Item features missing from sampled subgraph" bad_mask = item_x[:, 0] == 0.0 assert not bad_mask.any(), ( @@ -802,7 +633,7 @@ def test_weighted_sampling_never_traverses_zero_weight_edges_homogeneous( sampling. Fanout [10, 5] samples fewer neighbors than available good ones, so the weighted sampler actively selects from the pool each hop. """ - partition_output, n_hub, _, _ = _build_homogeneous_bipartite_weight_graph() + partition_output, n_hub = build_homogeneous_bipartite_weight_graph() assert isinstance(partition_output.partitioned_edge_index, GraphPartitionData) expected_weights = partition_output.partitioned_edge_index.weights @@ -829,7 +660,7 @@ def test_weighted_sampling_never_traverses_zero_weight_edges_heterogeneous( Fanout [10, 5] is smaller than the 40 available good items, so the sampler actively selects. """ - partition_output, n_user, _, _ = _build_heterogeneous_bipartite_weight_graph() + partition_output, n_user = build_heterogeneous_bipartite_weight_graph() dataset = DistDataset(rank=0, world_size=1, edge_dir="out") dataset.build(partition_output=partition_output) @@ -846,9 +677,7 @@ def test_weighted_sampling_partial_weights_heterogeneous(self) -> None: Verifies that mixing weighted and unweighted edge types in one heterogeneous graph does not crash and that weighted edges still behave correctly. """ - partition_output, n_user, _, _ = ( - _build_heterogeneous_bipartite_partial_weight_graph() - ) + partition_output, n_user = _build_heterogeneous_bipartite_partial_weight_graph() dataset = DistDataset(rank=0, world_size=1, edge_dir="out") dataset.build(partition_output=partition_output)