-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transfom] Mutable transform #1833
Changes from all commits
c2d1e4c
42fd45e
f6a9db8
6f9ed5e
23c7ef4
1358c56
fdfdc86
49b882a
53411e2
fb5cb86
5eeab84
fda0e04
b5dee42
1f6b763
0c9153b
a27dc16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -561,7 +561,7 @@ def add_edges(self, u, v, data=None, etype=None): | |
return | ||
|
||
assert len(u) == len(v) or len(u) == 1 or len(v) == 1, \ | ||
'need the number of source nodes and the number of destination nodes are same, ' \ | ||
'The number of source nodes and the number of destination nodes should be same, ' \ | ||
'or either the number of source nodes or the number of destination nodes is 1.' | ||
|
||
if len(u) == 1 and len(v) > 1: | ||
|
@@ -696,12 +696,8 @@ def remove_edges(self, eids, etype=None): | |
for c_etype in self.canonical_etypes: | ||
# the target edge type | ||
if c_etype == (u_type, e_type, v_type): | ||
old_eids = self.edges(form='eid', order='eid', etype=c_etype) | ||
# trick here, eid_0 is 0 and should be handled | ||
old_eids[0] += 1 | ||
old_eids = F.scatter_row(old_eids, eids, F.full_1d( | ||
len(eids), 0, F.dtype(old_eids), F.context(old_eids))) | ||
edges[c_etype] = F.tensor(F.nonzero_1d(old_eids), dtype=F.dtype(old_eids)) | ||
origin_eids = self.edges(form='eid', order='eid', etype=c_etype) | ||
edges[c_etype] = utils.compensate(eids, origin_eids) | ||
else: | ||
edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype) | ||
|
||
|
@@ -788,12 +784,8 @@ def remove_nodes(self, nids, ntype=None): | |
nodes = {} | ||
for c_ntype in self.ntypes: | ||
if self.get_ntype_id(c_ntype) == ntid: | ||
old_nids = self.nodes(c_ntype) | ||
# trick here, nid_0 is 0 and should be handled | ||
old_nids[0] += 1 | ||
old_nids = F.scatter_row(old_nids, nids, F.full_1d( | ||
len(nids), 0, F.dtype(old_nids), F.context(old_nids))) | ||
nodes[c_ntype] = F.tensor(F.nonzero_1d(old_nids), dtype=F.dtype(old_nids)) | ||
original_nids = self.nodes(c_ntype) | ||
nodes[c_ntype] = utils.compensate(nids, original_nids) | ||
else: | ||
nodes[c_ntype] = self.nodes(c_ntype) | ||
|
||
|
@@ -803,127 +795,6 @@ def remove_nodes(self, nids, ntype=None): | |
self._node_frames = sub_g._node_frames | ||
self._edge_frames = sub_g._edge_frames | ||
|
||
def add_selfloop(self, etype=None): | ||
r""" Add self loop for each node in the graph. | ||
|
||
Parameters | ||
---------- | ||
etype : str or tuple of str, optional | ||
The type of the edges to remove. Can be omitted if there is | ||
only one edge type in the graph. | ||
|
||
Notes | ||
----- | ||
* It is recommanded to ``remove_selfloop`` before invoking | ||
``add_selfloop``. | ||
* Inplace update is applied to the current graph. | ||
* Features for the new edges (self-loop edges) will be created | ||
by initializers defined with :func:`set_n_initializer` | ||
(default initializer fills zeros). | ||
|
||
Examples | ||
-------- | ||
|
||
>>> import dgl | ||
>>> import torch | ||
|
||
**Homogeneous Graphs or Heterogeneous Graphs with A Single Node Type** | ||
|
||
>>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([2, 1, 0]))) | ||
>>> g.ndata['hv'] = torch.arange(3).float().reshape(-1, 1) | ||
>>> g.edata['he'] = torch.arange(3).float().reshape(-1, 1) | ||
>>> g | ||
>>> g | ||
Graph(num_nodes=3, num_edges=6, | ||
ndata_schemes={'hv': Scheme(shape=(1,), dtype=torch.float32)} | ||
edata_schemes={'he': Scheme(shape=(1,), dtype=torch.float32)}) | ||
>>> g.edata['he'] | ||
tensor([[0.], | ||
[1.], | ||
[2.], | ||
[0.], | ||
[0.], | ||
[0.]]) | ||
|
||
**Heterogeneous Graphs with Multiple Node Types** | ||
|
||
>>> g = dgl.heterograph({ | ||
('user', 'follows', 'user'): (torch.tensor([1, 2]), | ||
torch.tensor([0, 1])), | ||
('user', 'plays', 'game'): (torch.tensor([0, 1]), | ||
torch.tensor([0, 1]))}) | ||
>>> g.add_selfloop(etype='follows') | ||
>>> g | ||
Graph(num_nodes={'user': 3, 'game': 2}, | ||
num_edges={('user', 'plays', 'game'): 2, ('user', 'follows', 'user'): 5}, | ||
metagraph=[('user', 'user'), ('user', 'game')]) | ||
""" | ||
etype = self.to_canonical_etype(etype) | ||
if etype[0] != etype[2]: | ||
raise DGLError( | ||
'add_selfloop does not support unidirectional bipartite graphs: {}.' \ | ||
'Please make sure the types of head node and tail node are identical.' \ | ||
''.format(etype)) | ||
nodes = self.nodes(etype[0]) | ||
self.add_edges(nodes, nodes, etype=etype) | ||
|
||
def remove_selfloop(self, etype=None): | ||
r""" Remove self loops for each node in the graph. | ||
|
||
If there are multiple self loops for a certain node, | ||
all of them will be removed. | ||
|
||
Examples | ||
-------- | ||
|
||
>>> import dgl | ||
>>> import torch | ||
|
||
**Homogeneous Graphs or Heterogeneous Graphs with A Single Node Type** | ||
|
||
>>> g = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([1, 0, 0, 2])), | ||
idtype=idtype, device=F.ctx()) | ||
>>> g.edata['he'] = torch.arange(4).float().reshape(-1, 1) | ||
>>> g.remove_selfloop() | ||
>>> g | ||
Graph(num_nodes=3, num_edges=2, | ||
edata_schemes={'he': Scheme(shape=(2,), dtype=torch.float32)}) | ||
>>> g.edata['he'] | ||
tensor([[0.],[3.]]) | ||
|
||
**Heterogeneous Graphs with Multiple Node Types** | ||
|
||
>>> g = dgl.heterograph({ | ||
>>> ('user', 'follows', 'user'): (torch.tensor([0, 1, 1, 1, 2]), | ||
>>> torch.tensor([0, 0, 1, 1, 1])), | ||
>>> ('user', 'plays', 'game'): (torch.tensor([0, 1]), | ||
>>> torch.tensor([0, 1])) | ||
>>> }) | ||
>>> g.remove_selfloop(etype='follows') | ||
>>> g.num_nodes('user') | ||
3 | ||
>>> g.num_nodes('game') | ||
2 | ||
>>> g.num_edges('follows') | ||
2 | ||
>>> g.num_edges('plays') | ||
2 | ||
|
||
See Also | ||
-------- | ||
add_selfloop | ||
""" | ||
# TODO(xiangsx) need to handle block | ||
etype = self.to_canonical_etype(etype) | ||
if etype[0] != etype[2]: | ||
raise DGLError( | ||
'remove_selfloop does not support unidirectional bipartite graphs: {}.' \ | ||
'Please make sure the types of head node and tail node are identical.' \ | ||
''.format(etype)) | ||
u, v = self.edges(form='uv', order='eid', etype=etype) | ||
self_loop_eids = F.tensor(F.nonzero_1d(u == v), dtype=F.dtype(u)) | ||
self.remove_edges(self_loop_eids, etype=etype) | ||
|
||
################################################################# | ||
# Metagraph query | ||
################################################################# | ||
|
@@ -4476,6 +4347,32 @@ def cpu(self): | |
""" | ||
return self.to(F.cpu()) | ||
|
||
def clone(self): | ||
"""Return a heterograph object that is a clone of current graph. | ||
|
||
Returns | ||
------- | ||
DGLHeteroGraph | ||
The graph object that is a clone of current graph. | ||
""" | ||
meta_edges = [] | ||
for s_ntype, _, d_ntype in self.canonical_etypes: | ||
meta_edges.append((self.get_ntype_id(s_ntype), self.get_ntype_id(d_ntype))) | ||
|
||
metagraph = graph_index.from_edge_list(meta_edges, True) | ||
# rebuild graph idx | ||
num_nodes_per_type = [self.number_of_nodes(c_ntype) for c_ntype in self.ntypes] | ||
relation_graphs = [self._graph.get_relation_graph(self.get_etype_id(c_etype)) | ||
for c_etype in self.canonical_etypes] | ||
hgidx = heterograph_index.create_heterograph_from_relations( | ||
metagraph, relation_graphs, utils.toindex(num_nodes_per_type, "int64")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like we could just pass the same graphindex object to the result graph because those inplace mutation operators is actually out-place because a new graphindex object is created, which means it is safe to share graph index across different graphs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The graph index in C++ has lots of shared_ptr objects. Will this cause problems? |
||
|
||
local_node_frames = [fr.clone() for fr in self._node_frames] | ||
local_edge_frames = [fr.clone() for fr in self._edge_frames] | ||
return DGLHeteroGraph(hgidx, self.ntypes, self.etypes, | ||
local_node_frames, local_edge_frames) | ||
|
||
|
||
def local_var(self): | ||
"""Return a heterograph object that can be used in a local function scope. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to create this range array right? Just an integer is enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eids is supposed to be consecutive but nids are not.
The compensate can handle none consecutive origin_ids.
I think this is not a critical path, so it is ok.