From fb0547d613926017a097c6f8f7a292554eb00533 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 4 Jul 2024 12:03:16 -0400 Subject: [PATCH] [GraphBolt] Add temporal labor sampling to graph. (#7500) --- .../graphbolt/fused_csc_sampling_graph.h | 4 +- graphbolt/src/cuda/common.h | 10 -- graphbolt/src/cuda/neighbor_sampler.cu | 1 + graphbolt/src/fused_csc_sampling_graph.cc | 60 +++++--- graphbolt/src/macro.h | 10 ++ .../impl/fused_csc_sampling_graph.py | 133 ++++++++++++++++++ .../impl/test_fused_csc_sampling_graph.py | 14 +- 7 files changed, 203 insertions(+), 29 deletions(-) diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index a2d280777444..3ea2827573c1 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -404,7 +404,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { torch::optional input_nodes_pre_time_window, torch::optional probs_or_mask, torch::optional node_timestamp_attr_name, - torch::optional edge_timestamp_attr_name) const; + torch::optional edge_timestamp_attr_name, + torch::optional random_seed, + double seed2_contribution) const; /** * @brief Copy the graph to shared memory. diff --git a/graphbolt/src/cuda/common.h b/graphbolt/src/cuda/common.h index ea11b005952c..c2ffc719c438 100644 --- a/graphbolt/src/cuda/common.h +++ b/graphbolt/src/cuda/common.h @@ -194,16 +194,6 @@ struct CopyScalar { bool is_ready_; }; -// This includes all integer, float and boolean types. -#define GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(...) \ - AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Bool, __VA_ARGS__) - -#define GRAPHBOLT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)) - #define GRAPHBOLT_DISPATCH_ELEMENT_SIZES(element_size, name, ...) \ [&] { \ switch (element_size) { \ diff --git a/graphbolt/src/cuda/neighbor_sampler.cu b/graphbolt/src/cuda/neighbor_sampler.cu index d41104ce57af..71d357706672 100644 --- a/graphbolt/src/cuda/neighbor_sampler.cu +++ b/graphbolt/src/cuda/neighbor_sampler.cu @@ -38,6 +38,7 @@ #include #include +#include "../macro.h" #include "../random.h" #include "../utils.h" #include "./common.h" diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 3c217e011fce..68be349cb032 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -910,7 +910,9 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( torch::optional input_nodes_pre_time_window, torch::optional probs_or_mask, torch::optional node_timestamp_attr_name, - torch::optional edge_timestamp_attr_name) const { + torch::optional edge_timestamp_attr_name, + torch::optional random_seed, + double seed2_contribution) const { torch::optional> seed_offsets = torch::nullopt; // 1. Get probs_or_mask. if (probs_or_mask.has_value()) { @@ -928,19 +930,45 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name); // 4. Call SampleNeighborsImpl if (layer) { - const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt( - static_cast(0), std::numeric_limits::max()); - SamplerArgs args{indices_, random_seed, NumNodes()}; - return SampleNeighborsImpl( - input_nodes, seed_offsets, fanouts, return_eids, - GetTemporalNumPickFn( - input_nodes_timestamp, this->indices_, fanouts, replace, - type_per_edge_, input_nodes_pre_time_window, probs_or_mask, - node_timestamp, edge_timestamp), - GetTemporalPickFn( - input_nodes_timestamp, this->indices_, fanouts, replace, - indptr_.options(), type_per_edge_, input_nodes_pre_time_window, - probs_or_mask, node_timestamp, edge_timestamp, args)); + if (random_seed.has_value() && random_seed->numel() >= 2) { + SamplerArgs args{ + indices_, + {random_seed.value(), static_cast(seed2_contribution)}, + NumNodes()}; + return SampleNeighborsImpl( + input_nodes, seed_offsets, fanouts, return_eids, + GetTemporalNumPickFn( + input_nodes_timestamp, indices_, fanouts, replace, type_per_edge_, + input_nodes_pre_time_window, probs_or_mask, node_timestamp, + edge_timestamp), + GetTemporalPickFn( + input_nodes_timestamp, indices_, fanouts, replace, + indptr_.options(), type_per_edge_, input_nodes_pre_time_window, + probs_or_mask, node_timestamp, edge_timestamp, args)); + } else { + auto args = [&] { + if (random_seed.has_value() && random_seed->numel() == 1) { + return SamplerArgs{ + indices_, random_seed.value(), NumNodes()}; + } else { + return SamplerArgs{ + indices_, + RandomEngine::ThreadLocal()->RandInt( + static_cast(0), std::numeric_limits::max()), + NumNodes()}; + } + }(); + return SampleNeighborsImpl( + input_nodes, seed_offsets, fanouts, return_eids, + GetTemporalNumPickFn( + input_nodes_timestamp, indices_, fanouts, replace, type_per_edge_, + input_nodes_pre_time_window, probs_or_mask, node_timestamp, + edge_timestamp), + GetTemporalPickFn( + input_nodes_timestamp, indices_, fanouts, replace, + indptr_.options(), type_per_edge_, input_nodes_pre_time_window, + probs_or_mask, node_timestamp, edge_timestamp, args)); + } } else { SamplerArgs args; return SampleNeighborsImpl( @@ -1560,7 +1588,7 @@ int64_t TemporalPick( masked_prob = probs_or_mask.value().slice(0, offset, offset + num_neighbors) * mask; } else { - masked_prob = mask.to(torch::kFloat32); + masked_prob = S == SamplerType::NEIGHBOR ? mask.to(torch::kFloat32) : mask; } if constexpr (S == SamplerType::NEIGHBOR) { auto picked_indices = NonUniformPickOp(masked_prob, fanout, replace); @@ -1693,7 +1721,7 @@ std::enable_if_t Pick( probs_or_mask.value(), picked_data_ptr); } else { int64_t picked_count; - AT_DISPATCH_FLOATING_TYPES( + GRAPHBOLT_DISPATCH_ALL_TYPES( probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] { if (replace) { picked_count = LaborPick( diff --git a/graphbolt/src/macro.h b/graphbolt/src/macro.h index 9e130b9a848c..8ff929b85266 100644 --- a/graphbolt/src/macro.h +++ b/graphbolt/src/macro.h @@ -25,6 +25,16 @@ namespace graphbolt { TORCH_CHECK(false, name, " is only available on CUDA device."); #endif +// This includes all integer, float and boolean types. +#define GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Bool, __VA_ARGS__) + +#define GRAPHBOLT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)) + } // namespace graphbolt #endif // GRAPHBOLT_MACRO_H_ diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index f9c583c30d98..1e753a2aeee6 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -1119,6 +1119,139 @@ def temporal_sample_neighbors( probs_or_mask, node_timestamp_attr_name, edge_timestamp_attr_name, + None, # random_seed, labor parameter + 0, # seed2_contribution, labor_parameter + ) + return self._convert_to_sampled_subgraph(C_sampled_subgraph) + + def temporal_sample_layer_neighbors( + self, + nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], + input_nodes_timestamp: Union[torch.Tensor, Dict[str, torch.Tensor]], + fanouts: torch.Tensor, + replace: bool = False, + input_nodes_pre_time_window: Optional[ + Union[torch.Tensor, Dict[str, torch.Tensor]] + ] = None, + probs_name: Optional[str] = None, + node_timestamp_attr_name: Optional[str] = None, + edge_timestamp_attr_name: Optional[str] = None, + random_seed: torch.Tensor = None, + seed2_contribution: float = 0.0, + ) -> torch.ScriptObject: + """Temporally Sample neighboring edges of the given nodes and return the induced + subgraph via layer-neighbor sampling from the NeurIPS 2023 paper + `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs + `__ + + If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given, + the sampled neighbor or edge of an input node must have a timestamp + that is smaller than that of the input node. + + Parameters + ---------- + nodes: torch.Tensor + IDs of the given seed nodes. + input_nodes_timestamp: torch.Tensor + Timestamps of the given seed nodes. + fanouts: torch.Tensor + The number of edges to be sampled for each node with or without + considering edge types. + - When the length is 1, it indicates that the fanout applies to + all neighbors of the node as a collective, regardless of the + edge type. + - Otherwise, the length should equal to the number of edge + types, and each fanout value corresponds to a specific edge + type of the nodes. + The value of each fanout should be >= 0 or = -1. + - When the value is -1, all neighbors (with non-zero probability, + if weighted) will be sampled once regardless of replacement. It + is equivalent to selecting all neighbors with non-zero + probability when the fanout is >= the number of neighbors (and + replace is set to false). + - When the value is a non-negative integer, it serves as a + minimum threshold for selecting neighbors. + replace: bool + Boolean indicating whether the sample is preformed with or + without replacement. If True, a value can be selected multiple + times. Otherwise, each value can be selected only once. + input_nodes_pre_time_window: torch.Tensor + The time window of the nodes represents a period of time before + `input_nodes_timestamp`. If provided, only neighbors and related + edges whose timestamps fall within `[input_nodes_timestamp - + input_nodes_pre_time_window, input_nodes_timestamp]` will be + filtered. + probs_name: str, optional + An optional string specifying the name of an edge attribute. This + attribute tensor should contain (unnormalized) probabilities + corresponding to each neighboring edge of a node. It must be a 1D + floating-point or boolean tensor, with the number of elements + equalling the total number of edges. + node_timestamp_attr_name: str, optional + An optional string specifying the name of an node attribute. + edge_timestamp_attr_name: str, optional + An optional string specifying the name of an edge attribute. + random_seed: torch.Tensor, optional + An int64 tensor with one or two elements. + + The passed random_seed makes it so that for any seed node ``s`` and + its neighbor ``t``, the rolled random variate ``r_t`` is the same + for any call to this function with the same random seed. When + sampling as part of the same batch, one would want identical seeds + so that LABOR can globally sample. One example is that for + heterogenous graphs, there is a single random seed passed for each + edge type. This will sample much fewer nodes compared to having + unique random seeds for each edge type. If one called this function + individually for each edge type for a heterogenous graph with + different random seeds, then it would run LABOR locally for each + edge type, resulting into a larger number of nodes being sampled. + + If this function is called without a ``random_seed``, we get the + random seed by getting a random number from GraphBolt. Use this + argument with identical random_seed if multiple calls to this + function are used to sample as part of a single batch. + + If given two numbers, then the ``seed2_contribution`` argument + determines the interpolation between the two random seeds. + seed2_contribution: float, optional + A float value between [0, 1) that determines the contribution of the + second random seed, ``random_seed[-1]``, to generate the random + variates. + + Returns + ------- + SampledSubgraphImpl + The sampled subgraph. + """ + if isinstance(nodes, dict): + ( + nodes, + input_nodes_timestamp, + input_nodes_pre_time_window, + ) = self._convert_to_homogeneous_nodes( + nodes, input_nodes_timestamp, input_nodes_pre_time_window + ) + + # Ensure nodes is 1-D tensor. + probs_or_mask = self.edge_attributes[probs_name] if probs_name else None + self._check_sampler_arguments(nodes, fanouts, probs_or_mask) + has_original_eids = ( + self.edge_attributes is not None + and ORIGINAL_EDGE_ID in self.edge_attributes + ) + C_sampled_subgraph = self._c_csc_graph.temporal_sample_neighbors( + nodes, + input_nodes_timestamp, + fanouts.tolist(), + replace, + True, + has_original_eids, + input_nodes_pre_time_window, + probs_or_mask, + node_timestamp_attr_name, + edge_timestamp_attr_name, + random_seed, + seed2_contribution, ) return self._convert_to_sampled_subgraph(C_sampled_subgraph) diff --git a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py index 4204fba159a6..1f23e8c685af 100644 --- a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py +++ b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py @@ -825,10 +825,16 @@ def test_in_subgraph_hetero(): @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("replace", [False, True]) +@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("use_node_timestamp", [False, True]) @pytest.mark.parametrize("use_edge_timestamp", [False, True]) def test_temporal_sample_neighbors_homo( - indptr_dtype, indices_dtype, replace, use_node_timestamp, use_edge_timestamp + indptr_dtype, + indices_dtype, + replace, + labor, + use_node_timestamp, + use_edge_timestamp, ): """Original graph in COO: 1 0 1 0 1 @@ -853,7 +859,11 @@ def test_temporal_sample_neighbors_homo( # Generate subgraph via sample neighbors. fanouts = torch.LongTensor([2]) - sampler = graph.temporal_sample_neighbors + sampler = ( + graph.temporal_sample_layer_neighbors + if labor + else graph.temporal_sample_neighbors + ) seed_list = [1, 3, 4] seed_timestamp = torch.randint(0, 100, (len(seed_list),), dtype=torch.int64)