diff --git a/saist/main.py b/saist/main.py index 408922b..114881f 100755 --- a/saist/main.py +++ b/saist/main.py @@ -2,39 +2,42 @@ import asyncio import logging import os - from typing import Optional from dotenv import load_dotenv +from latex import Latex from llm.adapters import BaseLlmAdapter +from llm.adapters.anthropic import AnthropicAdapter +from llm.adapters.bedrock import BedrockAdapter from llm.adapters.deepseek import DeepseekAdapter +from llm.adapters.faike import FaikeAdapter from llm.adapters.gemini import GeminiAdapter -from llm.adapters.openai import OpenAiAdapter from llm.adapters.ollama import OllamaAdapter -from llm.adapters.faike import FaikeAdapter -from models import FindingContext, FindingEnriched, Finding, Findings -from llm.adapters.anthropic import AnthropicAdapter -from llm.adapters.bedrock import BedrockAdapter -from web import FindingsServer +from llm.adapters.openai import OpenAiAdapter + +from models import Finding, FindingContext, FindingEnriched, Findings + +from scm import BaseScmAdapter, Scm from scm.adapters.filesystem import FilesystemAdapter -from scm import BaseScmAdapter from scm.adapters.git import GitAdapter -from util.git import parse_unified_diff -from util.filtering import FilterRules -from util.prompts import prompts from scm.adapters.github import Github -from scm import Scm + from shell import Shell -from latex import Latex from util.argparsing import parse_args - +from util.caching import * +from util.filtering import FilterRules +from util.git import parse_unified_diff +from util.output import print_banner, write_csv from util.poem import poem +from util.prompts import prompts -from util.output import print_banner, write_csv +from web import FindingsServer -from util.caching import * +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn, MofNCompleteColumn +from rich.console import Group +from rich.live import Live prompts = prompts() load_dotenv(".env") @@ -232,6 +235,7 @@ async def main(): # 3) Analyze each file in parallel print("🔍 Analyzing files for security issues...") max_workers = min(args.llm_rate_limit, len(app_files)) + logging.debug(f"{max_workers=}") all_findings = await generate_findings(scm, llm, app_files, max_workers, args.disable_tools, args.disable_caching, args.cache_folder) if not all_findings: @@ -363,22 +367,22 @@ async def main(): if args.ci and len(all_findings) > 0: exit(1) -async def process_file(scm: Scm, llm, filename, patch_text, semaphore, disable_tools, disable_caching, cache_folder): - async with semaphore: - start = asyncio.get_event_loop().time() - if disable_caching is True: +async def process_file(scm: Scm, llm, filename, patch_text, disable_tools, disable_caching, cache_folder): + start = asyncio.get_event_loop().time() + if disable_caching is True: + result = await analyze_single_file(scm, llm, filename, patch_text, disable_tools) + else: + hash: str = await hash_file(scm, filename) + cache_file = os.path.join(cache_folder, hash + ".json") + if not os.path.exists(cache_file): result = await analyze_single_file(scm, llm, filename, patch_text, disable_tools) + store_findings_to_cache_file(filename, result, cache_file) else: - hash: str = await hash_file(scm, filename) - cache_file = os.path.join(cache_folder, hash + ".json") - if not os.path.exists(cache_file): - result = await analyze_single_file(scm, llm, filename, patch_text, disable_tools) - store_findings_to_cache_file(filename, result, cache_file) - else: - result = findings_from_cache_file(cache_file) - elapsed = asyncio.get_event_loop().time() - start - if elapsed < 1: - await asyncio.sleep(1 - elapsed) + result = findings_from_cache_file(cache_file) + elapsed = asyncio.get_event_loop().time() - start + if elapsed < 1: + await asyncio.sleep(1 - elapsed) + return result async def generate_findings(scm, llm, app_files, max_concurrent, disable_tools, disable_caching, cache_folder): @@ -388,13 +392,52 @@ async def generate_findings(scm, llm, app_files, max_concurrent, disable_tools, semaphore = asyncio.Semaphore(max_concurrent) - tasks = [ - process_file(scm, llm, filename, patch_text, semaphore, disable_tools, disable_caching, cache_folder) - for filename, patch_text in app_files - ] + overall_progress = Progress( + TextColumn("[bold blue]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn()) + + file_progress = Progress( + SpinnerColumn(), + TextColumn("[blue]{task.description}"), + transient=True + ) + + progress_group = Group( + overall_progress, + file_progress, + ) + + def task_progress_wrapper(func, overall_progress, overall_task, file_progress, filename, semaphore): + async def sub_func(*args, **kwargs): + async with semaphore: + task_description_text = f"{filename}..." + file_task = file_progress.add_task(description=task_description_text, transient=True) + task_result = await func(*args, **kwargs) + file_progress.remove_task(file_task) + file_progress.refresh() + overall_progress.update(overall_task, advance=1) + return task_result + return sub_func + - all_findings = [] - results = await asyncio.gather(*tasks) + with Live(progress_group): + tasks = [] + overall_task = overall_progress.add_task(f"Analyzing {len(app_files)} files...", total=len(app_files), start=True) # Add a task + for filename, patch_text in app_files: + wrapper_func = task_progress_wrapper(process_file, overall_progress, overall_task, file_progress, filename, semaphore)(scm, llm, filename, patch_text, disable_tools, disable_caching, cache_folder) + tasks.append( + wrapper_func + ) + + all_findings = [] + try: + results = await asyncio.gather(*tasks) + finally: + overall_progress.stop() + file_progress.stop() for result in results: if result: