-
Notifications
You must be signed in to change notification settings - Fork 304
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
neighbor sampling in COO/CSR format (#1982)
This pull request adds neighborhood sampling, as needed by GNN frameworks (DGL, PyTorch-Geometric). Since I did not hear back on most of the other issues that need to be addressed before this, I am continuing with my plan of first opening a PR with just the API. Once we agree on the final API, and once a minimal version of cugraph-ops is integrated, we can add the implementation of this API. In particular, for now I am suggesting that the sampling type is exposed in the public API (it does not exist yet in cugraph-ops since that has not been integrated yet). This must be decided ahead of sampling for best performance (either by the end user or some automatic heuristic on the original graph), which is why it makes sense to have as a separate parameter for this API. EDIT: link to issue #1978 Authors: - Matt Joux (https://github.com/MatthiasKohl) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) - Robert Maynard (https://github.com/robertmaynard) - Andrei Schaffer (https://github.com/aschaffer) - Chuck Hastings (https://github.com/ChuckHastings) URL: #1982
- Loading branch information
1 parent
df49ad7
commit e95171f
Showing
9 changed files
with
248 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
/* | ||
* Copyright (c) 2021-2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cugraph/algorithms.hpp> | ||
|
||
#include <utilities/cugraph_ops_utils.hpp> | ||
|
||
#include <cugraph-ops/graph/sampling.h> | ||
|
||
namespace cugraph { | ||
|
||
template <typename graph_t> | ||
std::tuple<rmm::device_uvector<typename graph_t::edge_type>, | ||
rmm::device_uvector<typename graph_t::vertex_type>> | ||
sample_neighbors_adjacency_list(raft::handle_t const& handle, | ||
ops::gnn::graph::Rng& rng, | ||
graph_t const& graph, | ||
typename graph_t::vertex_type const* ptr_d_start, | ||
size_t num_start_vertices, | ||
size_t sampling_size, | ||
ops::gnn::graph::SamplingAlgoT sampling_algo) | ||
{ | ||
const auto [ops_graph, max_degree] = detail::get_graph_and_max_degree(graph); | ||
return ops::gnn::graph::uniform_sample_csr(rng, | ||
ops_graph, | ||
ptr_d_start, | ||
num_start_vertices, | ||
sampling_size, | ||
sampling_algo, | ||
max_degree, | ||
handle.get_stream()); | ||
} | ||
|
||
template <typename graph_t> | ||
std::tuple<rmm::device_uvector<typename graph_t::vertex_type>, | ||
rmm::device_uvector<typename graph_t::vertex_type>> | ||
sample_neighbors_edgelist(raft::handle_t const& handle, | ||
ops::gnn::graph::Rng& rng, | ||
graph_t const& graph, | ||
typename graph_t::vertex_type const* ptr_d_start, | ||
size_t num_start_vertices, | ||
size_t sampling_size, | ||
ops::gnn::graph::SamplingAlgoT sampling_algo) | ||
{ | ||
const auto [ops_graph, max_degree] = detail::get_graph_and_max_degree(graph); | ||
return ops::gnn::graph::uniform_sample_coo(rng, | ||
ops_graph, | ||
ptr_d_start, | ||
num_start_vertices, | ||
sampling_size, | ||
sampling_algo, | ||
max_degree, | ||
handle.get_stream()); | ||
} | ||
|
||
// template explicit instantiation directives (EIDir's): | ||
// | ||
// CSR SG FP32{ | ||
template std::tuple<rmm::device_uvector<int32_t>, rmm::device_uvector<int32_t>> | ||
sample_neighbors_adjacency_list<graph_view_t<int32_t, int32_t, float, false, false>>( | ||
raft::handle_t const& handle, | ||
ops::gnn::graph::Rng& rng, | ||
graph_view_t<int32_t, int32_t, float, false, false> const& gview, | ||
int32_t const* ptr_d_start, | ||
size_t num_start_vertices, | ||
size_t sampling_size, | ||
ops::gnn::graph::SamplingAlgoT sampling_algo); | ||
|
||
template std::tuple<rmm::device_uvector<int64_t>, rmm::device_uvector<int64_t>> | ||
sample_neighbors_adjacency_list<graph_view_t<int64_t, int64_t, float, false, false>>( | ||
raft::handle_t const& handle, | ||
ops::gnn::graph::Rng& rng, | ||
graph_view_t<int64_t, int64_t, float, false, false> const& gview, | ||
int64_t const* ptr_d_start, | ||
size_t num_start_vertices, | ||
size_t sampling_size, | ||
ops::gnn::graph::SamplingAlgoT sampling_algo); | ||
//} | ||
// | ||
// COO SG FP32{ | ||
template std::tuple<rmm::device_uvector<int32_t>, rmm::device_uvector<int32_t>> | ||
sample_neighbors_edgelist<graph_view_t<int32_t, int32_t, float, false, false>>( | ||
raft::handle_t const& handle, | ||
ops::gnn::graph::Rng& rng, | ||
graph_view_t<int32_t, int32_t, float, false, false> const& gview, | ||
int32_t const* ptr_d_start, | ||
size_t num_start_vertices, | ||
size_t sampling_size, | ||
ops::gnn::graph::SamplingAlgoT sampling_algo); | ||
|
||
template std::tuple<rmm::device_uvector<int64_t>, rmm::device_uvector<int64_t>> | ||
sample_neighbors_edgelist<graph_view_t<int64_t, int64_t, float, false, false>>( | ||
raft::handle_t const& handle, | ||
ops::gnn::graph::Rng& rng, | ||
graph_view_t<int64_t, int64_t, float, false, false> const& gview, | ||
int64_t const* ptr_d_start, | ||
size_t num_start_vertices, | ||
size_t sampling_size, | ||
ops::gnn::graph::SamplingAlgoT sampling_algo); | ||
//} | ||
|
||
} // namespace cugraph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* Copyright (c) 2021-2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cugraph/graph_view.hpp> | ||
|
||
#include <cugraph-ops/graph/format.h> | ||
|
||
#include <tuple> | ||
|
||
namespace cugraph { | ||
namespace detail { | ||
|
||
template <typename NodeTypeT, typename EdgeTypeT, typename WeightT> | ||
ops::gnn::graph::fg_csr<EdgeTypeT> get_graph( | ||
graph_view_t<NodeTypeT, EdgeTypeT, WeightT, false, false> const& gview) | ||
{ | ||
ops::gnn::graph::fg_csr<EdgeTypeT> graph; | ||
graph.n_nodes = gview.get_number_of_vertices(); | ||
graph.n_indices = gview.get_number_of_edges(); | ||
// FIXME: this is evil and is just temporary until we have a matching type in cugraph-ops | ||
// or we change the type accepted by the functions calling into cugraph-ops | ||
graph.offsets = const_cast<EdgeTypeT*>(gview.get_matrix_partition_view().get_offsets()); | ||
graph.indices = const_cast<EdgeTypeT*>(gview.get_matrix_partition_view().get_indices()); | ||
return graph; | ||
} | ||
|
||
template <typename NodeTypeT, typename EdgeTypeT, typename WeightT> | ||
std::tuple<ops::gnn::graph::fg_csr<EdgeTypeT>, NodeTypeT> get_graph_and_max_degree( | ||
graph_view_t<NodeTypeT, EdgeTypeT, WeightT, false, false> const& gview) | ||
{ | ||
// FIXME this is sufficient for now, but if there is a fast (cached) way | ||
// of getting max degree, use that instead | ||
auto max_degree = std::numeric_limits<NodeTypeT>::max(); | ||
return std::make_tuple(get_graph(gview), max_degree); | ||
} | ||
|
||
} // namespace detail | ||
} // namespace cugraph |