Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 201 additions & 76 deletions gigl-core/core/sampling/ppr_forward_push.cpp

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions gigl-core/core/sampling/ppr_forward_push.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ class PPRForwardPush {
std::optional<std::unordered_map<int32_t, torch::Tensor>> drainQueue();

// Push residuals given fetched neighbor data.
// fetchedByEtypeId: {etype_id: (node_ids[N], flat_nbrs[sum(counts)], counts[N])}
// fetchedByEtypeId: {etype_id: (node_ids[N], flat_nbrs[sum(counts)], counts[N], flat_weights[sum(counts)])}
// flat_weights is empty (numel()==0) for uniform-residual mode; non-empty for
// weight-proportional mode. _hasWeights is latched true on the first call with a
// non-empty flat_weights and never reset within one PPRForwardPush lifetime.
void pushResiduals(const std::unordered_map<
int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>&
int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>&
fetchedByEtypeId);

// Return top-k PPR nodes per seed per node type.
Expand Down Expand Up @@ -103,6 +106,14 @@ class PPRForwardPush {
// impractical (contrast with _state above). Populated incrementally; avoids re-fetching.
std::unordered_map<uint64_t, std::vector<int32_t>> _neighborCache;

// True once any pushResiduals call receives a non-empty flat_weights tensor.
// Latched true for the object lifetime; never reset.
bool _hasWeights{false};

// Per-edge weights parallel to _neighborCache: _weightCache[packKey(node, etype)][i]
// is the weight of the i-th cached neighbor. Only populated in weighted mode.
std::unordered_map<uint64_t, std::vector<double>> _weightCache;

};

} // namespace gigl
10 changes: 8 additions & 2 deletions gigl-core/core/sampling/python_ppr_forward_push.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,21 @@ namespace gigl {
// pushResiduals: a wrapper is needed solely to release the GIL during the C++ push.
// pybind11/stl.h handles all type conversions automatically; the other methods use
// direct member function pointers for the same reason.
//
// Each tuple value is (node_ids, flat_nbrs, counts, flat_weights). flat_weights is
// an empty tensor in uniform-residual mode and a non-empty float64 tensor in
// weight-proportional mode.
static void pushResidualsWrapper(PPRForwardPush& state, const py::dict& fetchedByEtypeId) {
std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> neighborTensorsByEtypeId;
std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
neighborTensorsByEtypeId;
// Dict iteration touches Python objects — GIL must be held here.
for (auto item : fetchedByEtypeId) {
auto edgeTypeId = item.first.cast<int32_t>();
auto neighborTensors = item.second.cast<py::tuple>();
neighborTensorsByEtypeId[edgeTypeId] = {neighborTensors[0].cast<torch::Tensor>(),
neighborTensors[1].cast<torch::Tensor>(),
neighborTensors[2].cast<torch::Tensor>()};
neighborTensors[2].cast<torch::Tensor>(),
neighborTensors[3].cast<torch::Tensor>()};
}
// C++ push only uses tensor accessor/data_ptr APIs — GIL-safe to release.
// Releasing here lets the asyncio event loop process RPC completion callbacks
Expand Down
4 changes: 3 additions & 1 deletion gigl-core/src/gigl_core/ppr_forward_push.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class PPRForwardPush:
def drain_queue(self) -> dict[int, torch.Tensor] | None: ...
def push_residuals(
self,
fetched_by_etype_id: dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
fetched_by_etype_id: dict[
int, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
],
) -> None: ...
def extract_top_k(
self, max_ppr_nodes: int
Expand Down
7 changes: 0 additions & 7 deletions gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ def validate_for_weighted_sampling(

Raises:
ValueError: If ``with_weight=True`` but no edge weights are registered.
NotImplementedError: If ``with_weight=True`` and a PPR sampler is requested.
"""
if not with_weight:
return
Expand All @@ -362,12 +361,6 @@ def validate_for_weighted_sampling(
"with_weight=True requires edge weights to be registered in the dataset. "
"Pass weight_edge_feat_name to build_dataset() to register edge weights."
)
# TODO(mkolodner-sc): Implement weight-proportional residual propagation for PPR.
if with_weight and isinstance(sampler_options, PPRSamplerOptions):
raise NotImplementedError(
"Weighted sampling is not yet supported with PPRSamplerOptions. "
"Weight-proportional residual propagation for PPR is planned but not implemented."
)

@staticmethod
def create_sampling_config(
Expand Down
60 changes: 50 additions & 10 deletions gigl/distributed/dist_ppr_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ class DistPPRNeighborSampler(BaseDistNeighborSampler):
scores are approximated here using the Forward Push algorithm (Andersen et
al., 2006).

**Weighted PPR**: when ``edge_weights`` is provided, neighbor fetching during
traversal always uses uniform sampling (all neighbors are fetched without
weight-biased selection). The weighting is applied exclusively in how the PPR
residual is spread: each neighbor receives residual proportional to its edge
weight rather than an equal share. This is the correct formulation — using
weighted sampling during traversal would double-count high-weight edges (once
by over-representing them in the fetch and again by giving them more residual).

This sampler supports both homogeneous and heterogeneous graphs. For heterogeneous graphs,
the PPR algorithm traverses across all edge types, switching edge types based on the
current node type and the configured edge direction.
Expand Down Expand Up @@ -90,6 +98,9 @@ def __init__(
total_degree_dtype: torch.dtype = torch.int32,
degree_tensors: Union[torch.Tensor, dict[EdgeType, torch.Tensor]],
max_fetch_iterations: Optional[int] = None,
edge_weights: Optional[
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
Expand All @@ -98,6 +109,7 @@ def __init__(
self._requeue_threshold_factor = alpha * eps
self._num_neighbors_per_hop = num_neighbors_per_hop
self._max_fetch_iterations = max_fetch_iterations
self._edge_weights = edge_weights

# Build mapping from node type to edge types that can be traversed from that node type.
self._node_type_to_edge_types: dict[NodeType, list[EdgeType]] = defaultdict(
Expand Down Expand Up @@ -251,7 +263,7 @@ async def _batch_fetch_neighbors(
self,
nodes_by_etype_id: dict[int, torch.Tensor],
device: torch.device,
) -> dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
) -> dict[int, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Batch fetch neighbors for nodes grouped by integer edge type ID.

Issues one ``_sample_one_hop`` call per edge type (not per node), so all
Expand All @@ -265,11 +277,13 @@ async def _batch_fetch_neighbors(
device: Torch device for intermediate tensor creation.

Returns:
Dict mapping etype_id to ``(node_ids, flat_neighbors, counts)`` as
int64 tensors, ready to pass directly to ``push_residuals``.
Dict mapping etype_id to ``(node_ids, flat_neighbors, counts, flat_weights)``
as tensors, ready to pass directly to ``push_residuals``.
``flat_neighbors`` is the flat concatenation of all neighbor lists
for that edge type; ``counts[i]`` is the neighbor count for
``node_ids[i]``.
``node_ids[i]``. ``flat_weights`` is a float64 tensor of the same
shape as ``flat_neighbors`` in weighted mode, or an empty tensor in
uniform mode.

Example::

Expand All @@ -279,8 +293,8 @@ async def _batch_fetch_neighbors(
}
# Might return (neighbor lists depend on graph structure):
{
2: (tensor([0, 3]), tensor([5, 9, 2, 1]), tensor([3, 1])),
5: (tensor([7]), tensor([0, 3]), tensor([2])),
2: (tensor([0, 3]), tensor([5, 9, 2, 1]), tensor([3, 1]), tensor([])),
5: (tensor([7]), tensor([0, 3]), tensor([2]), tensor([])),
}
"""
# Fire all per-edge-type RPC calls concurrently. Each _sample_one_hop
Expand All @@ -299,10 +313,36 @@ async def _batch_fetch_neighbors(
)
)
outputs: list[NeighborOutput] = await asyncio.gather(*sample_tasks)
return {
eid: (nodes_by_etype_id[eid], output.nbr, output.nbr_num)
for eid, output in zip(eids, outputs)
}

_empty_weights = torch.empty(0, dtype=torch.float64)

result: dict[
int, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
] = {}
for eid, output in zip(eids, outputs):
if self._edge_weights is not None:
assert output.edge is not None, (
"output.edge must be set when edge_weights is provided; "
"ensure with_edge=True in SamplingConfig (hardcoded in create_sampling_config)."
)
if self._is_homogeneous:
assert isinstance(self._edge_weights, torch.Tensor)
flat_weights = self._edge_weights[output.edge].to(torch.float64)
else:
assert isinstance(self._edge_weights, dict)
etype = self._etype_id_to_etype[eid]
flat_weights = self._edge_weights[etype][output.edge].to(
torch.float64
)
else:
flat_weights = _empty_weights
result[eid] = (
nodes_by_etype_id[eid],
output.nbr,
output.nbr_num,
flat_weights,
)
return result

async def _compute_ppr_scores(
self,
Expand Down
1 change: 1 addition & 0 deletions gigl/distributed/dist_sampling_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _sampling_worker_loop(
sampler_options=sampler_options,
degree_tensors=degree_tensors,
current_device=current_device,
edge_weights=data.edge_weights,
)
dist_sampler.start_loop()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def _handle_command(command: SharedMpCommand, payload: CommandPayload) -> bool:
sampler_options=sampler_options,
degree_tensors=degree_tensors,
current_device=current_device,
edge_weights=data.edge_weights,
)
sampler.start_loop()
with state_lock:
Expand Down
12 changes: 11 additions & 1 deletion gigl/distributed/utils/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def create_dist_sampler(
sampler_options: SamplerOptions,
degree_tensors: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]],
current_device: torch.device,
edge_weights: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]] = None,
) -> SamplerRuntime:
"""Create a GiGL sampler runtime for one channel on one worker.

Expand All @@ -49,6 +50,10 @@ def create_dist_sampler(
degree_tensors: Pre-computed degree tensors required by PPR sampling.
Must not be ``None`` when ``sampler_options`` is :class:`PPRSamplerOptions`.
current_device: The device on which sampling will run.
edge_weights: Per-edge weight tensors for this rank's partition. Required when
``sampler_options`` is :class:`PPRSamplerOptions` and
``sampling_config.with_weight`` is ``True``. Ignored for k-hop sampling
(GLT handles weight-biased sampling internally via ``with_weight``).

Returns:
A configured sampler runtime, either :class:`DistNeighborSampler` or
Expand Down Expand Up @@ -77,15 +82,20 @@ def create_dist_sampler(
)
elif isinstance(sampler_options, PPRSamplerOptions):
assert degree_tensors is not None
# PPR traversal must always use uniform neighbor sampling: biased selection
# would double-count high-weight edges (once in the fetch, once in residual
# distribution). Weight influence is captured entirely in the residual math.
# Pass edge_weights only when with_weight=True; None disables weighted residuals.
sampler = DistPPRNeighborSampler(
**shared_sampler_kwargs,
**{**shared_sampler_kwargs, "with_weight": False},
alpha=sampler_options.alpha,
eps=sampler_options.eps,
max_ppr_nodes=sampler_options.max_ppr_nodes,
max_fetch_iterations=sampler_options.max_fetch_iterations,
num_neighbors_per_hop=sampler_options.num_neighbors_per_hop,
total_degree_dtype=sampler_options.total_degree_dtype,
degree_tensors=degree_tensors,
edge_weights=edge_weights if sampling_config.with_weight else None,
)
else:
raise NotImplementedError(
Expand Down
Loading