diff --git a/gentopia/memory/embeddings.py b/gentopia/memory/embeddings.py index 3055b74..4136530 100644 --- a/gentopia/memory/embeddings.py +++ b/gentopia/memory/embeddings.py @@ -27,6 +27,8 @@ ) from gentopia.memory.utils import get_from_dict_or_env +from enum import Enum + logger = logging.getLogger(__name__) @@ -156,6 +158,9 @@ async def _async_embed_with_retry(**kwargs: Any) -> Any: return await _async_embed_with_retry(**kwargs) +class SpecialAllow: + class ValueField(str, Enum): + VALID_VALUE = "all" class OpenAIEmbeddings(BaseModel, Embeddings): """Wrapper around OpenAI embedding models.""" @@ -173,8 +178,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): embedding_ctx_length: int = 8191 openai_api_key: Optional[str] = None openai_organization: Optional[str] = None - allowed_special: Union[Literal["all"], Set[str]] = set() - disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" + allowed_special: Union[SpecialAllow.ValueField, Set[str]] = set() + disallowed_special: Union[SpecialAllow.ValueField, Set[str], Sequence[str]] = "all" chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 6