Skip to content

Updating memory fails for datasets that are not bipartite #29

@daniel-gomm

Description

@daniel-gomm

Hi,

If I am not mistaken, there seems to be a bug when using the model on a Unipartite dataset when updating the memory at the end of each batch memory_update_at_start=False.

Running the model like this incorrectly triggers the AssertionError: Trying to update to time in the past of the memory_updater module. This is due to lines 185-186 in tgn.py.

def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                                  edge_idxs, n_neighbors=20):
    ...
    if self.use_memory:
      if self.memory_update_at_start:
        # Update memory for all nodes with messages stored in previous batches
        memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
                                                      self.memory.messages)
      else:
        memory = self.memory.get_memory(list(range(self.n_nodes)))
        last_update = self.memory.last_update

      ...

    if self.use_memory:
      if self.memory_update_at_start:
        # Persist the updates to the memory only for sources and destinations (since now we have
        # new messages for them)
        self.update_memory(positives, self.memory.messages)

        assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
          "Something wrong in how the memory was updated"

        # Remove messages for the positives since we have already updated the memory using them
        self.memory.clear_messages(positives)

      unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes, source_node_embedding, destination_nodes, destination_node_embedding, edge_times, edge_idxs)
      unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes, destination_node_embedding, source_nodes, source_node_embedding, edge_times, edge_idxs)
      if self.memory_update_at_start:
        self.memory.store_raw_messages(unique_sources, source_id_to_messages)
        self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
      else:
        self.update_memory(unique_sources, source_id_to_messages)                  <-- 185
        self.update_memory(unique_destinations, destination_id_to_messages)        <-- 186

     ...

    return source_node_embedding, destination_node_embedding, negative_node_embedding

When the source_nodes and destination_nodes contain non-overlapping node ids this is not a problem. However, when using a unipartite graph, the same node id can be in the source_nodes and the destination_nodes, which causes the described issue if this node id is associated with a later timestamp on the source node side, then the target node side.

This problem can be resolved by replacing:

      if self.memory_update_at_start:
        self.memory.store_raw_messages(unique_sources, source_id_to_messages)
        self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
      else:
        self.update_memory(unique_sources, source_id_to_messages)
        self.update_memory(unique_destinations, destination_id_to_messages)

with:

            self.memory.store_raw_messages(unique_sources, source_id_to_messages)
            self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)

            if not self.memory_update_at_start:
                unique_node_ids = np.unique(np.concatenate((unique_sources, unique_destinations)))
                self.update_memory(unique_node_ids,
                             self.memory.messages)
                self.memory.clear_messages(unique_node_ids)

Edit: Found an issue in the fix initially proposed and updated matching the pull request

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions