Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions gigl/distributed/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
DistNodeSplitter,
NodeAnchorLinkSplitter,
NodeSplitter,
get_max_labels_per_anchor_node_from_runtime_args,
select_ssl_positive_label_edges,
)

Expand Down Expand Up @@ -502,6 +503,8 @@ def build_dataset_from_task_config_uri(
- should_load_tensors_in_parallel (bool): Whether TFRecord loading should happen in parallel across entities
Must be None if supervised edge labels are provided in advance.
Slotted for refactor once this functionality is available in the transductive `splitter` directly.
- max_labels_per_anchor_node (Optional[int]): Cap for how many labels to
materialize per anchor node for ABLP label fetching.
If training there are two additional arguments:
- num_val (float): Percentage of edges to use for validation, defaults to 0.1. Must in in range [0, 1].
- num_test (float): Percentage of edges to use for testing, defaults to 0.1. Must be in range [0, 1].
Expand Down Expand Up @@ -530,6 +533,7 @@ def build_dataset_from_task_config_uri(
)

ssl_positive_label_percentage: Optional[float] = None
max_labels_per_anchor_node: Optional[int] = None
splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None
if is_inference:
args = dict(gbml_config_pb_wrapper.inferencer_config.inferencer_args)
Expand Down Expand Up @@ -576,6 +580,7 @@ def build_dataset_from_task_config_uri(
raise ValueError(
f"Unsupported task metadata type: {task_metadata_pb_wrapper.task_metadata_type}"
)
max_labels_per_anchor_node = get_max_labels_per_anchor_node_from_runtime_args(args)

assert sample_edge_direction in (
"in",
Expand Down Expand Up @@ -628,5 +633,6 @@ def build_dataset_from_task_config_uri(
splitter=splitter,
_ssl_positive_label_percentage=ssl_positive_label_percentage,
)
dataset.max_labels_per_anchor_node = max_labels_per_anchor_node

return dataset
1 change: 1 addition & 0 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def _setup_for_colocated(
node_ids=curr_process_nodes,
positive_label_edge_type=positive_label_edge_type,
negative_label_edge_type=negative_label_edge_type,
max_labels_per_anchor_node=dataset.max_labels_per_anchor_node,
)
positive_labels_by_label_edge_type[positive_label_edge_type] = (
positive_labels
Expand Down
23 changes: 22 additions & 1 deletion gigl/distributed/dist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
GraphPartitionData,
PartitionOutput,
)
from gigl.utils.data_splitters import NodeAnchorLinkSplitter, NodeSplitter
from gigl.utils.data_splitters import (
NodeAnchorLinkSplitter,
NodeSplitter,
)
from gigl.utils.share_memory import share_memory

logger = Logger()
Expand Down Expand Up @@ -80,6 +83,7 @@ def __init__(
degree_tensor: Optional[
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
] = None,
max_labels_per_anchor_node: Optional[int] = None,
) -> None:
"""
Initializes the fields of the DistDataset class. This function is called upon each serialization of the DistDataset instance.
Expand All @@ -105,6 +109,8 @@ def __init__(
edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous.
Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case.
degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property.
max_labels_per_anchor_node (Optional[int]): Optional cap for how many
labels to materialize per anchor node for ABLP label fetching.
"""
self._rank: int = rank
self._world_size: int = world_size
Expand Down Expand Up @@ -143,6 +149,7 @@ def __init__(
self._degree_tensor: Optional[
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
] = degree_tensor
self._max_labels_per_anchor_node = max_labels_per_anchor_node

# TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear
# naming (i.e. rank, world_size).
Expand Down Expand Up @@ -329,6 +336,16 @@ def degree_tensor(
self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph)
return self._degree_tensor

@property
def max_labels_per_anchor_node(self) -> Optional[int]:
return self._max_labels_per_anchor_node

@max_labels_per_anchor_node.setter
def max_labels_per_anchor_node(
self, new_max_labels_per_anchor_node: Optional[int]
) -> None:
self._max_labels_per_anchor_node = new_max_labels_per_anchor_node

@property
def train_node_ids(
self,
Expand Down Expand Up @@ -858,6 +875,7 @@ def share_ipc(
Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]],
Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]],
Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]],
Optional[int],
]:
"""
Serializes the member variables of the DistDatasetClass
Expand All @@ -880,6 +898,7 @@ def share_ipc(
Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous
Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous
Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous
Optional[int]: Optional per-anchor label cap for ABLP label fetching
"""
# TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function

Expand Down Expand Up @@ -908,6 +927,7 @@ def share_ipc(
self._node_feature_info, # Additional field unique to DistDataset class
self._edge_feature_info, # Additional field unique to DistDataset class
self._degree_tensor, # Additional field unique to DistDataset class
self._max_labels_per_anchor_node, # Additional field unique to DistDataset class
)
return ipc_handle

Expand Down Expand Up @@ -1164,6 +1184,7 @@ def _rebuild_distributed_dataset(
Union[FeatureInfo, dict[EdgeType, FeatureInfo]]
], # Edge feature dim and its data type
Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # Degree tensors
Optional[int], # Optional per-anchor label cap for ABLP label fetching
],
):
dataset = DistDataset.from_ipc_handle(ipc_handle)
Expand Down
6 changes: 5 additions & 1 deletion gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,11 @@ def get_ablp_input(
request.supervision_edge_type, self.dataset.get_edge_types()
)
positive_labels, negative_labels = get_labels_for_anchor_nodes(
self.dataset, anchors, positive_label_edge_type, negative_label_edge_type
self.dataset,
anchors,
positive_label_edge_type,
negative_label_edge_type,
max_labels_per_anchor_node=self.dataset.max_labels_per_anchor_node,
)
return anchors, positive_labels, negative_labels

Expand Down
19 changes: 17 additions & 2 deletions gigl/distributed/graph_store/storage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
)
from gigl.env.distributed import GraphStoreInfo
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.utils.data_splitters import DistNodeAnchorLinkSplitter, DistNodeSplitter
from gigl.utils.data_splitters import (
DistNodeAnchorLinkSplitter,
DistNodeSplitter,
get_max_labels_per_anchor_node_from_runtime_args,
)

logger = Logger()

Expand All @@ -45,6 +49,7 @@ def build_storage_dataset(
splitter: Optional[Union[DistNodeAnchorLinkSplitter, DistNodeSplitter]] = None,
should_load_tensors_in_parallel: bool = True,
ssl_positive_label_percentage: Optional[float] = None,
max_labels_per_anchor_node: Optional[int] = None,
) -> DistDataset:
"""Build a :class:`DistDataset` for a storage node from a task config.

Expand All @@ -71,26 +76,36 @@ def build_storage_dataset(
self-supervised positive labels. Must be ``None`` when
supervised edge labels are already provided. For example,
``0.1`` selects 10 % of edges.
max_labels_per_anchor_node: Optional cap for how many labels to
materialize per anchor node when the storage server serves ABLP
input. If ``None``, this is inferred from the task config's
``trainer_args``.

Returns:
A partitioned :class:`DistDataset` ready to be served.
"""
gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
gbml_config_uri=task_config_uri
)
if max_labels_per_anchor_node is None:
max_labels_per_anchor_node = get_max_labels_per_anchor_node_from_runtime_args(
dict(gbml_config_pb_wrapper.trainer_config.trainer_args)
)
Comment on lines +91 to +93
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, I meant to leave a comment here about using storage args for graph store mode 1.

How about we just never call this function here, and always require it to be passed in? We can just parse it as a CLI flag e.g. at 2

serialized_graph_metadata = convert_pb_to_serialized_graph_metadata(
preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper,
graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper,
tfrecord_uri_pattern=tf_record_uri_pattern,
)
return build_dataset(
dataset = build_dataset(
serialized_graph_metadata=serialized_graph_metadata,
sample_edge_direction=sample_edge_direction,
should_load_tensors_in_parallel=should_load_tensors_in_parallel,
partitioner_class=DistRangePartitioner,
splitter=splitter,
_ssl_positive_label_percentage=ssl_positive_label_percentage,
)
dataset.max_labels_per_anchor_node = max_labels_per_anchor_node
return dataset


def _run_storage_server_session(
Expand Down
56 changes: 54 additions & 2 deletions gigl/utils/data_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,48 @@
logger = Logger()

PADDING_NODE: Final[torch.Tensor] = torch.tensor(-1, dtype=torch.int64)
MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG: Final[str] = "max_labels_per_anchor_node"

# We need to make the protocols for the node splitter and node anchor linked spliter runtime checkable so that
# we can make isinstance() checks on them at runtime.


def validate_max_labels_per_anchor_node(
max_labels_per_anchor_node: Optional[int],
) -> None:
"""Validate the optional per-anchor label cap.

Args:
max_labels_per_anchor_node: The value to validate.

Raises:
ValueError: If max_labels_per_anchor_node is not None and not a positive integer.
"""
if max_labels_per_anchor_node is not None and max_labels_per_anchor_node <= 0:
raise ValueError(
"max_labels_per_anchor_node must be a positive integer when provided."
)


def get_max_labels_per_anchor_node_from_runtime_args(
runtime_args: Mapping[str, str],
) -> Optional[int]:
"""Parse the optional per-anchor label cap from runtime args."""
raw_max_labels_per_anchor_node = runtime_args.get(
MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG
)
if raw_max_labels_per_anchor_node is None:
return None
try:
parsed_max_labels_per_anchor_node = int(raw_max_labels_per_anchor_node)
except ValueError as exc:
raise ValueError(
f"Invalid {MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG} value "
f"{raw_max_labels_per_anchor_node!r}. Expected a positive integer."
) from exc
return parsed_max_labels_per_anchor_node


@runtime_checkable
class NodeAnchorLinkSplitter(Protocol):
"""Protocol that should be satisfied for anything that is used to split on edges.
Expand Down Expand Up @@ -562,6 +599,7 @@ def get_labels_for_anchor_nodes(
node_ids: torch.Tensor,
positive_label_edge_type: PyGEdgeType,
negative_label_edge_type: Optional[PyGEdgeType] = None,
max_labels_per_anchor_node: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Selects labels for the given node ids based on the provided edge types.

Expand Down Expand Up @@ -592,6 +630,8 @@ def get_labels_for_anchor_nodes(
positive_label_edge_type (PyGEdgeType): The edge type to use for the positive labels.
negative_label_edge_type (Optional[PyGEdgeType]): The edge type to use for the negative labels.
Defaults to None. If not provided no negative labels will be returned.
max_labels_per_anchor_node (Optional[int]): If provided, caps the number of
positive and negative labels materialized per anchor node.
Returns:
Tuple of (positive labels, negative_labels?)
negative labels may be None depending on if negative_label_edge_type is provided.
Expand All @@ -612,13 +652,19 @@ def get_labels_for_anchor_nodes(

# Labels is NxM, where N is the number of nodes, and M is the max number of labels.
positive_labels = _get_padded_labels(
node_ids, positive_node_topo, allow_non_existant_node_ids=False
node_ids,
positive_node_topo,
allow_non_existant_node_ids=False,
max_labels_per_anchor_node=max_labels_per_anchor_node,
)

if negative_node_topo is not None:
# Labels is NxM, where N is the number of nodes, and M is the max number of labels.
negative_labels = _get_padded_labels(
node_ids, negative_node_topo, allow_non_existant_node_ids=True
node_ids,
negative_node_topo,
allow_non_existant_node_ids=True,
max_labels_per_anchor_node=max_labels_per_anchor_node,
)
else:
negative_labels = None
Expand All @@ -630,6 +676,7 @@ def _get_padded_labels(
anchor_node_ids: torch.Tensor,
topo: Topology,
allow_non_existant_node_ids: bool = False,
max_labels_per_anchor_node: Optional[int] = None,
) -> torch.Tensor:
"""Returns the padded labels and the max range of labels.

Expand All @@ -642,9 +689,12 @@ def _get_padded_labels(
topo (Topology): The topology to use for the labels.
allow_non_existant_node_ids (bool): If True, will allow anchor node ids that do not exist in the topology.
This means that the returned tensor will be padded with `PADDING_NODE` for those anchor node ids.
max_labels_per_anchor_node (Optional[int]): If provided, caps the number of
labels materialized per anchor node.
Returns:
The shape of the returned tensor is [N, max_number_of_labels].
"""
validate_max_labels_per_anchor_node(max_labels_per_anchor_node)
# indptr is the ROW_INDEX of a CSR matrix.
# and indices is the COL_INDEX of a CSR matrix.
# See https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
Expand All @@ -660,6 +710,8 @@ def _get_padded_labels(
ends = indptr[anchor_node_ids + 1] # [N]

max_range = int(torch.max(ends - starts).item())
if max_labels_per_anchor_node is not None:
max_range = min(max_range, max_labels_per_anchor_node)

# Sample all labels based on the CSR start/stop indices.
# Creates "indices" for us to us, e.g [[0, 1], [2, 3]]
Expand Down
31 changes: 30 additions & 1 deletion tests/unit/distributed/dist_ablp_neighborloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def tearDown(self):
10: torch.tensor([13, 16]),
15: torch.tensor([17]),
},
max_labels_per_anchor_node=None,
),
param(
"Positive edges",
Expand All @@ -457,6 +458,28 @@ def tearDown(self):
15: torch.tensor([16]),
},
expected_negative_labels=None,
max_labels_per_anchor_node=None,
),
param(
"Positive and Negative edges with label cap",
labeled_edges={
_POSITIVE_EDGE_TYPE: torch.tensor([[10, 15], [15, 16]]),
_NEGATIVE_EDGE_TYPE: torch.tensor(
[[10, 10, 11, 15], [13, 16, 14, 17]]
),
},
expected_node=torch.tensor([10, 11, 12, 13, 14, 15, 16, 17]),
expected_srcs=torch.tensor([10, 10, 15, 15, 16, 16, 11, 11]),
expected_dsts=torch.tensor([11, 12, 13, 14, 12, 14, 13, 17]),
expected_positive_labels={
10: torch.tensor([15]),
15: torch.tensor([16]),
},
expected_negative_labels={
10: torch.tensor([13]),
15: torch.tensor([17]),
},
max_labels_per_anchor_node=1,
),
]
)
Expand All @@ -469,6 +492,7 @@ def test_ablp_dataloader(
expected_dsts,
expected_positive_labels,
expected_negative_labels,
max_labels_per_anchor_node,
):
# Graph looks like https://is.gd/w2oEVp:
# Message passing
Expand Down Expand Up @@ -511,7 +535,12 @@ def test_ablp_dataloader(
partitioned_positive_labels=None,
partitioned_node_labels=None,
)
dataset = DistDataset(rank=0, world_size=1, edge_dir="out")
dataset = DistDataset(
rank=0,
world_size=1,
edge_dir="out",
max_labels_per_anchor_node=max_labels_per_anchor_node,
)
dataset.build(partition_output=partition_output)

mp.spawn(
Expand Down
Loading