From bf587d4ba954a203b1bae1cb742b45bc580dae11 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Sun, 24 Mar 2024 16:55:38 -0400 Subject: [PATCH] fix the bug. --- python/dgl/convert.py | 2 +- python/dgl/utils/data.py | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/python/dgl/convert.py b/python/dgl/convert.py index 242f96b5b061..c2fd1e789c96 100644 --- a/python/dgl/convert.py +++ b/python/dgl/convert.py @@ -551,7 +551,7 @@ def create_block( data, idtype, bipartite=True, - infer_node_count=False, + infer_node_count=need_infer, ) node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays) if need_infer: diff --git a/python/dgl/utils/data.py b/python/dgl/utils/data.py index 3363a2ff9a63..da78d6916d88 100644 --- a/python/dgl/utils/data.py +++ b/python/dgl/utils/data.py @@ -199,8 +199,7 @@ def graphdata2tensors( num_src, num_dst = ( infer_num_nodes(data, bipartite=bipartite) if infer_node_count - else None, - None, + else (None, None) ) elif isinstance(data, list): src, dst = elist2tensor(data, idtype) @@ -208,16 +207,14 @@ def graphdata2tensors( num_src, num_dst = ( infer_num_nodes(data, bipartite=bipartite) if infer_node_count - else None, - None, + else (None, None) ) elif isinstance(data, sp.sparse.spmatrix): # We can get scipy matrix's number of rows and columns easily. num_src, num_dst = ( infer_num_nodes(data, bipartite=bipartite) if infer_node_count - else None, - None, + else (None, None) ) data = scipy2tensor(data, idtype) elif isinstance(data, nx.Graph): @@ -225,8 +222,7 @@ def graphdata2tensors( num_src, num_dst = ( infer_num_nodes(data, bipartite=bipartite) if infer_node_count - else None, - None, + else (None, None) ) edge_id_attr_name = kwargs.get("edge_id_attr_name", None) if bipartite: