-
Notifications
You must be signed in to change notification settings - Fork 1
Add semantic search with Voyage AI embeddings #185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from intelstream.database.models import Base, ContentItem, DiscordConfig, Source | ||
| from intelstream.database.models import Base, ContentEmbedding, ContentItem, DiscordConfig, Source | ||
| from intelstream.database.repository import Repository | ||
|
|
||
| __all__ = ["Base", "ContentItem", "DiscordConfig", "Repository", "Source"] | ||
| __all__ = ["Base", "ContentEmbedding", "ContentItem", "DiscordConfig", "Repository", "Source"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| from datetime import UTC, datetime, timedelta | ||
|
|
||
| import structlog | ||
| from sqlalchemy import exists, func, select, text | ||
| from sqlalchemy import delete, exists, func, select, text | ||
| from sqlalchemy.exc import IntegrityError, OperationalError | ||
| from sqlalchemy.ext.asyncio import ( | ||
| AsyncConnection, | ||
|
|
@@ -18,6 +18,7 @@ | |
| ) | ||
| from intelstream.database.models import ( | ||
| Base, | ||
| ContentEmbedding, | ||
| ContentItem, | ||
| DiscordConfig, | ||
| ExtractionCache, | ||
|
|
@@ -890,3 +891,84 @@ async def set_github_repo_active(self, repo_id: str, is_active: bool) -> bool: | |
| await session.commit() | ||
| return True | ||
| return False | ||
|
|
||
| async def add_content_embedding( | ||
| self, content_item_id: str, embedding_json: str, model_name: str, text_hash: str | ||
| ) -> ContentEmbedding: | ||
| async with self.session() as session: | ||
| result = await session.execute( | ||
| select(ContentEmbedding).where(ContentEmbedding.content_item_id == content_item_id) | ||
| ) | ||
| existing = result.scalar_one_or_none() | ||
| if existing: | ||
| existing.embedding_json = embedding_json | ||
| existing.model_name = model_name | ||
| existing.text_hash = text_hash | ||
| else: | ||
| existing = ContentEmbedding( | ||
| content_item_id=content_item_id, | ||
| embedding_json=embedding_json, | ||
| model_name=model_name, | ||
| text_hash=text_hash, | ||
| ) | ||
| session.add(existing) | ||
| await session.commit() | ||
| await session.refresh(existing) | ||
| return existing | ||
|
|
||
| async def get_all_embeddings(self) -> list[ContentEmbedding]: | ||
| async with self.session() as session: | ||
| result = await session.execute(select(ContentEmbedding)) | ||
| return list(result.scalars().all()) | ||
|
|
||
| async def get_items_without_embeddings(self, limit: int = 100) -> list[ContentItem]: | ||
| async with self.session() as session: | ||
| subquery = select(ContentEmbedding.content_item_id) | ||
| result = await session.execute( | ||
| select(ContentItem) | ||
| .where(ContentItem.id.notin_(subquery)) | ||
| .where(ContentItem.summary.isnot(None)) | ||
| .order_by(ContentItem.created_at.desc()) | ||
| .limit(limit) | ||
| ) | ||
| return list(result.scalars().all()) | ||
|
|
||
| async def get_embeddings_with_items( | ||
| self, | ||
| guild_id: str | None = None, | ||
| source_type: str | None = None, | ||
| since: datetime | None = None, | ||
| ) -> list[tuple[ContentEmbedding, ContentItem, Source]]: | ||
| async with self.session() as session: | ||
| query = ( | ||
| select(ContentEmbedding, ContentItem, Source) | ||
| .join(ContentItem, ContentEmbedding.content_item_id == ContentItem.id) | ||
| .join(Source, ContentItem.source_id == Source.id) | ||
| ) | ||
| if guild_id: | ||
| query = query.where(Source.guild_id == guild_id) | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Suggestion] The return [tuple(row) for row in result.all()]or even Minor style point only -- the current code is clear enough. Not blocking. |
||
| if source_type: | ||
| try: | ||
| st = SourceType(source_type) | ||
| except ValueError: | ||
| return [] | ||
| query = query.where(Source.type == st) | ||
| if since: | ||
| query = query.where(ContentItem.published_at >= since) | ||
| result = await session.execute(query) | ||
| return [(row[0], row[1], row[2]) for row in result.all()] | ||
|
|
||
| async def get_latest_embedding(self) -> ContentEmbedding | None: | ||
| async with self.session() as session: | ||
| result = await session.execute( | ||
| select(ContentEmbedding).order_by(ContentEmbedding.created_at.desc()).limit(1) | ||
| ) | ||
| return result.scalar_one_or_none() | ||
|
|
||
| async def clear_all_embeddings(self) -> int: | ||
| async with self.session() as session: | ||
| result = await session.execute(select(func.count()).select_from(ContentEmbedding)) | ||
| count = result.scalar_one() | ||
| await session.execute(delete(ContentEmbedding)) | ||
| await session.commit() | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Blocking - Bug] Wait, re-reading: the code DOES do a count first (line ~958-959 in the actual file), then executes the delete. That's correct and matches the design. The |
||
| return count | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,10 +36,23 @@ async def cog_load(self) -> None: | |
| max_input_length=self.bot.settings.summary_max_input_length, | ||
| ) | ||
|
|
||
| search_service = None | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Good] Clean wiring. The conditional import and creation of One minor note: the |
||
| if self.bot.settings.voyage_api_key: | ||
| from intelstream.services.search import SearchService, VoyageEmbeddingProvider | ||
|
|
||
| provider = VoyageEmbeddingProvider( | ||
| self.bot.settings.voyage_api_key, | ||
| model=self.bot.settings.search_embedding_model, | ||
| ) | ||
| search_service = SearchService(self.bot.repository, provider) | ||
|
|
||
| self.bot.search_service = search_service | ||
|
|
||
| self._pipeline = ContentPipeline( | ||
| settings=self.bot.settings, | ||
| repository=self.bot.repository, | ||
| summarizer=summarizer, | ||
| search_service=search_service, | ||
| ) | ||
| await self._pipeline.initialize() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import discord | ||
| import structlog | ||
| from discord import app_commands | ||
| from discord.ext import commands | ||
|
|
||
| from intelstream.database.models import SourceType | ||
| from intelstream.services.content_poster import SOURCE_TYPE_LABELS | ||
|
|
||
| if TYPE_CHECKING: | ||
| from intelstream.bot import IntelStreamBot | ||
|
|
||
| logger = structlog.get_logger() | ||
|
|
||
|
|
||
| class SearchCog(commands.Cog): | ||
| def __init__(self, bot: IntelStreamBot) -> None: | ||
| self.bot = bot | ||
|
|
||
| async def cog_load(self) -> None: | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Suggestion] The However, there's a missing async def setup(bot: "IntelStreamBot") -> None:
await bot.add_cog(SearchCog(bot))While it's not strictly required since |
||
| if not self.bot.settings.voyage_api_key: | ||
| logger.info("Voyage API key not configured, search disabled") | ||
| return | ||
|
|
||
| @app_commands.command(name="search", description="Search across all content") | ||
| @app_commands.describe( | ||
| query="What to search for (3-200 characters)", | ||
| days="Limit to last N days (optional)", | ||
| source_type="Filter by source type (optional)", | ||
| ) | ||
| @app_commands.choices( | ||
| source_type=[ | ||
| app_commands.Choice(name=label, value=st.value) | ||
| for st, label in SOURCE_TYPE_LABELS.items() | ||
| ] | ||
| ) | ||
| @app_commands.checks.cooldown(5, 60.0) | ||
| async def search( | ||
| self, | ||
| interaction: discord.Interaction, | ||
| query: str, | ||
| days: int | None = None, | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Blocking] The Add validation: if days is not None and days < 1:
await interaction.response.send_message(
"Days must be at least 1.", ephemeral=True
)
returnAlso consider an upper bound (e.g., |
||
| source_type: str | None = None, | ||
| ) -> None: | ||
| if not self.bot.search_service: | ||
| await interaction.response.send_message( | ||
| "Semantic search is not configured.", ephemeral=True | ||
| ) | ||
| return | ||
|
|
||
| if len(query) < 3 or len(query) > 200: | ||
| await interaction.response.send_message( | ||
| "Query must be between 3 and 200 characters.", ephemeral=True | ||
| ) | ||
| return | ||
|
|
||
| if days is not None and days < 1: | ||
| await interaction.response.send_message( | ||
| "Days must be a positive number.", ephemeral=True | ||
| ) | ||
| return | ||
|
|
||
| await interaction.response.defer(ephemeral=True) | ||
|
|
||
| guild_id = str(interaction.guild_id) if interaction.guild_id else None | ||
|
|
||
| try: | ||
| results = await self.bot.search_service.search( | ||
| query=query, | ||
| guild_id=guild_id, | ||
| source_type=source_type, | ||
| days=days, | ||
| limit=self.bot.settings.search_max_results, | ||
| threshold=self.bot.settings.search_similarity_threshold, | ||
| ) | ||
| except Exception: | ||
| logger.exception("Search failed", query=query) | ||
| await interaction.followup.send( | ||
| "An error occurred while searching. Please try again.", ephemeral=True | ||
| ) | ||
| return | ||
|
|
||
| if not results: | ||
| await interaction.followup.send("No relevant results found.", ephemeral=True) | ||
| return | ||
|
|
||
| embed = discord.Embed( | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Blocking] The query string is embedded directly into the Discord embed title without sanitization. While Discord embeds don't execute code, a very long query (up to 200 chars) combined with the Actually, re-checking: 200 + 11 = 211 < 256, so this is fine. Disregard this concern. However, the One real concern: if |
||
| title=f'Search: "{query}"', | ||
| color=0x3498DB, | ||
| description=f"{len(results)} results found", | ||
| ) | ||
| for i, result in enumerate(results, 1): | ||
| source_label = SOURCE_TYPE_LABELS.get(SourceType(result.source_type), "Unknown") | ||
| date_str = result.published_at.strftime("%b %d") | ||
| embed.add_field( | ||
| name=f"{i}. {result.title[:250]}", | ||
| value=f"[Link]({result.original_url}) | {source_label} | {date_str} | Score: {result.score:.2f}", | ||
| inline=False, | ||
| ) | ||
| await interaction.followup.send(embed=embed, ephemeral=True) | ||
|
|
||
| @search.error | ||
| async def search_error( | ||
| self, interaction: discord.Interaction, error: app_commands.AppCommandError | ||
| ) -> None: | ||
| if isinstance(error, app_commands.CommandOnCooldown): | ||
| await interaction.response.send_message( | ||
| f"Search is on cooldown. Try again in {error.retry_after:.0f}s.", | ||
| ephemeral=True, | ||
| ) | ||
| else: | ||
| raise error | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Blocking]
SourceType(source_type)here will raiseValueErrorif the user passes an invalid source type string that doesn't map to aSourceTypeenum value. This would cause an unhandled exception that bubbles up through the search path.The search cog uses
@app_commands.choiceswhich constrains the Discord UI, but a malformed API request could still pass an invalid string. The exception would be caught by the cog's broadexcept Exceptionhandler, which would show a generic error message -- acceptable but not ideal for UX.Consider wrapping in a try/except in either the repository method or the
_get_candidatescaller: