Skip to content

Problem about update_memory #25

@Void-JackLee

Description

@Void-JackLee

Hi @emalgorithm, I got some problems when reading your codes.

When memory_update_at_start=True, the msg_agg and msg_func will calc twice, before the compute_embedding and after compute_embedding. Before the compute_embedding, the get_updated_memory function will calc all nodes' memory. After the compute_embedding, update_memory function will calc positive nodes.

if self.memory_update_at_start:
    # Persist the updates to the memory only for sources and destinations (since now we have
    # new messages for them)
    self.update_memory(positives, self.memory.messages)

    assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
      "Something wrong in how the memory was updated"

    # Remove messages for the positives since we have already updated the memory using them
    self.memory.clear_messages(positives)

  unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
                                                                source_node_embedding,
                                                                destination_nodes,
                                                                destination_node_embedding,
                                                                edge_times, edge_idxs)
  unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
                                                                          destination_node_embedding,
                                                                          source_nodes,
                                                                          source_node_embedding,
                                                                          edge_times, edge_idxs)
  if self.memory_update_at_start:
    self.memory.store_raw_messages(unique_sources, source_id_to_messages)
    self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
  else:
    self.update_memory(unique_sources, source_id_to_messages)
    self.update_memory(unique_destinations, destination_id_to_messages)

The code annotation here was "Persist the updates to the memory only for sources and destinations (since now we have new messages for them)", but actually the message in this batch was update after the memory update, update_memory function was updating memory from the message in last batch. So here comes a problem that update_memory(positives, self.memory.messages) was updating positive nodes in this batch, and updated messages was from last batch. I don't understand why the code is doing this, maybe it's a bug?

I think here needs to update all nodes' memory (or record last batch's positive nodes), or update memory in get_updated_memory function directly (replace it to update_memory).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions