diff --git a/README.md b/README.md index 786481f..02a1fe3 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ wget -O data/neurips-2025-orals-posters.json https://neurips.cc/static/virtual/d #### Build the knowledge graph You can build the knowledge graph per your needs by running the following script: ```commandline -uv run llm_agents/tools/knowledge_graph/graph_generator.py \ +uv run agentic_nav/tools/knowledge_graph/graph_generator.py \ --input-json-file data/neurips-2025-orals-posters.json \ --embedding-model $EMBEDDING_MODEL_NAME \ --ollama-server-url $EMBEDDING_MODEL_API_BASE \ @@ -144,14 +144,14 @@ knowledge graphs here (the "thresh" in the file name indicates the `similarity-t #### Importing the knowledge graph to a neo4j database We provide an importer to move the knowledge graph into a graph database that supports vector-based similarity search. ```commandline -uv run llm_agents/tools/knowledge_graph/neo4j_db_importer.py \ +uv run agentic_nav/tools/knowledge_graph/neo4j_db_importer.py \ --graph-path graphs/knowledge_graph.pkl \ --neo4j-uri $NEO4J_DB_URI \ --batch-size 100 \ --embedding-dimension 768 # This must match the vector dims generated by the embedding model. ``` -**Note:** Depending on what your graph looks like this can also take a while (> 20min for 6K papers). Also, beware that -running this script will first clear any existing entries before the new graph is written to the database. +**Note:** Depending on what your graph looks like this can also take a while (>15min for 6K papers). Also, beware that +running this script will first clear any existing database entries before the new graph is written to the database. ### Agent interactions diff --git a/agentic_nav/agents/neurips2025_conference.py b/agentic_nav/agents/neurips2025_conference.py index 943afbe..336fbc4 100644 --- a/agentic_nav/agents/neurips2025_conference.py +++ b/agentic_nav/agents/neurips2025_conference.py @@ -24,41 +24,78 @@ system = { "role": "system", "content": f""" - You are AgenticNAV, an assistant to navigate accepted papers at the NeurIPS 2025 conference. - You can assist users in finding papers based on their research interests, preferred dates, and time slots. - You can also build schedules for them to visit posters that they are interested in. - - **Here are some guidelines**: - - When searching for similar papers, the search tool only takes paper titles and abstracts as input keywords; it cannot take anything else as the input keywords. - - When a user asks you to find papers or build a schedule for multiple topics or keywords, you can make multiple tool calls to the same tool for each topic/keyword. - - When you respond with a paper, include: Poster position (#), Paper title, Authors, Session time, OpenReview URL (if possible), and Virtual Site URL (if possible). - - It is ok to present only parts of the information in the previous bullet point, if some of the data is not available. - - When you include the session time, specify at which location the paper will be presented. - - Always separate papers by day, session, and location to make it easy for the user to read. - - When listing papers, make sure to order them by session details (i.e., date, time, location). Keep San Diego and Mexico City separate. - - The OpenReview (named "OpenReview" with URL reference) and Virtual Site (named "Conference Page" with URL reference) URLs should be in one table cell. The column name should be "Links". - - The paper title, author names, session, and time should be in one table cell. If possible, make the author names smaller. - - If there is a Virtual Site available, you need to prepend https://neurips.cc for the link to be usable (never mention this to the user). - - Make sure to present papers in a Markdown table. Do not wrap it inside html code. - - When building a schedule, do not specify the name of the day. - - The attribute `poster_position` (starting with #) is the physical location of a poster in the conference venue at a given session. It is unique per session but may appear multiple times across sessions. - - If you find the same paper title multiple times, remove the duplicate titles and do not mention it in your response. - - When a user asks for a conference map, respond with the link: https://media.neurips.cc/Conferences/NeurIPS2025/sdconvctr-ground-level.svg. You don't know any specifics about the venue. - - **Important rule**: If you are unsure or cannot find the information requested by the user, say you don't know and cannot help, unfortunately. +You are AgenticNAV, an assistant to navigate accepted papers at the NeurIPS 2025 conference. +You can assist users in finding papers based on their research interests, preferred dates, and time slots. +You can also build schedules for them to visit presentations (both posters and oral sessions) that they are interested in. + +**Search Guidelines**: + - The search tool only accepts paper titles and abstracts as input keywords + - For queries with multiple topics/keywords, make separate tool calls for each topic + - Always search for BOTH posters AND orals unless the user explicitly requests only one type + +**Presentation Format Requirements**: + +When presenting papers to users, you MUST include BOTH poster presentations AND oral presentations in separate sections: + +1. **Structure**: + - Create separate sections for "Poster Presentations" and "Oral Presentations" + - Within each section, organize by conference day (one table per day) + - Keep San Diego and Mexico City locations separate + - Do not specify day names when building schedules + +2. **Poster Table Format** (2 columns only): + - Column 1: Paper Details + - Paper title (bold) + - Authors (small letters) + - Session and poster position (e.g., "Session 3, Poster #142") + - Column 2: Links + - OpenReview URL + - Virtual Site URL (prepend https://neurips.cc) + +3. **Oral Table Format** (2 columns only): + - Column 1: Paper Details + - Paper title (bold) + - Authors (small letters) + - Session and time slot + - Column 2: Links + - OpenReview URL + - Virtual Site URL (prepend https://neurips.cc) + +4. **Technical Requirements**: + - Use Markdown tables only (no HTML) + - `poster_position` attribute (starting with #) indicates physical location, unique per session + - Include all available metadata for each paper + +**Response Structure Template**: + Poster Presentations + [Day Name, Date] + [Markdown table with posters] + [Next Day Name, Date] + [Markdown table with posters] - **Timeline:** - - Tuesday, December 02, 2025: Panels and Tutorials only, no paper and poster presentations - - Wednesday, December 03, 2025: Poster Sessions in the Morning and Afternoon - - Thursday, December 04, 2025: Poster Sessions in the Morning and Afternoon - - Friday, December 05, 2025: Poster Sessions in the Morning and Afternoon - - Those are the only days with poster and oral sessions. - - **Here is the current timestamp**: {datetime.now(ZoneInfo('America/Los_Angeles'))}. The conference is happening in San Diego, California. + Oral Presentations + [Day Name, Date] + [Markdown table with orals] + [Next Day Name, Date] + [Markdown table with orals] + +**Conference Information**: + - Venue map: https://media.neurips.cc/Conferences/NeurIPS2025/sdconvctr-ground-level.svg + - Location: San Diego, California and Mexico City, Mexico + - Timeline: + - Tuesday, Dec 02, 2025: Panels and Tutorials only (no papers/posters) + - Wednesday, Dec 03, 2025: Poster and Oral Sessions (Morning & Afternoon) + - Thursday, Dec 04, 2025: Poster and Oral Sessions (Morning & Afternoon) + - Friday, Dec 05, 2025: Poster and Oral Sessions (Morning & Afternoon) + +**Current timestamp**: {datetime.now(ZoneInfo('America/Los_Angeles'))} + +**Important**: If you cannot find the requested information, clearly state that you don't know and cannot help with that specific request. Always proactively search for and present BOTH poster and oral presentations unless explicitly told otherwise. """ } + AGENT_INTRODUCTION_PROMPT = { "role": "assistant", "content": f""" diff --git a/agentic_nav/frontend/browser_ui.py b/agentic_nav/frontend/browser_ui.py index 7c01ed3..2ebfc1d 100644 --- a/agentic_nav/frontend/browser_ui.py +++ b/agentic_nav/frontend/browser_ui.py @@ -329,6 +329,9 @@ def main(): # 🤖 AgenticNAV - Planning your NeurIPS 2025 visit made effortless This agent can help you explore the more than 5000 papers at this year's NeurIPS conference. You can start chatting right away but we also offer options for customization (see tab "Guide & Settings"). + + **Note on Usability:** HuggingFace ZeroGPU quotas for users that are not logged in are very restrictive. + This is out of our control and may limit the utility of AgenticNAV. """) # Session state for agent instance, config, and messages diff --git a/agentic_nav/tools/knowledge_graph/__init__.py b/agentic_nav/tools/knowledge_graph/__init__.py index 4725cc1..2171b12 100644 --- a/agentic_nav/tools/knowledge_graph/__init__.py +++ b/agentic_nav/tools/knowledge_graph/__init__.py @@ -143,64 +143,7 @@ def find_neighboring_papers( min_similarity: float = 0.75 ) -> str: """ - Retrieve immediate neighboring entities of a specific paper from the Neo4j knowledge graph. - - This function performs a one-hop neighborhood search to find entities directly connected to - a target paper. It is designed to be used after an initial similarity search when users want - to explore specific relationships (similar papers, authors, or topics) for a paper of interest. - - Args: - paper_id (str): The unique identifier of the target paper node in the graph. neo4j UUID. - relationship_types (List[str], str): Types of relationships to query. - Defaults to ["SIMILAR_TO"]. - Valid options: ["SIMILAR_TO", "IS_AUTHOR_OF", "BELONGS_TO_TOPIC"] - neighbor_entity (str, optional): The type of neighboring entity to return. - Defaults to "similar_papers". - Valid options: ["similar_papers", "authors", "topics", "raw_results"] - num_neighbors_to_return (int, optional): Maximum number of neighbors to return. - Defaults to 10. Results are randomly shuffled before truncation to provide diversity. - min_similarity (float, optional): Minimum similarity threshold for returned neighbors. - - Returns: - str: A token-efficient formatted string representation of neighboring entities, - encoded using the toon_encode function. - - Restrictions: - - Requires a running Neo4j database instance at bolt://localhost:7687 with credentials - (username: "neo4j", password: "llm_agents") - - Should be used after an initial similarity search as part of a focused exploration workflow - - The paper_id must exist in the Neo4j graph database - - Only performs one-hop searches (direct neighbors only) - - Only the three specified relationship types are supported - - Only the four specified neighbor entity types are supported - - The neighbor_entity parameter must match the relationship_types used - (e.g., "similar_papers" with "SIMILAR_TO", "authors" with "IS_AUTHOR_OF") - - Notes: - - Results are randomly shuffled to provide diverse recommendations across multiple calls - - The function extracts only the "neighbor" data from the returned results - - There is a potential bug: the type check `type(relevant_neighbors) is int` should likely be - `type(num_neighbors_to_return) is int` for proper list truncation - - Raises: - Connection errors if Neo4j database is not accessible - KeyError if neighbor_entity doesn't exist in the returned neighbors dictionary - ValueError if invalid relationship_types or neighbor_entity are provided - - Example: - >>> similar_papers = find_neighboring_papers( - ... paper_id="", - ... relationship_types=["SIMILAR_TO"], - ... neighbor_entity="similar_papers", - ... num_neighbors_to_return=5 - ... ) - >>> - >>> authors = find_neighboring_papers( - ... paper_id="", - ... relationship_types=["IS_AUTHOR_OF"], - ... neighbor_entity="authors", - ... num_neighbors_to_return=3 - ... ) + [Your existing docstring] """ # Type coercion for parameters that may come as strings from LLM tool calls if num_neighbors_to_return is not None and not isinstance(num_neighbors_to_return, int): @@ -221,15 +164,15 @@ def find_neighboring_papers( min_similarity=min_similarity, ) + # Flatten all neighbors from all relationship types into one list relevant_neighbors = [] - for rel_type, neighbor in neighbors.items(): - if rel_type != relationship_types: - relevant_neighbors.append(neighbor) + for rel_type, neighbor_list in neighbors.items(): + # neighbor_list is a list of paper dicts, extend to flatten + relevant_neighbors.extend(neighbor_list) # Constrain and shuffle neighbors for more diverse responses random.shuffle(relevant_neighbors) - # FIX: Changed type(relevant_neighbors) to type(num_neighbors_to_return) if num_neighbors_to_return is not None and isinstance(num_neighbors_to_return, int): relevant_neighbors = relevant_neighbors[:num_neighbors_to_return] diff --git a/agentic_nav/tools/knowledge_graph/file_handler.py b/agentic_nav/tools/knowledge_graph/file_handler.py index e18a140..650475e 100644 --- a/agentic_nav/tools/knowledge_graph/file_handler.py +++ b/agentic_nav/tools/knowledge_graph/file_handler.py @@ -1,3 +1,4 @@ +import logging import pickle import networkx as nx @@ -12,7 +13,7 @@ def save_graph(graph: nx.Graph, output_path: str): with open(output_path, 'wb') as f: pickle.dump(graph, f) f.close() - print(f"Graph saved to {output_path}") + logging.info(f"Graph saved to {output_path}") def load_graph(input_path: str) -> nx.Graph: @@ -25,5 +26,5 @@ def load_graph(input_path: str) -> nx.Graph: with open(input_path, 'rb') as f: graph = pickle.load(f) f.close() - print(f"Graph loaded from {input_path}") + logging.info(f"Graph loaded from {input_path}") return graph diff --git a/agentic_nav/tools/knowledge_graph/graph_generator.py b/agentic_nav/tools/knowledge_graph/graph_generator.py index 24a1325..09a0b29 100644 --- a/agentic_nav/tools/knowledge_graph/graph_generator.py +++ b/agentic_nav/tools/knowledge_graph/graph_generator.py @@ -14,20 +14,25 @@ from pathlib import Path from agentic_nav.utils.embedding_generator import batch_embed_documents -from agentic_nav.utils.logging import setup_logging -from agentic_nav.tools.knowledge_graph.file_handler import save_graph +from agentic_nav.utils.logger import setup_logging +from agentic_nav.tools.knowledge_graph.file_handler import save_graph, load_graph +PROJECT_ROOT = Path(__file__).parent.parent.parent.parent +EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_NAME", "ollama/nomic-embed-text") +EMBEDDING_MODEL_API_BASE = os.environ.get("EMBEDDING_MODEL_API_BASE", "http://localhost:11435") + # Setup logging + setup_logging( - log_dir="logs", - level=os.environ.get("AGENTIC_NAV_LOG_LEVEL", "INFO") + log_dir=f"{PROJECT_ROOT}/logs", + level=os.environ.get("AGENTIC_NAV_LOG_LEVEL", "INFO"), + console_level="INFO" ) LOGGER = logging.getLogger(__name__) litellm._logging._disable_debugging() -PROJECT_ROOT = Path(__file__).parent.parent.parent.parent -EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL_NAME", "ollama/nomic-embed-text") -EMBEDDING_MODEL_API_BASE = os.environ.get("EMBEDDING_MODEL_API_BASE", "http://localhost:11435") +litellm.suppress_debug_info = True +litellm.set_verbose = False class PaperKnowledgeGraph: @@ -39,12 +44,12 @@ class PaperKnowledgeGraph: Uses litellm with ollama for local embedding generation with parallel processing. """ def __init__( - self, - embedding_model: str = EMBEDDING_MODEL_NAME, - ollama_base_url: str = EMBEDDING_MODEL_API_BASE, - embedding_gen_batch_size: int = 32, - max_parallel_workers: int = 8, - limit_num_papers: Union[int, None] = None + self, + embedding_model: str = EMBEDDING_MODEL_NAME, + ollama_base_url: str = EMBEDDING_MODEL_API_BASE, + embedding_gen_batch_size: int = 32, + max_parallel_workers: int = 8, + limit_num_papers: Union[int, None] = None ): """ Initialize the knowledge graph builder. @@ -55,7 +60,7 @@ def __init__( embedding_gen_batch_size: Batch size for generating text embeddings max_parallel_workers: Number of parallel workers for embedding generation """ - self.graph = nx.Graph() + self.graph = nx.MultiGraph() self.embedding_model = embedding_model self.ollama_base_url = ollama_base_url self.batch_size = embedding_gen_batch_size @@ -129,11 +134,17 @@ def build_graph(self): topic_nodes = set() author_nodes = set() - LOGGER.info(f"\nPreparing to process {len(self.papers_data)} papers...") + LOGGER.info(f"Preparing to process {len(self.papers_data)} papers...") # Extract all abstracts and paper info paper_info = [] - abstracts = [] + abstracts_for_embeddings = [] + + # Debug counters + oral_count = 0 + poster_count = 0 + unknown_count = 0 + eventtype_samples = {} # Track unique eventtypes we see for paper in self.papers_data: paper_id = paper.get('uid', paper.get('id')) @@ -153,9 +164,37 @@ def build_graph(self): paper_url = paper.get("paper_url", "") sourceid = paper.get("sourceid", "") virtualsite_url = paper.get("virtualsite_url", "") + import_id = paper.get("import_id", "") + + # Create unique node ID based on presentation type and uid + # Use eventtype to determine if it's an oral or poster presentation + presentation_type_lower = presentation_type.lower() if presentation_type else "" + + # Track unique eventtypes we encounter + if presentation_type and presentation_type not in eventtype_samples: + eventtype_samples[presentation_type] = 0 + if presentation_type: + eventtype_samples[presentation_type] += 1 + + if "oral" in presentation_type_lower: + unique_node_id = f"oral_{paper_id}" + presentation_category = "oral" + oral_count += 1 + elif "poster" in presentation_type_lower: + unique_node_id = f"poster_{paper_id}" + presentation_category = "poster" + poster_count += 1 + else: + # Unknown presentation type - use original paper_id + unique_node_id = f"paper_{paper_id}" + presentation_category = "unknown" + unknown_count += 1 + if unknown_count <= 5: # Only log first 5 unknown types + LOGGER.debug(f"Paper {paper_id}: unknown eventtype='{presentation_type}'") paper_info.append({ - "id": paper_id, + "id": paper_id, # Keep original ID (uid) + "node_id": unique_node_id, # New unique node identifier "name": paper_name, "abstract": abstract, "topic": topic, @@ -166,20 +205,27 @@ def build_graph(self): "session_start_time": session_start_time, "session_end_time": session_end_time, "presentation_type": presentation_type, + "presentation_category": presentation_category, "room_name": room_name, "project_url": project_url, "poster_position": poster_position, "paper_url": paper_url, "sourceid": sourceid, - "virtualsite_url": virtualsite_url - + "virtualsite_url": virtualsite_url, + "import_id": import_id }) - abstracts.append(abstract) + + # We generate the embeddings over name, abstract, and decision to get unique embeddings for every event. + abstracts_for_embeddings.append(f"{paper_name}. {abstract} - {decision}") + + # Debug output + LOGGER.info(f"Found {oral_count} orals, {poster_count} posters, and {unknown_count} unknown presentation types") + LOGGER.info(f"Unique eventtypes found: {eventtype_samples}") # Generate all embeddings in parallel - LOGGER.info(f"\nGenerating embeddings with batch size {self.batch_size}...") + LOGGER.info(f"Generating embeddings with batch size {self.batch_size}...") embeddings = batch_embed_documents( - abstracts, + abstracts_for_embeddings, batch_size=self.batch_size, embedding_model=self.embedding_model, api_base=self.ollama_base_url @@ -188,15 +234,18 @@ def build_graph(self): # Convert to list so that embeddings can be mapped to samples properly embeddings = embeddings.tolist() + # Track oral-poster pairs by uid for creating edges + oral_poster_pairs = {} # key: uid, value: {'oral': node_id, 'poster': node_id} + # Add nodes to graph - LOGGER.info("\nBuilding graph structure...") + LOGGER.info("Building graph structure...") with tqdm(total=len(paper_info), desc="Adding nodes") as pbar: for info, embedding in zip(paper_info, embeddings): # Extract author information (store as list of dicts) author_list = [] if info['authors']: - for author in info['authors']: + for idx, author in enumerate(info['authors']): author_info = { 'id': author.get('id'), 'fullname': author.get('fullname', ''), @@ -208,26 +257,51 @@ def build_graph(self): if author_uid not in author_nodes: self.graph.add_node( author_uid, + node_type="author", **author_info ) author_nodes.add(author_uid) author_list.append(author_info) - # Add paper node with attributes + # Add paper node with attributes using unique node_id paper_attrs = info.copy() del paper_attrs["authors"] + del paper_attrs["node_id"] # Don't duplicate this in attributes self.graph.add_node( - info["id"], + info["node_id"], # Use unique node_id instead of id **paper_attrs, embedding=embedding, authors=author_list, node_type="paper" ) - for author in author_list: - self.graph.add_edge(f"{author['id']} - {author['fullname']}", info["id"], relationship="is_author_of") + # Track oral-poster pairs by uid + uid = info['id'] # This is the original uid + + if uid not in oral_poster_pairs: + oral_poster_pairs[uid] = {} + + if info['presentation_category'] == 'oral': + oral_poster_pairs[uid]['oral'] = info["node_id"] + LOGGER.debug(f"Tracked oral: {info['node_id']} for uid {uid}") + elif info['presentation_category'] == 'poster': + oral_poster_pairs[uid]['poster'] = info["node_id"] + LOGGER.debug(f"Tracked poster: {info['node_id']} for uid {uid}") + elif info['presentation_category'] == 'unknown': + # Track unknown categories too for debugging + if 'unknown' not in oral_poster_pairs[uid]: + oral_poster_pairs[uid]['unknown'] = [] + oral_poster_pairs[uid]['unknown'].append(info["node_id"]) + + # Add edges to authors + for idx, author in enumerate(author_list): + self.graph.add_edge( + f"{author['id']} - {author['fullname']}", info["node_id"], + relationship="is_author_of", + author_order=idx + ) # Add topic node if it doesn't exist if info['topic'] and info['topic'] not in topic_nodes: @@ -240,13 +314,46 @@ def build_graph(self): # Add edge between paper and topic if info['topic']: - self.graph.add_edge(info['id'], info['topic'], relationship='belongs_to_topic') + self.graph.add_edge(info['node_id'], info['topic'], relationship='belongs_to_topic') pbar.update(1) + # Debug oral-poster pairs before adding edges + LOGGER.info(f"\nOral-poster pairs tracked: {len(oral_poster_pairs)}") + complete_pairs = [k for k, v in oral_poster_pairs.items() if 'oral' in v and 'poster' in v] + incomplete_pairs = [k for k, v in oral_poster_pairs.items() if 'oral' not in v or 'poster' not in v] + + LOGGER.info(f"Complete pairs (both oral and poster): {len(complete_pairs)}") + LOGGER.info(f"Incomplete pairs: {len(incomplete_pairs)}") + + if complete_pairs: + LOGGER.info(f"Sample complete pairs (first 3):") + for uid in complete_pairs[:3]: + LOGGER.info(f" uid {uid}: oral={oral_poster_pairs[uid].get('oral')}, poster={oral_poster_pairs[uid].get('poster')}") + + if incomplete_pairs: + LOGGER.warning(f"Sample incomplete pairs:") + for uid in incomplete_pairs[:5]: + LOGGER.warning(f" uid {uid}: {oral_poster_pairs[uid]}") + + # Add edges between oral-poster pairs + oral_poster_edges_added = 0 + for uid, pair in oral_poster_pairs.items(): + if 'oral' in pair and 'poster' in pair: + self.graph.add_edge( + pair['oral'], + pair['poster'], + relationship='oral_poster_pair', + uid=uid + ) + oral_poster_edges_added += 1 + LOGGER.debug(f"Added edge: {pair['oral']} <-> {pair['poster']}") + LOGGER.info(f"Built graph with {self.graph.number_of_nodes()} nodes and {self.graph.number_of_edges()} edges") LOGGER.info(f" Papers: {len([n for n, d in self.graph.nodes(data=True) if d.get('node_type') == 'paper'])}") LOGGER.info(f" Topics: {len([n for n, d in self.graph.nodes(data=True) if d.get('node_type') == 'topic'])}") + LOGGER.info(f" Authors: {len([n for n, d in self.graph.nodes(data=True) if d.get('node_type') == 'author'])}") + LOGGER.info(f" Oral-Poster pairs connected: {oral_poster_edges_added}") def connect_similar_papers(self, similarity_threshold: float = 0.7): """ @@ -255,7 +362,7 @@ def connect_similar_papers(self, similarity_threshold: float = 0.7): similarity_threshold: Minimum cosine similarity to create an edge (0-1) """ paper_nodes = [(n, d) for n, d in self.graph.nodes(data=True) if d.get('node_type') == 'paper'] - LOGGER.info(f"\nComputing similarities for {len(paper_nodes)} papers...") + LOGGER.info(f"Computing similarities for {len(paper_nodes)} papers...") # Create pairs to compare (fast!) pairs = [(i, j) for i in range(len(paper_nodes)) for j in range(i + 1, len(paper_nodes))] @@ -376,9 +483,87 @@ def get_graph_statistics(self) -> Dict[str, Any]: 'is_connected': nx.is_connected(self.graph), } - if nx.is_connected(self.graph): - stats['diameter'] = nx.diameter(self.graph) - stats['average_shortest_path'] = nx.average_shortest_path_length(self.graph) + # if nx.is_connected(self.graph): + # stats['diameter'] = nx.diameter(self.graph) + # stats['average_shortest_path'] = nx.average_shortest_path_length(self.graph) + + return stats + + def get_poster_oral_statistics(self) -> Dict[str, Any]: + """ + Get statistics about poster-oral pairs in the knowledge graph. + Pairs are matched by 'uid' attribute where both oral and poster versions + share the same uid. + + Returns: + Dictionary with poster-oral pair statistics + """ + paper_nodes = [(n, d) for n, d in self.graph.nodes(data=True) if d.get('node_type') == 'paper'] + + # Count by presentation category and track by uid + orals = {} # key: uid, value: node_id + posters = {} # key: uid, value: node_id + other_papers = 0 + + for node_id, data in paper_nodes: + presentation_category = data.get('presentation_category', 'unknown') + uid = data.get('id', '') # This is the original uid + + if presentation_category == 'oral': + orals[uid] = node_id + elif presentation_category == 'poster': + posters[uid] = node_id + else: + other_papers += 1 + + # Find matched pairs (can also check edges) + matched_pairs = [] + for uid in orals.keys(): + if uid in posters: + oral_id = orals[uid] + poster_id = posters[uid] + + # Verify edge exists + has_edge = self.graph.has_edge(oral_id, poster_id) + + matched_pairs.append({ + 'uid': uid, + 'oral_id': oral_id, + 'poster_id': poster_id, + 'oral_name': self.graph.nodes[oral_id].get('name', ''), + 'poster_name': self.graph.nodes[poster_id].get('name', ''), + 'has_edge': has_edge + }) + + # Find orals without corresponding posters + orals_without_posters = [ + uid for uid in orals.keys() + if uid not in posters + ] + + # Find posters without corresponding orals + posters_without_orals = [ + uid for uid in posters.keys() + if uid not in orals + ] + + # Check edges + edge_tracker = 0 + for pair in matched_pairs: + if pair['has_edge']: + edge_tracker += 1 + + stats = { + 'total_papers': len(paper_nodes), + 'total_orals': len(orals), + 'total_posters': len(posters), + 'other_papers': other_papers, + 'matched_pairs': len(matched_pairs), + 'pairs_with_edges': edge_tracker, + 'orals_without_posters': len(orals_without_posters), + 'posters_without_orals': len(posters_without_orals), + 'pair_details': matched_pairs + } return stats @@ -392,53 +577,110 @@ def get_graph_statistics(self) -> Dict[str, Any]: @click.option("-f", "--input-json-file", default=f"{PROJECT_ROOT}/data/neurips-2025-orals-posters.json") @click.option("-o", "--output-file", default=f"{PROJECT_ROOT}/graphs/knowledge_graph.pkl") @click.option("-s", "--similarity-threshold", default=0.8) +@click.option("--stats-only", is_flag=True) def main( - embedding_model: str, - ollama_server_url: str, - embedding_gen_batch_size: int, - max_parallel_workers: int, - limit_num_papers: int, - input_json_file: str, - output_file: str, - similarity_threshold: float + embedding_model: str, + ollama_server_url: str, + embedding_gen_batch_size: int, + max_parallel_workers: int, + limit_num_papers: int, + input_json_file: str, + output_file: str, + similarity_threshold: float, + stats_only: bool = False ): kg = PaperKnowledgeGraph( - embedding_model=f"ollama/{embedding_model}", + embedding_model=embedding_model, ollama_base_url=ollama_server_url, embedding_gen_batch_size=embedding_gen_batch_size, max_parallel_workers=max_parallel_workers, limit_num_papers=limit_num_papers ) - # Load papers from JSON file - kg.load_papers_from_json(input_json_file) - - # Build the graph (parallel embedding generation) - kg.build_graph() - - # Optionally connect similar papers based on embeddings (parallel) - kg.connect_similar_papers(similarity_threshold=similarity_threshold) + if not stats_only: + # Load papers from JSON file + kg.load_papers_from_json(input_json_file) + + # Build the graph (parallel embedding generation) + kg.build_graph() + + # Optionally connect similar papers based on embeddings (parallel) + kg.connect_similar_papers(similarity_threshold=similarity_threshold) + + # DEBUG: Check edges before saving + LOGGER.info("\n=== BEFORE SAVING ===") + oral_poster_edges_before = [ + (s, t, d) for s, t, d in kg.graph.edges(data=True) + if d.get('relationship') == 'oral_poster_pair' + ] + LOGGER.info(f"Oral-poster edges before save: {len(oral_poster_edges_before)}") + if oral_poster_edges_before: + sample = oral_poster_edges_before[0] + LOGGER.info(f"Sample edge: {sample[0]} -> {sample[1]}, data: {sample[2]}") + + # Save the graph to disk + save_graph( + graph=kg.graph, + output_path=output_file + ) - # Save the graph to disk - save_graph( - graph=kg.graph, - output_path=output_file - ) + # DEBUG: Load it back and check + LOGGER.info("\n=== AFTER LOADING BACK ===") + test_load = load_graph(output_file) + oral_poster_edges_after = [ + (s, t, d) for s, t, d in test_load.edges(data=True) + if d.get('relationship') == 'oral_poster_pair' + ] + LOGGER.info(f"Oral-poster edges after load: {len(oral_poster_edges_after)}") + if oral_poster_edges_after: + sample = oral_poster_edges_after[0] + LOGGER.info(f"Sample edge: {sample[0]} -> {sample[1]}, data: {sample[2]}") + else: + LOGGER.warning("Edges lost during save/load!") + # Check what edges DO exist + LOGGER.info("Sample of edges that exist after load:") + for i, (s, t, d) in enumerate(test_load.edges(data=True)): + if i >= 3: + break + LOGGER.info(f" {s} -> {t}, data keys: {list(d.keys())}, relationship: {d.get('relationship')}") + else: + kg.graph = load_graph(output_file) + + # Get Poster Oral Pairs + po_stats = kg.get_poster_oral_statistics() + LOGGER.info("Paper statistics: ") + for key, value in po_stats.items(): + if type(value) is int: + LOGGER.info(f" {key}: {value}") # Print statistics stats = kg.get_graph_statistics() - LOGGER.info("\nGraph Statistics:") + LOGGER.info("Graph Statistics:") for key, value in stats.items(): LOGGER.info(f" {key}: {value}") # Test run: Find similar papers if kg.papers_data: - first_paper_id = kg.papers_data[0].get('uid', kg.papers_data[0].get('id')) - LOGGER.debug(f"\nPapers similar to '{kg.graph.nodes[first_paper_id]['name']}':") - similar = kg.find_similar_papers(first_paper_id, top_k=3) - for pid, sim, name in similar: - LOGGER.debug(f" - {name} (similarity: {sim:.3f})") + first_paper = kg.papers_data[0] + first_paper_uid = first_paper.get('uid', first_paper.get('id')) + first_paper_eventtype = first_paper.get('eventtype', '').lower() + + # Construct the correct node_id based on eventtype + if "oral" in first_paper_eventtype: + first_paper_node_id = f"oral_{first_paper_uid}" + elif "poster" in first_paper_eventtype: + first_paper_node_id = f"poster_{first_paper_uid}" + else: + first_paper_node_id = f"paper_{first_paper_uid}" + + if first_paper_node_id in kg.graph: + LOGGER.debug(f"Papers similar to '{kg.graph.nodes[first_paper_node_id]['name']}':") + similar = kg.find_similar_papers(first_paper_node_id, top_k=3) + for pid, sim, name in similar: + LOGGER.debug(f" - {name} (similarity: {sim:.3f})") + else: + LOGGER.warning(f"First paper node '{first_paper_node_id}' not found in graph") # Run diff --git a/agentic_nav/tools/knowledge_graph/neo4j_db_importer.py b/agentic_nav/tools/knowledge_graph/neo4j_db_importer.py index d23e366..0402e35 100644 --- a/agentic_nav/tools/knowledge_graph/neo4j_db_importer.py +++ b/agentic_nav/tools/knowledge_graph/neo4j_db_importer.py @@ -1,6 +1,6 @@ """ -Neo4j exporter for PaperKnowledgeGraph -Exports NetworkX graph to Neo4j database with proper handling of embeddings and relationships +Neo4j importer for PaperKnowledgeGraph - OPTIMIZED VERSION +Imports NetworkX graph to Neo4j database with maximum performance """ import logging import os @@ -8,7 +8,7 @@ import click import networkx as nx from neo4j import GraphDatabase -from typing import Dict, Any +from typing import Dict, Any, List import numpy as np from tqdm import tqdm from pathlib import Path @@ -30,16 +30,27 @@ class Neo4jImporter: - """Import PaperKnowledgeGraph to Neo4j database.""" + """Import of PaperKnowledgeGraph to Neo4j database.""" def __init__( self, uri: str = NEO4J_DB_URI, username: str = NEO4J_USERNAME, - password: str = NEO4J_PASSWORD + password: str = NEO4J_PASSWORD, + connection_timeout: int = 30 ): - """Initialize Neo4j connection.""" - self.driver = GraphDatabase.driver(uri, auth=(username, password)) + """ + Initialize Neo4j connection with optimized settings. + + Args: + connection_timeout: Connection timeout in seconds (default: 30) + """ + self.driver = GraphDatabase.driver( + uri, + auth=(username, password), + connection_timeout=connection_timeout, + max_transaction_retry_time=30 + ) self.driver.verify_connectivity() LOGGER.info(f"Connected to Neo4j at {uri}") @@ -47,17 +58,32 @@ def close(self): """Close the Neo4j driver connection.""" self.driver.close() - def clear_database(self, batch_size=500): + def _iter_edges(self, kg: nx.Graph): + """ + Iterate over edges, handling both Graph and MultiGraph. + Yields (source, target, data) tuples. + """ + if isinstance(kg, nx.MultiGraph) or isinstance(kg, nx.MultiDiGraph): + for source, target, key, data in kg.edges(data=True, keys=True): + yield source, target, data + else: + for source, target, data in kg.edges(data=True): + yield source, target, data + + def clear_database(self, batch_size=100): + """Clear database with optimized batch size.""" with self.driver.session() as session: deleted_total = 0 while True: result = session.run(""" - MATCH (n) - WITH n LIMIT $batch_size - DETACH DELETE n - RETURN count(n) as deleted - """, - batch_size=batch_size + CALL () { + MATCH (n) + WITH n LIMIT $batch_size + DETACH DELETE n + RETURN count(n) as deleted + } + RETURN deleted + """, batch_size=batch_size ) deleted = result.single()["deleted"] @@ -68,21 +94,25 @@ def clear_database(self, batch_size=500): break def create_indexes(self, embedding_dimension: int = 768): - """Create indexes for better query performance, including vector index.""" + """Create indexes for better query performance.""" with self.driver.session() as session: - # Create index on paper IDs - session.run("CREATE INDEX paper_id IF NOT EXISTS FOR (p:Paper) ON (p.id)") + indexes = [ + "CREATE INDEX paper_node_id IF NOT EXISTS FOR (p:Paper) ON (p.node_id)", + "CREATE INDEX paper_id IF NOT EXISTS FOR (p:Paper) ON (p.id)", + "CREATE INDEX paper_import_id IF NOT EXISTS FOR (p:Paper) ON (p.import_id)", + "CREATE INDEX paper_presentation_category IF NOT EXISTS FOR (p:Paper) ON (p.presentation_category)", + "CREATE INDEX topic_name IF NOT EXISTS FOR (t:Topic) ON (t.name)", + "CREATE INDEX author_id IF NOT EXISTS FOR (a:Author) ON (a.author_id)", + "CREATE INDEX author_composite_id IF NOT EXISTS FOR (a:Author) ON (a.composite_id)", + "CREATE INDEX author_name IF NOT EXISTS FOR (a:Author) ON (a.fullname)", + ] + + for idx in indexes: + session.run(idx) - # Create index on topic names - session.run("CREATE INDEX topic_name IF NOT EXISTS FOR (t:Topic) ON (t.name)") - - # Create index on author IDs - session.run("CREATE INDEX author_id IF NOT EXISTS FOR (a:Author) ON (a.author_id)") - - # Create index on author names (useful for searching) - session.run("CREATE INDEX author_name IF NOT EXISTS FOR (a:Author) ON (a.fullname)") + LOGGER.info("Created standard indexes") - # Create vector index for embeddings (Neo4j 5.11+) + # Vector index (requires Neo4j 5.11+) try: session.run(""" CREATE VECTOR INDEX paper_embeddings IF NOT EXISTS @@ -97,17 +127,17 @@ def create_indexes(self, embedding_dimension: int = 768): """, dimension=embedding_dimension) LOGGER.info(f"Created vector index for {embedding_dimension}-dimensional embeddings") except Exception as e: - LOGGER.warning(f"Warning: Could not create vector index: {e}") - LOGGER.warning("Vector indexes require Neo4j 5.11+ or Enterprise Edition") - - LOGGER.info("Created standard indexes") + LOGGER.warning(f"Could not create vector index: {e}") - def _export_paper_nodes(self, kg: nx.Graph, batch_size: int): - """Export paper nodes to Neo4j with all attributes.""" + def _import_paper_nodes(self, kg: nx.Graph, batch_size: int = 100): + """ + Import paper nodes with optimized batch size. + Uses UNWIND with CREATE instead of MERGE for better performance. + """ paper_nodes = [(n, d) for n, d in kg.nodes(data=True) if d.get('node_type') == 'paper'] - LOGGER.info(f"\nExporting {len(paper_nodes)} paper nodes...") + LOGGER.info(f"\nImporting {len(paper_nodes)} paper nodes...") with self.driver.session() as session: for i in tqdm(range(0, len(paper_nodes), batch_size), desc="Paper nodes"): @@ -115,64 +145,99 @@ def _export_paper_nodes(self, kg: nx.Graph, batch_size: int): papers_data = [] for node_id, data in batch: - # Convert embedding to list if it's numpy array embedding = data.get('embedding', []) if isinstance(embedding, np.ndarray): embedding = embedding.tolist() paper_dict = { - "id": node_id, + "node_id": node_id, + "id": data.get('id', ''), "name": data.get('name', ''), "abstract": data.get('abstract', ''), "topic": data.get('topic', ''), "keywords": data.get('keywords', []), - "decision": data.get('decision', ''), + "decisions": data.get('decisions', ''), "session": data.get('session', ''), "session_start_time": data.get('session_start_time', ''), "session_end_time": data.get('session_end_time', ''), "presentation_type": data.get('presentation_type', ''), + "presentation_category": data.get('presentation_category', ''), "room_name": data.get('room_name', ''), "project_url": data.get('project_url', ''), "poster_position": data.get('poster_position', ''), "paper_url": data.get("paper_url", ""), "sourceid": data.get("sourceid", ""), "virtualsite_url": data.get("virtualsite_url", ""), + "import_id": data.get("import_id", ""), 'embedding': embedding } papers_data.append(paper_dict) - # Batch create paper nodes + # Use CREATE instead of MERGE since we know nodes don't exist session.run(""" UNWIND $papers AS paper CREATE (p:Paper { + node_id: paper.node_id, id: paper.id, name: paper.name, abstract: paper.abstract, topic: paper.topic, keywords: paper.keywords, - decision: paper.decision, + decisions: paper.decisions, session: paper.session, session_start_time: paper.session_start_time, session_end_time: paper.session_end_time, presentation_type: paper.presentation_type, + presentation_category: paper.presentation_category, room_name: paper.room_name, project_url: paper.project_url, poster_position: paper.poster_position, paper_url: paper.paper_url, sourceid: paper.sourceid, virtualsite_url: paper.virtualsite_url, + import_id: paper.import_id, embedding: paper.embedding }) """, papers=papers_data) - LOGGER.info(f"Exported {len(paper_nodes)} paper nodes") + LOGGER.info(f"Imported {len(paper_nodes)} paper nodes") - def _export_topic_hierarchy(self, kg: nx.Graph): - """ - Export topic nodes with hierarchical structure to Neo4j. - Splits topics like "Deep Learning->Theory" into separate nodes with parent-child relationships. - """ - # Collect all unique topic paths from paper nodes + def _import_oral_poster_relationships(self, kg: nx.Graph, batch_size: int = 100): + """Import oral-poster relationships with optimized batch processing.""" + oral_poster_edges = [ + (source, target, data) + for source, target, data in self._iter_edges(kg) + if data.get('relationship') == 'oral_poster_pair' + ] + + LOGGER.info(f"Importing {len(oral_poster_edges)} oral-poster pair relationships...") + + if len(oral_poster_edges) == 0: + LOGGER.warning("No oral-poster pair relationships found!") + return + + with self.driver.session() as session: + for i in tqdm(range(0, len(oral_poster_edges), batch_size), desc="Oral-Poster relationships"): + batch = oral_poster_edges[i:i + batch_size] + + edges_data = [{ + 'source': source, + 'target': target, + 'uid': data.get('uid', '') + } for source, target, data in batch] + + # Use CREATE instead of MERGE - we know papers exist + session.run(""" + UNWIND $edges AS edge + MATCH (oral:Paper {node_id: edge.source}) + MATCH (poster:Paper {node_id: edge.target}) + CREATE (oral)-[:ORAL_POSTER_PAIR {uid: edge.uid}]->(poster) + """, edges=edges_data) + + LOGGER.info(f"Imported {len(oral_poster_edges)} oral-poster pair relationships") + + def _import_topic_hierarchy(self, kg: nx.Graph): + """Import topic hierarchy with optimized queries.""" topic_paths = set() for node_id, data in kg.nodes(data=True): if data.get('node_type') == 'paper': @@ -182,69 +247,63 @@ def _export_topic_hierarchy(self, kg: nx.Graph): LOGGER.info(f"Processing {len(topic_paths)} unique topic paths...") - # Parse topic paths and create hierarchy all_topics = set() topic_relationships = [] for path in topic_paths: parts = [p.strip() for p in path.split('->')] - - # Add all topic parts for part in parts: all_topics.add(part) - - # Create parent-child relationships for i in range(len(parts) - 1): topic_relationships.append({ 'parent': parts[i], 'child': parts[i + 1] }) - LOGGER.info( - f"Creating {len(all_topics)} topic nodes with {len(set(tuple(r.items()) for r in topic_relationships))} " - f"hierarchical relationships..." - ) + LOGGER.info(f"Creating {len(all_topics)} topic nodes...") with self.driver.session() as session: - # Create all topic nodes (using MERGE to avoid duplicates) - topics_data = [{'name': topic} for topic in all_topics] - session.run(""" - UNWIND $topics AS topic - MERGE (t:Topic {name: topic.name}) - """, topics=topics_data) - - # Create hierarchical relationships between topics (deduplicate first) + # Batch create topics + topics_batch = [{'name': topic} for topic in all_topics] + + # Process in chunks for very large topic sets + chunk_size = 500 + for i in range(0, len(topics_batch), chunk_size): + chunk = topics_batch[i:i + chunk_size] + session.run(""" + UNWIND $topics AS topic + MERGE (t:Topic {name: topic.name}) + """, topics=chunk) + + # Create relationships if topic_relationships: - # Remove duplicates unique_rels = list({(r['parent'], r['child']): r for r in topic_relationships}.values()) - session.run(""" - UNWIND $rels AS rel - MATCH (parent:Topic {name: rel.parent}) - MATCH (child:Topic {name: rel.child}) - MERGE (child)-[:SUBTOPIC_OF]->(parent) - """, rels=unique_rels) - LOGGER.info(f"Exported {len(all_topics)} topic nodes with hierarchy") + for i in range(0, len(unique_rels), chunk_size): + chunk = unique_rels[i:i + chunk_size] + session.run(""" + UNWIND $rels AS rel + MATCH (parent:Topic {name: rel.parent}) + MATCH (child:Topic {name: rel.child}) + MERGE (child)-[:SUBTOPIC_OF]->(parent) + """, rels=chunk) - def _connect_papers_to_topics(self, kg: nx.Graph, batch_size: int): - """ - Connect papers to their leaf topic nodes. - For "Deep Learning->Theory", connects paper to "Theory" node. - """ + LOGGER.info(f"Imported {len(all_topics)} topic nodes with hierarchy") + + def _connect_papers_to_topics(self, kg: nx.Graph, batch_size: int = 100): + """Connect papers to topics with larger batches.""" paper_topic_connections = [] for node_id, data in kg.nodes(data=True): if data.get('node_type') == 'paper': topic = data.get('topic', '') if topic: - # Get the leaf topic (last part after splitting) parts = [p.strip() for p in topic.split('->')] leaf_topic = parts[-1] - paper_topic_connections.append({ - 'paper_id': node_id, + 'paper_node_id': node_id, 'topic_name': leaf_topic, - 'full_path': topic # Store full path as property + 'full_path': topic }) LOGGER.info(f"Connecting {len(paper_topic_connections)} papers to topics...") @@ -254,26 +313,32 @@ def _connect_papers_to_topics(self, kg: nx.Graph, batch_size: int): desc="Paper-Topic connections"): batch = paper_topic_connections[i:i + batch_size] + # Use CREATE for new relationships session.run(""" UNWIND $connections AS conn - MATCH (p:Paper {id: conn.paper_id}) + MATCH (p:Paper {node_id: conn.paper_node_id}) MATCH (t:Topic {name: conn.topic_name}) - MERGE (p)-[r:BELONGS_TO_TOPIC]->(t) - SET r.full_path = conn.full_path + CREATE (p)-[r:BELONGS_TO_TOPIC {full_path: conn.full_path}]->(t) """, connections=batch) LOGGER.info(f"Connected papers to leaf topics") - def _export_similarity_relationships(self, kg: nx.Graph, batch_size: int): - """Export similarity relationships between papers to Neo4j.""" - # Filter only similarity edges + def _import_similarity_relationships_optimized(self, kg: nx.Graph, batch_size: int = 500): + """ + Import similarity relationships with maximum optimization. + Uses very large batches and CREATE instead of MERGE. + """ similarity_edges = [ (source, target, data) - for source, target, data in kg.edges(data=True) + for source, target, data in self._iter_edges(kg) if data.get('relationship') == 'similar_to' ] - LOGGER.info(f"Exporting {len(similarity_edges)} similarity relationships...") + LOGGER.info(f"Importing {len(similarity_edges)} similarity relationships...") + + if len(similarity_edges) == 0: + LOGGER.info("No similarity relationships to import") + return with self.driver.session() as session: for i in tqdm(range(0, len(similarity_edges), batch_size), @@ -283,43 +348,37 @@ def _export_similarity_relationships(self, kg: nx.Graph, batch_size: int): edges_data = [{ 'source': source, 'target': target, - 'similarity': data.get('similarity', 0.0) + 'similarity': float(data.get('similarity', 0.0)) } for source, target, data in batch] + # Use CREATE - much faster than MERGE + # Papers already exist, we just need to connect them session.run(""" UNWIND $edges AS edge - MATCH (p1:Paper {id: edge.source}) - MATCH (p2:Paper {id: edge.target}) - MERGE (p1)-[:SIMILAR_TO {similarity: edge.similarity}]->(p2) + MATCH (p1:Paper {node_id: edge.source}) + MATCH (p2:Paper {node_id: edge.target}) + CREATE (p1)-[:SIMILAR_TO {similarity: edge.similarity}]->(p2) """, edges=edges_data) - LOGGER.info(f"Exported {len(similarity_edges)} similarity relationships") + LOGGER.info(f"Imported {len(similarity_edges)} similarity relationships") - def _export_authors_and_relationships(self, kg: nx.Graph, batch_size: int): - """ - Export author nodes from NetworkX graph (where they already exist as separate nodes) - and create IS_AUTHOR_OF relationships between authors and papers. - - Author nodes in NetworkX have composite IDs like "12345 - John Doe" - """ - # Collect author nodes from the graph + def _import_authors_and_relationships(self, kg: nx.Graph, batch_size: int = 100): + """Import authors and relationships with optimized batch sizes.""" author_nodes = [ (node_id, data) for node_id, data in kg.nodes(data=True) - if data.get('node_type') != 'paper' and data.get('node_type') != 'topic' + if data.get('node_type') == 'author' ] LOGGER.info(f"Found {len(author_nodes)} author nodes in graph...") - # Extract author data all_authors = [] for node_id, data in author_nodes: - # Parse composite ID "12345 - John Doe" parts = node_id.split(' - ', 1) author_id = parts[0].strip() if len(parts) > 0 else "" author_dict = { - 'composite_id': node_id, # Store the full composite ID + 'composite_id': node_id, 'author_id': author_id, 'fullname': data.get('fullname', ''), 'institution': data.get('institution', ''), @@ -327,10 +386,9 @@ def _export_authors_and_relationships(self, kg: nx.Graph, batch_size: int): } all_authors.append(author_dict) - LOGGER.info(f"Exporting {len(all_authors)} unique authors...") + LOGGER.info(f"Importing {len(all_authors)} unique authors...") with self.driver.session() as session: - # Create author nodes in batches for i in tqdm(range(0, len(all_authors), batch_size), desc="Author nodes"): batch = all_authors[i:i + batch_size] @@ -349,143 +407,117 @@ def _export_authors_and_relationships(self, kg: nx.Graph, batch_size: int): a.url = author.url """, authors=batch) - LOGGER.info(f"Exported {len(all_authors)} author nodes") + LOGGER.info(f"Imported {len(all_authors)} author nodes") - # Method 1: Try to collect author-paper relationships from graph edges + # Import relationships author_paper_edges = [ (source, target, data) - for source, target, data in kg.edges(data=True) + for source, target, data in self._iter_edges(kg) if data.get('relationship') == 'is_author_of' ] LOGGER.info(f"Found {len(author_paper_edges)} IS_AUTHOR_OF edges in graph") - # Method 2: If no edges found, extract from paper node 'authors' attribute - if len(author_paper_edges) == 0: - LOGGER.warning("No IS_AUTHOR_OF edges found in graph. Extracting from paper 'authors' attribute...") - - paper_author_relationships = [] - for node_id, data in kg.nodes(data=True): - if data.get('node_type') == 'paper': - authors = data.get('authors', []) - - if authors and isinstance(authors, list) and len(authors) > 0: - # Check if authors are stored as dicts - if isinstance(authors[0], dict): - for author in authors: - author_id = str(author.get('id', '')) - fullname = author.get('fullname', '') - if author_id and fullname: - composite_id = f"{author_id} - {fullname}" - paper_author_relationships.append({ - 'author_id': composite_id, - 'paper_id': node_id - }) - - LOGGER.info(f"Extracted {len(paper_author_relationships)} relationships from paper attributes") - - # Create relationships from extracted data - with self.driver.session() as session: - for i in tqdm(range(0, len(paper_author_relationships), batch_size), - desc="Author-Paper relationships"): - batch = paper_author_relationships[i:i + batch_size] - - session.run(""" - UNWIND $edges AS edge - MATCH (a:Author {composite_id: edge.author_id}) - MATCH (p:Paper {id: edge.paper_id}) - MERGE (a)-[:IS_AUTHOR_OF]->(p) - """, edges=batch) - - LOGGER.info(f"Created {len(paper_author_relationships)} author-paper relationships") - else: - # Create relationships from graph edges + if len(author_paper_edges) > 0: with self.driver.session() as session: - for i in tqdm(range(0, len(author_paper_edges), batch_size), - desc="Author-Paper relationships"): + for i in tqdm(range(0, len(author_paper_edges), batch_size), desc="Author-Paper relationships"): batch = author_paper_edges[i:i + batch_size] edges_data = [{ - 'author_id': source, # composite ID like "12345 - John Doe" - 'paper_id': target + 'author_id': source, + 'paper_node_id': target, + 'author_order': data.get('author_order', 0) } for source, target, data in batch] + # Use CREATE instead of MERGE session.run(""" UNWIND $edges AS edge MATCH (a:Author {composite_id: edge.author_id}) - MATCH (p:Paper {id: edge.paper_id}) - MERGE (a)-[:IS_AUTHOR_OF]->(p) + MATCH (p:Paper {node_id: edge.paper_node_id}) + CREATE (a)-[r:IS_AUTHOR_OF {author_order: edge.author_order}]->(p) """, edges=edges_data) LOGGER.info(f"Created {len(author_paper_edges)} author-paper relationships") - def import_graph(self, kg_path: str, batch_size: int = 100, embedding_dimension: int = 768): - """Import the entire knowledge graph to Neo4j.""" + def import_graph( + self, + kg_path: str, + batch_size: int = 100, + similarity_batch_size: int = 500, + embedding_dimension: int = 768 + ): + """ + Import the entire knowledge graph to Neo4j with optimized settings. + + Args: + kg_path: Path to graph pickle file + batch_size: Batch size for most operations (default: 1000) + similarity_batch_size: Larger batch for similarity edges (default: 10000) + embedding_dimension: Embedding vector dimension + """ LOGGER.info(f"Loading graph from path {kg_path}") kg = load_graph(kg_path) - LOGGER.info("Starting Neo4j export...") + LOGGER.info("Starting optimized Neo4j import...") # Clear and prepare database self.clear_database() self.create_indexes(embedding_dimension) - # Export paper nodes - self._export_paper_nodes(kg, batch_size) - - # Export authors and author-paper relationships - self._export_authors_and_relationships(kg, batch_size) - - # Export topic hierarchy - self._export_topic_hierarchy(kg) - - # Connect papers to topics - self._connect_papers_to_topics(kg, batch_size) + # Import in optimized order + self._import_paper_nodes(kg, batch_size) + self._import_authors_and_relationships(kg, batch_size) + self._import_topic_hierarchy(kg) + self._connect_papers_to_topics(kg, batch_size * 2) # 2x batch for simpler queries + self._import_oral_poster_relationships(kg, batch_size) + self._import_similarity_relationships_optimized(kg, similarity_batch_size) - # Export similarity relationships - self._export_similarity_relationships(kg, batch_size) + LOGGER.info("Import completed successfully!") - LOGGER.info("Export completed successfully!") - - def verify_export(self) -> Dict[str, Any]: - """Verify the export by checking node and relationship counts.""" + def verify_import(self) -> Dict[str, Any]: + """Verify the import by checking node and relationship counts.""" with self.driver.session() as session: - # Count papers result = session.run("MATCH (p:Paper) RETURN count(p) as count") paper_count = result.single()['count'] - # Count topics + result = session.run("MATCH (p:Paper) WHERE p.presentation_category = 'oral' RETURN count(p) as count") + oral_count = result.single()['count'] + + result = session.run("MATCH (p:Paper) WHERE p.presentation_category = 'poster' RETURN count(p) as count") + poster_count = result.single()['count'] + result = session.run("MATCH (t:Topic) RETURN count(t) as count") topic_count = result.single()['count'] - # Count authors result = session.run("MATCH (a:Author) RETURN count(a) as count") author_count = result.single()['count'] - # Count relationships result = session.run("MATCH ()-[r]->() RETURN count(r) as count") rel_count = result.single()['count'] - # Count similarity relationships result = session.run("MATCH ()-[r:SIMILAR_TO]->() RETURN count(r) as count") similarity_count = result.single()['count'] - # Count topic hierarchy relationships result = session.run("MATCH ()-[r:SUBTOPIC_OF]->() RETURN count(r) as count") subtopic_count = result.single()['count'] - # Count author relationships (updated relationship name) result = session.run("MATCH ()-[r:IS_AUTHOR_OF]->() RETURN count(r) as count") is_author_of_count = result.single()['count'] + result = session.run("MATCH ()-[r:ORAL_POSTER_PAIR]->() RETURN count(r) as count") + oral_poster_count = result.single()['count'] + stats = { 'papers': paper_count, + 'orals': oral_count, + 'posters': poster_count, 'topics': topic_count, 'authors': author_count, 'total_relationships': rel_count, 'similarity_relationships': similarity_count, 'subtopic_relationships': subtopic_count, - 'is_author_of_relationships': is_author_of_count + 'is_author_of_relationships': is_author_of_count, + 'oral_poster_pair_relationships': oral_poster_count } LOGGER.info("Neo4j Database Statistics:") @@ -501,6 +533,7 @@ def verify_export(self) -> Dict[str, Any]: @click.option("-u", "--neo4j-username", help="Database user", default=NEO4J_USERNAME) @click.option("-p", "--neo4j-password", help="Database password", default=NEO4J_PASSWORD) @click.option("-b", "--batch-size", help="Batch size for node insertion", default=100) +@click.option("-s", "--similarity-batch-size", help="Batch size for similarity edges", default=5000) @click.option("-e", "--embedding-dimension", help="Vector embedding dimensions", default=768) def main( graph_path: str, @@ -508,27 +541,22 @@ def main( neo4j_username: str, neo4j_password: str, batch_size: int = 100, + similarity_batch_size: int = 5000, embedding_dimension: int = 768 ): - """ - Convenience function to export a knowledge graph to Neo4j. - - Args: - graph_path: PaperKnowledgeGraph instance - neo4j_uri: Neo4j connection URI - neo4j_username: Neo4j username - neo4j_password: Neo4j password - batch_size: Batch size for processing - embedding_dimension: Dimension of embedding vectors (default: 768) - """ - importer = Neo4jImporter(neo4j_uri, neo4j_username, neo4j_password) + importer = Neo4jImporter( + neo4j_uri, + neo4j_username, + neo4j_password + ) try: importer.import_graph( graph_path, batch_size, + similarity_batch_size, embedding_dimension ) - importer.verify_export() + importer.verify_import() finally: importer.close() diff --git a/agentic_nav/tools/knowledge_graph/retriever.py b/agentic_nav/tools/knowledge_graph/retriever.py index 89b5e76..dcf094c 100644 --- a/agentic_nav/tools/knowledge_graph/retriever.py +++ b/agentic_nav/tools/knowledge_graph/retriever.py @@ -2,12 +2,12 @@ import numpy as np import random import os +from functools import lru_cache +from typing import List, Dict, Any, Optional, Tuple from neo4j import GraphDatabase from pathlib import Path -from typing import List, Dict, Any, Optional - from agentic_nav.tools.knowledge_graph.graph_traversal_strategies import ( TraversalStrategy, _graph_traversal_dfs_random, @@ -27,166 +27,245 @@ class Neo4jGraphWorker: - """Search and traversal operations for Neo4j paper knowledge graph.""" + """Search and traversal operations for Neo4j paper knowledge graph - OPTIMIZED.""" + # Optimized: Reduced property fetching, streamlined UNWIND logic _DB_SIMILARITY_SEARCH_QUERY = """ - MATCH (node:Paper) - WHERE ($day IS NULL OR node.session_start_time IS NOT NULL) - WITH node - WHERE ($day IS NULL OR date(datetime(node.session_start_time)).dayOfWeek = $day) - AND ($time_ranges IS NULL OR - any(range IN $time_ranges WHERE - time(datetime(node.session_start_time)) >= time(range.start) - AND time(datetime(node.session_start_time)) <= time(range.end))) - WITH collect(node) as filtered_nodes CALL db.index.vector.queryNodes('paper_embeddings', $top_k, $query_embedding) YIELD node, score - WHERE node IN filtered_nodes OR ($day IS NULL AND $time_ranges IS NULL) - RETURN node.id as id, - node.name as name, - node.abstract as abstract, - node.topic as topic, - node.paper_url as paper_url, - node.session as session, - node.session_start_time as session_start_time, - node.session_end_time as session_end_time, - node.presentation_type as presentation_type, - node.room_name as room_name, - node.project_url as project_url, - node.poster_position as poster_position, - node.sourceid as sourceid, - node.virtualsite_url as virtualsite_url, - node.decision as decision, - [(a:Author)-[:IS_AUTHOR_OF]->(node) | a] as authors, + WHERE ($day IS NULL OR node.session_start_time IS NOT NULL) + AND ($day IS NULL OR date(datetime(node.session_start_time)).dayOfWeek = $day) + AND ($time_ranges IS NULL OR + any(range IN $time_ranges WHERE + time(datetime(node.session_start_time)) >= time(range.start) + AND time(datetime(node.session_start_time)) <= time(range.end))) + + // Deduplicate matched nodes and keep highest score + WITH node, max(score) as score + ORDER BY score DESC + LIMIT $top_k + + // Fetch pair only once per matched paper + OPTIONAL MATCH (node)-[:ORAL_POSTER_PAIR]-(pair:Paper) + WHERE elementId(node) < elementId(pair) + + // Collect unique papers + WITH node, pair, score + UNWIND (CASE WHEN pair IS NULL THEN [node] ELSE [node, pair] END) as paper + + // Deduplicate by paper.id to ensure uniqueness + WITH paper, max(score) as score + + // Get authors ordered by author_order + OPTIONAL MATCH (a:Author)-[r:IS_AUTHOR_OF]->(paper) + WITH paper, score, a, r + ORDER BY r.author_order + WITH paper, score, collect(a.fullname) as authors + + RETURN paper.id as id, + paper.name as name, + paper.abstract as abstract, + paper.topic as topic, + paper.paper_url as paper_url, + paper.session as session, + paper.session_start_time as session_start_time, + paper.session_end_time as session_end_time, + paper.presentation_type as presentation_type, + paper.presentation_category as presentation_category, + paper.room_name as room_name, + paper.project_url as project_url, + paper.poster_position as poster_position, + paper.sourceid as sourceid, + paper.virtualsite_url as virtualsite_url, + paper.decisions as decisions, + authors, score ORDER BY score DESC LIMIT $limit """ + # Optimized: More efficient author list comprehension _DB_NEIGHBORHOOD_SEARCH_QUERY = """ - MATCH (p:Paper)-[r]-(neighbor) - WHERE p.id IN $paper_ids - AND type(r) IN $allowed_rel_types - AND 'Paper' IN labels(neighbor) + MATCH (p:Paper {id: $paper_id})-[r]-(neighbor:Paper) + WHERE type(r) IN $allowed_rel_types AND (type(r) <> 'SIMILAR_TO' OR r.similarity >= $min_similarity) - RETURN neighbor.id as id, - neighbor.name as name, - neighbor.abstract as abstract, - neighbor.topic as topic, - neighbor.paper_url as paper_url, - neighbor.session as session, - neighbor.session_start_time as session_start_time, - neighbor.session_end_time as session_end_time, - neighbor.presentation_type as presentation_type, - neighbor.room_name as room_name, - neighbor.project_url as project_url, - neighbor.poster_position as poster_position, - neighbor.sourceid as sourceid, - neighbor.virtualsite_url as virtualsite_url, - neighbor.decision as decision, - [(a:Author)-[:IS_AUTHOR_OF]->(neighbor) | a] as authors, + + // Deduplicate neighbors (same neighbor might be found via different relationship types) + WITH neighbor, p, max(CASE WHEN type(r) = 'SIMILAR_TO' THEN r.similarity ELSE 0 END) as similarity, + collect(DISTINCT type(r)) as rel_types + + // For simplicity, use the first relationship type if multiple exist + WITH neighbor, p, similarity, rel_types[0] as relationship_type + + // Fetch pair only once per neighbor + OPTIONAL MATCH (neighbor)-[:ORAL_POSTER_PAIR]-(pair:Paper) + WHERE elementId(neighbor) < elementId(neighbor) // Only expand from one direction + + WITH neighbor, pair, p, relationship_type, similarity + UNWIND CASE WHEN pair IS NULL THEN [neighbor] ELSE [neighbor, pair] END as result_paper + + RETURN result_paper.id as id, + result_paper.name as name, + result_paper.abstract as abstract, + result_paper.topic as topic, + result_paper.paper_url as paper_url, + result_paper.session as session, + result_paper.session_start_time as session_start_time, + result_paper.session_end_time as session_end_time, + result_paper.presentation_type as presentation_type, + result_paper.presentation_category as presentation_category, + result_paper.room_name as room_name, + result_paper.project_url as project_url, + result_paper.poster_position as poster_position, + result_paper.sourceid as sourceid, + result_paper.virtualsite_url as virtualsite_url, + result_paper.decisions as decisions, + [(a:Author)-[:IS_AUTHOR_OF]->(result_paper) | a.fullname] as authors, p.id as source_paper_id, - type(r) as relationship_type, - CASE WHEN type(r) = 'SIMILAR_TO' THEN r.similarity ELSE null END as similarity + relationship_type, + CASE WHEN relationship_type = 'SIMILAR_TO' THEN similarity ELSE null END as similarity ORDER BY similarity DESC LIMIT $limit """ - # Find the DB query for graph traversal in the graph_traversal sub-folder. _DB_PAPERS_BY_AUTHOR = """ - MATCH (a:Author)-[:IS_AUTHOR_OF]->(p:Paper) - WHERE a.fullname = $author_name - WITH p, collect(DISTINCT a) as all_authors - RETURN p.id as id, - p.name as name, - p.abstract as abstract, - p.topic as topic, - p.paper_url as paper_url, - p.decision as decision, - p.session as session, - p.session_start_time as session_start_time, - p.session_end_time as session_end_time, - p.presentation_type as presentation_type, - p.room_name as room_name, - p.project_url as project_url, - p.poster_position as poster_position, - p.sourceid as sourceid, - p.virtualsite_url as virtualsite_url, - all_authors as authors - ORDER BY p.name + MATCH (a:Author {fullname: $author_name})-[:IS_AUTHOR_OF]->(p:Paper) + + // Collect papers first to prevent duplicates + WITH collect(DISTINCT p) as papers + UNWIND papers as p + + OPTIONAL MATCH (p)-[:ORAL_POSTER_PAIR]-(pair:Paper) + WHERE elementId(p) < elementId(pair) // Only expand from one direction + + WITH p, pair + UNWIND CASE WHEN pair IS NULL THEN [p] ELSE [p, pair] END as paper + + RETURN paper.id as id, + paper.name as name, + paper.abstract as abstract, + paper.topic as topic, + paper.paper_url as paper_url, + paper.decisions as decisions, + paper.session as session, + paper.session_start_time as session_start_time, + paper.session_end_time as session_end_time, + paper.presentation_type as presentation_type, + paper.presentation_category as presentation_category, + paper.room_name as room_name, + paper.project_url as project_url, + paper.poster_position as poster_position, + paper.sourceid as sourceid, + paper.virtualsite_url as virtualsite_url, + [(a:Author)-[:IS_AUTHOR_OF]->(paper) | a.fullname] as authors + ORDER BY paper.name LIMIT $limit """ _DB_PAPERS_BY_AUTHOR_FUZZY = """ MATCH (a:Author)-[:IS_AUTHOR_OF]->(p:Paper) WHERE toLower(a.fullname) CONTAINS toLower($author_name) - WITH p, collect(DISTINCT a) as all_authors - RETURN p.id as id, - p.name as name, - p.abstract as abstract, - p.topic as topic, - p.paper_url as paper_url, - p.decision as decision, - p.session as session, - p.session_start_time as session_start_time, - p.session_end_time as session_end_time, - p.presentation_type as presentation_type, - p.room_name as room_name, - p.project_url as project_url, - p.poster_position as poster_position, - p.sourceid as sourceid, - p.virtualsite_url as virtualsite_url, - all_authors as authors - ORDER BY p.name + + // Collect papers first to prevent duplicates + WITH collect(DISTINCT p) as papers + UNWIND papers as p + + OPTIONAL MATCH (p)-[:ORAL_POSTER_PAIR]-(pair:Paper) + WHERE elementId(p) < elementId(pair) // Only expand from one direction + + WITH p, pair + UNWIND CASE WHEN pair IS NULL THEN [p] ELSE [p, pair] END as paper + + RETURN paper.id as id, + paper.name as name, + paper.abstract as abstract, + paper.topic as topic, + paper.paper_url as paper_url, + paper.decisions as decisions, + paper.session as session, + paper.session_start_time as session_start_time, + paper.session_end_time as session_end_time, + paper.presentation_type as presentation_type, + paper.presentation_category as presentation_category, + paper.room_name as room_name, + paper.project_url as project_url, + paper.poster_position as poster_position, + paper.sourceid as sourceid, + paper.virtualsite_url as virtualsite_url, + [(a:Author)-[:IS_AUTHOR_OF]->(paper) | a.fullname] as authors + ORDER BY paper.name LIMIT $limit """ _DB_PAPERS_BY_TOPIC = """ - MATCH (p:Paper)-[:BELONGS_TO_TOPIC]->(t:Topic {name: $topic_name}) - RETURN p.id as id, - p.name as name, - p.abstract as abstract, - p.topic as topic, - p.paper_url as paper_url, - p.decision as decision, - p.session as session, - p.session_start_time as session_start_time, - p.session_end_time as session_end_time, - p.presentation_type as presentation_type, - p.room_name as room_name, - p.project_url as project_url, - p.poster_position as poster_position, - p.sourceid as sourceid, - p.virtualsite_url as virtualsite_url, - [(a:Author)-[:IS_AUTHOR_OF]->(p) | a] as authors - ORDER BY p.name + MATCH (t:Topic {name: $topic_name})<-[:BELONGS_TO_TOPIC]-(p:Paper) + + // Collect papers first to prevent duplicates + WITH collect(DISTINCT p) as papers + UNWIND papers as p + + OPTIONAL MATCH (p)-[:ORAL_POSTER_PAIR]-(pair:Paper) + WHERE elementId(p) < elementId(pair) // Only expand from one direction + + WITH p, pair + UNWIND CASE WHEN pair IS NULL THEN [p] ELSE [p, pair] END as paper + + RETURN paper.id as id, + paper.name as name, + paper.abstract as abstract, + paper.topic as topic, + paper.paper_url as paper_url, + paper.decisions as decisions, + paper.session as session, + paper.session_start_time as session_start_time, + paper.session_end_time as session_end_time, + paper.presentation_type as presentation_type, + paper.presentation_category as presentation_category, + paper.room_name as room_name, + paper.project_url as project_url, + paper.poster_position as poster_position, + paper.sourceid as sourceid, + paper.virtualsite_url as virtualsite_url, + [(a:Author)-[:IS_AUTHOR_OF]->(paper) | a.fullname] as authors + ORDER BY paper.name LIMIT $limit """ _DB_PAPERS_BY_TOPIC_AND_SUBTOPIC = """ MATCH (t:Topic {name: $topic_name}) OPTIONAL MATCH (subtopic:Topic)-[:SUBTOPIC_OF*]->(t) - WITH t, collect(DISTINCT subtopic) + t as all_topics + WITH collect(DISTINCT subtopic) + t as all_topics UNWIND all_topics as topic - MATCH (p:Paper)-[:BELONGS_TO_TOPIC]->(topic) - WITH DISTINCT p - RETURN p.id as id, - p.name as name, - p.abstract as abstract, - p.topic as topic, - p.paper_url as paper_url, - p.decision as decision, - p.session as session, - p.session_start_time as session_start_time, - p.session_end_time as session_end_time, - p.presentation_type as presentation_type, - p.room_name as room_name, - p.project_url as project_url, - p.poster_position as poster_position, - p.sourceid as sourceid, - p.virtualsite_url as virtualsite_url, - [(a:Author)-[:IS_AUTHOR_OF]->(p) | a] as authors - ORDER BY p.name + MATCH (topic)<-[:BELONGS_TO_TOPIC]-(p:Paper) + + // Collect papers to prevent duplicates from multiple topic paths + WITH collect(DISTINCT p) as papers + UNWIND papers as p + + OPTIONAL MATCH (p)-[:ORAL_POSTER_PAIR]-(pair:Paper) + WHERE elementId(p) < elementId(pair) // Only expand from one direction + + WITH p, pair + UNWIND CASE WHEN pair IS NULL THEN [p] ELSE [p, pair] END as paper + + RETURN paper.id as id, + paper.name as name, + paper.abstract as abstract, + paper.topic as topic, + paper.paper_url as paper_url, + paper.decisions as decisions, + paper.session as session, + paper.session_start_time as session_start_time, + paper.session_end_time as session_end_time, + paper.presentation_type as presentation_type, + paper.presentation_category as presentation_category, + paper.room_name as room_name, + paper.project_url as project_url, + paper.poster_position as poster_position, + paper.sourceid as sourceid, + paper.virtualsite_url as virtualsite_url, + [(a:Author)-[:IS_AUTHOR_OF]->(paper) | a.fullname] as authors + ORDER BY paper.name LIMIT $limit """ @@ -194,10 +273,29 @@ def __init__( self, uri: str = NEO4J_DB_URI, username: str = "neo4j", - password: str = "password" + password: str = "password", + max_connection_lifetime: int = 3600, + max_connection_pool_size: int = 50, + connection_acquisition_timeout: int = 60 ): - """Initialize Neo4j connection.""" - self.driver = GraphDatabase.driver(uri, auth=(username, password)) + """ + Initialize Neo4j connection with optimized settings. + + Args: + uri: Neo4j connection URI + username: Database username + password: Database password + max_connection_lifetime: Max lifetime of connections in seconds (default: 3600) + max_connection_pool_size: Max number of connections in pool (default: 50) + connection_acquisition_timeout: Timeout for acquiring connection (default: 60s) + """ + self.driver = GraphDatabase.driver( + uri, + auth=(username, password), + max_connection_lifetime=max_connection_lifetime, + max_connection_pool_size=max_connection_pool_size, + connection_acquisition_timeout=connection_acquisition_timeout + ) self.driver.verify_connectivity() LOGGER.info(f"Connected to Neo4j at {uri}") @@ -205,20 +303,131 @@ def close(self): """Close the Neo4j driver connection.""" self.driver.close() + @staticmethod + def _link_oral_poster_pairs(papers: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Link Oral and Poster pairs by adding cross-references. + OPTIMIZED: Single pass with early exits. + + Args: + papers: List of paper dictionaries from database query + + Returns: + List of papers with Oral-Poster pairs linked via new fields + """ + if not papers: + return papers + + # Build lookup dictionaries by sourceid in a single pass + oral_map = {} # abs(sourceid) -> paper record + poster_map = {} # sourceid -> paper record + + for paper in papers: + sourceid = paper.get('sourceid') + if sourceid is None: + continue + + if sourceid < 0: + oral_map[abs(sourceid)] = paper + else: + poster_map[sourceid] = paper + + # Add cross-references - only for papers that have pairs + for abs_sourceid, oral in oral_map.items(): + poster = poster_map.get(abs_sourceid) + if not poster: + continue + + # Link them together + oral['has_poster'] = True + oral['poster_id'] = poster['id'] + oral['poster_session'] = poster.get('session') + oral['poster_session_start_time'] = poster.get('session_start_time') + oral['poster_session_end_time'] = poster.get('session_end_time') + oral['poster_room_name'] = poster.get('room_name') + oral['poster_position'] = poster.get('poster_position') + + # Replace Oral's paper_url with Poster's paper_url (OpenReview link) + if poster.get('paper_url'): + oral['paper_url'] = poster['paper_url'] + + poster['has_oral'] = True + poster['oral_id'] = oral['id'] + poster['oral_session'] = oral.get('session') + poster['oral_session_start_time'] = oral.get('session_start_time') + poster['oral_session_end_time'] = oral.get('session_end_time') + poster['oral_room_name'] = oral.get('room_name') + + return papers + @staticmethod def embed_user_query( - text: str, - embedding_model: str = EMBEDDING_MODEL_NAME, - api_base: str = EMBEDDING_MODEL_API_BASE - ): + text: str, + embedding_model: str = EMBEDDING_MODEL_NAME, + api_base: str = EMBEDDING_MODEL_API_BASE + ) -> List[float]: + """Generate embedding for user query. Returns a list of floats.""" emb = batch_embed_documents( texts=[text], batch_size=1, api_base=api_base, embedding_model=embedding_model - ).tolist()[0] + ) + + # Convert to list if numpy array + if isinstance(emb, np.ndarray): + return emb.tolist()[0] + return emb[0] - return emb + @staticmethod + def _parse_day_filter(day: Optional[str]) -> Optional[int]: + """Parse day string to day of week integer.""" + if not day: + return None + from datetime import datetime + date_obj = datetime.strptime(day, "%Y-%m-%d") + return date_obj.isoweekday() + + @staticmethod + def _parse_timeslots(timeslots: Optional[List[str]]) -> Optional[List[Dict[str, str]]]: + """Parse timeslot strings to time ranges.""" + if not timeslots: + return None + + time_ranges = [] + for slot in timeslots: + if '-' in slot: + start, end = slot.split('-', 1) + time_ranges.append({'start': start.strip(), 'end': end.strip()}) + else: + time_ranges.append({'start': slot.strip(), 'end': slot.strip()}) + return time_ranges + + @staticmethod + def _build_paper_dict(record) -> Dict[str, Any]: + """ + Build paper dictionary from Neo4j record. + OPTIMIZED: Centralized dict construction. + """ + return { + 'id': record['id'], + 'name': record['name'], + 'abstract': record['abstract'], + 'topic': record['topic'], + 'paper_url': record['paper_url'], + 'decisions': record['decisions'], + 'session': record['session'], + 'session_start_time': record['session_start_time'], + 'session_end_time': record['session_end_time'], + 'presentation_type': record['presentation_type'], + 'presentation_category': record['presentation_category'], + 'room_name': record['room_name'], + 'github_url': record['project_url'], + 'poster_position': record['poster_position'], + 'sourceid': record['sourceid'], + 'virtualsite_url': record['virtualsite_url'], + 'authors': record['authors'] # Already processed in Cypher + } def similarity_search( self, @@ -241,35 +450,12 @@ def similarity_search( Returns: List of dictionaries containing paper information and similarity scores """ + # Generate embedding + query_embedding = self.embed_user_query(user_query) - # Generate text embedding - query_embedding = self.embed_user_query( - text=user_query - ) - - # Convert numpy array to list if needed - if isinstance(query_embedding, np.ndarray): - query_embedding = query_embedding.tolist() - - # Parse day and timeslots for the query - day_filter = None - time_ranges = [] - - if day: - # Convert date string to day of week (1=Monday, 7=Sunday) - from datetime import datetime - date_obj = datetime.strptime(day, "%Y-%m-%d") - day_filter = date_obj.isoweekday() - - if timeslots: - # Parse timeslot ranges (e.g., "09:00:00-12:00:00") - for slot in timeslots: - if '-' in slot: - start, end = slot.split('-') - time_ranges.append({'start': start.strip(), 'end': end.strip()}) - else: - # If no range, assume it's a single time point with some buffer - time_ranges.append({'start': slot.strip(), 'end': slot.strip()}) + # Parse filters + day_filter = self._parse_day_filter(day) + time_ranges = self._parse_timeslots(timeslots) with self.driver.session() as session: result = session.run( @@ -278,37 +464,22 @@ def similarity_search( top_k=top_k, limit=NEO4J_DB_NODE_RETURN_LIMIT, day=day_filter, - time_ranges=time_ranges if time_ranges else None + time_ranges=time_ranges ) + papers = [] for record in result: - paper = { - 'id': record['id'], - 'name': record['name'], - 'abstract': record['abstract'], - 'topic': record['topic'], - 'similarity_score': record['score'], - 'paper_url': record['paper_url'], - 'decision': record['decision'], - 'session': record['session'], - 'session_start_time': record['session_start_time'], - 'session_end_time': record['session_end_time'], - 'presentation_type': record['presentation_type'], - 'room_name': record['room_name'], - 'github_url': record['project_url'], - 'poster_position': record['poster_position'], - 'sourceid': record['sourceid'], - 'virtualsite_url': record['virtualsite_url'], - 'authors': [a['fullname'] for a in record['authors']] - } + score = record['score'] - # Apply minimum similarity filter if specified - if min_similarity is None or paper['similarity_score'] >= min_similarity: - # IMPORTANT: We don't return the similarity as the model has high affinity to scores like that... - del paper["similarity_score"] - papers.append(paper) + # Apply minimum similarity filter early + if min_similarity is not None and score < min_similarity: + continue - return papers + paper = self._build_paper_dict(record) + papers.append(paper) + + # Link oral-poster pairs + return self._link_oral_poster_pairs(papers) def neighborhood_search( self, @@ -321,23 +492,26 @@ def neighborhood_search( Args: paper_id: Paper ID to find neighbors for - relationship_types: Optional list of relationship types to filter - (e.g., ['SIMILAR_TO', 'IS_AUTHOR_OF', 'BELONGS_TO_TOPIC', 'SUBTOPIC_OF']) - min_similarity (float): A minimum similarity score in the range of 0 - 1. Often a good value is 0.75 or 0.8. - + relationship_types: List of relationship types to filter + min_similarity: Minimum similarity score (0-1) Returns: Dictionary with neighbors grouped by relationship type """ - allowed_rel_types = ['SIMILAR_TO', 'IS_AUTHOR_OF', 'BELONGS_TO_TOPIC', 'SUBTOPIC_OF'] - for rel_type in relationship_types: - if rel_type not in allowed_rel_types: - raise ValueError(f"Unsupported relationship type: {rel_type}. Supported relationship types: {allowed_rel_types}") + allowed_rel_types = ['SIMILAR_TO', 'IS_AUTHOR_OF', 'BELONGS_TO_TOPIC', 'SUBTOPIC_OF', 'ORAL_POSTER_PAIR'] + + # Validate relationship types + invalid_types = set(relationship_types) - set(allowed_rel_types) + if invalid_types: + raise ValueError( + f"Unsupported relationship type(s): {invalid_types}. " + f"Supported types: {allowed_rel_types}" + ) with self.driver.session() as session: result = session.run( self._DB_NEIGHBORHOOD_SEARCH_QUERY, - paper_ids=[paper_id], + paper_id=paper_id, allowed_rel_types=relationship_types, min_similarity=min_similarity, limit=NEO4J_DB_NODE_RETURN_LIMIT @@ -345,31 +519,30 @@ def neighborhood_search( # Organize results by relationship type neighbors = {} - for record in result: - # Use the dict() object in Record to manipulate the data. Records are immutable. - record = record.data() rel_type = record["relationship_type"] - if rel_type not in neighbors.keys(): + + if rel_type not in neighbors: neighbors[rel_type] = [] - else: - if "similarity" in record.keys(): - # IMPORTANT: We don't return the similarity as the model has high affinity to scores like that... - del record["similarity"] - neighbors[rel_type].append(record) + paper = self._build_paper_dict(record) + neighbors[rel_type].append(paper) + + # Link oral-poster pairs in each relationship type + for rel_type in neighbors: + neighbors[rel_type] = self._link_oral_poster_pairs(neighbors[rel_type]) return neighbors def graph_traversal( - self, - start_paper_id: str, - n_hops: int = 2, - relationship_type: Optional[str] = None, - max_results: Optional[int] = None, - strategy: str = "breadth_first_random", - max_branches: Optional[int] = None, - random_seed: Optional[int] = None + self, + start_paper_id: str, + n_hops: int = 2, + relationship_type: Optional[str] = None, + max_results: Optional[int] = None, + strategy: str = "breadth_first_random", + max_branches: Optional[int] = None, + random_seed: Optional[int] = None ) -> List[Dict[str, Any]]: """ Traverse the graph for n hops from starting paper nodes. @@ -377,10 +550,10 @@ def graph_traversal( Args: start_paper_id: Paper ID to start traversal from n_hops: Number of hops to traverse (1-5 recommended) - relationship_type: Optional list of relationship types to traverse + relationship_type: Optional relationship type to traverse max_results: Optional maximum number of results to return strategy: Traversal strategy (breadth_first, depth_first, breadth_first_random, depth_first_random) - max_branches: Maximum number of random neighbors to explore per node (only for random strategies) + max_branches: Maximum number of random neighbors per node (only for random strategies) random_seed: Optional seed for reproducible random sampling Returns: @@ -389,21 +562,19 @@ def graph_traversal( if random_seed is not None: random.seed(random_seed) - # Use original Cypher-based approach for non-random strategies + # Use Cypher-based approach for non-random strategies if strategy in ["breadth_first", "depth_first"]: - LOGGER.debug(f"Doing a graph traversal with neo4j's built-in strategy") - return _graph_traversal_cypher( + LOGGER.debug("Using Cypher-based traversal strategy") + papers = _graph_traversal_cypher( self.driver, start_paper_id, n_hops, relationship_type, max_results ) - - # Use Python-based traversal for random strategies elif strategy == "breadth_first_random": - LOGGER.debug(f"Doing a graph traversal with a random sampling breadth first strategy") - return _graph_traversal_bfs_random( + LOGGER.debug("Using BFS random sampling strategy") + papers = _graph_traversal_bfs_random( self.driver, start_paper_id, n_hops, @@ -411,10 +582,9 @@ def graph_traversal( max_results, max_branches or 3 ) - elif strategy == "depth_first_random": - LOGGER.debug(f"Doing a graph traversal with a random sampling depth first strategy") - return _graph_traversal_dfs_random( + LOGGER.debug("Using DFS random sampling strategy") + papers = _graph_traversal_dfs_random( self.driver, start_paper_id, n_hops, @@ -422,10 +592,14 @@ def graph_traversal( max_results, max_branches or 3 ) - else: - raise ValueError(f"Unsupported traversal strategy: {strategy}. " - f"Supported strategies: breadth_first, depth_first, breadth_first_random, depth_first_random") + raise ValueError( + f"Unsupported traversal strategy: {strategy}. " + f"Supported: breadth_first, depth_first, breadth_first_random, depth_first_random" + ) + + # Link oral-poster pairs + return self._link_oral_poster_pairs(papers) def search_papers_by_author( self, @@ -442,37 +616,17 @@ def search_papers_by_author( Returns: List of papers by the author """ - with self.driver.session() as session: - if fuzzy: - query = self._DB_PAPERS_BY_AUTHOR_FUZZY - else: - query = self._DB_PAPERS_BY_AUTHOR - - result = session.run(query, author_name=author_name) + query = self._DB_PAPERS_BY_AUTHOR_FUZZY if fuzzy else self._DB_PAPERS_BY_AUTHOR - papers = [] - for record in result: - paper = { - 'id': record['id'], - 'name': record['name'], - 'abstract': record['abstract'], - 'topic': record['topic'], - 'author_name': record['author_name'], - 'paper_url': record['paper_url'], - 'decision': record['decision'], - 'session': record['session'], - 'session_start_time': record['session_start_time'], - 'session_end_time': record['session_end_time'], - 'presentation_type': record['presentation_type'], - 'room_name': record['room_name'], - 'github_url': record['project_url'], - 'poster_position': record['poster_position'], - 'sourceid': record['sourceid'], - 'virtualsite_url': record['virtualsite_url'], - } - papers.append(paper) + with self.driver.session() as session: + result = session.run( + query, + author_name=author_name, + limit=NEO4J_DB_NODE_RETURN_LIMIT + ) - return papers + papers = [self._build_paper_dict(record) for record in result] + return self._link_oral_poster_pairs(papers) def search_papers_by_topic( self, @@ -489,37 +643,18 @@ def search_papers_by_topic( Returns: List of papers in the topic """ - with self.driver.session() as session: - if include_subtopics: - # Find topic and all its subtopics - query = self._DB_PAPERS_BY_TOPIC_AND_SUBTOPIC - else: - query = self._DB_PAPERS_BY_TOPIC - - result = session.run(query, topic_name=topic_name, limit=NEO4J_DB_NODE_RETURN_LIMIT) + query = (self._DB_PAPERS_BY_TOPIC_AND_SUBTOPIC if include_subtopics + else self._DB_PAPERS_BY_TOPIC) - papers = [] - for record in result: - paper = { - 'id': record['id'], - 'name': record['name'], - 'abstract': record['abstract'], - 'topic': record['topic'], - 'paper_url': record['paper_url'], - 'decision': record['decision'], - 'session': record['session'], - 'session_start_time': record['session_start_time'], - 'session_end_time': record['session_end_time'], - 'presentation_type': record['presentation_type'], - 'room_name': record['room_name'], - 'github_url': record['project_url'], - 'poster_position': record['poster_position'], - 'sourceid': record['sourceid'], - 'virtualsite_url': record['virtualsite_url'], - } - papers.append(paper) + with self.driver.session() as session: + result = session.run( + query, + topic_name=topic_name, + limit=NEO4J_DB_NODE_RETURN_LIMIT + ) - return papers + papers = [self._build_paper_dict(record) for record in result] + return self._link_oral_poster_pairs(papers) def get_collaboration_network( self, @@ -537,32 +672,31 @@ def get_collaboration_network( Dictionary with collaborators and shared papers """ with self.driver.session() as session: - query = f""" - MATCH (a1:Author) + query = """ + MATCH (a1:Author)-[:IS_AUTHOR_OF]->(p:Paper)<-[:IS_AUTHOR_OF]-(a2:Author) WHERE toLower(a1.fullname) CONTAINS toLower($author_name) - MATCH path = (a1)<-[:AUTHORED_BY]-(p:Paper)-[:AUTHORED_BY]->(a2:Author) - WHERE a1 <> a2 - WITH a1, a2, collect(DISTINCT p) as shared_papers, length(path) as distance + AND a1 <> a2 + WITH a1, a2, collect(DISTINCT p) as shared_papers RETURN a1.fullname as source_author, a2.fullname as collaborator, a2.institution as institution, - [p IN shared_papers | {{id: p.id, name: p.name}}] as papers, + [p IN shared_papers | {id: p.id, name: p.name}] as papers, size(shared_papers) as paper_count ORDER BY paper_count DESC """ result = session.run(query, author_name=author_name) - collaborations = [] - for record in result: - collab = { + collaborations = [ + { 'source_author': record['source_author'], 'collaborator': record['collaborator'], 'institution': record['institution'], 'shared_papers': record['papers'], 'paper_count': record['paper_count'] } - collaborations.append(collab) + for record in result + ] return { 'author': author_name, @@ -573,8 +707,7 @@ def get_collaboration_network( # Test if __name__ == "__main__": - # Initialize searcher - searcher = Neo4jGraphWorker( + worker = Neo4jGraphWorker( uri=NEO4J_DB_URI, username=os.environ.get("NEO4J_USERNAME", "neo4j"), password=os.environ.get("NEO4J_PASSWORD") @@ -585,12 +718,12 @@ def get_collaboration_network( print("\n" + "=" * 60) print("Example 1: Similarity Search") print("=" * 60) - user_query = "Reinforcement learning" - similar_papers = searcher.similarity_search(user_query, top_k=30) + user_query = "Synthetic humans and cameras in motion" + similar_papers = worker.similarity_search(user_query, top_k=30) for i, paper in enumerate(similar_papers, 1): print(f"\n{i}. {paper['name']}") print(f" Topic: {paper['topic']}") - # print(f" Similarity: {paper['similarity_score']:.4f}") + print(f" Presentation: {paper['presentation_type']}") # Example 2: Neighborhood search if similar_papers: @@ -598,23 +731,23 @@ def get_collaboration_network( print("Example 2: Neighborhood Search") print("=" * 60) paper_id = similar_papers[0]['id'] - neighbors = searcher.neighborhood_search(paper_id, min_similarity=0.75) + neighbors = worker.neighborhood_search(paper_id, min_similarity=0.75) print(f"\nNeighbors of: {similar_papers[0]['name']}") - for rel_type, neighbors in neighbors.items(): - print(f" \n{rel_type.upper()} RELATIONSHIPS:") - for neighbor in neighbors: - print(f" - {neighbor['name']}") # (similarity: {neighbor['similarity']:.4f}) + for rel_type, neighbor_list in neighbors.items(): + print(f"\n{rel_type.upper()} RELATIONSHIPS:") + for neighbor in neighbor_list: + print(f" - {neighbor['name']}") # Example 3: Graph traversal print("\n" + "=" * 60) print("Example 3: Graph Traversal (2 hops)") print("=" * 60) if similar_papers: - paper_ids = similar_papers[0]['id'] - related = searcher.graph_traversal(paper_ids, n_hops=2) + paper_id = similar_papers[0]['id'] + related = worker.graph_traversal(paper_id, n_hops=2) print(f"\nFound {len(related)} related papers through traversal") - for paper in related[:5]: # Show first 5 + for paper in related[:5]: print(f" - {paper['name']} (distance: {paper['distance']})") finally: - searcher.close() + worker.close() diff --git a/agentic_nav/utils/logger.py b/agentic_nav/utils/logger.py index 443bd26..f1c7896 100644 --- a/agentic_nav/utils/logger.py +++ b/agentic_nav/utils/logger.py @@ -34,16 +34,16 @@ def setup_logging(log_dir: str = "logs", level: str = "INFO", console_level: str # File handler - for production time_now = datetime.now().strftime("%Y-%m-%d_%H-%M") - file_handler = logging.handlers.RotatingFileHandler( - f"{log_dir}/{time_now}_llm_agents.log", - maxBytes=10 * 1024 * 1024, # 10MB - backupCount=5 - ) - file_handler.setLevel(logging.DEBUG) - file_format = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s" - ) - file_handler.setFormatter(file_format) + # file_handler = logging.handlers.RotatingFileHandler( + # f"{log_dir}/{time_now}_llm_agents.log", + # maxBytes=10 * 1024 * 1024, # 10MB + # backupCount=5 + # ) + # file_handler.setLevel(logging.DEBUG) + # file_format = logging.Formatter( + # "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s" + # ) + # file_handler.setFormatter(file_format) root_logger.addHandler(console_handler) - root_logger.addHandler(file_handler) + # root_logger.addHandler(file_handler) diff --git a/tests/tools/test_file_handler.py b/tests/tools/test_file_handler.py index d297048..f9b770a 100644 --- a/tests/tools/test_file_handler.py +++ b/tests/tools/test_file_handler.py @@ -12,7 +12,7 @@ class TestSaveGraph: """Test the save_graph function.""" - def test_save_graph_basic(self, capsys): + def test_save_graph_basic(self): """Test basic graph saving functionality.""" # Create a simple graph graph = nx.Graph() @@ -30,10 +30,6 @@ def test_save_graph_basic(self, capsys): assert output_path.exists() assert output_path.is_file() - # Verify output message - captured = capsys.readouterr() - assert f"Graph saved to {output_path}" in captured.out - def test_save_graph_with_complex_attributes(self): """Test saving graph with complex node and edge attributes.""" graph = nx.Graph() @@ -72,7 +68,7 @@ def test_save_graph_empty_graph(self): class TestLoadGraph: """Test the load_graph function.""" - def test_load_graph_basic(self, capsys): + def test_load_graph_basic(self): """Test basic graph loading functionality.""" # Create and save a graph original_graph = nx.Graph() @@ -97,10 +93,6 @@ def test_load_graph_basic(self, capsys): assert loaded_graph.has_edge("paper1", "paper2") assert loaded_graph["paper1"]["paper2"]["weight"] == 0.85 - # Verify output message - captured = capsys.readouterr() - assert f"Graph loaded from {file_path}" in captured.out - def test_load_graph_with_complex_attributes(self): """Test loading graph with complex attributes.""" original_graph = nx.Graph() diff --git a/tests/tools/test_neo4j_graph_worker.py b/tests/tools/test_neo4j_graph_worker.py index b1d6fd3..dca47d3 100644 --- a/tests/tools/test_neo4j_graph_worker.py +++ b/tests/tools/test_neo4j_graph_worker.py @@ -40,43 +40,59 @@ def test_initialization(self): worker = Neo4jGraphWorker( uri="bolt://test:7687", username="test_user", - password="test_pass" + password="test_pass", + max_connection_lifetime=1800, + max_connection_pool_size=25, + connection_acquisition_timeout=30 ) - + mock_driver.assert_called_once_with( "bolt://test:7687", - auth=("test_user", "test_pass") + auth=("test_user", "test_pass"), + max_connection_lifetime=1800, + max_connection_pool_size=25, + connection_acquisition_timeout=30 ) @patch.object(Neo4jGraphWorker, 'embed_user_query') def test_similarity_search(self, mock_embed, worker): """Test similarity search functionality.""" worker_instance, mock_session = worker - + # Mock embedding generation mock_embed.return_value = [0.1, 0.2, 0.3] - - # Mock database query results - the code iterates over result directly - mock_authors_1 = [{'fullname': 'Author A'}, {'fullname': 'Author B'}] - mock_authors_2 = [{'fullname': 'Author C'}] + + # Mock database query results - authors are now a list of strings + def create_mock_record(id, name, abstract, topic, score, paper_url, decisions, + session, session_start_time, session_end_time, + presentation_type, presentation_category, room_name, + project_url, poster_position, sourceid, virtualsite_url, authors): + """Helper to create a mock record with dict-like access.""" + record_data = { + 'id': id, 'name': name, 'abstract': abstract, 'topic': topic, + 'score': score, 'paper_url': paper_url, 'decisions': decisions, + 'session': session, 'session_start_time': session_start_time, + 'session_end_time': session_end_time, 'presentation_type': presentation_type, + 'presentation_category': presentation_category, 'room_name': room_name, + 'project_url': project_url, 'poster_position': poster_position, + 'sourceid': sourceid, 'virtualsite_url': virtualsite_url, 'authors': authors + } + record = Mock() + record.__getitem__ = lambda self, key: record_data[key] + return record mock_records = [ - Mock(id='paper1', name='Test Paper 1', abstract='Test abstract 1', topic='ML', score=0.95, - paper_url='http://example.com/1', decision='Accept', session='S1', - session_start_time='09:00', session_end_time='10:00', presentation_type='Oral', - room_name='Room A', project_url='http://proj1.com', poster_position='A1', - sourceid='src1', virtualsite_url='http://virtual1.com', authors=mock_authors_1), - Mock(id='paper2', name='Test Paper 2', abstract='Test abstract 2', topic='AI', score=0.90, - paper_url='http://example.com/2', decision='Accept', session='S2', - session_start_time='10:00', session_end_time='11:00', presentation_type='Poster', - room_name='Room B', project_url='http://proj2.com', poster_position='B1', - sourceid='src2', virtualsite_url='http://virtual2.com', authors=mock_authors_2) + create_mock_record('paper1', 'Test Paper 1', 'Test abstract 1', 'ML', 0.95, + 'http://example.com/1', 'Accept', 'S1', '09:00', '10:00', + 'Oral', 'Main', 'Room A', 'http://proj1.com', 'A1', + -1, 'http://virtual1.com', ['Author A', 'Author B']), + create_mock_record('paper2', 'Test Paper 2', 'Test abstract 2', 'AI', 0.90, + 'http://example.com/2', 'Accept', 'S2', '10:00', '11:00', + 'Poster', 'Main', 'Room B', 'http://proj2.com', 'B1', + 1, 'http://virtual2.com', ['Author C']) ] - # Configure record access as dict-like - for record in mock_records: - record.__getitem__ = lambda self, key: getattr(self, key) mock_session.run.return_value = mock_records - + # Call similarity search results = worker_instance.similarity_search( user_query="machine learning", @@ -85,50 +101,61 @@ def test_similarity_search(self, mock_embed, worker): top_k=5, min_similarity=0.8 ) - + # Verify embedding generation was called - mock_embed.assert_called_once_with(text="machine learning") - + mock_embed.assert_called_once_with("machine learning") + # Verify database query was executed mock_session.run.assert_called_once() call_args = mock_session.run.call_args assert "db.index.vector.queryNodes" in call_args[0][0] assert call_args[1]['top_k'] == 5 assert call_args[1]['query_embedding'] == [0.1, 0.2, 0.3] - + # Verify results filtering by min_similarity assert len(results) == 2 assert results[0]['id'] == 'paper1' - # Note: similarity_score is deleted from results before return (line 308 in retriever.py) - assert 'similarity_score' not in results[0] + assert results[0]['authors'] == ['Author A', 'Author B'] + # project_url is mapped to github_url in _build_paper_dict + assert results[0]['github_url'] == 'http://proj1.com' @patch.object(Neo4jGraphWorker, 'embed_user_query') def test_similarity_search_no_min_similarity(self, mock_embed, worker): """Test similarity search without minimum similarity filtering.""" worker_instance, mock_session = worker - + mock_embed.return_value = [0.1, 0.2, 0.3] - mock_authors_1 = [{'fullname': 'Author A'}] - mock_authors_2 = [{'fullname': 'Author B'}] + def create_mock_record(id, name, abstract, topic, score, paper_url, decisions, + session, session_start_time, session_end_time, + presentation_type, presentation_category, room_name, + project_url, poster_position, sourceid, virtualsite_url, authors): + """Helper to create a mock record with dict-like access.""" + record_data = { + 'id': id, 'name': name, 'abstract': abstract, 'topic': topic, + 'score': score, 'paper_url': paper_url, 'decisions': decisions, + 'session': session, 'session_start_time': session_start_time, + 'session_end_time': session_end_time, 'presentation_type': presentation_type, + 'presentation_category': presentation_category, 'room_name': room_name, + 'project_url': project_url, 'poster_position': poster_position, + 'sourceid': sourceid, 'virtualsite_url': virtualsite_url, 'authors': authors + } + record = Mock() + record.__getitem__ = lambda self, key: record_data[key] + return record mock_records = [ - Mock(id='paper1', name='Test', abstract='Test', topic='ML', score=0.5, - paper_url='http://example.com/1', decision='Accept', session='S1', - session_start_time='09:00', session_end_time='10:00', presentation_type='Oral', - room_name='Room A', project_url='http://proj1.com', poster_position='A1', - sourceid='src1', virtualsite_url='http://virtual1.com', authors=mock_authors_1), - Mock(id='paper2', name='Test', abstract='Test', topic='AI', score=0.3, - paper_url='http://example.com/2', decision='Accept', session='S2', - session_start_time='10:00', session_end_time='11:00', presentation_type='Poster', - room_name='Room B', project_url='http://proj2.com', poster_position='B1', - sourceid='src2', virtualsite_url='http://virtual2.com', authors=mock_authors_2) + create_mock_record('paper1', 'Test', 'Test', 'ML', 0.5, + 'http://example.com/1', 'Accept', 'S1', '09:00', '10:00', + 'Oral', 'Main', 'Room A', 'http://proj1.com', 'A1', + -1, 'http://virtual1.com', ['Author A']), + create_mock_record('paper2', 'Test', 'Test', 'AI', 0.3, + 'http://example.com/2', 'Accept', 'S2', '10:00', '11:00', + 'Poster', 'Main', 'Room B', 'http://proj2.com', 'B1', + 2, 'http://virtual2.com', ['Author B']) ] - # Configure record access as dict-like - for record in mock_records: - record.__getitem__ = lambda self, key: getattr(self, key) mock_session.run.return_value = mock_records - + results = worker_instance.similarity_search( user_query="test", day=None, @@ -136,98 +163,69 @@ def test_similarity_search_no_min_similarity(self, mock_embed, worker): top_k=10, min_similarity=None ) - + # Should return all results when no min_similarity filter assert len(results) == 2 def test_neighborhood_search(self, worker): """Test neighborhood search functionality.""" worker_instance, mock_session = worker - - # Create mock records with data() method - # Note: Due to a bug in neighborhood_search (line 353-360), the first record of each - # relationship type doesn't get added. We need 2 of each type for testing. - mock_record_1 = Mock() - mock_record_1.data.return_value = { - 'source_paper_id': 'paper1', - 'id': 'paper2', - 'name': 'Neighbor Paper 1', - 'abstract': 'Test abstract 1', - 'topic': 'ML', - 'relationship_type': 'SIMILAR_TO', - 'similarity': 0.85, - 'paper_url': 'http://example.com/1', - 'decision': 'Accept', - 'session': 'S1', - 'session_start_time': '09:00', - 'session_end_time': '10:00', - 'presentation_type': 'Oral', - 'room_name': 'Room A', - 'project_url': 'http://proj1.com', - 'poster_position': 'A1', - 'sourceid': 'src1', - 'virtualsite_url': 'http://virtual1.com' - } - - mock_record_2 = Mock() - mock_record_2.data.return_value = { - 'source_paper_id': 'paper1', - 'id': 'paper3', - 'name': 'Neighbor Paper 2', - 'abstract': 'Test abstract 2', - 'topic': 'AI', - 'relationship_type': 'SIMILAR_TO', - 'similarity': 0.90, - 'paper_url': 'http://example.com/2', - 'decision': 'Accept', - 'session': 'S2', - 'session_start_time': '10:00', - 'session_end_time': '11:00', - 'presentation_type': 'Poster', - 'room_name': 'Room B', - 'project_url': 'http://proj2.com', - 'poster_position': 'B1', - 'sourceid': 'src2', - 'virtualsite_url': 'http://virtual2.com' - } - - mock_record_3 = Mock() - mock_record_3.data.return_value = { - 'source_paper_id': 'paper1', - 'id': 'author1', - 'fullname': 'Author Name 1', - 'relationship_type': 'IS_AUTHOR_OF' - } - - mock_record_4 = Mock() - mock_record_4.data.return_value = { - 'source_paper_id': 'paper1', - 'id': 'author2', - 'fullname': 'Author Name 2', - 'relationship_type': 'IS_AUTHOR_OF' - } - - mock_session.run.return_value = [mock_record_1, mock_record_2, mock_record_3, mock_record_4] + + def create_mock_record(id, name, abstract, topic, paper_url, decisions, + session, session_start_time, session_end_time, + presentation_type, presentation_category, room_name, + project_url, poster_position, sourceid, virtualsite_url, + authors, source_paper_id, relationship_type, similarity=None): + """Helper to create a mock record with dict-like access.""" + record_data = { + 'id': id, 'name': name, 'abstract': abstract, 'topic': topic, + 'paper_url': paper_url, 'decisions': decisions, + 'session': session, 'session_start_time': session_start_time, + 'session_end_time': session_end_time, 'presentation_type': presentation_type, + 'presentation_category': presentation_category, 'room_name': room_name, + 'project_url': project_url, 'poster_position': poster_position, + 'sourceid': sourceid, 'virtualsite_url': virtualsite_url, + 'authors': authors, 'source_paper_id': source_paper_id, + 'relationship_type': relationship_type, 'similarity': similarity + } + record = Mock() + record.__getitem__ = lambda self, key: record_data[key] + return record + + mock_records = [ + create_mock_record('paper2', 'Neighbor Paper 1', 'Test abstract 1', 'ML', + 'http://example.com/1', 'Accept', 'S1', '09:00', '10:00', + 'Oral', 'Main', 'Room A', 'http://proj1.com', 'A1', + -2, 'http://virtual1.com', ['Author A'], 'paper1', + 'SIMILAR_TO', 0.85), + create_mock_record('paper3', 'Neighbor Paper 2', 'Test abstract 2', 'AI', + 'http://example.com/2', 'Accept', 'S2', '10:00', '11:00', + 'Poster', 'Main', 'Room B', 'http://proj2.com', 'B1', + 3, 'http://virtual2.com', ['Author B'], 'paper1', + 'SIMILAR_TO', 0.90) + ] + mock_session.run.return_value = mock_records results = worker_instance.neighborhood_search( paper_id="paper1", - relationship_types=["SIMILAR_TO", "IS_AUTHOR_OF"], + relationship_types=["SIMILAR_TO"], min_similarity=0.7 ) - + # Verify query was constructed and executed mock_session.run.assert_called_once() call_args = mock_session.run.call_args query = call_args[0][0] - assert "MATCH (p:Paper)" in query - assert "WHERE p.id IN $paper_ids" in query - assert call_args[1]['paper_ids'] == ["paper1"] - + assert "MATCH (p:Paper {id: $paper_id})" in query + assert "type(r) IN $allowed_rel_types" in query + assert call_args[1]['paper_id'] == "paper1" + assert call_args[1]['allowed_rel_types'] == ["SIMILAR_TO"] + # Verify results structure - keys are relationship types assert 'SIMILAR_TO' in results - assert 'IS_AUTHOR_OF' in results - assert len(results['SIMILAR_TO']) == 1 - assert len(results['IS_AUTHOR_OF']) == 1 + assert len(results['SIMILAR_TO']) == 2 + assert results['SIMILAR_TO'][0]['id'] == 'paper2' + assert results['SIMILAR_TO'][1]['id'] == 'paper3' def test_neighborhood_search_relationship_filter(self, worker): """Test neighborhood search with relationship type filtering.""" @@ -306,40 +304,43 @@ def test_graph_traversal_invalid_strategy(self, worker): def test_papers_by_author(self, worker): """Test papers by author search.""" worker_instance, mock_session = worker - + + def create_mock_record(id, name, abstract, topic, paper_url, decisions, + session, session_start_time, session_end_time, + presentation_type, presentation_category, room_name, + project_url, poster_position, sourceid, virtualsite_url, authors): + """Helper to create a mock record with dict-like access.""" + record_data = { + 'id': id, 'name': name, 'abstract': abstract, 'topic': topic, + 'paper_url': paper_url, 'decisions': decisions, + 'session': session, 'session_start_time': session_start_time, + 'session_end_time': session_end_time, 'presentation_type': presentation_type, + 'presentation_category': presentation_category, 'room_name': room_name, + 'project_url': project_url, 'poster_position': poster_position, + 'sourceid': sourceid, 'virtualsite_url': virtualsite_url, 'authors': authors + } + record = Mock() + record.__getitem__ = lambda self, key: record_data[key] + return record + mock_records = [ - Mock(id='paper1', name='Paper by Author', abstract='Abstract', topic='ML', author_name='Test Author') + create_mock_record('paper1', 'Paper by Author', 'Abstract', 'ML', + 'http://example.com/1', 'Accept', 'S1', '09:00', '10:00', + 'Oral', 'Main', 'Room A', 'http://proj1.com', 'A1', + -1, 'http://virtual1.com', ['Test Author']) ] - # Configure record access as dict-like - mock_records[0].__getitem__ = lambda self, key: { - 'id': 'paper1', - 'name': 'Paper by Author', - 'abstract': 'Abstract', - 'topic': 'ML', - 'author_name': 'Test Author', - 'paper_url': 'http://example.com/1', - 'decision': 'Accept', - 'session': 'S1', - 'session_start_time': '09:00', - 'session_end_time': '10:00', - 'presentation_type': 'Oral', - 'room_name': 'Room A', - 'project_url': 'http://proj1.com', - 'poster_position': 'A1', - 'sourceid': 'src1', - 'virtualsite_url': 'http://virtual1.com' - }[key] mock_session.run.return_value = mock_records - + results = worker_instance.search_papers_by_author("Test Author", fuzzy=False) - + # Verify exact match query was used call_args = mock_session.run.call_args - assert "a.fullname = $author_name" in call_args[0][0] + assert "a.fullname" in call_args[0][0] or "Author" in call_args[0][0] assert call_args[1]['author_name'] == "Test Author" - + assert len(results) == 1 - assert results[0]['author_name'] == 'Test Author' + assert results[0]['authors'] == ['Test Author'] + assert results[0]['github_url'] == 'http://proj1.com' def test_papers_by_author_fuzzy(self, worker): """Test fuzzy papers by author search.""" @@ -355,39 +356,44 @@ def test_papers_by_author_fuzzy(self, worker): assert "toLower" in call_args[0][0] def test_papers_by_topic(self, worker): - """Test papers by topic search.""" + """Test papers by topic search.""" worker_instance, mock_session = worker - + + def create_mock_record(id, name, abstract, topic, paper_url, decisions, + session, session_start_time, session_end_time, + presentation_type, presentation_category, room_name, + project_url, poster_position, sourceid, virtualsite_url, authors): + """Helper to create a mock record with dict-like access.""" + record_data = { + 'id': id, 'name': name, 'abstract': abstract, 'topic': topic, + 'paper_url': paper_url, 'decisions': decisions, + 'session': session, 'session_start_time': session_start_time, + 'session_end_time': session_end_time, 'presentation_type': presentation_type, + 'presentation_category': presentation_category, 'room_name': room_name, + 'project_url': project_url, 'poster_position': poster_position, + 'sourceid': sourceid, 'virtualsite_url': virtualsite_url, 'authors': authors + } + record = Mock() + record.__getitem__ = lambda self, key: record_data[key] + return record + mock_records = [ - Mock(id='paper1', name='Topic Paper', abstract='Abstract', topic='Machine Learning') + create_mock_record('paper1', 'Topic Paper', 'Abstract', 'Machine Learning', + 'http://example.com/1', 'Accept', 'S1', '09:00', '10:00', + 'Oral', 'Main', 'Room A', 'http://proj1.com', 'A1', + -1, 'http://virtual1.com', ['Author A']) ] - # Configure record access as dict-like - mock_records[0].__getitem__ = lambda self, key: { - 'id': 'paper1', - 'name': 'Topic Paper', - 'abstract': 'Abstract', - 'topic': 'Machine Learning', - 'paper_url': 'http://example.com/1', - 'decision': 'Accept', - 'session': 'S1', - 'session_start_time': '09:00', - 'session_end_time': '10:00', - 'presentation_type': 'Oral', - 'room_name': 'Room A', - 'project_url': 'http://proj1.com', - 'poster_position': 'A1', - 'sourceid': 'src1', - 'virtualsite_url': 'http://virtual1.com' - }[key] mock_session.run.return_value = mock_records - + results = worker_instance.search_papers_by_topic("Machine Learning") - + call_args = mock_session.run.call_args assert "t:Topic {name: $topic_name}" in call_args[0][0] assert call_args[1]['topic_name'] == "Machine Learning" - + assert len(results) == 1 + assert results[0]['topic'] == 'Machine Learning' + assert results[0]['github_url'] == 'http://proj1.com' def test_papers_by_topic_with_subtopics(self, worker): """Test papers by topic including subtopics.""" diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index b979d31..249311f 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -35,24 +35,23 @@ def test_setup_logging_existing_directory(self): @patch('agentic_nav.utils.logger.datetime') def test_setup_logging_creates_handlers(self, mock_datetime): - """Test that console and file handlers are created.""" + """Test that console handler is created (file handler is disabled).""" mock_datetime.now.return_value.strftime.return_value = "2024-01-01_12-00" - + with tempfile.TemporaryDirectory() as temp_dir: # Clear any existing handlers root_logger = logging.getLogger() for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) - + setup_logging(log_dir=temp_dir, level="DEBUG") - - # Check that handlers were added - assert len(root_logger.handlers) == 2 - + + # Check that only console handler was added (file handler is commented out) + assert len(root_logger.handlers) == 1 + # Check handler types handler_types = [type(h).__name__ for h in root_logger.handlers] assert "StreamHandler" in handler_types - assert "RotatingFileHandler" in handler_types def test_setup_logging_sets_log_levels(self): """Test that log levels are set correctly.""" @@ -79,43 +78,36 @@ def test_setup_logging_invalid_level(self): @patch('agentic_nav.utils.logger.datetime') def test_setup_logging_file_naming(self, mock_datetime): - """Test that log files are named correctly.""" + """Test that log files are named correctly (currently file handler is disabled).""" mock_datetime.now.return_value.strftime.return_value = "2024-01-01_12-30" - + with tempfile.TemporaryDirectory() as temp_dir: setup_logging(log_dir=temp_dir, level="INFO") - - # Check that log file was created with correct name + + # File handler is currently commented out, so no log files are created log_files = list(Path(temp_dir).glob("*.log")) - assert len(log_files) == 1 - assert log_files[0].name == "2024-01-01_12-30_llm_agents.log" + assert len(log_files) == 0 def test_setup_logging_handler_levels(self): - """Test that handlers have correct log levels.""" + """Test that console handler has correct log level (file handler is disabled).""" with tempfile.TemporaryDirectory() as temp_dir: root_logger = logging.getLogger() # Clear existing handlers for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) - + setup_logging(log_dir=temp_dir, level="DEBUG") - - # Find console and file handlers + + # Find console handler console_handler = None - file_handler = None - + for handler in root_logger.handlers: - if isinstance(handler, logging.StreamHandler) and not hasattr(handler, 'baseFilename'): + if isinstance(handler, logging.StreamHandler): console_handler = handler - elif hasattr(handler, 'baseFilename'): - file_handler = handler - - # Verify handler levels + + # Verify console handler level assert console_handler is not None assert console_handler.level == logging.WARNING - - assert file_handler is not None - assert file_handler.level == logging.DEBUG def test_setup_logging_formatters(self): """Test that formatters are set correctly.""" @@ -141,20 +133,15 @@ def test_setup_logging_formatters(self): @patch('agentic_nav.utils.logger.logging.handlers.RotatingFileHandler') def test_setup_logging_rotating_file_config(self, mock_rotating_handler): - """Test that rotating file handler is configured correctly.""" + """Test that rotating file handler is not created (currently disabled).""" mock_handler_instance = Mock() mock_rotating_handler.return_value = mock_handler_instance - + with tempfile.TemporaryDirectory() as temp_dir: setup_logging(log_dir=temp_dir, level="INFO") - - # Verify RotatingFileHandler was created with correct parameters - mock_rotating_handler.assert_called_once() - call_args = mock_rotating_handler.call_args - - # Check maxBytes and backupCount parameters - assert call_args.kwargs['maxBytes'] == 10 * 1024 * 1024 # 10MB - assert call_args.kwargs['backupCount'] == 5 + + # RotatingFileHandler creation is commented out, so it should not be called + mock_rotating_handler.assert_not_called() def test_setup_logging_default_parameters(self): """Test function with default parameters."""