diff --git a/saist/llm/adapters/azureopenai.py b/saist/llm/adapters/azureopenai.py new file mode 100644 index 0000000..3bb9523 --- /dev/null +++ b/saist/llm/adapters/azureopenai.py @@ -0,0 +1,21 @@ +from typing import Optional +import os +from llm.adapters import BaseLlmAdapter + +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.providers.azure import AzureProvider + +class AzureOpenAiAdapter(BaseLlmAdapter): + def __init__(self, model: str = None, api_key: Optional[str] = None): + if model is None: + model = "gpt-4o" + if api_key: + raise ValueError("Do not provide API keys for AZURE - use ENV variables") + self.model = OpenAIModel( + model, + provider = AzureProvider( + api_key=os.environ.get("AZURE_OPENAI_API_KEY"), + azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT"), + api_version = os.environ.get("AZURE_OPENAI_API_VERSION"), + ) + ) diff --git a/saist/main.py b/saist/main.py index 947393c..4128b83 100644 --- a/saist/main.py +++ b/saist/main.py @@ -14,6 +14,7 @@ from models import FindingEnriched, Finding, Findings from llm.adapters.anthropic import AnthropicAdapter from llm.adapters.bedrock import BedrockAdapter +from llm.adapters.azureopenai import AzureOpenAiAdapter from web import FindingsServer from scm.adapters.filesystem import FilesystemAdapter from scm import BaseScmAdapter @@ -97,7 +98,10 @@ async def _get_llm_adapter(args) -> BaseLlmAdapter: if args.llm == 'anthropic': llm = AnthropicAdapter( api_key = args.llm_api_key, model=model) logger.debug(f"Using LLM: anthropic Model: {llm.model}") - if args.llm == 'bedrock': + elif args.llm == 'azureopenai': + llm = AzureOpenAiAdapter( api_key = args.llm_api_key, model=model) + logger.debug(f"Using LLM: anthropic Model: {llm.model}") + elif args.llm == 'bedrock': llm = BedrockAdapter( api_key = args.llm_api_key, model=model) logger.debug(f"Using LLM: AWS bedrock Model: {llm.model}") elif args.llm == 'deepseek': diff --git a/saist/util/argparsing.py b/saist/util/argparsing.py index ef105ae..8ab6a75 100644 --- a/saist/util/argparsing.py +++ b/saist/util/argparsing.py @@ -120,7 +120,7 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument( "--llm", type=str, - choices=["anthropic", "bedrock", "deepseek", "gemini", "ollama", "openai"], + choices=["anthropic","azureopenai", "bedrock", "deepseek", "gemini", "ollama", "openai"], required=True, action=EnvDefault, envvar="SAIST_LLM" @@ -201,10 +201,13 @@ def parse_args(): if args.llm == "bedrock" and args.llm_api_key: parser.error(f"Do not provide an API key for bedrock, use AWS ENV variables https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-envvars.html") + if args.llm == "azureopenai" and args.llm_api_key: + parser.error(f"Do not provide an API key for bedrock, use ENV variables") + if args.llm == "bedrock" and args.interactive: parser.error("Sorry, we dont support interactive mode with bedrock as AWS tool calling is a bit broken") - if args.llm not in [ "ollama", "bedrock" ] and args.llm_api_key is None: + if args.llm not in [ "ollama", "bedrock", "azureopenai" ] and args.llm_api_key is None: parser.error(f"You must provide an api key with --llm-api-key if using {args.llm}") if args.llm == "ollama" and args.interactive: