You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If I am not mistaken, there seems to be a bug when using the model on a Unipartite dataset when updating the memory at the end of each batch memory_update_at_start=False.
Running the model like this incorrectly triggers the AssertionError: Trying to update to time in the past of the memory_updater module. This is due to lines 185-186 in tgn.py.
defcompute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
edge_idxs, n_neighbors=20):
...
ifself.use_memory:
ifself.memory_update_at_start:
# Update memory for all nodes with messages stored in previous batchesmemory, last_update=self.get_updated_memory(list(range(self.n_nodes)),
self.memory.messages)
else:
memory=self.memory.get_memory(list(range(self.n_nodes)))
last_update=self.memory.last_update
...
ifself.use_memory:
ifself.memory_update_at_start:
# Persist the updates to the memory only for sources and destinations (since now we have# new messages for them)self.update_memory(positives, self.memory.messages)
asserttorch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
"Something wrong in how the memory was updated"# Remove messages for the positives since we have already updated the memory using themself.memory.clear_messages(positives)
unique_sources, source_id_to_messages=self.get_raw_messages(source_nodes, source_node_embedding, destination_nodes, destination_node_embedding, edge_times, edge_idxs)
unique_destinations, destination_id_to_messages=self.get_raw_messages(destination_nodes, destination_node_embedding, source_nodes, source_node_embedding, edge_times, edge_idxs)
ifself.memory_update_at_start:
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
else:
self.update_memory(unique_sources, source_id_to_messages) <--185self.update_memory(unique_destinations, destination_id_to_messages) <--186
...
returnsource_node_embedding, destination_node_embedding, negative_node_embedding
When the source_nodes and destination_nodes contain non-overlapping node ids this is not a problem. However, when using a unipartite graph, the same node id can be in the source_nodes and the destination_nodes, which causes the described issue if this node id is associated with a later timestamp on the source node side, then the target node side.
Hi,
If I am not mistaken, there seems to be a bug when using the model on a Unipartite dataset when updating the memory at the end of each batch
memory_update_at_start=False
.Running the model like this incorrectly triggers the
AssertionError: Trying to update to time in the past
of the memory_updater module. This is due to lines 185-186 in tgn.py.When the
source_nodes
anddestination_nodes
contain non-overlapping node ids this is not a problem. However, when using a unipartite graph, the same node id can be in thesource_nodes
and thedestination_nodes
, which causes the described issue if this node id is associated with a later timestamp on the source node side, then the target node side.This problem can be resolved by replacing:
with:
Edit: Found an issue in the fix initially proposed and updated matching the pull request
The text was updated successfully, but these errors were encountered: