diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index 1f78f83f3..b43301674 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -49,6 +49,7 @@ DistNodeSplitter, NodeAnchorLinkSplitter, NodeSplitter, + get_max_labels_per_anchor_node_from_runtime_args, select_ssl_positive_label_edges, ) @@ -502,6 +503,8 @@ def build_dataset_from_task_config_uri( - should_load_tensors_in_parallel (bool): Whether TFRecord loading should happen in parallel across entities Must be None if supervised edge labels are provided in advance. Slotted for refactor once this functionality is available in the transductive `splitter` directly. + - max_labels_per_anchor_node (Optional[int]): Cap for how many labels to + materialize per anchor node for ABLP label fetching. If training there are two additional arguments: - num_val (float): Percentage of edges to use for validation, defaults to 0.1. Must in in range [0, 1]. - num_test (float): Percentage of edges to use for testing, defaults to 0.1. Must be in range [0, 1]. @@ -530,6 +533,7 @@ def build_dataset_from_task_config_uri( ) ssl_positive_label_percentage: Optional[float] = None + max_labels_per_anchor_node: Optional[int] = None splitter: Optional[Union[NodeSplitter, NodeAnchorLinkSplitter]] = None if is_inference: args = dict(gbml_config_pb_wrapper.inferencer_config.inferencer_args) @@ -576,6 +580,7 @@ def build_dataset_from_task_config_uri( raise ValueError( f"Unsupported task metadata type: {task_metadata_pb_wrapper.task_metadata_type}" ) + max_labels_per_anchor_node = get_max_labels_per_anchor_node_from_runtime_args(args) assert sample_edge_direction in ( "in", @@ -628,5 +633,6 @@ def build_dataset_from_task_config_uri( splitter=splitter, _ssl_positive_label_percentage=ssl_positive_label_percentage, ) + dataset.max_labels_per_anchor_node = max_labels_per_anchor_node return dataset diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 215a92a51..c7845f701 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -538,6 +538,7 @@ def _setup_for_colocated( node_ids=curr_process_nodes, positive_label_edge_type=positive_label_edge_type, negative_label_edge_type=negative_label_edge_type, + max_labels_per_anchor_node=dataset.max_labels_per_anchor_node, ) positive_labels_by_label_edge_type[positive_label_edge_type] = ( positive_labels diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index d37bbc925..b40f2969a 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -26,7 +26,10 @@ GraphPartitionData, PartitionOutput, ) -from gigl.utils.data_splitters import NodeAnchorLinkSplitter, NodeSplitter +from gigl.utils.data_splitters import ( + NodeAnchorLinkSplitter, + NodeSplitter, +) from gigl.utils.share_memory import share_memory logger = Logger() @@ -80,6 +83,7 @@ def __init__( degree_tensor: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = None, + max_labels_per_anchor_node: Optional[int] = None, ) -> None: """ Initializes the fields of the DistDataset class. This function is called upon each serialization of the DistDataset instance. @@ -105,6 +109,8 @@ def __init__( edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Dimension of edge features and its data type, will be a dict if heterogeneous. Note this will be None in the homogeneous case if the data has no edge features, or will only contain edge types with edge features in the heterogeneous case. degree_tensor: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Pre-computed degree tensor. Lazily computed on first access via the degree_tensor property. + max_labels_per_anchor_node (Optional[int]): Optional cap for how many + labels to materialize per anchor node for ABLP label fetching. """ self._rank: int = rank self._world_size: int = world_size @@ -143,6 +149,7 @@ def __init__( self._degree_tensor: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ] = degree_tensor + self._max_labels_per_anchor_node = max_labels_per_anchor_node # TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear # naming (i.e. rank, world_size). @@ -329,6 +336,16 @@ def degree_tensor( self._degree_tensor = compute_and_broadcast_degree_tensor(self.graph) return self._degree_tensor + @property + def max_labels_per_anchor_node(self) -> Optional[int]: + return self._max_labels_per_anchor_node + + @max_labels_per_anchor_node.setter + def max_labels_per_anchor_node( + self, new_max_labels_per_anchor_node: Optional[int] + ) -> None: + self._max_labels_per_anchor_node = new_max_labels_per_anchor_node + @property def train_node_ids( self, @@ -858,6 +875,7 @@ def share_ipc( Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]], Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]], Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], + Optional[int], ]: """ Serializes the member variables of the DistDatasetClass @@ -880,6 +898,7 @@ def share_ipc( Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]]: Node feature dim and its data type, will be a dict if heterogeneous Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]]: Edge feature dim and its data type, will be a dict if heterogeneous Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]: Degree tensors, will be a dict if heterogeneous + Optional[int]: Optional per-anchor label cap for ABLP label fetching """ # TODO (mkolodner-sc): Investigate moving share_memory calls to the build() function @@ -908,6 +927,7 @@ def share_ipc( self._node_feature_info, # Additional field unique to DistDataset class self._edge_feature_info, # Additional field unique to DistDataset class self._degree_tensor, # Additional field unique to DistDataset class + self._max_labels_per_anchor_node, # Additional field unique to DistDataset class ) return ipc_handle @@ -1164,6 +1184,7 @@ def _rebuild_distributed_dataset( Union[FeatureInfo, dict[EdgeType, FeatureInfo]] ], # Edge feature dim and its data type Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]], # Degree tensors + Optional[int], # Optional per-anchor label cap for ABLP label fetching ], ): dataset = DistDataset.from_ipc_handle(ipc_handle) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 1110e47c4..533b1dfb3 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -408,7 +408,11 @@ def get_ablp_input( request.supervision_edge_type, self.dataset.get_edge_types() ) positive_labels, negative_labels = get_labels_for_anchor_nodes( - self.dataset, anchors, positive_label_edge_type, negative_label_edge_type + self.dataset, + anchors, + positive_label_edge_type, + negative_label_edge_type, + max_labels_per_anchor_node=self.dataset.max_labels_per_anchor_node, ) return anchors, positive_labels, negative_labels diff --git a/gigl/distributed/graph_store/storage_utils.py b/gigl/distributed/graph_store/storage_utils.py index 548e8ee7f..e67da070a 100644 --- a/gigl/distributed/graph_store/storage_utils.py +++ b/gigl/distributed/graph_store/storage_utils.py @@ -32,7 +32,11 @@ ) from gigl.env.distributed import GraphStoreInfo from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper -from gigl.utils.data_splitters import DistNodeAnchorLinkSplitter, DistNodeSplitter +from gigl.utils.data_splitters import ( + DistNodeAnchorLinkSplitter, + DistNodeSplitter, + get_max_labels_per_anchor_node_from_runtime_args, +) logger = Logger() @@ -45,6 +49,7 @@ def build_storage_dataset( splitter: Optional[Union[DistNodeAnchorLinkSplitter, DistNodeSplitter]] = None, should_load_tensors_in_parallel: bool = True, ssl_positive_label_percentage: Optional[float] = None, + max_labels_per_anchor_node: Optional[int] = None, ) -> DistDataset: """Build a :class:`DistDataset` for a storage node from a task config. @@ -71,6 +76,10 @@ def build_storage_dataset( self-supervised positive labels. Must be ``None`` when supervised edge labels are already provided. For example, ``0.1`` selects 10 % of edges. + max_labels_per_anchor_node: Optional cap for how many labels to + materialize per anchor node when the storage server serves ABLP + input. If ``None``, this is inferred from the task config's + ``trainer_args``. Returns: A partitioned :class:`DistDataset` ready to be served. @@ -78,12 +87,16 @@ def build_storage_dataset( gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=task_config_uri ) + if max_labels_per_anchor_node is None: + max_labels_per_anchor_node = get_max_labels_per_anchor_node_from_runtime_args( + dict(gbml_config_pb_wrapper.trainer_config.trainer_args) + ) serialized_graph_metadata = convert_pb_to_serialized_graph_metadata( preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper, graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper, tfrecord_uri_pattern=tf_record_uri_pattern, ) - return build_dataset( + dataset = build_dataset( serialized_graph_metadata=serialized_graph_metadata, sample_edge_direction=sample_edge_direction, should_load_tensors_in_parallel=should_load_tensors_in_parallel, @@ -91,6 +104,8 @@ def build_storage_dataset( splitter=splitter, _ssl_positive_label_percentage=ssl_positive_label_percentage, ) + dataset.max_labels_per_anchor_node = max_labels_per_anchor_node + return dataset def _run_storage_server_session( diff --git a/gigl/utils/data_splitters.py b/gigl/utils/data_splitters.py index 4aa416f2e..7367b1dbc 100644 --- a/gigl/utils/data_splitters.py +++ b/gigl/utils/data_splitters.py @@ -31,11 +31,48 @@ logger = Logger() PADDING_NODE: Final[torch.Tensor] = torch.tensor(-1, dtype=torch.int64) +MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG: Final[str] = "max_labels_per_anchor_node" # We need to make the protocols for the node splitter and node anchor linked spliter runtime checkable so that # we can make isinstance() checks on them at runtime. +def validate_max_labels_per_anchor_node( + max_labels_per_anchor_node: Optional[int], +) -> None: + """Validate the optional per-anchor label cap. + + Args: + max_labels_per_anchor_node: The value to validate. + + Raises: + ValueError: If max_labels_per_anchor_node is not None and not a positive integer. + """ + if max_labels_per_anchor_node is not None and max_labels_per_anchor_node <= 0: + raise ValueError( + "max_labels_per_anchor_node must be a positive integer when provided." + ) + + +def get_max_labels_per_anchor_node_from_runtime_args( + runtime_args: Mapping[str, str], +) -> Optional[int]: + """Parse the optional per-anchor label cap from runtime args.""" + raw_max_labels_per_anchor_node = runtime_args.get( + MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG + ) + if raw_max_labels_per_anchor_node is None: + return None + try: + parsed_max_labels_per_anchor_node = int(raw_max_labels_per_anchor_node) + except ValueError as exc: + raise ValueError( + f"Invalid {MAX_LABELS_PER_ANCHOR_NODE_RUNTIME_ARG} value " + f"{raw_max_labels_per_anchor_node!r}. Expected a positive integer." + ) from exc + return parsed_max_labels_per_anchor_node + + @runtime_checkable class NodeAnchorLinkSplitter(Protocol): """Protocol that should be satisfied for anything that is used to split on edges. @@ -562,6 +599,7 @@ def get_labels_for_anchor_nodes( node_ids: torch.Tensor, positive_label_edge_type: PyGEdgeType, negative_label_edge_type: Optional[PyGEdgeType] = None, + max_labels_per_anchor_node: Optional[int] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Selects labels for the given node ids based on the provided edge types. @@ -592,6 +630,8 @@ def get_labels_for_anchor_nodes( positive_label_edge_type (PyGEdgeType): The edge type to use for the positive labels. negative_label_edge_type (Optional[PyGEdgeType]): The edge type to use for the negative labels. Defaults to None. If not provided no negative labels will be returned. + max_labels_per_anchor_node (Optional[int]): If provided, caps the number of + positive and negative labels materialized per anchor node. Returns: Tuple of (positive labels, negative_labels?) negative labels may be None depending on if negative_label_edge_type is provided. @@ -612,13 +652,19 @@ def get_labels_for_anchor_nodes( # Labels is NxM, where N is the number of nodes, and M is the max number of labels. positive_labels = _get_padded_labels( - node_ids, positive_node_topo, allow_non_existant_node_ids=False + node_ids, + positive_node_topo, + allow_non_existant_node_ids=False, + max_labels_per_anchor_node=max_labels_per_anchor_node, ) if negative_node_topo is not None: # Labels is NxM, where N is the number of nodes, and M is the max number of labels. negative_labels = _get_padded_labels( - node_ids, negative_node_topo, allow_non_existant_node_ids=True + node_ids, + negative_node_topo, + allow_non_existant_node_ids=True, + max_labels_per_anchor_node=max_labels_per_anchor_node, ) else: negative_labels = None @@ -630,6 +676,7 @@ def _get_padded_labels( anchor_node_ids: torch.Tensor, topo: Topology, allow_non_existant_node_ids: bool = False, + max_labels_per_anchor_node: Optional[int] = None, ) -> torch.Tensor: """Returns the padded labels and the max range of labels. @@ -642,9 +689,12 @@ def _get_padded_labels( topo (Topology): The topology to use for the labels. allow_non_existant_node_ids (bool): If True, will allow anchor node ids that do not exist in the topology. This means that the returned tensor will be padded with `PADDING_NODE` for those anchor node ids. + max_labels_per_anchor_node (Optional[int]): If provided, caps the number of + labels materialized per anchor node. Returns: The shape of the returned tensor is [N, max_number_of_labels]. """ + validate_max_labels_per_anchor_node(max_labels_per_anchor_node) # indptr is the ROW_INDEX of a CSR matrix. # and indices is the COL_INDEX of a CSR matrix. # See https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format) @@ -660,6 +710,8 @@ def _get_padded_labels( ends = indptr[anchor_node_ids + 1] # [N] max_range = int(torch.max(ends - starts).item()) + if max_labels_per_anchor_node is not None: + max_range = min(max_range, max_labels_per_anchor_node) # Sample all labels based on the CSR start/stop indices. # Creates "indices" for us to us, e.g [[0, 1], [2, 3]] diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 4575b7ad4..315b2b590 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -445,6 +445,7 @@ def tearDown(self): 10: torch.tensor([13, 16]), 15: torch.tensor([17]), }, + max_labels_per_anchor_node=None, ), param( "Positive edges", @@ -457,6 +458,28 @@ def tearDown(self): 15: torch.tensor([16]), }, expected_negative_labels=None, + max_labels_per_anchor_node=None, + ), + param( + "Positive and Negative edges with label cap", + labeled_edges={ + _POSITIVE_EDGE_TYPE: torch.tensor([[10, 15], [15, 16]]), + _NEGATIVE_EDGE_TYPE: torch.tensor( + [[10, 10, 11, 15], [13, 16, 14, 17]] + ), + }, + expected_node=torch.tensor([10, 11, 12, 13, 14, 15, 16, 17]), + expected_srcs=torch.tensor([10, 10, 15, 15, 16, 16, 11, 11]), + expected_dsts=torch.tensor([11, 12, 13, 14, 12, 14, 13, 17]), + expected_positive_labels={ + 10: torch.tensor([15]), + 15: torch.tensor([16]), + }, + expected_negative_labels={ + 10: torch.tensor([13]), + 15: torch.tensor([17]), + }, + max_labels_per_anchor_node=1, ), ] ) @@ -469,6 +492,7 @@ def test_ablp_dataloader( expected_dsts, expected_positive_labels, expected_negative_labels, + max_labels_per_anchor_node, ): # Graph looks like https://is.gd/w2oEVp: # Message passing @@ -511,7 +535,12 @@ def test_ablp_dataloader( partitioned_positive_labels=None, partitioned_node_labels=None, ) - dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset = DistDataset( + rank=0, + world_size=1, + edge_dir="out", + max_labels_per_anchor_node=max_labels_per_anchor_node, + ) dataset.build(partition_output=partition_output) mp.spawn( diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py index a0f1a3594..8a38ccceb 100644 --- a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -501,6 +501,35 @@ def test_fetch_ablp_input(self, mock_async_request): torch.tensor([[1]]), ) + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=_mock_async_request_server, + ) + def test_fetch_ablp_input_respects_max_labels_per_anchor_node( + self, mock_async_request + ): + _create_server_with_splits() + self.assertIsNotNone(_test_server) + assert _test_server is not None + _test_server.dataset.max_labels_per_anchor_node = 1 + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + result = remote_dataset.fetch_ablp_input( + split="train", anchor_node_type=USER, supervision_edge_type=USER_TO_STORY + ) + pos_labels, neg_labels = result[0].labels[USER_TO_STORY] + self.assert_tensor_equality( + pos_labels, + torch.tensor([[0], [1], [2]]), + ) + assert neg_labels is not None + self.assert_tensor_equality( + neg_labels, + torch.tensor([[2], [3], [4]]), + ) + @patch( "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", side_effect=_mock_async_request_server, diff --git a/tests/unit/utils/data_splitters_test.py b/tests/unit/utils/data_splitters_test.py index c9c7f6ce4..22424d889 100644 --- a/tests/unit/utils/data_splitters_test.py +++ b/tests/unit/utils/data_splitters_test.py @@ -18,6 +18,7 @@ _fast_hash, _get_padded_labels, get_labels_for_anchor_nodes, + get_max_labels_per_anchor_node_from_runtime_args, select_ssl_positive_label_edges, ) from tests.test_assets.distributed.utils import ( @@ -810,6 +811,37 @@ def test_get_padded_labels(self, _, node_ids, topo, expected): labels = _get_padded_labels(node_ids, topo) assert_close(labels, expected, rtol=0, atol=0) + def test_get_padded_labels_with_max_labels_per_anchor_node(self): + labels = _get_padded_labels( + torch.tensor([0, 1]), + Topology( + edge_index=torch.tensor([[0, 0, 1], [1, 2, 2]], dtype=torch.int64), + layout="CSR", + ), + max_labels_per_anchor_node=1, + ) + assert_close( + labels, + torch.tensor([[1], [2]], dtype=torch.int64), + rtol=0, + atol=0, + ) + + def test_get_max_labels_per_anchor_node_from_runtime_args(self): + self.assertIsNone(get_max_labels_per_anchor_node_from_runtime_args({})) + self.assertEqual( + get_max_labels_per_anchor_node_from_runtime_args( + {"max_labels_per_anchor_node": "3"} + ), + 3, + ) + + def test_get_max_labels_per_anchor_node_from_runtime_args_invalid(self): + with self.assertRaises(ValueError): + get_max_labels_per_anchor_node_from_runtime_args( + {"max_labels_per_anchor_node": "0"} + ) + @parameterized.expand( [ param(