diff --git a/auditor/generations/paraphrase.py b/auditor/generations/paraphrase.py index 1af1d0f..8069799 100644 --- a/auditor/generations/paraphrase.py +++ b/auditor/generations/paraphrase.py @@ -1,6 +1,7 @@ from typing import List, Optional import os import openai +import litellm from auditor.perturbations.constants import OPENAI_CHAT_COMPLETION @@ -37,7 +38,7 @@ def generate_similar_sentences( engine = model api_version = api_version - response = openai.ChatCompletion.create( + response = litellm.completion( model=model, messages=payload, temperature=temperature, diff --git a/auditor/perturbations/paraphrase.py b/auditor/perturbations/paraphrase.py index 9324b1d..a521cd9 100644 --- a/auditor/perturbations/paraphrase.py +++ b/auditor/perturbations/paraphrase.py @@ -3,6 +3,7 @@ import re import openai +import litellm from auditor.perturbations.base import TransformBase from auditor.perturbations.constants import OPENAI_CHAT_COMPLETION @@ -73,7 +74,7 @@ def transform( "content": prompt } ] - response = openai.ChatCompletion.create( + response = litellm.completion( model=self.model, messages=payload, temperature=self.temperature, diff --git a/pyproject.toml b/pyproject.toml index 2082f54..a131877 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "jinja2==3.1.2", "langchain>=0.0.158,<=0.0.330", "openai>=0.27.0,<=0.28.1", + "litellm==0.13.2", "sentence-transformers>=2.2.2", "tqdm>=4.66.1", "httplib2~=0.22.0"