Skip to content

spreading activation and hierarchical memory #13

@DiTo97

Description

@DiTo97

some ideas borrowed from memoripy re-implemented in a N4j graph database:

import time
import typing

import neo4j
import numpy
from numpy.typing import NDArray
from sklearn.cluster import KMeans


class Interaction(typing.TypedDict):
    id: str  # unique identifier for the interaction
    prompt: str  # prompt or input associated with the interaction
    output: str  # output or response of the interaction
    embedding: list[float]  # embedding vector representing the interaction
    timestamp: int  # timestamp of when the interaction occurred
    access_count: int  # number of times the interaction has been accessed
    decay_factor: float  # decay factor for relevance over time
    concepts: set[str]  # concepts associated with the interaction


class SimilarityRecord(typing.TypedDict):
    id: str
    prompt: str
    similarity: float


class ReinforcementRecord(typing.TypedDict):
    id: str
    prompt: str
    adjusted_relevance: float


class ClusteringRecord(typing.TypedDict):
    id: str
    prompt: str


class Record(typing.TypedDict):
    similarity: float
    activation_score: float


class MemoryStore:
    def __init__(self, endpoint: str, username: str, password: str):
        self.driver = neo4j.GraphDatabase.driver(endpoint, auth=(username, password))

    def __enter__(self):
        return self

    def __exit__(self, exception_type, exception_value, traceback):
        self.close()

    def close(self):
        self.driver.close()

    def add_interaction(self, interaction: Interaction):
        with self.driver.session() as session:
            session.write_transaction(self._add_interaction_tx, interaction)

    @staticmethod
    def _add_interaction_tx(tx: neo4j.Transaction, interaction: Interaction):
        query = """
CREATE (i:Interaction:ShortTermMemory {id: $id, prompt: $prompt, output: $output, 
                                        embedding: $embedding, timestamp: $timestamp, 
                                        access_count: $access_count, decay_factor: $decay_factor})
WITH i
UNWIND $concepts AS concept
MERGE (c:Concept {name: concept})
WITH i, c
UNWIND $concepts AS other_concept
WITH i, c, other_concept
MATCH (c1:Concept {name: other_concept})
WHERE c <> c1
MERGE (c)-[r:ASSOCIATED_WITH]->(c1)
ON CREATE SET r.weight = 1
ON MATCH SET r.weight = r.weight + 1
RETURN i
"""
        tx.run(query, {
            "id": interaction["id"],
            "prompt": interaction["prompt"],
            "output": interaction["output"],
            "embedding": interaction["embedding"],
            "timestamp": interaction.get("timestamp", time.time()),
            "access_count": interaction.get("access_count", 1),
            "decay_factor": interaction.get("decay_factor", 1.0),
            "concepts": interaction.get("concepts", [])
        })

    def classify_memory(self):
        with self.driver.session() as session:
            session.write_transaction(self._classify_memory_tx)

    @staticmethod
    def _classify_memory_tx(tx: neo4j.Transaction):
        query = """
MATCH (i:Interaction:ShortTermMemory)
WHERE i.access_count > 10
REMOVE i:ShortTermMemory
SET i:LongTermMemory
"""
        tx.run(query)

    def cluster_interactions(self, min_num_clusters: int = 10):
        with self.driver.session() as session:
            session.write_transaction(self._cluster_interactions_tx, min_num_clusters)

    @staticmethod
    def _cluster_interactions_tx(tx: neo4j.Transaction, num_clusters: int = 10):
        """
        Cluster interactions based on embeddings and assign cluster labels to nodes.

        Args:
            tx (Transaction): The Neo4j transaction object.
            num_clusters (int): Number of clusters to form.
        """
        print("Retrieving interactions and clustering...")

        # Query to retrieve all interaction nodes and their embeddings
        query = """
        MATCH (i:Interaction)
        RETURN id(i) AS node_id, i.embedding AS embedding
        """
        result = tx.run(query)

        # Extract embeddings and node IDs
        interactions = []
        embeddings = []
        for record in result:
            node_id = record["node_id"]
            embedding = record["embedding"]
            if embedding:  # Ensure the embedding is not null
                embeddings.append(numpy.array(embedding))
                interactions.append(node_id)

        num_clusters = min(10, len(embeddings))

        # Perform clustering
        embeddings_matrix = numpy.vstack(embeddings)
        kmeans = KMeans(n_clusters=num_clusters, random_state=0)
        cluster_labels = kmeans.fit_predict(embeddings_matrix)

        print(f"Assigning cluster labels to {len(interactions)} interactions...")

        # Update interaction nodes with cluster labels
        for node_id, cluster_label in zip(interactions, cluster_labels):
            tx.run(
                """
                MATCH (i:Interaction)
                WHERE id(i) = $node_id
                SET i.cluster = $cluster_label
                """,
                node_id=node_id,
                cluster_label=int(cluster_label),
            )

        print("Clustering completed and labels assigned to interaction nodes.")

    def retrieve(
        self, 
        query_embedding: NDArray[float], 
        query_concepts: list[str], 
        similarity_threshold: float = 0.8,
        access_multiplier: float = 1.1,
        decay_multiplier: float = 0.9
    ) -> list[tuple[str, Record]]:
        with self.driver.session() as session:
            initial_results = session.read_transaction(
                self._retrieve_by_similarity_tx, query_embedding, similarity_threshold
            )
            activated_concepts = session.read_transaction(self._spreading_activation_tx, query_concepts)
            cluster_results = session.read_transaction(self._retrieve_from_cluster_tx, query_embedding)

            enhanced_results = session.read_transaction(
                self._apply_decay_and_reinforcement_tx, 
                initial_results, 
                access_multiplier, 
                decay_multiplier
            )

        return self._combine_results(enhanced_results, activated_concepts, cluster_results)

    @staticmethod
    def _apply_decay_and_reinforcement_tx(
        tx: neo4j.Transaction, 
        initial_results: list[SimilarityRecord], 
        access_multiplier: float, 
        decay_multiplier: float
    ) -> list[ReinforcementRecord]:
        query = """
UNWIND $results AS result
MATCH (i:Interaction {id: result.id})
WITH i, result.similarity AS similarity
CALL {
    WITH i
    RETURN duration.between(datetime({epochSeconds: i.timestamp}), datetime()).seconds AS age
}
WITH i, similarity, age, 
    apoc.math.pow($access_multiplier, i.access_count) AS reinforced_access, 
    apoc.math.pow($decay_multiplier, age) AS decay_factor
RETURN i.id AS id, 
    i.prompt AS prompt, 
    similarity * reinforced_access * decay_factor AS adjusted_relevance
ORDER BY adjusted_relevance DESC
LIMIT 10
"""
        results = tx.run(query, {
            "results": initial_results,
            "access_multiplier": access_multiplier,
            "decay_multiplier": decay_multiplier,
        })

        return [{"id": record["id"], "prompt": record["prompt"], "adjusted_relevance": record["adjusted_relevance"]}
                for record in results]

    @staticmethod
    def _retrieve_by_similarity_tx(
        tx: neo4j.Transaction, query_embedding: NDArray[float], similarity_threshold: float
    ) -> list[SimilarityRecord]:
        query = """
MATCH (i:Interaction)
WITH i, gds.alpha.similarity.cosine(i.embedding, $query_embedding) AS similarity
WHERE similarity >= $similarity_threshold
RETURN i.id AS id, i.prompt AS prompt, similarity
ORDER BY similarity DESC
LIMIT 10
"""
        results = tx.run(query, {"query_embedding": query_embedding, "similarity_threshold": similarity_threshold})
        return [{"id": record["id"], "prompt": record["prompt"], "similarity": record["similarity"]} for record in results]

    @staticmethod
    def _spreading_activation_tx(tx: neo4j.Transaction, query_concepts: list[str]) -> dict[str, float]:
        query = """
UNWIND $query_concepts AS concept
MATCH (c:Concept {name: concept})-[r:ASSOCIATED_WITH*1..2]-(neighbor:Concept)
WITH neighbor, SUM(r.weight) AS activation_score
RETURN neighbor.name AS concept, activation_score
ORDER BY activation_score DESC
"""
        results = tx.run(query, {"query_concepts": query_concepts})
        return {record["concept"]: record["activation_score"] for record in results}

    @staticmethod
    def _retrieve_from_cluster_tx(tx, query_embedding) -> list[ClusteringRecord]:
        query = """
MATCH (i:Interaction)
WHERE i.cluster IS NOT NULL
WITH i.cluster AS cluster, COLLECT(i.embedding) AS embeddings
WITH cluster, gds.alpha.similarity.cosine(gds.util.asNode(embeddings), $query_embedding) AS similarity
ORDER BY similarity DESC
LIMIT 1
MATCH (i:Interaction {cluster: cluster})
RETURN i.id AS id, i.prompt AS prompt
"""
        results = tx.run(query, {"query_embedding": query_embedding})
        return [{"id": record["id"], "prompt": record["prompt"]} for record in results]

    def _combine_results(
        self, 
        reinforcement_results: list[ReinforcementRecord], 
        activated_concepts: dict[str, float], 
        cluster_results: list[ClusteringRecord]
    ) -> list[tuple[str, Record]]:
        combined_results = {}

        for result in reinforcement_results:
            id = result["id"]
            combined_results[id] = {"similarity": result["similarity"], "activation_score": 0}

        for concept, score in activated_concepts.items():
            if concept in combined_results:
                combined_results[concept]["activation_score"] += score

        for result in cluster_results:
            id = result["id"]
            if id in combined_results:
                combined_results[id]["cluster_score"] = 1.0

        return sorted(combined_results.items(), key=lambda x: (
            x[1]["similarity"] + x[1].get("activation_score", 0) + x[1].get("cluster_score", 0)
        ), reverse=True)

usage

docker run \
    --restart always \
    --env NEO4J_AUTH=neo4j/[password] \
    --publish=7474:7474 \
    --publish=7687:7687 \
    --env NEO4J_PLUGINS='["apoc"]' \
    --env NEO4J_dbms_security_procedures_unrestricted=apoc.*,algo.* \
    --env NEO4J_dbms_security_procedures_allowlist=apoc.*,algo.* \
    neo4j:latest
with MemoryStore("bolt://localhost:7687", "neo4j", "[password]") as store:
    interaction = {
        "id": 1,
        "prompt": "Explain the theory of relativity",
        "output": "The theory of relativity is...",
        "embedding": [0.1, 0.2, 0.3],
        "concepts": ["relativity", "physics"]
    }

    store.add_interaction(interaction)

    store.classify_memory()
    store.cluster_interactions()

    query_embedding = [0.1, 0.2, 0.3]
    query_concepts = ["relativity", "physics"]
    results = store.retrieve(
        query_embedding=query_embedding,
        query_concepts=query_concepts,
        similarity_threshold=0.8,
        access_multiplier=1.1,
        decay_multiplier=0.9
    )
    print(results)

testing

import pytest
from neo4j import GraphDatabase
from memory_manager import MemoryManager  # Assuming this is the name of your module

# Database URI and credentials (use test credentials)
NEO4J_URI = "neo4j://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "password"

@pytest.fixture(scope="module")
def driver():
    """Fixture to set up a Neo4j driver for tests."""
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
    yield driver
    driver.close()

@pytest.fixture(scope="function")
def clear_database(driver):
    """Fixture to clear the database before each test."""
    with driver.session() as session:
        session.execute_write(lambda tx: tx.run("MATCH (n) DETACH DELETE n"))

# Test cases
def test_add_interaction(driver, clear_database):
    """Test adding an interaction to the memory."""
    interaction = {
        "id": "1",
        "prompt": "Test prompt",
        "output": "Test output",
        "embedding": [0.1, 0.2, 0.3],
        "timestamp": 1234567890,
        "access_count": 1,
        "decay_factor": 1.0,
        "concepts": {"concept1", "concept2"}
    }

    with driver.session() as session:
        session.execute_write(MemoryManager.add_interaction, interaction)

        # Verify interaction was added
        result = session.run("MATCH (i:Interaction) RETURN count(i) AS count").single()
        assert result["count"] == 1

        # Verify concepts were added
        result = session.run("MATCH (c:Concept) RETURN count(c) AS count").single()
        assert result["count"] == 2

        # Verify relationships between concepts
        result = session.run("MATCH (:Concept)-[r:ASSOCIATED_WITH]-(:Concept) RETURN count(r) AS count").single()
        assert result["count"] == 1  # Only one relationship between the two concepts


def test_retrieve_interactions(driver, clear_database):
    """Test retrieving interactions based on semantic similarity and spreading activation."""
    interactions = [
        {
            "id": "1",
            "prompt": "Prompt 1",
            "output": "Output 1",
            "embedding": [0.1, 0.2, 0.3],
            "timestamp": 1234567890,
            "access_count": 1,
            "decay_factor": 1.0,
            "concepts": {"concept1", "concept2"}
        },
        {
            "id": "2",
            "prompt": "Prompt 2",
            "output": "Output 2",
            "embedding": [0.4, 0.5, 0.6],
            "timestamp": 1234567900,
            "access_count": 2,
            "decay_factor": 1.0,
            "concepts": {"concept2", "concept3"}
        }
    ]

    with driver.session() as session:
        # Add interactions
        for interaction in interactions:
            session.execute_write(MemoryManager.add_interaction, interaction)

        # Query embeddings similar to interaction "1"
        query_embedding = [0.1, 0.2, 0.3]
        query_concepts = {"concept2"}
        retrieved = session.execute_read(MemoryManager.retrieve_interactions, query_embedding, query_concepts, similarity_threshold=50)

        assert len(retrieved) > 0
        assert retrieved[0]["id"] == "1"  # Most similar interaction

def test_cluster_interactions(driver, clear_database):
    """Test clustering interactions and assigning cluster labels."""
    interactions = [
        {
            "id": "1",
            "prompt": "Prompt 1",
            "output": "Output 1",
            "embedding": [0.1, 0.2, 0.3],
            "timestamp": 1234567890,
            "access_count": 1,
            "decay_factor": 1.0,
            "concepts": {"concept1"}
        },
        {
            "id": "2",
            "prompt": "Prompt 2",
            "output": "Output 2",
            "embedding": [0.4, 0.5, 0.6],
            "timestamp": 1234567900,
            "access_count": 2,
            "decay_factor": 1.0,
            "concepts": {"concept2"}
        }
    ]

    with driver.session() as session:
        # Add interactions
        for interaction in interactions:
            session.execute_write(MemoryManager.add_interaction, interaction)

        # Cluster interactions
        session.execute_write(MemoryManager.cluster_interactions, num_clusters=2)

        # Verify cluster labels
        result = session.run("MATCH (i:Interaction) RETURN i.cluster AS cluster")
        clusters = {record["cluster"] for record in result}
        assert len(clusters) == 2  # Two clusters should exist

def test_decay_and_reinforcement(driver, clear_database):
    """Test decay factor and access count reinforcement during retrieval."""
    interaction = {
        "id": "1",
        "prompt": "Decay test",
        "output": "Decay output",
        "embedding": [0.1, 0.2, 0.3],
        "timestamp": 1234567890,
        "access_count": 1,
        "decay_factor": 1.0,
        "concepts": {"concept1"}
    }

    with driver.session() as session:
        # Add interaction
        session.execute_write(MemoryManager.add_interaction, interaction)

        # Query for retrieval to trigger decay and reinforcement
        query_embedding = [0.1, 0.2, 0.3]
        query_concepts = {"concept1"}
        retrieved = session.execute_read(MemoryManager.retrieve_interactions, query_embedding, query_concepts, similarity_threshold=50)

        # Verify decay factor and access count update
        result = session.run("MATCH (i:Interaction {id: '1'}) RETURN i.access_count AS access_count, i.decay_factor AS decay_factor").single()
        assert result["access_count"] > 1
        assert result["decay_factor"] > 1.0  # Reinforced decay factor

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions