Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/link_prediction/heterogeneous_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def _inference_process(
node_type_to_input_node_ids: Optional[
Union[torch.Tensor, dict[NodeType, torch.Tensor]]
] = args.dataset.node_ids
assert isinstance(node_type_to_input_node_ids, dict), (
assert node_type_to_input_node_ids is not None and not isinstance(
node_type_to_input_node_ids, torch.Tensor
), (
f"Node IDs must be a dictionary for heterogeneous inference, got {type(node_type_to_input_node_ids)}"
)
input_node_ids: torch.Tensor = node_type_to_input_node_ids[args.inference_node_type]
Expand Down
145 changes: 140 additions & 5 deletions gigl/common/data/load_torch_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,63 @@
_ID_FMT = "{entity}_ids"
_FEATURE_FMT = "{entity}_features"
_LABEL_FMT = "{entity}_labels"
_EDGE_WEIGHTS_KEY = "edge_weights"
_NODE_KEY = "node"


def _extract_weight_col(
feat_tensor: torch.Tensor,
feature_keys: list[str],
feature_spec: dict,
col_name: str,
edge_type: EdgeType,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Slice a named weight column out of a feature tensor.

Accounts for multi-dim features: each feature key may contribute more than one column
to ``feat_tensor`` (e.g. ``FixedLenFeature(shape=[16])`` contributes 16 columns).
The weight feature must be a scalar (width 1).

Args:
feat_tensor: Edge feature tensor of shape ``[num_edges, total_feature_cols]``.
feature_keys: Ordered list of feature names matching the columns of ``feat_tensor``.
feature_spec: Feature spec dict mapping feature name to its TF feature spec (used to
determine per-key column widths).
col_name: Name of the column to extract as weights.
edge_type: Edge type (used only in error messages).

Returns:
A tuple ``(weights, trimmed_features)`` where ``weights`` is a 1-D tensor of shape
``[num_edges]`` and ``trimmed_features`` is ``feat_tensor`` with the weight column
removed.

Raises:
ValueError: If ``col_name`` is not in ``feature_keys`` or the weight feature is not
width 1.
"""
if col_name not in feature_keys:
raise ValueError(
f"weight_edge_feat_name '{col_name}' not found in edge feature keys "
f"for edge type {edge_type}: {feature_keys}"
)
key_idx = feature_keys.index(col_name)
col_widths = []
for key in feature_keys:
spec = feature_spec[key]
col_widths.append(spec.shape[-1] if spec.shape else 1)
weight_width = col_widths[key_idx]
if weight_width != 1:
raise ValueError(
f"weight_edge_feat_name '{col_name}' for edge type {edge_type} must be a scalar "
f"feature (width 1), but has width {weight_width}."
)
col_offset = sum(col_widths[:key_idx])
weights = feat_tensor[:, col_offset]
keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_offset]
trimmed = feat_tensor[:, keep_cols] if keep_cols else None
return weights, trimmed


_EDGE_KEY = "edge"
_POSITIVE_LABEL_KEY = "positive_label"
_NEGATIVE_LABEL_KEY = "negative_label"
Expand Down Expand Up @@ -72,6 +128,7 @@ def _data_loading_process(
],
rank: int,
tf_dataset_options: TFDatasetOptions = TFDatasetOptions(),
weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None,
Comment thread
mkolodner-sc marked this conversation as resolved.
) -> None:
"""
Spawned multiprocessing.Process which loads homogeneous or heterogeneous information for a specific entity type [node, edge, positive_label, negative_label]
Expand All @@ -89,6 +146,11 @@ def _data_loading_process(
Serialized information for current entity
rank (int): Rank of the current machine
tf_dataset_options (TFDatasetOptions): The options to use when building the dataset.
weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Only used when
``entity_type == _EDGE_KEY``. Name of the edge feature column to extract as
sampling weights. Ignored for node, positive_label, and negative_label entities.
Supply a single string for homogeneous graphs or a per-edge-type dict for
heterogeneous graphs.
"""
# We add a try - except clause here to ensure that exceptions are properly circulated back to the parent process
try:
Expand Down Expand Up @@ -117,6 +179,7 @@ def _data_loading_process(
ids: dict[Union[NodeType, EdgeType], torch.Tensor] = {}
features: dict[Union[NodeType, EdgeType], torch.Tensor] = {}
labels: dict[Union[NodeType, EdgeType], torch.Tensor] = {}
weights: dict[Union[NodeType, EdgeType], torch.Tensor] = {}
for (
graph_type,
serialized_entity_tf_record_info,
Expand All @@ -129,14 +192,13 @@ def _data_loading_process(
raise NotImplementedError(
"Label keys are not supported for edge entities"
)
(
entity_ids,
entity_features,
entity_labels,
) = tf_record_dataloader.load_as_torch_tensors(
loaded_entity = tf_record_dataloader.load_as_torch_tensors(
serialized_tf_record_info=serialized_entity_tf_record_info,
tf_dataset_options=tf_dataset_options,
)
entity_ids = loaded_entity.ids
entity_features = loaded_entity.features
entity_labels = loaded_entity.labels
ids[graph_type] = entity_ids
logger.info(
f"Rank {rank} finished loading {entity_type} ids of shape {entity_ids.shape} for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}"
Expand All @@ -161,6 +223,61 @@ def _data_loading_process(
f"Rank {rank} did not detect {entity_type} labels for graph type {graph_type} from {serialized_entity_tf_record_info.tfrecord_uri_prefix.uri}"
)

# Extract weight column from edge features when weight_edge_feat_name is set.
# The weight column is sliced out of each edge type's feature tensor and stored
# separately so it is not duplicated in the feature matrix.
if weight_edge_feat_name is not None and entity_type == _EDGE_KEY:
if isinstance(weight_edge_feat_name, str):
if len(serialized_tf_record_info) != 1 or len(features) != 1:
raise ValueError(
f"weight_edge_feat_name must be a dict[EdgeType, str] for heterogeneous "
f"graphs with multiple edge types ({sorted(serialized_tf_record_info)}). "
"Provide an explicit per-edge-type mapping instead of a single string."
)
col_name = weight_edge_feat_name
edge_type, feat_tensor = next(iter(features.items()))
assert isinstance(edge_type, EdgeType)
feature_keys = list(serialized_tf_record_info[edge_type].feature_keys)
weights[edge_type], trimmed = _extract_weight_col(
feat_tensor,
feature_keys,
serialized_tf_record_info[edge_type].feature_spec,
col_name,
edge_type,
)
if trimmed is not None:
features[edge_type] = trimmed
else:
del features[edge_type]
logger.info(
f"Rank {rank} extracted weight column '{col_name}' "
f"from {entity_type} features for type {edge_type}"
)
else:
# Iterate the EdgeType-keyed dict directly to stay within EdgeType.
for edge_type, col_name in weight_edge_feat_name.items():
if edge_type not in features:
continue
feat_tensor = features[edge_type]
feature_keys = list(
serialized_tf_record_info[edge_type].feature_keys
)
weights[edge_type], trimmed = _extract_weight_col(
feat_tensor,
feature_keys,
serialized_tf_record_info[edge_type].feature_spec,
col_name,
edge_type,
)
if trimmed is not None:
features[edge_type] = trimmed
else:
del features[edge_type]
logger.info(
f"Rank {rank} extracted weight column '{col_name}' "
f"from {entity_type} features for type {edge_type}"
)

logger.info(
f"Rank {rank} is attempting to share {entity_type} id memory for tfrecord directories: {all_tf_record_uris}"
)
Expand All @@ -180,6 +297,12 @@ def _data_loading_process(
)
share_memory(labels)

if weights:
logger.info(
f"Rank {rank} is attempting to share {entity_type} weight memory for tfrecord directories: {all_tf_record_uris}"
)
share_memory(weights)

output_dict[_ID_FMT.format(entity=entity_type)] = (
list(ids.values())[0] if is_input_homogeneous else ids
)
Expand All @@ -191,6 +314,10 @@ def _data_loading_process(
output_dict[_LABEL_FMT.format(entity=entity_type)] = (
list(labels.values())[0] if is_input_homogeneous else labels
)
if weights:
output_dict[_EDGE_WEIGHTS_KEY] = (
list(weights.values())[0] if is_input_homogeneous else weights
)

logger.info(
f"Rank {rank} has finished loading {entity_type} data from tfrecord directories: {all_tf_record_uris}, elapsed time: {time.time() - start_time:.2f} seconds"
Expand All @@ -207,6 +334,7 @@ def load_torch_tensors_from_tf_record(
rank: int = 0,
node_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(),
edge_tf_dataset_options: TFDatasetOptions = TFDatasetOptions(),
weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None,
) -> LoadedGraphTensors:
"""
Loads all torch tensors from a SerializedGraphMetadata object for all entity [node, edge, positive_label, negative_label] and edge / node types.
Expand All @@ -222,6 +350,10 @@ def load_torch_tensors_from_tf_record(
rank (int): Rank on current machine
node_tf_dataset_options (TFDatasetOptions): The options to use for nodes when building the dataset.
edge_tf_dataset_options (TFDatasetOptions): The options to use for edges when building the dataset.
weight_edge_feat_name (Optional[Union[str, dict[EdgeType, str]]]): Name of the edge feature column to extract
as sampling weights. The column is removed from the edge feature matrix and returned separately via
``LoadedGraphTensors.edge_weights``. Supply a single string for homogeneous graphs or a per-edge-type
dict for heterogeneous graphs.
Returns:
loaded_graph_tensors (LoadedGraphTensors): Unpartitioned Graph Tensors
"""
Expand Down Expand Up @@ -269,6 +401,7 @@ def load_torch_tensors_from_tf_record(
"serialized_tf_record_info": serialized_graph_metadata.edge_entity_info,
"rank": rank,
"tf_dataset_options": edge_tf_dataset_options,
"weight_edge_feat_name": weight_edge_feat_name,
},
)

Expand Down Expand Up @@ -351,6 +484,7 @@ def load_torch_tensors_from_tf_record(

edge_index = edge_output_dict[_ID_FMT.format(entity=_EDGE_KEY)]
edge_features = edge_output_dict.get(_FEATURE_FMT.format(entity=_EDGE_KEY), None)
edge_weights = edge_output_dict.get(_EDGE_WEIGHTS_KEY, None)

positive_labels = edge_output_dict.get(
_ID_FMT.format(entity=_POSITIVE_LABEL_KEY), None
Expand Down Expand Up @@ -378,4 +512,5 @@ def load_torch_tensors_from_tf_record(
edge_features=edge_features,
positive_label=positive_labels,
negative_label=negative_labels,
edge_weights=edge_weights,
)
2 changes: 1 addition & 1 deletion gigl/common/metrics/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def wrap(*args: Any, **kwargs: Any) -> Any:
result = func(*args, **kwargs)
except Exception as e:
logger.info(
f"Exception raised, will flush metrics for: {func.__name__} and re-raise exception"
f"Exception raised, will flush metrics for: {getattr(func, '__name__')} and re-raise exception"
)
logger.error(f"Exception: {e}")
logger.error(traceback.format_exc())
Expand Down
42 changes: 41 additions & 1 deletion gigl/distributed/base_dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,50 @@ def __init__(
"for graph-store mode."
)

@staticmethod
def validate_for_weighted_sampling(
with_weight: bool,
dataset: Union[DistDataset, RemoteDistDataset],
sampler_options: SamplerOptions,
) -> None:
"""Validates the ``with_weight`` parameter against the dataset and sampler.

Args:
with_weight: Whether weighted sampling was requested.
dataset: The dataset being sampled from.
sampler_options: The sampler to be used.

Raises:
ValueError: If ``with_weight=True`` but no edge weights are registered.
NotImplementedError: If ``with_weight=True`` and a PPR sampler is requested.
"""
if not with_weight:
return
has_edge_weights = (
dataset.has_edge_weights
if isinstance(dataset, DistDataset)
else dataset.fetch_edge_weights_registered()
)
if not has_edge_weights:
raise ValueError(
"with_weight=True requires edge weights to be registered in the dataset. "
"Pass weight_edge_feat_name to build_dataset() to register edge weights."
)
# TODO(mkolodner-sc): Implement weight-proportional residual propagation for PPR.
if with_weight and isinstance(sampler_options, PPRSamplerOptions):
raise NotImplementedError(
"Weighted sampling is not yet supported with PPRSamplerOptions. "
"Weight-proportional residual propagation for PPR is planned but not implemented."
)

@staticmethod
def create_sampling_config(
num_neighbors: Union[list[int], dict[EdgeType, list[int]]],
dataset_schema: DatasetSchema,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
with_weight: bool = False,
) -> SamplingConfig:
"""Creates a SamplingConfig with patched fanout.

Expand All @@ -352,6 +389,9 @@ def create_sampling_config(
batch_size: How many samples per batch.
shuffle: Whether to shuffle input nodes.
drop_last: Whether to drop the last incomplete batch.
with_weight: Whether to use edge weights for sampling. Requires that
edge weights were registered during dataset construction via
``DistPartitioner.register_edge_weights()``.

Returns:
A fully configured SamplingConfig.
Expand All @@ -369,7 +409,7 @@ def create_sampling_config(
with_edge=True,
collect_features=True,
with_neg=False,
with_weight=False,
with_weight=with_weight,
edge_dir=dataset_schema.edge_dir,
seed=None,
)
Expand Down
Loading