Skip to content

Commit

Permalink
Update the Random Walk binding (#1599)
Browse files Browse the repository at this point in the history
  • Loading branch information
Iroy30 authored Jun 3, 2021
1 parent 4e20f73 commit 637c139
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 236 deletions.
14 changes: 13 additions & 1 deletion cpp/include/cugraph/utilities/cython.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ struct random_walk_ret_t {
std::unique_ptr<rmm::device_buffer> d_sizes_;
};

struct random_walk_path_t {
std::unique_ptr<rmm::device_buffer> d_v_offsets;
std::unique_ptr<rmm::device_buffer> d_w_sizes;
std::unique_ptr<rmm::device_buffer> d_w_offsets;
};

struct graph_generator_t {
std::unique_ptr<rmm::device_buffer> d_source;
std::unique_ptr<rmm::device_buffer> d_destination;
Expand Down Expand Up @@ -538,7 +544,13 @@ call_random_walks(raft::handle_t const& handle,
graph_container_t const& graph_container,
vertex_t const* ptr_start_set,
edge_t num_paths,
edge_t max_depth);
edge_t max_depth,
bool use_padding);

template <typename index_t>
std::unique_ptr<random_walk_path_t> call_rw_paths(raft::handle_t const& handle,
index_t num_paths,
index_t const* vertex_path_sizes);

// convertor from random_walks return type to COO:
//
Expand Down
36 changes: 30 additions & 6 deletions cpp/src/utilities/cython.cu
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,8 @@ call_random_walks(raft::handle_t const& handle,
graph_container_t const& graph_container,
vertex_t const* ptr_start_set,
edge_t num_paths,
edge_t max_depth)
edge_t max_depth,
bool use_padding)
{
if (graph_container.weightType == numberTypeEnum::floatType) {
using weight_t = float;
Expand All @@ -888,7 +889,7 @@ call_random_walks(raft::handle_t const& handle,
detail::create_graph<vertex_t, edge_t, weight_t, false, false>(handle, graph_container);

auto triplet = cugraph::experimental::random_walks(
handle, graph->view(), ptr_start_set, num_paths, max_depth);
handle, graph->view(), ptr_start_set, num_paths, max_depth, use_padding);

random_walk_ret_t rw_tri{std::get<0>(triplet).size(),
std::get<1>(triplet).size(),
Expand All @@ -907,7 +908,7 @@ call_random_walks(raft::handle_t const& handle,
detail::create_graph<vertex_t, edge_t, weight_t, false, false>(handle, graph_container);

auto triplet = cugraph::experimental::random_walks(
handle, graph->view(), ptr_start_set, num_paths, max_depth);
handle, graph->view(), ptr_start_set, num_paths, max_depth, use_padding);

random_walk_ret_t rw_tri{std::get<0>(triplet).size(),
std::get<1>(triplet).size(),
Expand All @@ -924,6 +925,20 @@ call_random_walks(raft::handle_t const& handle,
}
}

template <typename index_t>
std::unique_ptr<random_walk_path_t> call_rw_paths(raft::handle_t const& handle,
index_t num_paths,
index_t const* vertex_path_sizes)
{
auto triplet =
cugraph::experimental::query_rw_sizes_offsets<index_t>(handle, num_paths, vertex_path_sizes);
random_walk_path_t rw_path_tri{
std::make_unique<rmm::device_buffer>(std::get<0>(triplet).release()),
std::make_unique<rmm::device_buffer>(std::get<1>(triplet).release()),
std::make_unique<rmm::device_buffer>(std::get<2>(triplet).release())};
return std::make_unique<random_walk_path_t>(std::move(rw_path_tri));
}

template <typename vertex_t, typename index_t>
std::unique_ptr<random_walk_coo_t> random_walks_to_coo(raft::handle_t const& handle,
random_walk_ret_t& rw_tri)
Expand Down Expand Up @@ -1354,21 +1369,30 @@ template std::unique_ptr<random_walk_ret_t> call_random_walks<int32_t, int32_t>(
graph_container_t const& graph_container,
int32_t const* ptr_start_set,
int32_t num_paths,
int32_t max_depth);
int32_t max_depth,
bool use_padding);

template std::unique_ptr<random_walk_ret_t> call_random_walks<int32_t, int64_t>(
raft::handle_t const& handle,
graph_container_t const& graph_container,
int32_t const* ptr_start_set,
int64_t num_paths,
int64_t max_depth);
int64_t max_depth,
bool use_padding);

template std::unique_ptr<random_walk_ret_t> call_random_walks<int64_t, int64_t>(
raft::handle_t const& handle,
graph_container_t const& graph_container,
int64_t const* ptr_start_set,
int64_t num_paths,
int64_t max_depth);
int64_t max_depth,
bool use_padding);

template std::unique_ptr<random_walk_path_t> call_rw_paths<int32_t>(
raft::handle_t const& handle, int32_t num_paths, int32_t const* vertex_path_sizes);

template std::unique_ptr<random_walk_path_t> call_rw_paths<int64_t>(
raft::handle_t const& handle, int64_t num_paths, int64_t const* vertex_path_sizes);

template std::unique_ptr<random_walk_coo_t> random_walks_to_coo<int32_t, int32_t>(
raft::handle_t const& handle, random_walk_ret_t& rw_tri);
Expand Down
178 changes: 30 additions & 148 deletions notebooks/sampling/RandomWalk.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -45,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -58,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -67,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -78,28 +78,17 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(34, 78)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# some stats on the graph\n",
"(G.number_of_nodes(), G.number_of_edges() )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -109,11 +98,21 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rw, so = cugraph.random_walks(G, seeds, 4)"
"# random walk path length\n",
"path_length = 4"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rw, so, sz = cugraph.random_walks(G, seeds, path_length, use_padding=True)"
]
},
{
Expand All @@ -131,144 +130,27 @@
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 0\n",
"1 3\n",
"2 6\n",
"dtype: int64"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"so"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>src</th>\n",
" <th>dst</th>\n",
" <th>weight</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>17</td>\n",
" <td>6</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>6</td>\n",
" <td>17</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>17</td>\n",
" <td>6</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>19</td>\n",
" <td>33</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>33</td>\n",
" <td>31</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>31</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" src dst weight\n",
"0 17 6 1.0\n",
"1 6 17 1.0\n",
"2 17 6 1.0\n",
"3 19 33 1.0\n",
"4 33 31 1.0\n",
"5 31 2 1.0"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"rw"
"rw.head(10)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"seed 17 starts at index 0 and is 3 rows\n",
"seed 19 starts at index 3 and is 3 rows\n"
]
}
],
"outputs": [],
"source": [
"idx = 0\n",
"for i in range(len(seeds)):\n",
" print(f\"seed {seeds[i]} starts at index {so[i]} and is {so[1 + 1] - so[1]} rows\")"
" for j in range(path_length):\n",
" print(f\"{rw[idx]}\", end=\" \")\n",
" idx += 1\n",
" print(\" \")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -305,7 +187,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.8.10"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion python/cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
from cugraph.raft import raft_include_test
from cugraph.comms import comms

from cugraph.sampling import random_walks
from cugraph.sampling import random_walks, rw_path

# Versioneer
from ._version import get_versions
Expand Down
2 changes: 1 addition & 1 deletion python/cugraph/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from cugraph.sampling.random_walks import random_walks
from cugraph.sampling.random_walks import random_walks, rw_path
8 changes: 7 additions & 1 deletion python/cugraph/sampling/random_walks.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,10 @@ cdef extern from "cugraph/utilities/cython.hpp" namespace "cugraph::cython":
const graph_container_t &g,
const vertex_t *ptr_d_start,
edge_t num_paths,
edge_t max_depth) except +
edge_t max_depth,
bool use_padding) except +

cdef unique_ptr[random_walk_path_t] call_rw_paths[index_t](
const handle_t &handle,
index_t num_paths,
const index_t* sizes) except +
Loading

0 comments on commit 637c139

Please sign in to comment.