Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 79 additions & 36 deletions saist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,42 @@
import asyncio
import logging
import os

from typing import Optional

from dotenv import load_dotenv
from latex import Latex

from llm.adapters import BaseLlmAdapter
from llm.adapters.anthropic import AnthropicAdapter
from llm.adapters.bedrock import BedrockAdapter
from llm.adapters.deepseek import DeepseekAdapter
from llm.adapters.faike import FaikeAdapter
from llm.adapters.gemini import GeminiAdapter
from llm.adapters.openai import OpenAiAdapter
from llm.adapters.ollama import OllamaAdapter
from llm.adapters.faike import FaikeAdapter
from models import FindingContext, FindingEnriched, Finding, Findings
from llm.adapters.anthropic import AnthropicAdapter
from llm.adapters.bedrock import BedrockAdapter
from web import FindingsServer
from llm.adapters.openai import OpenAiAdapter

from models import Finding, FindingContext, FindingEnriched, Findings

from scm import BaseScmAdapter, Scm
from scm.adapters.filesystem import FilesystemAdapter
from scm import BaseScmAdapter
from scm.adapters.git import GitAdapter
from util.git import parse_unified_diff
from util.filtering import FilterRules
from util.prompts import prompts
from scm.adapters.github import Github
from scm import Scm

from shell import Shell
from latex import Latex

from util.argparsing import parse_args

from util.caching import *
from util.filtering import FilterRules
from util.git import parse_unified_diff
from util.output import print_banner, write_csv
from util.poem import poem
from util.prompts import prompts

from util.output import print_banner, write_csv
from web import FindingsServer

from util.caching import *
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn, MofNCompleteColumn
from rich.console import Group
from rich.live import Live

prompts = prompts()
load_dotenv(".env")
Expand Down Expand Up @@ -232,6 +235,7 @@ async def main():
# 3) Analyze each file in parallel
print("🔍 Analyzing files for security issues...")
max_workers = min(args.llm_rate_limit, len(app_files))
logging.debug(f"{max_workers=}")
all_findings = await generate_findings(scm, llm, app_files, max_workers, args.disable_tools, args.disable_caching, args.cache_folder)

if not all_findings:
Expand Down Expand Up @@ -363,22 +367,22 @@ async def main():
if args.ci and len(all_findings) > 0:
exit(1)

async def process_file(scm: Scm, llm, filename, patch_text, semaphore, disable_tools, disable_caching, cache_folder):
async with semaphore:
start = asyncio.get_event_loop().time()
if disable_caching is True:
async def process_file(scm: Scm, llm, filename, patch_text, disable_tools, disable_caching, cache_folder):
start = asyncio.get_event_loop().time()
if disable_caching is True:
result = await analyze_single_file(scm, llm, filename, patch_text, disable_tools)
else:
hash: str = await hash_file(scm, filename)
cache_file = os.path.join(cache_folder, hash + ".json")
if not os.path.exists(cache_file):
result = await analyze_single_file(scm, llm, filename, patch_text, disable_tools)
store_findings_to_cache_file(filename, result, cache_file)
else:
hash: str = await hash_file(scm, filename)
cache_file = os.path.join(cache_folder, hash + ".json")
if not os.path.exists(cache_file):
result = await analyze_single_file(scm, llm, filename, patch_text, disable_tools)
store_findings_to_cache_file(filename, result, cache_file)
else:
result = findings_from_cache_file(cache_file)
elapsed = asyncio.get_event_loop().time() - start
if elapsed < 1:
await asyncio.sleep(1 - elapsed)
result = findings_from_cache_file(cache_file)
elapsed = asyncio.get_event_loop().time() - start
if elapsed < 1:
await asyncio.sleep(1 - elapsed)

return result

async def generate_findings(scm, llm, app_files, max_concurrent, disable_tools, disable_caching, cache_folder):
Expand All @@ -388,13 +392,52 @@ async def generate_findings(scm, llm, app_files, max_concurrent, disable_tools,

semaphore = asyncio.Semaphore(max_concurrent)

tasks = [
process_file(scm, llm, filename, patch_text, semaphore, disable_tools, disable_caching, cache_folder)
for filename, patch_text in app_files
]
overall_progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
TimeRemainingColumn())

file_progress = Progress(
SpinnerColumn(),
TextColumn("[blue]{task.description}"),
transient=True
)

progress_group = Group(
overall_progress,
file_progress,
)

def task_progress_wrapper(func, overall_progress, overall_task, file_progress, filename, semaphore):
async def sub_func(*args, **kwargs):
async with semaphore:
task_description_text = f"{filename}..."
file_task = file_progress.add_task(description=task_description_text, transient=True)
task_result = await func(*args, **kwargs)
file_progress.remove_task(file_task)
file_progress.refresh()
overall_progress.update(overall_task, advance=1)
return task_result
return sub_func


all_findings = []
results = await asyncio.gather(*tasks)
with Live(progress_group):
tasks = []
overall_task = overall_progress.add_task(f"Analyzing {len(app_files)} files...", total=len(app_files), start=True) # Add a task
for filename, patch_text in app_files:
wrapper_func = task_progress_wrapper(process_file, overall_progress, overall_task, file_progress, filename, semaphore)(scm, llm, filename, patch_text, disable_tools, disable_caching, cache_folder)
tasks.append(
wrapper_func
)

all_findings = []
try:
results = await asyncio.gather(*tasks)
finally:
overall_progress.stop()
file_progress.stop()

for result in results:
if result:
Expand Down
Loading