Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b9be302
update
kmonte Mar 13, 2026
b3022c8
Merge branch 'main' into kmonte/update-node-shard-strategy
kmonte Mar 23, 2026
b437700
test
kmonte Mar 23, 2026
9dc44a4
update
kmonte Mar 23, 2026
2cd219d
Merge branch 'main' into kmonte/update-node-shard-strategy
kmonte Mar 23, 2026
1de868a
maybe cleanup
kmonte Mar 23, 2026
d44847e
Add detailed docstring with examples to ShardStrategy enum
kmonte Mar 24, 2026
4023152
Extract _validate_contiguous_args to free function, remove world_size…
kmonte Mar 24, 2026
dfd8f94
Remove unnecessary .clone() from ServerSlice.slice_tensor
kmonte Mar 24, 2026
d336096
Upgrade integration test to compare actual node IDs, not just counts
kmonte Mar 24, 2026
836a6e7
Simplify TestRemoteDistDatasetContiguous test class
kmonte Mar 24, 2026
60e6786
Add missing type annotations to test helper functions
kmonte Mar 24, 2026
c205582
update
kmonte Mar 24, 2026
be6ce6c
update
kmonte Mar 24, 2026
597e22b
update
kmonte Mar 24, 2026
abcee50
Merge branch 'main' into kmonte/update-node-shard-strategy
Mar 25, 2026
8371bc3
format
Mar 25, 2026
6f11385
Merge branch 'main' into kmonte/update-node-shard-strategy
Mar 26, 2026
5d9959a
swap default
Mar 26, 2026
efb486e
Merge branch 'main' into kmonte/update-node-shard-strategy
kmonte Apr 2, 2026
4d930e6
Merge branch 'main' into kmonte/update-node-shard-strategy
kmonte Apr 7, 2026
9303339
address comments
kmonte Apr 7, 2026
cb79b32
Merge branch 'kmonte/update-node-shard-strategy' of ssh://github.com/…
kmonte Apr 7, 2026
d53a4d3
revert
kmonte Apr 7, 2026
0e6f2eb
address comments
kmonte Apr 7, 2026
b78ea42
Merge remote-tracking branch 'origin/main' into kmonte/update-node-sh…
kmonte Apr 8, 2026
b9cf718
update docs
kmonte Apr 8, 2026
ad25141
update
kmonte Apr 8, 2026
987b23f
one den
kmonte Apr 8, 2026
8e2f741
Merge branch 'main' into kmonte/update-node-shard-strategy
kmontemayor2-sc Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,12 @@ def _setup_for_graph_store(
# Extract supervision edge types and derive label edge types from the
# ABLPInputNodes.labels dict (keyed by supervision edge type).
self._supervision_edge_types = list(first_input.labels.keys())
has_negatives = any(neg is not None for _, neg in first_input.labels.values())
has_negatives = False
for ablp_input in input_nodes.values():
for maybe_negative_labels in ablp_input.labels.values():
if maybe_negative_labels is not None:
has_negatives = True
break

self._positive_label_edge_types = [
message_passing_to_positive_label(et) for et in self._supervision_edge_types
Expand Down
8 changes: 7 additions & 1 deletion gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
FetchABLPInputRequest,
FetchNodesRequest,
)
from gigl.distributed.graph_store.sharding import ServerSlice
from gigl.distributed.sampler import ABLPNodeSamplerInput
from gigl.distributed.sampler_options import PPRSamplerOptions, SamplerOptions
from gigl.distributed.utils.neighborloader import shard_nodes_by_process
Expand Down Expand Up @@ -283,7 +284,7 @@ def get_node_ids(

Args:
request: The node-fetch request, including split, node type,
and round-robin rank/world_size.
and either round-robin rank/world_size or a contiguous slice.

Returns:
The node ids.
Expand All @@ -306,6 +307,7 @@ def get_node_ids(
node_type=request.node_type,
rank=request.rank,
world_size=request.world_size,
server_slice=request.server_slice,
)

def _get_node_ids(
Expand All @@ -314,6 +316,7 @@ def _get_node_ids(
node_type: Optional[NodeType],
rank: Optional[int] = None,
world_size: Optional[int] = None,
server_slice: Optional[ServerSlice] = None,
) -> torch.Tensor:
"""Core implementation for fetching node IDs by split, type, and sharding.

Expand Down Expand Up @@ -366,6 +369,8 @@ def _get_node_ids(
f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}."
)

if server_slice is not None:
return server_slice.slice_tensor(nodes)
if rank is not None and world_size is not None:
return shard_nodes_by_process(nodes, rank, world_size)
return nodes
Expand Down Expand Up @@ -420,6 +425,7 @@ def get_ablp_input(
node_type=request.node_type,
rank=request.rank,
world_size=request.world_size,
server_slice=request.server_slice,
)
positive_label_edge_type, negative_label_edge_type = select_label_edge_types(
request.supervision_edge_type, self.dataset.get_edge_types()
Expand Down
25 changes: 23 additions & 2 deletions gigl/distributed/graph_store/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Literal, Optional, Union

from gigl.distributed.graph_store.sharding import ServerSlice
from gigl.src.common.types.graph_data import EdgeType, NodeType


Expand All @@ -17,6 +18,9 @@ class FetchNodesRequest:
Must be provided together with ``rank``.
split: The split of the dataset to get node ids from.
node_type: The type of nodes to get node ids for.
server_slice: An optional :class:`~gigl.distributed.graph_store.sharding.ServerSlice`
describing the fraction of this server's data to return.
Cannot be combined with ``rank``/``world_size``.

Examples:
Fetch all nodes without sharding:
Expand All @@ -36,18 +40,25 @@ class FetchNodesRequest:
world_size: Optional[int] = None
split: Optional[Union[Literal["train", "val", "test"], str]] = None
node_type: Optional[NodeType] = None
server_slice: Optional[ServerSlice] = None

def validate(self) -> None:
"""Validate that the request has consistent rank/world_size.

Raises:
ValueError: If only one of ``rank`` or ``world_size`` is provided.
ValueError:
If only one of ``rank`` or ``world_size`` is provided.
If ``server_slice`` is provided together with ``rank`` or ``world_size``.
"""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
if self.server_slice is not None and (
self.rank is not None or self.world_size is not None
):
raise ValueError("server_slice cannot be combined with rank/world_size.")


@dataclass(frozen=True)
Expand All @@ -62,6 +73,9 @@ class FetchABLPInputRequest:
Must be provided together with ``world_size``.
world_size: The total number of processes in the distributed setup.
Must be provided together with ``rank``.
server_slice: An optional :class:`~gigl.distributed.graph_store.sharding.ServerSlice`
describing the fraction of this server's data to return.
Cannot be combined with ``rank``/``world_size``.

Examples:
Fetch training ABLP input without sharding:
Expand All @@ -78,15 +92,22 @@ class FetchABLPInputRequest:
supervision_edge_type: EdgeType
rank: Optional[int] = None
world_size: Optional[int] = None
server_slice: Optional[ServerSlice] = None
Comment thread
kmontemayor2-sc marked this conversation as resolved.

def validate(self) -> None:
"""Validate that the request has consistent rank/world_size.

Raises:
ValueError: If only one of ``rank`` or ``world_size`` is provided.
ValueError:
If only one of ``rank`` or ``world_size`` is provided.
If ``server_slice`` is provided together with ``rank`` or ``world_size``.
"""
if (self.rank is None) ^ (self.world_size is None):
raise ValueError(
"rank and world_size must be provided together. "
f"Received rank={self.rank}, world_size={self.world_size}"
)
if self.server_slice is not None and (
self.rank is not None or self.world_size is not None
):
raise ValueError("server_slice cannot be combined with rank/world_size.")
Loading