-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
some ideas borrowed from memoripy re-implemented in a N4j graph database:
import time
import typing
import neo4j
import numpy
from numpy.typing import NDArray
from sklearn.cluster import KMeans
class Interaction(typing.TypedDict):
id: str # unique identifier for the interaction
prompt: str # prompt or input associated with the interaction
output: str # output or response of the interaction
embedding: list[float] # embedding vector representing the interaction
timestamp: int # timestamp of when the interaction occurred
access_count: int # number of times the interaction has been accessed
decay_factor: float # decay factor for relevance over time
concepts: set[str] # concepts associated with the interaction
class SimilarityRecord(typing.TypedDict):
id: str
prompt: str
similarity: float
class ReinforcementRecord(typing.TypedDict):
id: str
prompt: str
adjusted_relevance: float
class ClusteringRecord(typing.TypedDict):
id: str
prompt: str
class Record(typing.TypedDict):
similarity: float
activation_score: float
class MemoryStore:
def __init__(self, endpoint: str, username: str, password: str):
self.driver = neo4j.GraphDatabase.driver(endpoint, auth=(username, password))
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, traceback):
self.close()
def close(self):
self.driver.close()
def add_interaction(self, interaction: Interaction):
with self.driver.session() as session:
session.write_transaction(self._add_interaction_tx, interaction)
@staticmethod
def _add_interaction_tx(tx: neo4j.Transaction, interaction: Interaction):
query = """
CREATE (i:Interaction:ShortTermMemory {id: $id, prompt: $prompt, output: $output,
embedding: $embedding, timestamp: $timestamp,
access_count: $access_count, decay_factor: $decay_factor})
WITH i
UNWIND $concepts AS concept
MERGE (c:Concept {name: concept})
WITH i, c
UNWIND $concepts AS other_concept
WITH i, c, other_concept
MATCH (c1:Concept {name: other_concept})
WHERE c <> c1
MERGE (c)-[r:ASSOCIATED_WITH]->(c1)
ON CREATE SET r.weight = 1
ON MATCH SET r.weight = r.weight + 1
RETURN i
"""
tx.run(query, {
"id": interaction["id"],
"prompt": interaction["prompt"],
"output": interaction["output"],
"embedding": interaction["embedding"],
"timestamp": interaction.get("timestamp", time.time()),
"access_count": interaction.get("access_count", 1),
"decay_factor": interaction.get("decay_factor", 1.0),
"concepts": interaction.get("concepts", [])
})
def classify_memory(self):
with self.driver.session() as session:
session.write_transaction(self._classify_memory_tx)
@staticmethod
def _classify_memory_tx(tx: neo4j.Transaction):
query = """
MATCH (i:Interaction:ShortTermMemory)
WHERE i.access_count > 10
REMOVE i:ShortTermMemory
SET i:LongTermMemory
"""
tx.run(query)
def cluster_interactions(self, min_num_clusters: int = 10):
with self.driver.session() as session:
session.write_transaction(self._cluster_interactions_tx, min_num_clusters)
@staticmethod
def _cluster_interactions_tx(tx: neo4j.Transaction, num_clusters: int = 10):
"""
Cluster interactions based on embeddings and assign cluster labels to nodes.
Args:
tx (Transaction): The Neo4j transaction object.
num_clusters (int): Number of clusters to form.
"""
print("Retrieving interactions and clustering...")
# Query to retrieve all interaction nodes and their embeddings
query = """
MATCH (i:Interaction)
RETURN id(i) AS node_id, i.embedding AS embedding
"""
result = tx.run(query)
# Extract embeddings and node IDs
interactions = []
embeddings = []
for record in result:
node_id = record["node_id"]
embedding = record["embedding"]
if embedding: # Ensure the embedding is not null
embeddings.append(numpy.array(embedding))
interactions.append(node_id)
num_clusters = min(10, len(embeddings))
# Perform clustering
embeddings_matrix = numpy.vstack(embeddings)
kmeans = KMeans(n_clusters=num_clusters, random_state=0)
cluster_labels = kmeans.fit_predict(embeddings_matrix)
print(f"Assigning cluster labels to {len(interactions)} interactions...")
# Update interaction nodes with cluster labels
for node_id, cluster_label in zip(interactions, cluster_labels):
tx.run(
"""
MATCH (i:Interaction)
WHERE id(i) = $node_id
SET i.cluster = $cluster_label
""",
node_id=node_id,
cluster_label=int(cluster_label),
)
print("Clustering completed and labels assigned to interaction nodes.")
def retrieve(
self,
query_embedding: NDArray[float],
query_concepts: list[str],
similarity_threshold: float = 0.8,
access_multiplier: float = 1.1,
decay_multiplier: float = 0.9
) -> list[tuple[str, Record]]:
with self.driver.session() as session:
initial_results = session.read_transaction(
self._retrieve_by_similarity_tx, query_embedding, similarity_threshold
)
activated_concepts = session.read_transaction(self._spreading_activation_tx, query_concepts)
cluster_results = session.read_transaction(self._retrieve_from_cluster_tx, query_embedding)
enhanced_results = session.read_transaction(
self._apply_decay_and_reinforcement_tx,
initial_results,
access_multiplier,
decay_multiplier
)
return self._combine_results(enhanced_results, activated_concepts, cluster_results)
@staticmethod
def _apply_decay_and_reinforcement_tx(
tx: neo4j.Transaction,
initial_results: list[SimilarityRecord],
access_multiplier: float,
decay_multiplier: float
) -> list[ReinforcementRecord]:
query = """
UNWIND $results AS result
MATCH (i:Interaction {id: result.id})
WITH i, result.similarity AS similarity
CALL {
WITH i
RETURN duration.between(datetime({epochSeconds: i.timestamp}), datetime()).seconds AS age
}
WITH i, similarity, age,
apoc.math.pow($access_multiplier, i.access_count) AS reinforced_access,
apoc.math.pow($decay_multiplier, age) AS decay_factor
RETURN i.id AS id,
i.prompt AS prompt,
similarity * reinforced_access * decay_factor AS adjusted_relevance
ORDER BY adjusted_relevance DESC
LIMIT 10
"""
results = tx.run(query, {
"results": initial_results,
"access_multiplier": access_multiplier,
"decay_multiplier": decay_multiplier,
})
return [{"id": record["id"], "prompt": record["prompt"], "adjusted_relevance": record["adjusted_relevance"]}
for record in results]
@staticmethod
def _retrieve_by_similarity_tx(
tx: neo4j.Transaction, query_embedding: NDArray[float], similarity_threshold: float
) -> list[SimilarityRecord]:
query = """
MATCH (i:Interaction)
WITH i, gds.alpha.similarity.cosine(i.embedding, $query_embedding) AS similarity
WHERE similarity >= $similarity_threshold
RETURN i.id AS id, i.prompt AS prompt, similarity
ORDER BY similarity DESC
LIMIT 10
"""
results = tx.run(query, {"query_embedding": query_embedding, "similarity_threshold": similarity_threshold})
return [{"id": record["id"], "prompt": record["prompt"], "similarity": record["similarity"]} for record in results]
@staticmethod
def _spreading_activation_tx(tx: neo4j.Transaction, query_concepts: list[str]) -> dict[str, float]:
query = """
UNWIND $query_concepts AS concept
MATCH (c:Concept {name: concept})-[r:ASSOCIATED_WITH*1..2]-(neighbor:Concept)
WITH neighbor, SUM(r.weight) AS activation_score
RETURN neighbor.name AS concept, activation_score
ORDER BY activation_score DESC
"""
results = tx.run(query, {"query_concepts": query_concepts})
return {record["concept"]: record["activation_score"] for record in results}
@staticmethod
def _retrieve_from_cluster_tx(tx, query_embedding) -> list[ClusteringRecord]:
query = """
MATCH (i:Interaction)
WHERE i.cluster IS NOT NULL
WITH i.cluster AS cluster, COLLECT(i.embedding) AS embeddings
WITH cluster, gds.alpha.similarity.cosine(gds.util.asNode(embeddings), $query_embedding) AS similarity
ORDER BY similarity DESC
LIMIT 1
MATCH (i:Interaction {cluster: cluster})
RETURN i.id AS id, i.prompt AS prompt
"""
results = tx.run(query, {"query_embedding": query_embedding})
return [{"id": record["id"], "prompt": record["prompt"]} for record in results]
def _combine_results(
self,
reinforcement_results: list[ReinforcementRecord],
activated_concepts: dict[str, float],
cluster_results: list[ClusteringRecord]
) -> list[tuple[str, Record]]:
combined_results = {}
for result in reinforcement_results:
id = result["id"]
combined_results[id] = {"similarity": result["similarity"], "activation_score": 0}
for concept, score in activated_concepts.items():
if concept in combined_results:
combined_results[concept]["activation_score"] += score
for result in cluster_results:
id = result["id"]
if id in combined_results:
combined_results[id]["cluster_score"] = 1.0
return sorted(combined_results.items(), key=lambda x: (
x[1]["similarity"] + x[1].get("activation_score", 0) + x[1].get("cluster_score", 0)
), reverse=True)usage
docker run \
--restart always \
--env NEO4J_AUTH=neo4j/[password] \
--publish=7474:7474 \
--publish=7687:7687 \
--env NEO4J_PLUGINS='["apoc"]' \
--env NEO4J_dbms_security_procedures_unrestricted=apoc.*,algo.* \
--env NEO4J_dbms_security_procedures_allowlist=apoc.*,algo.* \
neo4j:latestwith MemoryStore("bolt://localhost:7687", "neo4j", "[password]") as store:
interaction = {
"id": 1,
"prompt": "Explain the theory of relativity",
"output": "The theory of relativity is...",
"embedding": [0.1, 0.2, 0.3],
"concepts": ["relativity", "physics"]
}
store.add_interaction(interaction)
store.classify_memory()
store.cluster_interactions()
query_embedding = [0.1, 0.2, 0.3]
query_concepts = ["relativity", "physics"]
results = store.retrieve(
query_embedding=query_embedding,
query_concepts=query_concepts,
similarity_threshold=0.8,
access_multiplier=1.1,
decay_multiplier=0.9
)
print(results)testing
import pytest
from neo4j import GraphDatabase
from memory_manager import MemoryManager # Assuming this is the name of your module
# Database URI and credentials (use test credentials)
NEO4J_URI = "neo4j://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "password"
@pytest.fixture(scope="module")
def driver():
"""Fixture to set up a Neo4j driver for tests."""
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
yield driver
driver.close()
@pytest.fixture(scope="function")
def clear_database(driver):
"""Fixture to clear the database before each test."""
with driver.session() as session:
session.execute_write(lambda tx: tx.run("MATCH (n) DETACH DELETE n"))
# Test cases
def test_add_interaction(driver, clear_database):
"""Test adding an interaction to the memory."""
interaction = {
"id": "1",
"prompt": "Test prompt",
"output": "Test output",
"embedding": [0.1, 0.2, 0.3],
"timestamp": 1234567890,
"access_count": 1,
"decay_factor": 1.0,
"concepts": {"concept1", "concept2"}
}
with driver.session() as session:
session.execute_write(MemoryManager.add_interaction, interaction)
# Verify interaction was added
result = session.run("MATCH (i:Interaction) RETURN count(i) AS count").single()
assert result["count"] == 1
# Verify concepts were added
result = session.run("MATCH (c:Concept) RETURN count(c) AS count").single()
assert result["count"] == 2
# Verify relationships between concepts
result = session.run("MATCH (:Concept)-[r:ASSOCIATED_WITH]-(:Concept) RETURN count(r) AS count").single()
assert result["count"] == 1 # Only one relationship between the two concepts
def test_retrieve_interactions(driver, clear_database):
"""Test retrieving interactions based on semantic similarity and spreading activation."""
interactions = [
{
"id": "1",
"prompt": "Prompt 1",
"output": "Output 1",
"embedding": [0.1, 0.2, 0.3],
"timestamp": 1234567890,
"access_count": 1,
"decay_factor": 1.0,
"concepts": {"concept1", "concept2"}
},
{
"id": "2",
"prompt": "Prompt 2",
"output": "Output 2",
"embedding": [0.4, 0.5, 0.6],
"timestamp": 1234567900,
"access_count": 2,
"decay_factor": 1.0,
"concepts": {"concept2", "concept3"}
}
]
with driver.session() as session:
# Add interactions
for interaction in interactions:
session.execute_write(MemoryManager.add_interaction, interaction)
# Query embeddings similar to interaction "1"
query_embedding = [0.1, 0.2, 0.3]
query_concepts = {"concept2"}
retrieved = session.execute_read(MemoryManager.retrieve_interactions, query_embedding, query_concepts, similarity_threshold=50)
assert len(retrieved) > 0
assert retrieved[0]["id"] == "1" # Most similar interaction
def test_cluster_interactions(driver, clear_database):
"""Test clustering interactions and assigning cluster labels."""
interactions = [
{
"id": "1",
"prompt": "Prompt 1",
"output": "Output 1",
"embedding": [0.1, 0.2, 0.3],
"timestamp": 1234567890,
"access_count": 1,
"decay_factor": 1.0,
"concepts": {"concept1"}
},
{
"id": "2",
"prompt": "Prompt 2",
"output": "Output 2",
"embedding": [0.4, 0.5, 0.6],
"timestamp": 1234567900,
"access_count": 2,
"decay_factor": 1.0,
"concepts": {"concept2"}
}
]
with driver.session() as session:
# Add interactions
for interaction in interactions:
session.execute_write(MemoryManager.add_interaction, interaction)
# Cluster interactions
session.execute_write(MemoryManager.cluster_interactions, num_clusters=2)
# Verify cluster labels
result = session.run("MATCH (i:Interaction) RETURN i.cluster AS cluster")
clusters = {record["cluster"] for record in result}
assert len(clusters) == 2 # Two clusters should exist
def test_decay_and_reinforcement(driver, clear_database):
"""Test decay factor and access count reinforcement during retrieval."""
interaction = {
"id": "1",
"prompt": "Decay test",
"output": "Decay output",
"embedding": [0.1, 0.2, 0.3],
"timestamp": 1234567890,
"access_count": 1,
"decay_factor": 1.0,
"concepts": {"concept1"}
}
with driver.session() as session:
# Add interaction
session.execute_write(MemoryManager.add_interaction, interaction)
# Query for retrieval to trigger decay and reinforcement
query_embedding = [0.1, 0.2, 0.3]
query_concepts = {"concept1"}
retrieved = session.execute_read(MemoryManager.retrieve_interactions, query_embedding, query_concepts, similarity_threshold=50)
# Verify decay factor and access count update
result = session.run("MATCH (i:Interaction {id: '1'}) RETURN i.access_count AS access_count, i.decay_factor AS decay_factor").single()
assert result["access_count"] > 1
assert result["decay_factor"] > 1.0 # Reinforced decay factorReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels