diff --git a/libs/community/langchain_community/llms/cohere.py b/libs/community/langchain_community/llms/cohere.py index 17960b2dee866..ad41a0d3a4c80 100644 --- a/libs/community/langchain_community/llms/cohere.py +++ b/libs/community/langchain_community/llms/cohere.py @@ -71,7 +71,7 @@ async def _completion_with_retry(**kwargs: Any) -> Any: @deprecated( - since="0.0.30", removal="0.2.0", alternative_import="langchain_cohere.BaseCohere" + since="0.0.30", removal="0.2.0" ) class BaseCohere(Serializable): """Base class for Cohere models.""" diff --git a/libs/partners/cohere/langchain_cohere/__init__.py b/libs/partners/cohere/langchain_cohere/__init__.py index 1f554a006e258..59ff14ca622c2 100644 --- a/libs/partners/cohere/langchain_cohere/__init__.py +++ b/libs/partners/cohere/langchain_cohere/__init__.py @@ -5,7 +5,6 @@ __all__ = [ "ChatCohere", - "CohereVectorStore", "CohereEmbeddings", "CohereRagRetriever", "CohereRerank", diff --git a/libs/partners/cohere/langchain_cohere/chat_models.py b/libs/partners/cohere/langchain_cohere/chat_models.py index 25a537abe1ebb..6df0f13249a3f 100644 --- a/libs/partners/cohere/langchain_cohere/chat_models.py +++ b/libs/partners/cohere/langchain_cohere/chat_models.py @@ -1,5 +1,6 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +import cohere from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -9,6 +10,7 @@ agenerate_from_stream, generate_from_stream, ) +from langchain_core.load.serializable import Serializable from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -18,30 +20,8 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult - -from langchain_cohere.llms import BaseCohere - - -def get_role(message: BaseMessage) -> str: - """Get the role of the message. - - Args: - message: The message. - - Returns: - The role of the message. - - Raises: - ValueError: If the message is of an unknown type. - """ - if isinstance(message, ChatMessage) or isinstance(message, HumanMessage): - return "User" - elif isinstance(message, AIMessage): - return "Chatbot" - elif isinstance(message, SystemMessage): - return "System" - else: - raise ValueError(f"Got unknown type {message}") +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env def get_cohere_chat_request( @@ -83,7 +63,7 @@ def get_cohere_chat_request( req = { "message": messages[-1].content, "chat_history": [ - {"role": get_role(x), "message": x.content} for x in messages[:-1] + {"role": _get_role(x), "message": x.content} for x in messages[:-1] ], "documents": documents, "connectors": maybe_connectors, @@ -94,7 +74,7 @@ def get_cohere_chat_request( return {k: v for k, v in req.items() if v is not None} -class ChatCohere(BaseChatModel, BaseCohere): +class ChatCohere(BaseChatModel, Serializable): """`Cohere` chat large language models. To use, you should have the ``cohere`` python package installed, and the @@ -113,12 +93,49 @@ class ChatCohere(BaseChatModel, BaseCohere): chat.invoke(messages) """ + client: Any = None #: :meta private: + async_client: Any = None #: :meta private: + + model: Optional[str] = Field(default=None) + """Model name to use.""" + + temperature: Optional[float] = None + """A non-negative float that tunes the degree of randomness in generation.""" + + cohere_api_key: Optional[SecretStr] = None + """Cohere API key. If not provided, will be read from the environment variable.""" + + stop: Optional[List[str]] = None + + streaming: bool = Field(default=False) + """Whether to stream the results.""" + + user_agent: str = "langchain" + """Identifier for the application making the request.""" + class Config: """Configuration for this pydantic object.""" allow_population_by_field_name = True arbitrary_types_allowed = True + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validates that the Cohere API key exists in the environment and instantiates the API clients.""" + values["cohere_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "cohere_api_key", "COHERE_API_KEY") + ) + client_name = values["user_agent"] + values["client"] = cohere.Client( + api_key=values["cohere_api_key"].get_secret_value(), + client_name=client_name, + ) + values["async_client"] = cohere.AsyncClient( + api_key=values["cohere_api_key"].get_secret_value(), + client_name=client_name, + ) + return values + @property def _llm_type(self) -> str: """Return type of chat model.""" @@ -130,6 +147,7 @@ def _default_params(self) -> Dict[str, Any]: base_params = { "model": self.model, "temperature": self.temperature, + "stop_sequences": self.stop, } return {k: v for k, v in base_params.items() if v is not None} @@ -147,11 +165,7 @@ def _stream( ) -> Iterator[ChatGenerationChunk]: request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - if hasattr(self.client, "chat_stream"): # detect and support sdk v5 - stream = self.client.chat_stream(**request) - else: - stream = self.client.chat(**request, stream=True) - + stream = self.client.chat_stream(**request) for data in stream: if data.event_type == "text-generation": delta = data.text @@ -169,11 +183,7 @@ async def _astream( ) -> AsyncIterator[ChatGenerationChunk]: request = get_cohere_chat_request(messages, **self._default_params, **kwargs) - if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5 - stream = self.async_client.chat_stream(**request) - else: - stream = self.async_client.chat(**request, stream=True) - + stream = self.async_client.chat_stream(**request) async for data in stream: if data.event_type == "text-generation": delta = data.text @@ -247,3 +257,26 @@ async def _agenerate( def get_num_tokens(self, text: str) -> int: """Calculate number of tokens.""" return len(self.client.tokenize(text).tokens) + + +def _get_role(message: BaseMessage) -> str: + """ + Get the Cohere API representation of a role. + + Args: + message: The message. + + Returns: + The role of the message. + + Raises: + ValueError: If the message is of an unknown type. + """ + if isinstance(message, ChatMessage) or isinstance(message, HumanMessage): + return "USER" + elif isinstance(message, AIMessage): + return "CHATBOT" + elif isinstance(message, SystemMessage): + return "SYSTEM" + else: + raise ValueError(f"Got unknown type {message}") diff --git a/libs/partners/cohere/langchain_cohere/llms.py b/libs/partners/cohere/langchain_cohere/llms.py deleted file mode 100644 index 4cf30e42a3b01..0000000000000 --- a/libs/partners/cohere/langchain_cohere/llms.py +++ /dev/null @@ -1,234 +0,0 @@ -from __future__ import annotations - -import logging -import re -from typing import Any, Dict, List, Optional - -import cohere -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) -from langchain_core.language_models.llms import LLM -from langchain_core.load.serializable import Serializable -from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env - -from .utils import _create_retry_decorator - - -def enforce_stop_tokens(text: str, stop: List[str]) -> str: - """Cut off the text as soon as any stop words occur.""" - return re.split("|".join(stop), text, maxsplit=1)[0] - - -logger = logging.getLogger(__name__) - - -def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any: - """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm.max_retries) - - @retry_decorator - def _completion_with_retry(**kwargs: Any) -> Any: - return llm.client.generate(**kwargs) - - return _completion_with_retry(**kwargs) - - -def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any: - """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm.max_retries) - - @retry_decorator - async def _completion_with_retry(**kwargs: Any) -> Any: - return await llm.async_client.generate(**kwargs) - - return _completion_with_retry(**kwargs) - - -class BaseCohere(Serializable): - """Base class for Cohere models.""" - - client: Any = None #: :meta private: - async_client: Any = None #: :meta private: - model: Optional[str] = Field(default=None) - """Model name to use.""" - - temperature: Optional[float] = None - """A non-negative float that tunes the degree of randomness in generation.""" - - cohere_api_key: Optional[SecretStr] = None - """Cohere API key. If not provided, will be read from the environment variable.""" - - stop: Optional[List[str]] = None - - streaming: bool = Field(default=False) - """Whether to stream the results.""" - - user_agent: str = "langchain" - """Identifier for the application making the request.""" - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - values["cohere_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "cohere_api_key", "COHERE_API_KEY") - ) - client_name = values["user_agent"] - values["client"] = cohere.Client( - api_key=values["cohere_api_key"].get_secret_value(), - client_name=client_name, - ) - values["async_client"] = cohere.AsyncClient( - api_key=values["cohere_api_key"].get_secret_value(), - client_name=client_name, - ) - return values - - -class Cohere(LLM, BaseCohere): - """Cohere large language models. - - To use, you should have the ``cohere`` python package installed, and the - environment variable ``COHERE_API_KEY`` set with your API key, or pass - it as a named parameter to the constructor. - - Example: - .. code-block:: python - - from langchain_cohere import Cohere - - cohere = Cohere(cohere_api_key="my-api-key") - """ - - max_tokens: Optional[int] = None - """Denotes the number of tokens to predict per generation.""" - - k: Optional[int] = None - """Number of most likely tokens to consider at each step.""" - - p: Optional[int] = None - """Total probability mass of tokens to consider at each step.""" - - frequency_penalty: Optional[float] = None - """Penalizes repeated tokens according to frequency. Between 0 and 1.""" - - presence_penalty: Optional[float] = None - """Penalizes repeated tokens. Between 0 and 1.""" - - truncate: Optional[str] = None - """Specify how the client handles inputs longer than the maximum token - length: Truncate from START, END or NONE""" - - max_retries: int = 10 - """Maximum number of retries to make when generating.""" - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - extra = Extra.forbid - - @property - def _default_params(self) -> Dict[str, Any]: - """Configurable parameters for calling Cohere's generate API.""" - base_params = { - "model": self.model, - "temperature": self.temperature, - "max_tokens": self.max_tokens, - "k": self.k, - "p": self.p, - "frequency_penalty": self.frequency_penalty, - "presence_penalty": self.presence_penalty, - "truncate": self.truncate, - } - return {k: v for k, v in base_params.items() if v is not None} - - @property - def lc_secrets(self) -> Dict[str, str]: - return {"cohere_api_key": "COHERE_API_KEY"} - - @property - def _identifying_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" - return self._default_params - - @property - def _llm_type(self) -> str: - """Return type of llm.""" - return "cohere" - - def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict: - params = self._default_params - if self.stop is not None and stop is not None: - raise ValueError("`stop` found in both the input and default params.") - elif self.stop is not None: - params["stop_sequences"] = self.stop - else: - params["stop_sequences"] = stop - return {**params, **kwargs} - - def _process_response(self, response: Any, stop: Optional[List[str]]) -> str: - text = response.generations[0].text - # If stop tokens are provided, Cohere's endpoint returns them. - # In order to make this consistent with other endpoints, we strip them. - if stop: - text = enforce_stop_tokens(text, stop) - return text - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """Call out to Cohere's generate endpoint. - - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - The string generated by the model. - - Example: - .. code-block:: python - - response = cohere("Tell me a joke.") - """ - params = self._invocation_params(stop, **kwargs) - response = completion_with_retry( - self, model=self.model, prompt=prompt, **params - ) - _stop = params.get("stop_sequences") - return self._process_response(response, _stop) - - async def _acall( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """Async call out to Cohere's generate endpoint. - - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - The string generated by the model. - - Example: - .. code-block:: python - - response = await cohere("Tell me a joke.") - """ - params = self._invocation_params(stop, **kwargs) - response = await acompletion_with_retry( - self, model=self.model, prompt=prompt, **params - ) - _stop = params.get("stop_sequences") - return self._process_response(response, _stop) diff --git a/libs/partners/cohere/tests/unit_tests/test_chat_models.py b/libs/partners/cohere/tests/unit_tests/test_chat_models.py index eecfe33f3311a..592dade895e60 100644 --- a/libs/partners/cohere/tests/unit_tests/test_chat_models.py +++ b/libs/partners/cohere/tests/unit_tests/test_chat_models.py @@ -4,6 +4,7 @@ import pytest from langchain_cohere.chat_models import ChatCohere +from langchain_core.pydantic_v1 import SecretStr def test_initialization() -> None: @@ -11,15 +12,26 @@ def test_initialization() -> None: ChatCohere(cohere_api_key="test") +def test_cohere_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that cohere api key is a secret key.""" + # test initialization from init + assert isinstance(ChatCohere(cohere_api_key="1").cohere_api_key, SecretStr) + + # test initialization from env variable + monkeypatch.setenv("COHERE_API_KEY", "secret-api-key") + assert isinstance(ChatCohere().cohere_api_key, SecretStr) + + @pytest.mark.parametrize( "chat_cohere,expected", [ pytest.param(ChatCohere(cohere_api_key="test"), {}, id="defaults"), pytest.param( - ChatCohere(cohere_api_key="test", model="foo", temperature=1.0), + ChatCohere(cohere_api_key="test", model="foo", temperature=1.0, stop=["bar"]), { "model": "foo", "temperature": 1.0, + "stop_sequences": ["bar"], }, id="values are set", ), diff --git a/libs/partners/cohere/tests/unit_tests/test_llms.py b/libs/partners/cohere/tests/unit_tests/test_llms.py deleted file mode 100644 index 44cbe1a9e609d..0000000000000 --- a/libs/partners/cohere/tests/unit_tests/test_llms.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Test Cohere API wrapper.""" -import typing - -import pytest -from langchain_core.pydantic_v1 import SecretStr - -from langchain_cohere.llms import BaseCohere, Cohere - - -def test_cohere_api_key(monkeypatch: pytest.MonkeyPatch) -> None: - """Test that cohere api key is a secret key.""" - # test initialization from init - assert isinstance(BaseCohere(cohere_api_key="1").cohere_api_key, SecretStr) - - # test initialization from env variable - monkeypatch.setenv("COHERE_API_KEY", "secret-api-key") - assert isinstance(BaseCohere().cohere_api_key, SecretStr) - - -@pytest.mark.parametrize( - "cohere,expected", - [ - pytest.param(Cohere(cohere_api_key="test"), {}, id="defaults"), - pytest.param( - Cohere( - # the following are arbitrary testing values which shouldn't be used: - cohere_api_key="test", - model="foo", - temperature=0.1, - max_tokens=2, - k=3, - p=4, - frequency_penalty=0.5, - presence_penalty=0.6, - truncate="START", - ), - { - "model": "foo", - "temperature": 0.1, - "max_tokens": 2, - "k": 3, - "p": 4, - "frequency_penalty": 0.5, - "presence_penalty": 0.6, - "truncate": "START", - }, - id="with values set", - ), - ], -) -def test_default_params(cohere: Cohere, expected: typing.Dict) -> None: - actual = cohere._default_params - assert expected == actual - - -# def test_saving_loading_llm(tmp_path: Path) -> None: -# """Test saving/loading an Cohere LLM.""" -# llm = BaseCohere(max_tokens=10) -# llm.save(file_path=tmp_path / "cohere.yaml") -# loaded_llm = load_llm(tmp_path / "cohere.yaml") -# assert_llm_equality(llm, loaded_llm)