diff --git a/modules/src/langchain_mlrun/item.yaml b/modules/src/langchain_mlrun/item.yaml new file mode 100644 index 00000000..8dcad023 --- /dev/null +++ b/modules/src/langchain_mlrun/item.yaml @@ -0,0 +1,23 @@ +apiVersion: v1 +categories: +- langchain +- langgraph +- tracing +- monitoring +- llm +description: LangChain x MLRun integration - Orchestrate your LangChain code with MLRun. +example: langchain_mlrun.ipynb +generationDate: 2026-01-08:12-25 +hidden: false +labels: + author: Iguazio +mlrunVersion: 1.10.0 +name: langchain_mlrun +spec: + filename: langchain_mlrun.py + image: mlrun/mlrun + kind: generic + requirements: + - langchain + - pydantic-settings +version: 0.0.1 \ No newline at end of file diff --git a/modules/src/langchain_mlrun/langchain_mlrun.ipynb b/modules/src/langchain_mlrun/langchain_mlrun.ipynb new file mode 100644 index 00000000..a1f00541 --- /dev/null +++ b/modules/src/langchain_mlrun/langchain_mlrun.ipynb @@ -0,0 +1,899 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7955da79-02cc-42fe-aee0-5456d3e386fd", + "metadata": {}, + "source": [ + "# LangChain ✕ MLRun Integration\n", + "\n", + "`langchain_mlrun` is a hub module that implements LangChain integration with MLRun. Using the module allows MLRun to orchestrate LangChain and LangGraph code, enabling tracing and monitoring batch workflows and realtime deployments.\n", + "___" + ] + }, + { + "cell_type": "markdown", + "id": "8392a3e1-d0a1-409a-ae68-fcc36858d30a", + "metadata": {}, + "source": [ + "## Main Components\n", + "\n", + "This is a short brief of the components available to import from the `langchain_mlrun` module. For full docs, see the documentation page.\n", + "\n", + "### Settings\n", + "\n", + "The module uses Pydantic settings classes that can be configured programmatically or via environment variables. The main class is `MLRunTracerSettings`. It contains two sub-settings:\n", + "* `MLRunTracerClientSettings` - Connection settings (stream path, container, endpoint info). Env prefix: `\"MLRUN_TRACER_CLIENT_\"`\n", + "* `MLRunTracerMonitorSettings` - Controls what/how runs are captured (filters, labels, debug mode). Env prefix: `\"MLRUN_TRACER_MONITOR_\"`\n", + "\n", + "For more information about each setting, see the class docstrings.\n", + "\n", + "#### Example - via code configuration\n", + "\n", + "```python\n", + "from langchain_mlrun import MLRunTracerSettings, MLRunTracerClientSettings, MLRunTracerMonitorSettings\n", + "\n", + "settings = MLRunTracerSettings(\n", + " client=MLRunTracerClientSettings(\n", + " stream_path=\"my-project/model-endpoints/stream-v1\",\n", + " container=\"projects\",\n", + " model_endpoint_name=\"my_endpoint\",\n", + " model_endpoint_uid=\"abc123\",\n", + " serving_function=\"my_function\",\n", + " ),\n", + " monitor=MLRunTracerMonitorSettings(\n", + " label=\"production\",\n", + " root_run_only=True, # Only monitor root runs, not child runs\n", + " tags_filter=[\"important\"], # Only monitor runs with this tag\n", + " ),\n", + ")\n", + "```\n", + "\n", + "#### Example - environment variable configuration\n", + "\n", + "```bash\n", + "export MLRUN_TRACER_CLIENT_STREAM_PATH=\"my-project/model-endpoints/stream-v1\"\n", + "export MLRUN_TRACER_CLIENT_CONTAINER=\"projects\"\n", + "export MLRUN_TRACER_MONITOR_LABEL=\"production\"\n", + "export MLRUN_TRACER_MONITOR_ROOT_RUN_ONLY=\"true\"\n", + "```\n", + "\n", + "### MLRun Tracer\n", + "\n", + "`MLRunTracer` is a LangChain-compatible tracer that converts LangChain `Run` objects into MLRun monitoring events and publishes them to a V3IO stream. \n", + "\n", + "Key points:\n", + "* **No inheritance required** - use it directly without subclassing.\n", + "* **Fully customizable via settings** - control filtering, summarization, and output format.\n", + "* **Custom summarizer support** - pass your own `run_summarizer_function` via settings to customize how runs are converted to events.\n", + "\n", + "### Monitoring Setup Utility Function\n", + "\n", + "`setup_langchain_monitoring()` is a utility function that creates the necessary MLRun infrastructure for LangChain monitoring. This is a **temporary workaround** until custom endpoint creation support is added to MLRun.\n", + "\n", + "The function returns a dictionary of environment variables to configure auto-tracing. See how to use it in the tutorial section below.\n", + "\n", + "### LangChain Monitoring Application\n", + "\n", + "`LangChainMonitoringApp` is a base class (inheriting from MLRun's `ModelMonitoringApplicationBase`) for building monitoring applications that process events from the MLRun Tracer.\n", + "\n", + "It offers several built-in helper methods and metrics for analyzing LangChain runs:\n", + "\n", + "* Helper methods:\n", + " * `get_structured_runs()` - Parse raw monitoring samples into structured run dictionaries with filtering options\n", + " * `iterate_structured_runs()` - Iterate over all runs including nested child runs\n", + "* Metric methods:\n", + " * `calculate_average_latency()` - Average latency across root runs\n", + " * `calculate_success_rate()` - Percentage of runs without errors\n", + " * `count_token_usage()` - Total input/output tokens from LLM runs\n", + " * `count_run_names()` - Count occurrences of each run name\n", + "\n", + "The base app can be used as-is, but it is recommended to extend it with your own custom monitoring logic.\n", + "___" + ] + }, + { + "cell_type": "markdown", + "id": "7e24e1a5-d80a-4b7e-9b94-57b24e8b39d7", + "metadata": {}, + "source": [ + "## How to Apply MLRun?\n", + "\n", + "### Auto Tracing\n", + "\n", + "Auto tracing automatically instruments all LangChain code by setting the `MLRUN_MONITORING_ENABLED` environment variable and importing the module:\n", + "\n", + "```python\n", + "import os\n", + "os.environ[\"MLRUN_MONITORING_ENABLED\"] = \"1\"\n", + "# Set other MLRUN_TRACER_* environment variables as needed...\n", + "\n", + "# Import the module BEFORE any LangChain code\n", + "langchain_mlrun = mlrun.import_module(\"hub://langchain_mlrun\")\n", + "\n", + "# All LangChain/LangGraph code below will be automatically traced\n", + "chain.invoke(...)\n", + "```\n", + "\n", + "### Manual Tracing\n", + "\n", + "For more control, use the `mlrun_monitoring()` context manager to trace specific code blocks:\n", + "\n", + "```python\n", + "from langchain_mlrun import mlrun_monitoring, MLRunTracerSettings\n", + "\n", + "# Optional: customize settings\n", + "settings = MLRunTracerSettings(...)\n", + "\n", + "with mlrun_monitoring(settings=settings) as tracer:\n", + " # Only LangChain code within this block will be traced\n", + " result = chain.invoke({\"topic\": \"MLRun\"})\n", + "```\n", + "___" + ] + }, + { + "cell_type": "markdown", + "id": "68b52d3d-a431-44fb-acd6-ea33fec37a49", + "metadata": {}, + "source": [ + "## Tutorial\n", + "\n", + "In this tutorial we'll show how to orchestrate LangChain based code with MLRun using the `langchain_mlrun` hub module.\n", + "\n", + "### Prerequisites\n", + "\n", + "Install MLRun and the `langchain_mlrun` requirements." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "caf72aa6-06e8-4a04-bfc4-409b39d255fe", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install mlrun langchain pydantic-settings" + ] + }, + { + "cell_type": "markdown", + "id": "8aa18266-d3b5-40bd-a8b9-65345e419d8c", + "metadata": {}, + "source": [ + "### Create Project\n", + "\n", + "We'll first create an MLRun project" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2664df3e-d9c6-40dd-a215-29d60e4b4208", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-01-08 14:48:52,259 [info] Project loaded successfully: {\"project_name\":\"langchain-mlrun-7\"}\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "import datetime\n", + "import mlrun\n", + "\n", + "project = mlrun.get_or_create_project(\"langchain-mlrun-tutorial\")" + ] + }, + { + "cell_type": "markdown", + "id": "33f28986-c158-47fd-97a6-74f69892b4eb", + "metadata": {}, + "source": "### Enable Monitoring\n\nTo use MLRun's monitoring feature in our project we first need to set up the monitoring infrastructure. If you use MLRun CE, you'll need to create a Kafka stream, if you use MLRun enterprise, you can use V3IO." + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d9d2fa66-0498-445d-ab4a-8370f46aec1e", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Add here MLRun CE handler with Kafka, currently the tutorial is only with V3IO.\n", + "from mlrun.datastore import DatastoreProfileV3io\n", + "\n", + "# Create a V3IO data store:\n", + "v3io_ds = DatastoreProfileV3io(name=\"v3io-ds\",v3io_access_key=os.environ[\"V3IO_ACCESS_KEY\"])\n", + "project.register_datastore_profile(profile=v3io_ds)\n", + "\n", + "# Set the monitoring credentials:\n", + "project.set_model_monitoring_credentials(\n", + " stream_profile_name=v3io_ds.name,\n", + " tsdb_profile_name=v3io_ds.name\n", + ")\n", + "\n", + "# Enable monitoring for our project:\n", + "project.enable_model_monitoring(\n", + " base_period=1,\n", + " wait_for_deployment=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f23117fa-7b67-470c-80ca-976d14c2120e", + "metadata": {}, + "source": [ + "### Import `langchain_mlrun`\n", + "\n", + "Now we'll import `langchain_mlrun` from the hub." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2360cd49-b260-4140-bd16-138349e000b3", + "metadata": {}, + "outputs": [], + "source": "# Import the module from the hub:\nlangchain_mlrun = mlrun.import_module(\"hub://langchain_mlrun\")\n\n# Import the utility function and monitoring application from the module:\nsetup_langchain_monitoring = langchain_mlrun.setup_langchain_monitoring\nLangChainMonitoringApp = langchain_mlrun.LangChainMonitoringApp" + }, + { + "cell_type": "markdown", + "id": "de030131-ebaf-48f8-96ed-3c1013b5e260", + "metadata": {}, + "source": "### Create Monitorable Endpoint\n\nEndpoints are the entities being monitored by MLRun. As such we'll use the `setup_langchain_monitoring` utility function to create the model monitoring endpoint. By default, our endpoint name will be `\"langchain_mlrun_endpoint\"` but feel free to change it by using the required arguments." + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0e9baf78-3d38-46bd-89dd-6f83760eaeb0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating LangChain model endpoint\n", + "\n", + " [✓] Loading Project......................... Done (0.00s)\u001B[K\n", + " [✓] Creating Model.......................... Done (0.02s) \u001B[K\n", + " [✓] Creating Function....................... Done (0.02s) \u001B[K\n", + " [✓] Creating Model Endpoint................. Done (0.02s) \u001B[K\n", + "\n", + "✨ Done! LangChain monitoring model endpoint created successfully.\n", + "You can now set the following environment variables to enable MLRun tracing in your LangChain code:\n", + "\n", + "{\n", + " \"MLRUN_MONITORING_ENABLED\": \"1\",\n", + " \"MLRUN_TRACER_CLIENT_PROJECT\": \"langchain-mlrun-7\",\n", + " \"MLRUN_TRACER_CLIENT_STREAM_PATH\": \"langchain-mlrun-7/model-endpoints/stream-v1\",\n", + " \"MLRUN_TRACER_CLIENT_CONTAINER\": \"projects\",\n", + " \"MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_NAME\": \"langchain_mlrun_endpoint\",\n", + " \"MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_UID\": \"bb81af2058c14e7cbf58455aed3d69fc\",\n", + " \"MLRUN_TRACER_CLIENT_SERVING_FUNCTION\": \"langchain_mlrun_function\"\n", + "}\n", + "\n", + "To customize the monitoring behavior, you can also set additional environment variables prefixed with 'MLRUN_TRACER_MONITOR_'. Refer to the MLRun tracer documentation for more details.\n", + "\n" + ] + } + ], + "source": [ + "env_vars = setup_langchain_monitoring()" + ] + }, + { + "cell_type": "markdown", + "id": "dd45c94b-ee05-449c-9336-0aa659e66bda", + "metadata": {}, + "source": [ + "### Setup Environment Variables for Auto Tracing\n", + "\n", + "We'll use the environment variables returned from `setup_langchain_monitoring` to setup the environment for auto-tracing. Read the printed outputs for more information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c1988f8-c80a-4bf2-bfb1-d43523fc161f", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ.update(env_vars)" + ] + }, + { + "cell_type": "markdown", + "id": "d3f3b8e5-3538-4153-95da-e6d8776be3ac", + "metadata": {}, + "source": "### Run `langchain` or `langgraph` Code\n\nHere we have 3 functions, each using different method utilizing LLMs with `langchain` and `langgraph`:\n* `run_simple_chain` - Using `langchain`'s chains.\n* `run_simple_agent` - Using `langchain`'s `create_agent` function and `tool`s.\n* `run_langgraph_graph` - Using pure `langgraph`.\n\n> **Notice**: You don't need to set OpenAI API credentials, there is a mock `ChatModel` that will replace it if the credentials are not set in the environment. If you wish to use OpenAI models, make sure you `pip install langchain_openai` and set the `OPENAI_API_KEY` environment variable before continue to the next cell.\n\nBecause the auto-tracing environment is set, any run will be automatically traced and monitored!\n\nFeel free to adjust the code as you like.\n\n> **Remember**: To enable auto-tracing you do need to set the environment variables and import the `langchain_mlrun` module before any LangChain code. For batch jobs and realtime functions, make sure you set env vars in the MLRun function and add the import line `langchain_mlrun = mlrun.import_module(\"hub://langchain_mlrun\")` at the top of your code." + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "94b4d4b0-8d10-4ad3-8f16-7b1b7daeac11", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "from typing import Literal, TypedDict, Annotated, Sequence, Any, Callable\n", + "from operator import add\n", + "\n", + "from langchain_core.language_models import LanguageModelInput\n", + "from langchain_core.runnables import Runnable, RunnableLambda\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "from langchain_core.language_models.fake_chat_models import FakeListChatModel, GenericFakeChatModel\n", + "from langchain.agents import create_agent\n", + "from langchain_core.messages import AIMessage, HumanMessage\n", + "from langchain_core.tools import tool, BaseTool\n", + "\n", + "from langgraph.graph import StateGraph, START, END\n", + "from langchain_core.messages import BaseMessage\n", + "\n", + "\n", + "def _check_openai_credentials() -> bool:\n", + " \"\"\"\n", + " Check if OpenAI API key is set in environment variables.\n", + "\n", + " :return: True if OPENAI_API_KEY is set, False otherwise.\n", + " \"\"\"\n", + " return \"OPENAI_API_KEY\" in os.environ\n", + "\n", + "\n", + "# Import ChatOpenAI only if OpenAI credentials are available (meaning `langchain-openai` must be installed).\n", + "if _check_openai_credentials():\n", + " from langchain_openai import ChatOpenAI\n", + "\n", + " \n", + "class _ToolEnabledFakeModel(GenericFakeChatModel):\n", + " \"\"\"\n", + " A fake chat model that supports tool binding for running agent tracing tests.\n", + " \"\"\"\n", + "\n", + " def bind_tools(\n", + " self,\n", + " tools: Sequence[\n", + " dict[str, Any] | type | Callable | BaseTool # noqa: UP006\n", + " ],\n", + " *,\n", + " tool_choice: str | None = None,\n", + " **kwargs: Any,\n", + " ) -> Runnable[LanguageModelInput, AIMessage]:\n", + " return self\n", + "\n", + "\n", + "#: Tag value for testing tag filtering.\n", + "_dummy_tag = \"dummy_tag\"\n", + "\n", + "\n", + "def run_simple_chain() -> str:\n", + " \"\"\"\n", + " Run a simple LangChain chain that gets a fact about a topic.\n", + " \"\"\"\n", + " # Build a simple chain: prompt -> llm -> str output parser\n", + " llm = ChatOpenAI(\n", + " model=\"gpt-4o-mini\",\n", + " tags=[_dummy_tag]\n", + " ) if _check_openai_credentials() else (\n", + " FakeListChatModel(\n", + " responses=[\n", + " \"MLRun is an open-source orchestrator for machine learning pipelines.\"\n", + " ],\n", + " tags=[_dummy_tag]\n", + " )\n", + " )\n", + " prompt = ChatPromptTemplate.from_template(\"Tell me a short fact about {topic}\")\n", + " chain = prompt | llm | StrOutputParser()\n", + "\n", + " # Run the chain:\n", + " response = chain.invoke({\"topic\": \"MLRun\"})\n", + " return response\n", + "\n", + "\n", + "def run_simple_agent():\n", + " \"\"\"\n", + " Run a simple LangChain agent that uses two tools to get weather and stock price.\n", + " \"\"\"\n", + " # Define the tools:\n", + " @tool\n", + " def get_weather(city: str) -> str:\n", + " \"\"\"Get the current weather for a specific city.\"\"\"\n", + " return f\"The weather in {city} is 22°C and sunny.\"\n", + "\n", + " @tool\n", + " def get_stock_price(symbol: str) -> str:\n", + " \"\"\"Get the current stock price for a symbol.\"\"\"\n", + " return f\"The stock price for {symbol} is $150.25.\"\n", + "\n", + " # Define the model:\n", + " model = ChatOpenAI(\n", + " model=\"gpt-4o-mini\",\n", + " tags=[_dummy_tag]\n", + " ) if _check_openai_credentials() else (\n", + " _ToolEnabledFakeModel(\n", + " messages=iter(\n", + " [\n", + " AIMessage(\n", + " content=\"\",\n", + " tool_calls=[\n", + " {\"name\": \"get_weather\", \"args\": {\"city\": \"London\"}, \"id\": \"call_abc123\"},\n", + " {\"name\": \"get_stock_price\", \"args\": {\"symbol\": \"AAPL\"}, \"id\": \"call_def456\"}\n", + " ]\n", + " ),\n", + " AIMessage(content=\"The weather in London is 22°C and AAPL is trading at $150.25.\")\n", + " ]\n", + " ),\n", + " tags=[_dummy_tag]\n", + " )\n", + " )\n", + "\n", + " # Create the agent:\n", + " agent = create_agent(\n", + " model=model,\n", + " tools=[get_weather, get_stock_price],\n", + " system_prompt=\"You are a helpful assistant with access to tools.\"\n", + " )\n", + "\n", + " # Run the agent:\n", + " return agent.invoke({\"messages\": [\"What is the weather in London and the stock price of AAPL?\"]})\n", + "\n", + "\n", + "def run_langgraph_graph():\n", + " \"\"\"\n", + " Run a LangGraph agent that uses reflection to correct its answer.\n", + " \"\"\"\n", + " # Define the graph state:\n", + " class AgentState(TypedDict):\n", + " messages: Annotated[list[BaseMessage], add]\n", + " attempts: int\n", + "\n", + " # Define the model:\n", + " model = ChatOpenAI(model=\"gpt-4o-mini\") if _check_openai_credentials() else (\n", + " _ToolEnabledFakeModel(\n", + " messages=iter(\n", + " [\n", + " AIMessage(content=\"There are 2 'r's in Strawberry.\"), # Mocking the failure\n", + " AIMessage(content=\"I stand corrected. S-t-r-a-w-b-e-r-r-y. There are 3 'r's.\"), # Mocking the fix\n", + " ]\n", + " )\n", + " )\n", + " )\n", + "\n", + " # Define the graph nodes and router:\n", + " def call_model(state: AgentState):\n", + " response = model.invoke(state[\"messages\"])\n", + " return {\"messages\": [response], \"attempts\": state[\"attempts\"] + 1}\n", + "\n", + " def reflect_node(state: AgentState):\n", + " prompt = \"Wait, count the 'r's again slowly, letter by letter. Are you sure?\"\n", + " return {\"messages\": [HumanMessage(content=prompt)]}\n", + "\n", + " def router(state: AgentState) -> Literal[\"reflect\", END]:\n", + " # Make sure there are 2 attempts at least for an answer:\n", + " if state[\"attempts\"] == 1:\n", + " return \"reflect\"\n", + " return END\n", + "\n", + " # Build the graph:\n", + " builder = StateGraph(AgentState)\n", + " builder.add_node(\"model\", call_model)\n", + " tagged_reflect_node = RunnableLambda(reflect_node).with_config(tags=[_dummy_tag])\n", + " builder.add_node(\"reflect\", tagged_reflect_node)\n", + " builder.add_edge(START, \"model\")\n", + " builder.add_conditional_edges(\"model\", router)\n", + " builder.add_edge(\"reflect\", \"model\")\n", + " graph = builder.compile()\n", + "\n", + " # Run the graph:\n", + " return graph.invoke({\"messages\": [HumanMessage(content=\"How many 'r's in Strawberry?\")], \"attempts\": 0})" + ] + }, + { + "cell_type": "markdown", + "id": "49964f96-89ba-4f61-8788-38290a877aa2", + "metadata": {}, + "source": [ + "Let's create some traffic, we'll run whatever function you want in a loop to get some events. We take timestamps in order to use them later to run the monitoring application on the data we'll send." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b7e6418d-76f4-4b18-9ef9-c5bb40b20545", + "metadata": {}, + "outputs": [], + "source": [ + "# Run LangChain code and now it should be tracked and monitored in MLRun:\n", + "start_timestamp = datetime.datetime.now() - datetime.timedelta(minutes=1)\n", + "for i in range(20):\n", + " run_simple_agent()\n", + "end_timestamp = datetime.datetime.now() + datetime.timedelta(minutes=5)" + ] + }, + { + "cell_type": "markdown", + "id": "d9085765-91fd-4d31-84b4-927ecf9cc455", + "metadata": {}, + "source": "> **Note**: Please wait a minute or two until the events are processed." + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85fae3e4-5f1b-4f0c-ba71-81060f10804f", + "metadata": {}, + "outputs": [], + "source": [ + "time.sleep(60)" + ] + }, + { + "cell_type": "markdown", + "id": "2475ebec-fc32-4884-9723-3ca9cfde577f", + "metadata": {}, + "source": [ + "### Test the LangChain Monitoring Application\n", + "\n", + "To test a monitoring application, we use the `evaluate` class method. We'll run an evaluation on the data we just sent. It is a small local job and should run fast.\n", + "\n", + "Keep an eye for the returned metrics from the monitoring application." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "3d046755-9153-497a-a024-5d63316e1f91", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-01-08 14:49:22,970 [info] Changing function name - adding `\"-batch\"` suffix: {\"func_name\":\"testi-batch\"}\n", + "> 2026-01-08 14:49:23,143 [info] Storing function: {\"db\":\"http://mlrun-api:8080\",\"name\":\"testi-batch--handler\",\"uid\":\"43b34f848b6049c0949f04adc1090f10\"}\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
projectuiditerstartendstatekindnamelabelsinputsparametersresults
langchain-mlrun-70Jan 08 14:49:23NaTcompletedruntesti-batch--handler
v3io_user=guyl
kind=local
owner=guyl
host=jupyter-guyl-66647f988c-4kjd9
endpoints=['langchain_mlrun_endpoint']
start=2026-01-08T10:19:47.452879
end=2026-01-08T10:26:28.861851
base_period=None
write_output=False
existing_data_handling=fail_on_overlap
stream_profile=None
langchain_mlrun_endpoint-bb81af2058c14e7cbf58455aed3d69fc_2026-01-08T10:19:47.452879+00:00_2026-01-08T10:26:28.861851+00:00=[{metric_name: 'average_latency', metric_value: 1949.3444}, {metric_name: 'success_rate', metric_value: 1.0}, {metric_name: 'total_input_tokens', metric_value: 5480.0}, {metric_name: 'total_output_tokens', metric_value: 1404.0}, {metric_name: 'combined_total_tokens', metric_value: 6884.0}, {metric_name: 'run_name_counts_ChatOpenAI', metric_value: 40.0}, {metric_name: 'run_name_counts_model', metric_value: 40.0}, {metric_name: 'run_name_counts_get_weather', metric_value: 20.0}, {metric_name: 'run_name_counts_tools', metric_value: 40.0}, {metric_name: 'run_name_counts_get_stock_price', metric_value: 20.0}, {metric_name: 'run_name_counts_LangGraph', metric_value: 20.0}]
\n", + "
\n", + "
\n", + "
\n", + " Title\n", + " ×\n", + "
\n", + " \n", + "
\n", + "
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + " > to track results use the .show() or .logs() methods or click here to open in UI" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-01-08 14:49:23,944 [info] Run execution finished: {\"name\":\"testi-batch--handler\",\"status\":\"completed\"}\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LangChainMonitoringApp.evaluate(\n", + " func_name=\"langchain_monitoring_app_test\",\n", + " func_path=\"langchain_mlrun.py\",\n", + " run_local=True,\n", + " endpoints=[env_vars[\"MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_NAME\"]],\n", + " start=start_timestamp.isoformat(),\n", + " end=end_timestamp.isoformat(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "eda724c3-27f3-4d28-a7ba-1e59b9be2a37", + "metadata": {}, + "source": "### Deploy the Monitoring Application\n\nAll that's left to do now is to deploy our monitoring application!" + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "652b00d4-070d-4849-9784-4d461cb83eae", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> 2026-01-08 17:06:50,801 [info] Starting remote function deploy\n", + "2026-01-08 17:06:51 (info) Deploying function\n", + "2026-01-08 17:06:51 (info) Building\n", + "2026-01-08 17:06:52 (info) Staging files and preparing base images\n", + "2026-01-08 17:06:52 (warn) Using user provided base image, runtime interpreter version is provided by the base image\n", + "2026-01-08 17:06:52 (info) Building processor image\n", + "2026-01-08 17:08:52 (info) Build complete\n", + "2026-01-08 17:09:06 (info) Function deploy complete\n", + "> 2026-01-08 17:09:13,972 [info] Model endpoint creation task completed with state succeeded\n", + "> 2026-01-08 17:09:13,973 [info] Successfully deployed function: {\"external_invocation_urls\":[],\"internal_invocation_urls\":[\"nuclio-langchain-mlrun-7-langchain-monitoring-app.default-tenant.svc.cluster.local:8080\"]}\n" + ] + } + ], + "source": [ + "# Deploy the monitoring app:\n", + "LangChainMonitoringApp.deploy(\n", + " func_name=\"langchain_monitoring_app\",\n", + " func_path=\"langchain_mlrun.py\",\n", + " image=\"mlrun/mlrun\",\n", + " requirements=[\n", + " \"langchain\",\n", + " \"pydantic-settings\",\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c23bef7a-cbdb-4b22-a2d9-2edbfde5eb04", + "metadata": {}, + "source": [ + "Once it is deployed, you can run events again and see the monitoring application in MLRun UI in action:\n", + "\n", + "![mlrun ui example](./notebook_images/mlrun_ui.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f799f06-2e62-4e2f-a42f-c94b5fc18623", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mlrun-py311", + "language": "python", + "name": "conda-env-.conda-mlrun-py311-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/modules/src/langchain_mlrun/langchain_mlrun.py b/modules/src/langchain_mlrun/langchain_mlrun.py new file mode 100644 index 00000000..f7eb5f16 --- /dev/null +++ b/modules/src/langchain_mlrun/langchain_mlrun.py @@ -0,0 +1,1792 @@ +# Copyright 2026 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MLRun to LangChain integration - a tracer that converts LangChain Run objects into serializable event and send them to +MLRun monitoring. +""" + +from abc import ABC, abstractmethod +import copy +import importlib +import orjson +import os +import socket +from uuid import UUID +import threading +from contextlib import contextmanager +from contextvars import ContextVar +import datetime +from typing import Any, Callable, Generator, Optional + +from langchain_core.tracers import BaseTracer, Run +from langchain_core.tracers.context import register_configure_hook + +from pydantic import Field, field_validator, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict +from uuid_utils import uuid7 + +import mlrun +from mlrun.runtimes import RemoteRuntime +from mlrun.model_monitoring.applications import ( + ModelMonitoringApplicationBase, ModelMonitoringApplicationMetric, + ModelMonitoringApplicationResult, MonitoringApplicationContext, +) +import mlrun.common.schemas.model_monitoring.constants as mm_constants + +#: Environment variable name to use MLRun monitoring tracer via LangChain global tracing system: +mlrun_monitoring_env_var = "MLRUN_MONITORING_ENABLED" + + +class _MLRunEndPointClient(ABC): + """ + An MLRun model endpoint monitoring client base class to connect and send events on a monitoring stream. + """ + + def __init__( + self, + model_endpoint_name: str, + model_endpoint_uid: str, + serving_function: str | RemoteRuntime, + serving_function_tag: str | None = None, + project: str | mlrun.projects.MlrunProject = None, + ): + """ + Initialize an MLRun model endpoint monitoring client. + + :param model_endpoint_name: The monitoring endpoint related model name. + :param model_endpoint_uid: Model endpoint unique identifier. + :param serving_function: Serving function name or ``RemoteRuntime`` object. + :param serving_function_tag: Optional function tag (defaults to 'latest'). + :param project: Project name or ``MlrunProject``. If ``None``, uses the current project. + + raise: MLRunInvalidArgumentError: If there is no current active project and no `project` argument was provided. + """ + # Store the provided info: + self._model_endpoint_name = model_endpoint_name + self._model_endpoint_uid = model_endpoint_uid + + # Load project: + if project is None: + try: + self._project_name = mlrun.get_current_project(silent=False).name + except mlrun.errors.MLRunInvalidArgumentError: + raise mlrun.errors.MLRunInvalidArgumentError( + "There is no current active project. Either use `mlrun.get_or_create_project` prior to " + "initializing the monitoring tracer or pass a project name to load. You can also set the " + "environment variable: 'MLRUN_MONITORING_PROJECT'." + ) + elif isinstance(project, str): + self._project_name = project + else: + self._project_name = project.name + + # Load function: + if isinstance(serving_function, str): + self._serving_function_name = serving_function + self._serving_function_tag = serving_function_tag or "latest" + else: + self._serving_function_name = serving_function.metadata.name + self._serving_function_tag = ( + serving_function_tag or serving_function.metadata.tag + ) + + # Prepare the sample: + self._event_sample = { + "class": "CustomStream", + "worker": "0", + "model": self._model_endpoint_name, + "host": socket.gethostname(), + "function_uri": f"{self._project_name}/{self._serving_function_name}:{self._serving_function_tag}", + "endpoint_id": self._model_endpoint_uid, + "sampling_percentage": 100, + "request": {"inputs": [], "background_task_state": "succeeded"}, + "op": "infer", + "resp": { + "id": None, + "model_name": self._model_endpoint_name, + "outputs": [], + "timestamp": None, + "model_endpoint_uid": self._model_endpoint_uid, + }, + "when": None, + "microsec": 496, + "effective_sample_count": 1, + } + + @abstractmethod + def monitor( + self, + event_id: str, + label: str, + input_data: dict, + output_data: dict, + request_timestamp: str, + response_timestamp: str, + ): + """ + Monitor the provided event, sending it to the model endpoint monitoring stream. + + :param event_id: Unique event identifier used as the monitored record id. + :param label: Label for the run/event. + :param input_data: Serialized input data for the run. + :param output_data: Serialized output data for the run. + :param request_timestamp: Request/start timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + :param response_timestamp: Response/end timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + """ + pass + + def _create_event( + self, + event_id: str, + label: str, + input_data: dict, + output_data: dict, + request_timestamp: str, + response_timestamp: str, + ) -> dict: + """ + Create a new event out of the stored event sample. + + :param event_id: Unique event identifier used as the monitored record id. + :param label: Label for the run/event. + :param input_data: Serialized input data for the run. + :param output_data: Serialized output data for the run. + :param request_timestamp: Request/start timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + :param response_timestamp: Response/end timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + + :return: The event to send to the monitoring stream. + """ + # Copy the sample: + event = copy.deepcopy(self._event_sample) + + # Edit event with given parameters: + event["when"] = request_timestamp + event["request"]["inputs"].append(orjson.dumps({"label": label, "input": input_data}).decode('utf-8')) + event["resp"]["timestamp"] = response_timestamp + event["resp"]["outputs"].append(orjson.dumps(output_data).decode('utf-8')) + event["resp"]["id"] = event_id + + return event + + +class _V3IOMLRunEndPointClient(_MLRunEndPointClient): + """ + An MLRun model endpoint monitoring client to connect and send events on a V3IO stream. + """ + + def __init__( + self, + monitoring_stream_path: str, + monitoring_container: str, + model_endpoint_name: str, + model_endpoint_uid: str, + serving_function: str | RemoteRuntime, + serving_function_tag: str | None = None, + project: str | mlrun.projects.MlrunProject = None, + ): + """ + Initialize an MLRun model endpoint monitoring client. + + :param monitoring_stream_path: V3IO stream path. + :param monitoring_container: V3IO container name. + :param model_endpoint_name: The monitoring endpoint related model name. + :param model_endpoint_uid: Model endpoint unique identifier. + :param serving_function: Serving function name or ``RemoteRuntime`` object. + :param serving_function_tag: Optional function tag (defaults to 'latest'). + :param project: Project name or ``MlrunProject``. If ``None``, uses the current project. + + raise: MLRunInvalidArgumentError: If there is no current active project and no `project` argument was provided. + """ + super().__init__( + model_endpoint_name=model_endpoint_name, + model_endpoint_uid=model_endpoint_uid, + serving_function=serving_function, + serving_function_tag=serving_function_tag, + project=project, + ) + + import v3io + + # Store the provided info: + self._monitoring_stream_path = monitoring_stream_path + self._monitoring_container = monitoring_container + + # Initialize a V3IO client: + self._v3io_client = v3io.Client() + + def monitor( + self, + event_id: str, + label: str, + input_data: dict, + output_data: dict, + request_timestamp: str, + response_timestamp: str, + ): + """ + Monitor the provided event, sending it to the model endpoint monitoring stream. + + :param event_id: Unique event identifier used as the monitored record id. + :param label: Label for the run/event. + :param input_data: Serialized input data for the run. + :param output_data: Serialized output data for the run. + :param request_timestamp: Request/start timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + :param response_timestamp: Response/end timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + """ + # Copy the sample: + event = self._create_event( + event_id=event_id, + label=label, + input_data=input_data, + output_data=output_data, + request_timestamp=request_timestamp, + response_timestamp=response_timestamp, + ) + + # Push to stream: + self._v3io_client.stream.put_records( + container=self._monitoring_container, + stream_path=self._monitoring_stream_path, + records=[{"data": orjson.dumps(event).decode('utf-8')}], + ) + + +class _KafkaMLRunEndPointClient(_MLRunEndPointClient): + """ + An MLRun model endpoint monitoring client to connect and send events on a Kafka stream. + """ + + def __init__( + self, + monitoring_broker: str, + monitoring_topic: str, + # TODO: Add more Kafka producer options if needed... + model_endpoint_name: str, + model_endpoint_uid: str, + serving_function: str | RemoteRuntime, + serving_function_tag: str | None = None, + project: str | mlrun.projects.MlrunProject = None, + ): + """ + Initialize an MLRun model endpoint monitoring client. + + :param monitoring_broker: Kafka broker name. + :param monitoring_topic: Kafka topic name. + TODO: Add more Kafka producer options if needed... + :param model_endpoint_name: The monitoring endpoint related model name. + :param model_endpoint_uid: Model endpoint unique identifier. + :param serving_function: Serving function name or ``RemoteRuntime`` object. + :param serving_function_tag: Optional function tag (defaults to 'latest'). + :param project: Project name or ``MlrunProject``. If ``None``, uses the current project. + + raise: MLRunInvalidArgumentError: If there is no current active project and no `project` argument was provided. + """ + super().__init__( + model_endpoint_name=model_endpoint_name, + model_endpoint_uid=model_endpoint_uid, + serving_function=serving_function, + serving_function_tag=serving_function_tag, + project=project, + ) + + import kafka + + # Store the provided info: + self._monitoring_broker = monitoring_broker + self._monitoring_topic = monitoring_topic + + # Initialize a Kafka producer: + self._kafka_producer = kafka.KafkaProducer( + ... + ) + + def monitor( + self, + event_id: str, + label: str, + input_data: dict, + output_data: dict, + request_timestamp: str, + response_timestamp: str, + ): + """ + Monitor the provided event, sending it to the model endpoint monitoring stream. + + :param event_id: Unique event identifier used as the monitored record id. + :param label: Label for the run/event. + :param input_data: Serialized input data for the run. + :param output_data: Serialized output data for the run. + :param request_timestamp: Request/start timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + :param response_timestamp: Response/end timestamp in the format of '%Y-%m-%d %H:%M:%S%z'. + """ + # Copy the sample: + event = self._create_event( + event_id=event_id, + label=label, + input_data=input_data, + output_data=output_data, + request_timestamp=request_timestamp, + response_timestamp=response_timestamp, + ) + + # Push to stream: + self._kafka_producer.send( + topic=self._monitoring_topic, + value=orjson.dumps(event), + key=self._model_endpoint_uid, + ) + + +class MLRunTracerClientSettings(BaseSettings): + """ + MLRun tracer monitoring client configurations. These are mandatory arguments for allowing MLRun to send monitoring + events to a specific model endpoint stream. + """ + + v3io_stream_path: str | None = None + """ + The V3IO stream path to send the events to. + """ + + v3io_container: str | None = None + """ + The V3IO stream container. + """ + + kafka_broker: str | None = None + """ + The Kafka broker address. + """ + + kafka_topic: str | None = None + """ + The Kafka topic name. + """ + + # TODO: Add more Kafka producer options if needed... + + model_endpoint_name: str = ... + """ + The model endpoint name. + """ + + model_endpoint_uid: str = ... + """ + The model endpoint UID. + """ + + serving_function: str = ... + """ + The serving function name. + """ + + serving_function_tag: str | None = None + """ + The serving function tag. If not set, it will be 'latest' by default. + """ + + project: str | None = None + """ + The MLRun project name related to the serving function and model endpoint. + """ + + #: Pydantic model configuration to set the environment variable prefix. + model_config = SettingsConfigDict(env_prefix="MLRUN_TRACER_CLIENT_") + + @model_validator(mode='after') + def check_exclusive_sets(self) -> 'MLRunTracerClientSettings': + """ + Validate that either V3IO settings or Kafka settings are provided, but not both or none. + + :return: The validated settings instance. + """ + # Define the sets + v3io_settings = all([self.v3io_container, self.v3io_stream_path]) + kafka_settings = all([self.kafka_topic, self.kafka_broker]) # TODO: Add mandatory other kafka settings + + # Make sure only one set is provided: + if v3io_settings and kafka_settings: + raise ValueError("Provide either V3IO settings OR Kafka settings, not both.") + if not v3io_settings and not kafka_settings: + raise ValueError( + "You must provide either a complete V3IO settings or complete Kafka settings. See docs for more " + "information" + ) + + return self + +class MLRunTracerMonitorSettings(BaseSettings): + """ + MLRun tracer monitoring configurations. These are optional arguments to customize the LangChain runs summarization + into monitorable MLRun endpoint events. If needed, a custom summarization can be passed. + """ + + label: str = "default" + """ + Label to use for all monitored runs. Can be used to differentiate between different monitored sources on the same + endpoint. + """ + + tags_filter: list[str] | None = None + """ + Filter runs by tags. Only runs with at least one tag in this list will be monitored. + If None, no tag-based filtering is applied and runs with any tags are considered. + Default: None. + """ + + run_types_filter: list[str] | None = None + """ + Filter runs by run types (e.g. "chain", "llm", "chat", "tool"). + Only runs whose `run_type` appears in this list will be monitored. + If None, no run-type filtering is applied. + Default: None. + """ + + names_filter: list[str] | None = None + """ + Filter runs by class/name. Only runs whose `name` appears in this list will be monitored. + If None, no name-based filtering is applied. + Default: None. + """ + + include_full_run: bool = False + """ + If True, include the complete serialized run dict (the output of `run._get_dicts_safe()`) + in the event outputs under the key `full_run`. Useful for debugging or when consumers need + the raw run payload. Default: False. + """ + + include_errors: bool = True + """ + If True, include run error information in the outputs under the `error` key. + If False, runs that contain an error may be skipped by the summarizer filters. + Default: True. + """ + + include_metadata: bool = True + """ + If True, include run metadata (environment, tool metadata, etc.) in the inputs under + the `metadata` key. Default: True. + """ + + include_latency: bool = True + """ + If True, include latency information in the outputs under the `latency` key. + Default: True. + """ + + root_run_only: bool = False + """ + If True, only the root/top-level run will be monitored and any child runs will be + ignored/removed from monitoring. Use when only the top-level run should produce events. + Default: False. + """ + + split_runs: bool = False + """ + If True, child runs are emitted as separate monitoring events (each run summarized and + sent individually). If False, child runs are nested inside the parent/root run event under + `child_runs`. Default: False. + """ + + run_summarizer_function: ( + str + | Callable[ + [Run, Optional[BaseSettings]], + Generator[tuple[dict, dict] | None, None, None], + ] + | None + ) = None + """ + A function to summarize a `Run` object into a tuple of inputs and outputs. Can be passed directly or via a full + module path ("a.b.c.my_summarizer" will be imported as `from a.b.c import my_summarizer`). + + A summarizer is a function that will be used to process a run into monitoring events. The function is expected to be + of type: + `Callable[[Run, Optional[BaseSettings]], Generator[tuple[dict, dict] | None, None, None]]`, meaning + get a run object and optionally a settings object and return a generator yielding tuples of serialized dictionaries, + the (inputs, outputs) to send to MLRun monitoring as events or `None` to skip monitoring this run. + """ + + run_summarizer_settings: str | BaseSettings | None = None + """ + Settings to pass to the run summarizer function. Can be passed directly or via a full module path to be imported + and initialized. If the summarizer function does not require settings, this can be left as None. + """ + + debug: bool = False + """ + If True, disable sending events to MLRun and instead route events to `debug_target_list` + or print them as JSON to stdout. Useful for unit tests and local debugging. Default: False. + """ + + debug_target_list: list[dict] | bool = False + """ + Optional list to which debug events will be appended when `debug` is True. + If set, each generated event dict will be appended to this list. If not set and `debug` is True, + events will be printed to stdout as JSON. Default: False. + """ + + #: Pydantic model configuration to set the environment variable prefix. + model_config = SettingsConfigDict(env_prefix="MLRUN_TRACER_MONITOR_") + + @field_validator('debug_target_list', mode='before') + @classmethod + def convert_bool_to_list(cls, v): + """ + Convert a boolean `True` value to an empty list for `debug_target_list`. + + :param v: The value to validate. + + :returns: An empty list if `v` is True, otherwise the original value. + """ + if v is True: + return [] + return v + + +class MLRunTracerSettings(BaseSettings): + """ + MLRun tracer settings to configure the tracer. The settings are split into two groups: + + * `client`: settings required to connect and send events to the MLRun monitoring stream. + * `monitor`: settings controlling which LangChain runs are summarized and sent and how. + """ + + client: MLRunTracerClientSettings = Field(default_factory=MLRunTracerClientSettings) + """ + Client configuration group (``MLRunTracerClientSettings``). + + Contains the mandatory connection and endpoint information required to publish monitoring + events. Values may be supplied programmatically or via environment variables prefixed with + `MLRUN_TRACER_CLIENT_`. See more at ``MLRunTracerClientSettings``. + """ + + monitor: MLRunTracerMonitorSettings = Field(default_factory=MLRunTracerMonitorSettings) + """ + Monitoring configuration group (``MLRunTracerMonitorSettings``). + + Controls what runs are captured, how they are summarized (including custom summarizer import + options), whether child runs are split or nested, and debug behavior. Values may be supplied + programmatically or via environment variables prefixed with `MLRUN_TRACER_MONITOR_`. + See more at ``MLRunTracerMonitorSettings``. + """ + + #: Pydantic model configuration to set the environment variable prefix. + model_config = SettingsConfigDict(env_prefix="MLRUN_TRACER_") + + +class MLRunTracer(BaseTracer): + """ + MLRun tracer for LangChain runs allowing monitoring LangChain and LangGraph in production using MLRun's monitoring. + + There are two usage modes for the MLRun tracer following LangChain tracing best practices: + + 1. **Manual Mode** - Using the ``mlrun_monitoring`` context manager:: + + from mlrun_tracer import mlrun_monitoring + + with mlrun_monitoring(...) as tracer: + # LangChain code here. + pass + + 2. **Auto Mode** - Setting the `MLRUN_MONITORING_ENABLED="1"` environment variable:: + + import mlrun_integration.tracer + + # All LangChain code will be automatically traced and monitored. + pass + + To control how runs are being summarized into the events being monitored, the ``MLRunTracerSettings`` can be set. + As it is a Pydantic ``BaseSettings`` class, it can be done in two ways: + + 1. Initializing the settings classes and passing them to the context manager:: + + from mlrun_tracer import ( + mlrun_monitoring, + MLRunTracerSettings, + MLRunTracerClientSettings, + MLRunTracerMonitorSettings, + ) + + my_settings = MLRunTracerSettings( + client=MLRunTracerClientSettings(), + monitor=MLRunTracerMonitorSettings(root_run_only=True), + ) + + with mlrun_monitoring(settings=my_settings) as tracer: + # LangChain code here. + pass + + 2. Or via environment variables following the prefix 'MLRUN_TRACER_CLIENT_' for client settings and + 'MLRUN_TRACER_MONITOR_' for monitoring settings. + """ + + #: A singleton tracer for when using the tracer via environment variable to activate global tracing. + _singleton_tracer: "MLRunTracer | None" = None + #: A thread lock for initializing the tracer singleton safely. + _lock = threading.Lock() + #: A boolean flag to know whether the singleton was initialized. + _initialized = False + + def __new__(cls, *args, **kwargs) -> "MLRunTracer": + """ + Create or return an ``MLRunTracer`` instance. + + When ``MLRUN_MONITORING_ENABLED`` is not set to ``"1"``, a normal instance is returned. + When the env var is ``"1"``, a process-wide singleton is returned. Creation is thread-safe. + + :returns: MLRunTracer instance (singleton if 'auto' mode is active). + """ + # Check if needed to use a singleton as the user is using the MLRun tracer by setting the environment variable + # and not manually (via context manager): + if not cls._check_for_env_var_usage(): + return super(MLRunTracer, cls).__new__(cls) + + # Check if the singleton is set: + if cls._singleton_tracer is None: + # Acquire lock to initialize the singleton: + with cls._lock: + # Double-check after acquiring lock: + if cls._singleton_tracer is None: + cls._singleton_tracer = super(MLRunTracer, cls).__new__(cls) + + return cls._singleton_tracer + + def __init__(self, settings: MLRunTracerSettings = None, **kwargs): + """ + Initialize the tracer. + + :param settings: Settings to use for the tracer. If not passed, defaults are used and environment variables are + applied per Pydantic settings behavior. + :param kwargs: Passed to the base initializer. + """ + # Proceed with initialization only if singleton mode is not required or the singleton was not initialized: + if self._check_for_env_var_usage() and self._initialized: + return + + # Call the base tracer init: + super().__init__(**kwargs) + + # Set a UID for this instance: + self._uid = uuid7() + + # Set the settings: + self._settings = settings or MLRunTracerSettings() + self._client_settings = self._settings.client + self._monitor_settings = self._settings.monitor + + # Initialize the MLRun endpoint client: + self._mlrun_client = ( + self._get_mlrun_client() + if not self._monitor_settings.debug + else None + ) + + # In case the user passed a custom summarizer, import it: + self._custom_run_summarizer_function: ( + Callable[ + [Run, Optional[BaseSettings]], + Generator[tuple[dict, dict] | None, None, None], + ] + | None + ) = None + self._custom_run_summarizer_settings: BaseSettings | None = None + self._import_custom_run_summarizer() + + # Mark the initialization flag (for the singleton case): + self._initialized = True + + @property + def settings(self) -> MLRunTracerSettings: + """ + Access the effective settings. + + :returns: The settings used by this tracer. + """ + return self._settings + + def _get_mlrun_client(self) -> _MLRunEndPointClient: + """ + Create and return an MLRun model endpoint monitoring client based on the MLRun (CE or not) and current + configuration. + + :returns: An MLRun model endpoint monitoring client. + """ + if mlrun.mlconf.is_ce_mode(): + return _KafkaMLRunEndPointClient( + # TODO: Add more Kafka producer options if needed... + model_endpoint_name=self._client_settings.model_endpoint_name, + model_endpoint_uid=self._client_settings.model_endpoint_uid, + serving_function=self._client_settings.serving_function, + serving_function_tag=self._client_settings.serving_function_tag, + project=self._client_settings.project, + ) + return _V3IOMLRunEndPointClient( + monitoring_stream_path=self._client_settings.v3io_stream_path, + monitoring_container=self._client_settings.v3io_container, + model_endpoint_name=self._client_settings.model_endpoint_name, + model_endpoint_uid=self._client_settings.model_endpoint_uid, + serving_function=self._client_settings.serving_function, + serving_function_tag=self._client_settings.serving_function_tag, + project=self._client_settings.project, + ) + + def _import_custom_run_summarizer(self): + """ + Import or assign a custom run summarizer (and its custom settings) if configured. + """ + # If the user did not pass a run summarizer function, return: + if not self._monitor_settings.run_summarizer_function: + return + + # Check if the function needs to be imported: + if isinstance(self._monitor_settings.run_summarizer_function, str): + self._custom_run_summarizer_function = self._import_from_module_path( + module_path=self._monitor_settings.run_summarizer_function + ) + else: + self._custom_run_summarizer_function = ( + self._monitor_settings.run_summarizer_function + ) + + # Check if the user passed settings as well: + if self._monitor_settings.run_summarizer_settings: + # Check if the settings need to be imported: + if isinstance(self._monitor_settings.run_summarizer_settings, str): + self._custom_run_summarizer_settings = self._import_from_module_path( + module_path=self._monitor_settings.run_summarizer_settings + )() + else: + self._custom_run_summarizer_settings = ( + self._monitor_settings.run_summarizer_settings + ) + + def _persist_run(self, run: Run, level: int = 0) -> None: + """ + Summarize the run (and its children) into MLRun monitoring events. + + Note: This will use the MLRun tracer's default summarization that can be configured via + ``MLRunTracerMonitorSettings``, unless a custom summarizer was provided (via the same settings). + + :param run: LangChain run object to process holding all the nested tree of runs. + :param level: The nesting level of the run (0 for root runs, incremented for child runs). + """ + # Serialize the run: + serialized_run = self._serialize_run( + run=run, + include_child_runs=not (self._settings.monitor.root_run_only or self._settings.monitor.split_runs) + ) + + # Check for a user custom run summarizer function: + if self._custom_run_summarizer_function: + for summarized_run in self._custom_run_summarizer_function( + run, self._custom_run_summarizer_settings + ): + if summarized_run: + inputs, outputs = summarized_run + self._send_run_event( + event_id=serialized_run["id"], + inputs=inputs, + outputs=outputs, + start_time=run.start_time, + end_time=run.end_time, + ) + return + + # Check how to deal with the child runs, monitor them in separate events or as a single event: + if self._monitor_settings.split_runs and not self._settings.monitor.root_run_only: + # Monitor as separate events: + for child_run in run.child_runs: + self._persist_run(run=child_run, level=level + 1) + summarized_run = self._summarize_run(serialized_run=serialized_run, include_children=False) + if summarized_run: + inputs, outputs = summarized_run + inputs["child_level"] = level + self._send_run_event( + event_id=serialized_run["id"], + inputs=inputs, + outputs=outputs, + start_time=run.start_time, + end_time=run.end_time, + ) + return + + # Monitor the root event (include child runs if `root_run_only` is False): + summarized_run = self._summarize_run( + serialized_run=serialized_run, + include_children=not self._monitor_settings.root_run_only + ) + if not summarized_run: + return + inputs, outputs = summarized_run + inputs["child_level"] = level + self._send_run_event( + event_id=serialized_run["id"], + inputs=inputs, + outputs=outputs, + start_time=run.start_time, + end_time=run.end_time, + ) + + + def _serialize_run(self, run: Run, include_child_runs: bool) -> dict: + """ + Serialize a LangChain run into a dictionary. + + :param run: The run to serialize. + :param include_child_runs: Whether to include child runs in the serialization. + + :returns: The serialized run dictionary. + """ + # In LangChain 1.2.3+, the Run model uses Pydantic v2 with child_runs marked as Field(exclude=True), so we + # must manually serialize child runs. Still excluding manually for future compatibility. In previous + # LangChain versions, Run was Pydantic v1, so we use dict. + serialized_run = ( + run.model_dump(exclude={"child_runs"}) + if hasattr(run, "model_dump") + else run.dict(exclude={"child_runs"}) + ) + + # Manually serialize child runs if needed: + if include_child_runs and run.child_runs: + serialized_run["child_runs"] = [ + self._serialize_run(child_run, include_child_runs=True) + for child_run in run.child_runs + ] + + return orjson.loads(orjson.dumps(serialized_run, default=self._serialize_default)) + + def _serialize_default(self, obj: Any): + """ + Default serializer for objects present in LangChain run that are not serializable by default JSON encoder. It + includes handling Pydantic v1 and v2 models, UUIDs, and datetimes. + + :param obj: The object to serialize. + + :returns: The serialized object. + """ + if isinstance(obj, UUID): + return str(obj) + if isinstance(obj, datetime.datetime): + return obj.isoformat() + if hasattr(obj, "model_dump"): + return orjson.loads(orjson.dumps(obj.model_dump(), default=self._serialize_default)) + if hasattr(obj, "dict"): + return orjson.loads(orjson.dumps(obj.dict(), default=self._serialize_default)) + return str(obj) + + def _filter_by_tags(self, serialized_run: dict) -> bool: + """ + Apply tag-based filtering. + + :param serialized_run: Serialized run dictionary. + + :returns: True if the run passes tag filters or if no tag filter is configured. + """ + # Check if the user enabled filtering by tags: + if not self._monitor_settings.tags_filter: + return True + + # Filter the run: + return not set(self._monitor_settings.tags_filter).isdisjoint( + serialized_run["tags"] + ) + + def _filter_by_run_types(self, serialized_run: dict) -> bool: + """ + Apply run-type filtering. + + :param serialized_run: Serialized run dictionary. + + :returns: True if the run's ``run_type`` is allowed or if no run-type filter is configured. + """ + # Check if the user enabled filtering by run types: + if not self._monitor_settings.run_types_filter: + return True + + # Filter the run: + return serialized_run["run_type"] in self._monitor_settings.run_types_filter + + def _filter_by_names(self, serialized_run: dict) -> bool: + """ + Apply class/name filtering. + + :param serialized_run: Serialized run dictionary. + + :returns: True if the run's ``name`` is allowed or if no name filter is configured. + """ + # Check if the user enabled filtering by class names: + if not self._monitor_settings.names_filter: + return True + + # Filter the run: + return serialized_run["name"] in self._monitor_settings.names_filter + + def _get_run_inputs(self, serialized_run: dict) -> dict[str, Any]: + """ + Build the inputs dictionary for a monitoring event. + + :param serialized_run: Serialized run dictionary. + + :returns: A dictionary containing inputs, run metadata and (optionally) additional metadata. + """ + inputs = { + "inputs": serialized_run["inputs"], + "run_type": serialized_run["run_type"], + "run_name": serialized_run["name"], + "tags": serialized_run["tags"], + "run_id": serialized_run["id"], + "start_timestamp": serialized_run["start_time"], + } + if "parent_run_id" in serialized_run: + # Parent run ID is excluded when child runs are joined in the same event. When child runs are split, it is + # included and can be used to reconstruct the run tree if needed. + inputs = {**inputs, "parent_run_id": serialized_run["parent_run_id"]} + if self._monitor_settings.include_metadata and "metadata" in serialized_run: + inputs = {**inputs, "metadata": serialized_run["metadata"]} + + return inputs + + def _get_run_outputs(self, serialized_run: dict) -> dict[str, Any]: + """ + Build the outputs dictionary for a monitoring event. + + :param serialized_run: Serialized run dictionary. + + :returns: A dictionary with outputs and optional other collected info depending on monitor settings. + """ + outputs = {"outputs": serialized_run["outputs"], "end_timestamp": serialized_run["end_time"]} + if self._monitor_settings.include_latency and "latency" in serialized_run: + outputs = {**outputs, "latency": serialized_run["latency"]} + if self._monitor_settings.include_errors: + outputs = {**outputs, "error": serialized_run["error"]} + if self._monitor_settings.include_full_run: + outputs = {**outputs, "full_run": serialized_run} + + return outputs + + def _summarize_run(self, serialized_run: dict, include_children: bool) -> tuple[dict, dict] | None: + """ + Summarize a single run into (inputs, outputs) if it passes filters. + + :param serialized_run: Serialized run dictionary. + :param include_children: Whether to include child runs. + + :returns: The summarized run (inputs, outputs) tuple if the run should be monitored, otherwise ``None``. + """ + # Pass filters: + if not ( + self._filter_by_tags(serialized_run=serialized_run) + and self._filter_by_run_types(serialized_run=serialized_run) + and self._filter_by_names(serialized_run=serialized_run) + ): + return None + + # Check if needed to include errors: + if serialized_run["error"] and not self._monitor_settings.include_errors: + return None + + # Prepare the inputs and outputs: + inputs = self._get_run_inputs(serialized_run=serialized_run) + outputs = self._get_run_outputs(serialized_run=serialized_run) + + # Check if needed to include child runs: + if include_children: + outputs["child_runs"] = [] + for child_run in serialized_run.get("child_runs", []): + # Recursively summarize the child run: + summarized_child_run = self._summarize_run(serialized_run=child_run, include_children=True) + if summarized_child_run: + inputs_child, outputs_child = summarized_child_run + outputs["child_runs"].append( + { + "input_data": inputs_child, + "output_data": outputs_child, + } + ) + + return inputs, outputs + + def _send_run_event( + self, event_id: str, inputs: dict, outputs: dict, start_time: datetime.datetime, end_time: datetime.datetime + ): + """ + Send a monitoring event for a single run. + + Note: If monitor debug mode is enabled, appends to ``debug_target_list`` or prints JSON. + + :param event_id: Unique event identifier. + :param inputs: Inputs dictionary for the event. + :param outputs: Outputs dictionary for the event. + :param start_time: Request/start timestamp. + :param end_time: Response/end timestamp. + """ + event = { + "event_id": event_id, + "label": self._monitor_settings.label, + "input_data": {"input_data": inputs}, # So it will be a single "input feature" in MLRun monitoring. + "output_data": {"output_data": outputs}, # So it will be a single "output feature" in MLRun monitoring. + "request_timestamp": start_time.strftime("%Y-%m-%d %H:%M:%S%z"), + "response_timestamp": end_time.strftime("%Y-%m-%d %H:%M:%S%z"), + } + if self._monitor_settings.debug: + if isinstance(self._monitor_settings.debug_target_list, list): + self._monitor_settings.debug_target_list.append(event) + else: + print(orjson.dumps(event, option=orjson.OPT_INDENT_2 | orjson.OPT_APPEND_NEWLINE)) + return + + self._mlrun_client.monitor(**event) + + @staticmethod + def _check_for_env_var_usage() -> bool: + """ + Check whether global env-var activated tracing is requested. + + :returns: True when ``MLRUN_MONITORING_ENABLED`` environment variable equals ``"1"``. + """ + return os.environ.get(mlrun_monitoring_env_var, "0") == "1" + + @staticmethod + def _import_from_module_path(module_path: str) -> Any: + """ + Import an object from a full module path string. + + :param module_path: Full dotted path, e.g. ``a.b.module.object``. + + :returns: The imported object. + + raise: ValueError: If ``module_path`` is not a valid Python module path. + raise: ImportError: If module cannot be imported. + raise: AttributeError: If the object name is not found in the module. + """ + try: + module_name, object_name = module_path.rsplit(".", 1) + module = importlib.import_module(module_name) + obj = getattr(module, object_name) + except ValueError as value_error: + raise ValueError( + f"The provided '{module_path}' is not valid: it must have at least one '.'. " + f"If the class is locally defined, please add '__main__.MyObject' to the path." + ) from value_error + except ImportError as import_error: + raise ImportError( + f"Could not import '{module_path}'. Tried to import '{module_name}' and failed with the following " + f"error: {import_error}." + ) from import_error + except AttributeError as attribute_error: + raise AttributeError( + f"Could not import '{object_name}'. Tried to run 'from {module_name} import {object_name}' and could " + f"not find it: {attribute_error}" + ) from attribute_error + + return obj + + +#: MLRun monitoring context variable to set when the user wraps his code with `mlrun_monitoring`. From this context +# variable LangChain will get the tracer in a thread-safe way. +mlrun_monitoring_var: ContextVar[MLRunTracer | None] = ContextVar( + "mlrun_monitoring", default=None +) + + +@contextmanager +def mlrun_monitoring(settings: MLRunTracerSettings | None = None): + """ + Context manager to enable MLRun tracing for LangChain code to monitor LangChain runs. + + Example usage:: + + from mlrun_tracer import mlrun_monitoring, MLRunTracerSettings + + settings = MLRunTracerSettings(...) + with mlrun_monitoring(settings=settings) as tracer: + # LangChain execution within this block will be traced by `tracer`. + ... + + :param settings: The settings to use to configure the tracer. + """ + mlrun_tracer = MLRunTracer(settings=settings) + token = mlrun_monitoring_var.set(mlrun_tracer) + try: + yield mlrun_tracer + finally: + mlrun_monitoring_var.reset(token) + + +# Register a hook for LangChain to apply the MLRun tracer: +register_configure_hook( + context_var=mlrun_monitoring_var, + inheritable=True, # To allow inner runs (agent that uses a tool that uses a llm...) to be traced. + env_var=mlrun_monitoring_env_var, + handle_class=MLRunTracer, +) + + +# Temporary convenient function to set up the monitoring infrastructure required for the tracer. +def setup_langchain_monitoring( + project: str | mlrun.MlrunProject = None, + function_name: str = "langchain_mlrun_function", + model_name: str = "langchain_mlrun_model", + model_endpoint_name: str = "langchain_mlrun_endpoint", + v3io_container: str = "projects", + v3io_stream_path: str = None, + # TODO: Add Kafka parameters when Kafka monitoring is supported. +) -> dict: + """ + Create a model endpoint in the given project to be used for LangChain monitoring with MLRun and returns the + necessary environment variables to configure the MLRun tracer client. The project should already exist and have + monitoring enabled:: + + project.set_model_monitoring_credentials( + stream_profile_name=..., + tsdb_profile_name=... + ) + + This function creates and logs dummy model and function in the specified project in order to create the model + endpoint for monitoring. It is a temporary workaround and will be added as a feature in a future MLRun version. + + :param project: The MLRun project name or object where to create the model endpoint. If None, the current active + project will be used. + :param function_name: The name of the serving function to create. + :param model_name: The name of the model to create. + :param model_endpoint_name: The name of the model endpoint to create. + :param v3io_container: The V3IO container where the monitoring stream is located. + :param v3io_stream_path: The V3IO stream path for monitoring. If None, + ``/model-endpoints/stream-v1`` will be used. + TODO: Add Kafka parameters when Kafka monitoring is supported. + + :returns: A dictionary with the necessary environment variables to configure the MLRun tracer client. + + raise: MLRunInvalidArgumentError: If no project is provided and there is no current active project. + """ + import io + import time + import sys + from contextlib import redirect_stdout, redirect_stderr + import tempfile + import pickle + import json + + from mlrun.common.helpers import parse_versioned_object_uri + from mlrun.features import Feature + + class ProgressStep: + """ + A context manager to display progress of a code block with timing and optional output suppression. + """ + + def __init__(self, label: str, indent: int = 2, width: int = 40, clean: bool = True): + """ + Initialize the ProgressStep context manager. + + :param label: The label to display for the progress step. + :param indent: The number of spaces to indent the label. + :param width: The width to pad the label for alignment. + :param clean: Whether to suppress stdout and stderr during the block execution. + """ + # Store parameters: + self._label = label + self._indent = indent + self._width = width + self._clean = clean + + # Internal state: + self._start_time = None + self._sink = io.StringIO() + self._stdout_redirect = None + self._stderr_redirect = None + self._last_line_length = 0 # To track the line printed when terminals don't support '\033[K'. + + # Capture the stream currently in use (before and if clean is true and we redirect it): + self._terminal = sys.stdout + + def __enter__(self): + """ + Enter the context manager, starting the timer and printing the initial status. + """ + # Start timer: + self._start_time = time.perf_counter() + + # Print without newline (using \r to allow overwriting): + self._write(icon=" ", status="Running", new_line=False) + + # Silence all internal noise: + if self._clean: + self._stdout_redirect = redirect_stdout(self._sink) + self._stderr_redirect = redirect_stderr(self._sink) + self._stdout_redirect.__enter__() + self._stderr_redirect.__enter__() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit the context manager, stopping the timer and printing the final status. + + :param exc_type: The exception type, if any. + :param exc_val: The exception value, if any. + :param exc_tb: The exception traceback, if any. + """ + # Restore stdout/stderr: + if self._clean: + self._stdout_redirect.__exit__(exc_type, exc_val, exc_tb) + self._stderr_redirect.__exit__(exc_type, exc_val, exc_tb) + + # Calculate elapsed time: + elapsed = time.perf_counter() - self._start_time + + # Move cursor back to start of line ('\r') and overwrite ('\033[K' clears the line to the right): + if exc_type is None: + self._write(icon="✓", status=f"Done ({elapsed:.2f}s)", new_line=True) + else: + self._write(icon="✕", status="Failed", new_line=True) + + def update(self, status: str): + """ + Update the status message displayed for the progress step. + + :param status: The new status message to display. + """ + self._write(icon=" ", status=status, new_line=False) + + def _write(self, icon: str, status: str, new_line: bool): + """ + Write the progress line to the terminal, handling line clearing for terminals that do not support it. + + :param icon: The icon to display (e.g., checkmark, cross, space). + :param status: The status message to display. + :param new_line: Whether to end the line with a newline character. + """ + # Construct the basic line + line = f"\r{' ' * self._indent}[{icon}] {self._label.ljust(self._width, '.')} {status}" + + # Calculate if we need to pad with spaces to clear the old, longer line: + padding = max(0, self._last_line_length - len(line)) + + # Add spaces to clear old text (add the ANSI clear for terminals that support it): + line = f"{line}{' ' * padding}\033[K" + + # Add newline if needed: + if new_line: + line += "\n" + + # Write to terminal: + self._terminal.write(line) + self._terminal.flush() + + # Update the max length seen so far: + self._last_line_length = len(line) + + print("Creating LangChain model endpoint\n") + + # Get the project: + with ProgressStep("Loading Project"): + if project is None: + try: + project = mlrun.get_current_project(silent=False) + except mlrun.errors.MLRunInvalidArgumentError: + raise mlrun.errors.MLRunInvalidArgumentError( + "There is no current active project. Either use `mlrun.get_or_create_project` prior to " + "creating the monitoring endpoint or pass a project name to load." + ) + if isinstance(project, str): + project = mlrun.load_project(name=project) + + # Create and log the dummy model: + with ProgressStep(f"Creating Model") as progress_step: + # Check if the model already exists: + progress_step.update("Checking if model exists") + try: + dummy_model = project.get_artifact(key=model_name) + except mlrun.MLRunNotFoundError: + dummy_model = None + # If not, create and log it: + if not dummy_model: + progress_step.update(f"Logging model '{model_name}'") + with tempfile.TemporaryDirectory() as tmpdir: + # Create a dummy model file: + dummy_model_path = os.path.join(tmpdir, "for_langchain_mlrun_tracer.pkl") + with open(dummy_model_path, "wb") as f: + pickle.dump({"dummy": "model"}, f) + # Log the model: + dummy_model = project.log_model( + key=model_name, + model_file=dummy_model_path, + inputs=[Feature(value_type="str", name="input")], + outputs=[Feature(value_type='str', name="output")] + ) + + # Create and set the dummy function: + with ProgressStep("Creating Function") as progress_step: + # Check if the function already exists: + progress_step.update("Checking if function exists") + try: + dummy_function = project.get_function(key=function_name) + except mlrun.MLRunNotFoundError: + dummy_function = None + # If not, create and save it: + if not dummy_function: + progress_step.update(f"Setting function '{function_name}'") + with tempfile.TemporaryDirectory() as tmpdir: + # Create a dummy function file: + dummy_function_code = """ +def handler(context, event): + return "ok" +""" + dummy_function_path = os.path.join(tmpdir, "dummy_function.py") + with open(dummy_function_path, "w") as f: + f.write(dummy_function_code) + # Set the function in the project: + dummy_function = project.set_function( + func=dummy_function_path, name=function_name, image="mlrun/mlrun", kind="nuclio" + ) + dummy_function.save() + + # Create the model endpoint: + with ProgressStep("Creating Model Endpoint") as progress_step: + # Get the MLRun DB: + progress_step.update("Getting MLRun DB") + db = mlrun.get_run_db() + # Check if the model endpoint already exists: + progress_step.update("Checking if endpoint exists") + model_endpoint = project.list_model_endpoints(names=[model_endpoint_name]).endpoints + if model_endpoint: + model_endpoint = model_endpoint[0] + else: + progress_step.update("Creating model endpoint") + model_endpoint = mlrun.common.schemas.ModelEndpoint( + metadata=mlrun.common.schemas.ModelEndpointMetadata( + project=project.name, + name=model_endpoint_name, + endpoint_type=mlrun.common.schemas.model_monitoring.EndpointType.NODE_EP, + ), + spec=mlrun.common.schemas.ModelEndpointSpec( + function_name=dummy_function.metadata.name, + function_tag="latest", + model_path=dummy_model.uri, + model_class="CustomStream", + ), + status=mlrun.common.schemas.ModelEndpointStatus( + monitoring_mode=mm_constants.ModelMonitoringMode.enabled, + ), + ) + db.create_model_endpoint(model_endpoint=model_endpoint) + # Wait for the model endpoint UID to be set: + progress_step.update("Waiting for model endpoint") + uid_exist_flag = False + while not uid_exist_flag: + model_endpoint = project.list_model_endpoints(names=[model_endpoint_name]) + model_endpoint = model_endpoint.endpoints[0] + if model_endpoint.metadata.uid: + uid_exist_flag = True + + # Set parameters defaults: + v3io_stream_path = v3io_stream_path or f"{project.name}/model-endpoints/stream-v1" + # TODO: Support Kafka monitoring parameters defaults when Kafka monitoring is supported. + + if mlrun.mlconf.is_ce_mode(): + client_env_vars = { + "MLRUN_TRACER_CLIENT_KAFKA_...": ... + } + else: + client_env_vars = { + "MLRUN_TRACER_CLIENT_V3IO_STREAM_PATH": v3io_stream_path, + "MLRUN_TRACER_CLIENT_V3IO_CONTAINER": v3io_container, + } + + # Prepare the environment variables: + env_vars = { + "MLRUN_MONITORING_ENABLED": "1", + "MLRUN_TRACER_CLIENT_PROJECT": project.name, + "MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_NAME": model_endpoint.metadata.name, + "MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_UID": model_endpoint.metadata.uid, + "MLRUN_TRACER_CLIENT_SERVING_FUNCTION": function_name, + **client_env_vars + } + print("\n✨ Done! LangChain monitoring model endpoint created successfully.") + print("You can now set the following environment variables to enable MLRun tracing in your LangChain code:\n") + print(json.dumps(env_vars, indent=4)) + print( + "\nTo customize the monitoring behavior, you can also set additional environment variables prefixed with " + "'MLRUN_TRACER_MONITOR_'. Refer to the MLRun tracer documentation for more details.\n" + ) + + return env_vars + + +class LangChainMonitoringApp(ModelMonitoringApplicationBase): + """ + A base monitoring application for LangChain that calculates common metrics on LangChain runs traced with the MLRun + tracer. + + The class is inheritable and can be extended to add custom metrics or override existing ones. It provides methods to + extract structured runs from the monitoring context and calculate metrics such as average latency, success rate, + token usage, and run name counts. + + If inheriting, the main method to override is `do_tracking`, which performs the tracking on the monitoring context. + """ + + def do_tracking(self, monitoring_context: MonitoringApplicationContext) -> ( + ModelMonitoringApplicationResult | + list[ModelMonitoringApplicationResult | ModelMonitoringApplicationMetric] | + dict[str, Any] + ): + """ + The main function that performs tracking on the monitoring context. The LangChain monitoring app by default + will calculate all the provided metrics on the structured runs extracted from the monitoring context sample + dataframe. + + :param monitoring_context: The monitoring context containing the sample dataframe. + + :returns: The monitoring artifacts, metrics and results. + """ + # Get the structured runs from the monitoring context: + structured_runs, _ = self.get_structured_runs(monitoring_context=monitoring_context) + + # Calculate the metrics: + average_latency = self.calculate_average_latency(structured_runs=structured_runs) + success_rate = self.calculate_success_rate(structured_runs=structured_runs) + token_usage = self.count_token_usage(structured_runs=structured_runs) + run_name_counts = self.count_run_names(structured_runs=structured_runs) + + return [ + ModelMonitoringApplicationMetric( + name="average_latency", + value=average_latency, + ), + ModelMonitoringApplicationMetric( + name="success_rate", + value=success_rate, + ), + ModelMonitoringApplicationMetric( + name="total_input_tokens", + value=token_usage["total_input_tokens"], + ), + ModelMonitoringApplicationMetric( + name="total_output_tokens", + value=token_usage["total_output_tokens"], + ), + ModelMonitoringApplicationMetric( + name="combined_total_tokens", + value=token_usage["combined_total"], + ), + *[ModelMonitoringApplicationMetric( + name=f"run_name_counts_{run_name}", + value=count, + ) for run_name, count in run_name_counts.items()], + ] + + @staticmethod + def get_structured_runs( + monitoring_context: MonitoringApplicationContext, + labels_filter: list[str] = None, + tags_filter: list[str] = None, + run_name_filter: list[str] = None, + run_type_filter: list[str] = None, + flatten_child_runs: bool = False, + ignore_child_runs: bool = False, + ignore_errored_runs: bool = False, + ) -> tuple[list[dict], list[dict]]: + """ + Get the structured runs from the monitoring context sample dataframe. The sample dataframe contains the raw + input and output data as JSON strings - the way the MLRun tracer sends them as events to MLRun monitoring. This + function parses the JSON strings into structured dictionaries that can be used for further metrics calculations + and analysis. + + :param monitoring_context: The monitoring context containing the sample dataframe. + :param labels_filter: List of labels to filter the runs. Only runs with a label appearing in this list will + remain. If None, no filtering is applied. + :param tags_filter: List of tags to filter the runs. Only runs containing at least one tag from this list will + remain. If None, no filtering is applied. + :param run_name_filter: List of run names to filter the runs. Only runs with a name appearing in this list will + remain. If None, no filtering is applied. + :param run_type_filter: List of run types to filter the runs. Only runs with a type appearing in this list will + remain. If None, no filtering is applied. + :param flatten_child_runs: Whether to flatten child runs into the main runs list. If True, all child runs will + be extracted and added to the main runs list. If False, child runs will be kept nested within their parent + runs. + :param ignore_child_runs: Whether to ignore child runs completely. If True, child runs will be removed from the + output. If False, child runs will be processed according to the other parameters. + :param ignore_errored_runs: Whether to ignore runs that resulted in errors. If True, runs with errors will be + excluded from the output. If False, errored runs will be included. + + :returns: A list of structured run dictionaries that passed the filters and a list of samples that could not be + parsed due to errors. + """ + # Retrieve the input and output samples from the monitoring context: + samples = monitoring_context.sample_df[['input', 'output']].to_dict('records') + + # Prepare to collect structured samples: + structured_samples = [] + errored_samples = [] + + # Go over all samples: + for sample in samples: + try: + # Parse the input data into structured format: + parsed_input = orjson.loads(sample['input']) + label = parsed_input['label'] + parsed_input = parsed_input["input"]["input_data"] + # Parse the output data into structured format: + parsed_output = orjson.loads(sample['output'])["output_data"] + structured_samples.extend( + LangChainMonitoringApp._collect_run( + structured_input=parsed_input, + structured_output=parsed_output, + label=label, + labels_filter=labels_filter, + tags_filter=tags_filter, + run_name_filter=run_name_filter, + run_type_filter=run_type_filter, + flatten_child_runs=flatten_child_runs, + ignore_child_runs=ignore_child_runs, + ignore_errored_runs=ignore_errored_runs, + ) + ) + except Exception: + errored_samples.append(sample) + + return structured_samples, errored_samples + + @staticmethod + def _collect_run( + structured_input: dict, + structured_output: dict, + label: str, + child_level: int = 0, + labels_filter: list[str] = None, + tags_filter: list[str] = None, + run_name_filter: list[str] = None, + run_type_filter: list[str] = None, + flatten_child_runs: bool = False, + ignore_child_runs: bool = False, + ignore_errored_runs: bool = False, + ) -> list[dict]: + """ + Recursively collect runs from the structured input and output data, applying filters as specified. + + :param structured_input: The structured input data of the run. + :param structured_output: The structured output data of the run. + :param label: The label of the run. + :param child_level: The current child level of the run (0 for root runs). + :param labels_filter: Label filter as described in `get_structured_runs`. + :param tags_filter: Tag filter as described in `get_structured_runs`. + :param run_name_filter: Run name filter as described in `get_structured_runs`. + :param run_type_filter: Run type filter as described in `get_structured_runs`. + :param flatten_child_runs: Flag to flatten child runs as described in `get_structured_runs`. + :param ignore_child_runs: Flag to ignore child runs as described in `get_structured_runs`. + :param ignore_errored_runs: Flag to ignore errored runs as described in `get_structured_runs`. + + :returns: A list of structured run dictionaries that passed the filters. + """ + # Prepare to collect runs: + runs = [] + + # Filter by label: + if labels_filter and label not in labels_filter: + return runs + + # Handle child runs: + if "child_runs" in structured_output: + # Check if we need to ignore or flatten child runs: + if ignore_child_runs: + structured_output.pop("child_runs") + elif flatten_child_runs: + # Recursively collect child runs: + child_runs = structured_output.pop("child_runs") + flattened_runs = [] + for child_run in child_runs: + flattened_runs.extend( + LangChainMonitoringApp._collect_run( + structured_input=child_run["input_data"], + structured_output=child_run["output_data"], + label=label, + child_level=child_level + 1, + tags_filter=tags_filter, + run_name_filter=run_name_filter, + run_type_filter=run_type_filter, + flatten_child_runs=flatten_child_runs, + ignore_child_runs=ignore_child_runs, + ignore_errored_runs=ignore_errored_runs, + ) + ) + runs.extend(flattened_runs) + + # Filter by tags, run name, run type, and errors: + if tags_filter and not set(structured_input["tags"]).isdisjoint(tags_filter): + return runs + if run_name_filter and structured_input["run_name"] not in run_name_filter: + return runs + if run_type_filter and structured_input["run_type"] not in run_type_filter: + return runs + if ignore_errored_runs and structured_output.get("error", None): + return runs + + # Collect the current run: + runs.append({"label": label, "input_data": structured_input, "output_data": structured_output, + "child_level": child_level}) + return runs + + @staticmethod + def iterate_structured_runs(structured_runs: list[dict]) -> Generator[dict, None, None]: + """ + Iterates over all runs in the structured samples, including child runs. + + :param structured_runs: List of structured run samples. + + :returns: A generator yielding each run structure. + """ + # TODO: Add an option to stop at a certain child level. + for structured_run in structured_runs: + if "child_runs" in structured_run['output_data']: + for child_run in structured_run['output_data']['child_runs']: + yield from LangChainMonitoringApp.iterate_structured_runs([{ + "label": structured_run['label'], + "input_data": child_run['input_data'], + "output_data": child_run['output_data'], + "child_level": structured_run['child_level'] + 1 + }]) + yield structured_run + + @staticmethod + def count_run_names(structured_runs: list[dict]) -> dict[str, int]: + """ + Counts occurrences of each run name in the structured samples. + + :param structured_runs: List of structured run samples. + + :returns: A dictionary with run names as keys and their counts as values. + """ + # TODO: Add a nice plot artifact that will draw the bar chart for what is being used the most. + # Prepare to count run names: + run_name_counts = {} + + # Go over all the runs: + for structured_run in LangChainMonitoringApp.iterate_structured_runs(structured_runs): + run_name = structured_run['input_data']['run_name'] + if run_name in run_name_counts: + run_name_counts[run_name] += 1 + else: + run_name_counts[run_name] = 1 + + return run_name_counts + + @staticmethod + def count_token_usage(structured_runs: list[dict]) -> dict: + """ + Calculates total tokens by only counting unique 'llm' type runs. + + :param structured_runs: List of structured run samples. + + :returns: A dictionary with total input tokens, total output tokens, and combined total tokens. + """ + # TODO: Add a token count per model breakdown (a dictionary of : to token counts) + # including an artifact that will plot it nicely. Pay attention that different providers use different + # keys in the response metadata. We should implement a mapping for that so each provider will have its own + # handler that will know how to extract the relevant info out of a run. + # Prepare to count tokens: + total_input_tokens = 0 + total_output_tokens = 0 + + # Go over all the LLM typed runs: + for structured_run in LangChainMonitoringApp.iterate_structured_runs(structured_runs): + # Count only LLM type runs as chain runs may include duplicative information as they accumulate the tokens + # from the child runs: + if structured_run['input_data']['run_type'] != 'llm': + continue + # Look for the token count information: + outputs = structured_run['output_data']["outputs"] + # Newer implementations should have the metadata in the `AIMessage` kwargs under generations: + if "generations" in outputs: + for generation in outputs["generations"]: # Iterate over generations. + for sample in generation: # Iterate over the generation batch. + token_usage = sample.get("message", {}).get("kwargs", {}).get("usage_metadata", {}) + if token_usage: + total_input_tokens += ( + token_usage.get('input_tokens', 0) + or token_usage.get('prompt_tokens', 0) + ) + total_output_tokens += ( + token_usage.get('output_tokens', 0) or + token_usage.get('completion_tokens', 0) + ) + continue + # Older implementations may have the metadata under `llm_output`: + if "llm_output" in outputs: + token_usage = outputs["llm_output"].get("token_usage", {}) + if token_usage: + total_input_tokens += token_usage.get('input_tokens', 0) or token_usage.get('prompt_tokens', 0) + total_output_tokens += ( + token_usage.get('output_tokens', 0) or + token_usage.get('completion_tokens', 0) + ) + + return { + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "combined_total": total_input_tokens + total_output_tokens + } + + @staticmethod + def calculate_success_rate(structured_runs: list[dict]) -> float: + """ + Calculates the success rate across all runs. + + :param structured_runs: List of structured run samples. + + :returns: Success rate as a float percentage between 0 and 1. + """ + # TODO: Add an option to see errors breakdown by kind of error and maybe an option to show which run name yielded + # most of the errors with artifacts showcasing it. + successful_count = 0 + for structured_run in structured_runs: + if 'error' not in structured_run['output_data'] or structured_run['output_data']['error'] is None: + successful_count += 1 + return successful_count / len(structured_runs) if structured_runs else 0.0 + + @staticmethod + def calculate_average_latency(structured_runs: list[dict]) -> float: + """ + Calculates the average latency across all runs. + + :param structured_runs: List of structured run samples. + + :returns: Average latency in milliseconds. + """ + # TODO: Add an option to calculate latency per run name (to know which runs are slower/faster) and then return an + # artifact showcasing it. + # Prepare to calculate average latency: + total_latency = 0.0 + count = 0 + + # Go over all the root runs: + for structured_run in structured_runs: + # Skip child runs: + if structured_run["child_level"] > 0: + continue + # Check if latency is already provided: + if "latency" in structured_run['output_data']: + total_latency += structured_run['output_data']['latency'] + count += 1 + continue + # Calculate latency from timestamps: + start_time = datetime.datetime.fromisoformat(structured_run['input_data']['start_timestamp']) + end_time = datetime.datetime.fromisoformat(structured_run['output_data']['end_timestamp']) + total_latency += (end_time - start_time).total_seconds() * 1000 # Convert to milliseconds + count += 1 + + return total_latency / count if count > 0 else 0.0 diff --git a/modules/src/langchain_mlrun/notebook_images/mlrun_ui.png b/modules/src/langchain_mlrun/notebook_images/mlrun_ui.png new file mode 100644 index 00000000..9785eeae Binary files /dev/null and b/modules/src/langchain_mlrun/notebook_images/mlrun_ui.png differ diff --git a/modules/src/langchain_mlrun/requirements.txt b/modules/src/langchain_mlrun/requirements.txt new file mode 100644 index 00000000..fe350503 --- /dev/null +++ b/modules/src/langchain_mlrun/requirements.txt @@ -0,0 +1,3 @@ +pytest +langchain +pydantic-settings \ No newline at end of file diff --git a/modules/src/langchain_mlrun/test_langchain_mlrun.py b/modules/src/langchain_mlrun/test_langchain_mlrun.py new file mode 100644 index 00000000..bae27ce2 --- /dev/null +++ b/modules/src/langchain_mlrun/test_langchain_mlrun.py @@ -0,0 +1,1029 @@ +# Copyright 2026 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Literal, TypedDict, Annotated, Sequence, Any, Callable +from concurrent.futures import ThreadPoolExecutor +from operator import add + +import pytest +from langchain_core.language_models import LanguageModelInput +from langchain_core.runnables import Runnable, RunnableLambda +from pydantic import ValidationError + +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain_core.tracers import Run +from langchain_core.language_models.fake_chat_models import FakeListChatModel, GenericFakeChatModel +from langchain.agents import create_agent +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import tool, BaseTool + +from langgraph.graph import StateGraph, START, END +from langchain_core.messages import BaseMessage +from pydantic_settings import BaseSettings, SettingsConfigDict + +from langchain_mlrun import ( + mlrun_monitoring, + MLRunTracer, + MLRunTracerSettings, + MLRunTracerClientSettings, + MLRunTracerMonitorSettings, + mlrun_monitoring_env_var, + LangChainMonitoringApp, +) + + +def _check_openai_credentials() -> bool: + """ + Check if OpenAI API key is set in environment variables. + + :return: True if OPENAI_API_KEY is set, False otherwise. + """ + return "OPENAI_API_KEY" in os.environ + + +# Import ChatOpenAI only if OpenAI credentials are available (meaning `langchain-openai` must be installed). +if _check_openai_credentials(): + from langchain_openai import ChatOpenAI + + +class _ToolEnabledFakeModel(GenericFakeChatModel): + """ + A fake chat model that supports tool binding for running agent tracing tests. + """ + + def bind_tools( + self, + tools: Sequence[ + dict[str, Any] | type | Callable | BaseTool # noqa: UP006 + ], + *, + tool_choice: str | None = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, AIMessage]: + return self + + +#: Tag value for testing tag filtering. +_dummy_tag = "dummy_tag" + + +def _run_simple_chain() -> str: + """ + Run a simple LangChain chain that gets a fact about a topic. + """ + # Build a simple chain: prompt -> llm -> str output parser + llm = ChatOpenAI( + model="gpt-4o-mini", + tags=[_dummy_tag] + ) if _check_openai_credentials() else ( + FakeListChatModel( + responses=[ + "MLRun is an open-source orchestrator for machine learning pipelines." + ], + tags=[_dummy_tag] + ) + ) + prompt = ChatPromptTemplate.from_template("Tell me a short fact about {topic}") + chain = prompt | llm | StrOutputParser() + + # Run the chain: + response = chain.invoke({"topic": "MLRun"}) + return response + + +def _run_simple_agent(): + """ + Run a simple LangChain agent that uses two tools to get weather and stock price. + """ + # Define the tools: + @tool + def get_weather(city: str) -> str: + """Get the current weather for a specific city.""" + return f"The weather in {city} is 22°C and sunny." + + @tool + def get_stock_price(symbol: str) -> str: + """Get the current stock price for a symbol.""" + return f"The stock price for {symbol} is $150.25." + + # Define the model: + model = ChatOpenAI( + model="gpt-4o-mini", + tags=[_dummy_tag] + ) if _check_openai_credentials() else ( + _ToolEnabledFakeModel( + messages=iter( + [ + AIMessage( + content="", + tool_calls=[ + {"name": "get_weather", "args": {"city": "London"}, "id": "call_abc123"}, + {"name": "get_stock_price", "args": {"symbol": "AAPL"}, "id": "call_def456"} + ] + ), + AIMessage(content="The weather in London is 22°C and AAPL is trading at $150.25.") + ] + ), + tags=[_dummy_tag] + ) + ) + + # Create the agent: + agent = create_agent( + model=model, + tools=[get_weather, get_stock_price], + system_prompt="You are a helpful assistant with access to tools." + ) + + # Run the agent: + return agent.invoke({"messages": ["What is the weather in London and the stock price of AAPL?"]}) + + +def _run_langgraph_graph(): + """ + Run a LangGraph agent that uses reflection to correct its answer. + """ + + # Define the graph state: + class AgentState(TypedDict): + messages: Annotated[list[BaseMessage], add] + attempts: int + + # Define the model: + model = ChatOpenAI(model="gpt-4o-mini") if _check_openai_credentials() else ( + _ToolEnabledFakeModel( + messages=iter( + [ + AIMessage(content="There are 2 'r's in Strawberry."), # Mocking the failure + AIMessage(content="I stand corrected. S-t-r-a-w-b-e-r-r-y. There are 3 'r's."), # Mocking the fix + ] + ) + ) + ) + + # Define the graph nodes and router: + def call_model(state: AgentState): + response = model.invoke(state["messages"]) + return {"messages": [response], "attempts": state["attempts"] + 1} + + def reflect_node(state: AgentState): + prompt = "Wait, count the 'r's again slowly, letter by letter. Are you sure?" + return {"messages": [HumanMessage(content=prompt)]} + + def router(state: AgentState) -> Literal["reflect", END]: + # Make sure there are 2 attempts at least for an answer: + if state["attempts"] == 1: + return "reflect" + return END + + # Build the graph: + builder = StateGraph(AgentState) + builder.add_node("model", call_model) + tagged_reflect_node = RunnableLambda(reflect_node).with_config(tags=[_dummy_tag]) + builder.add_node("reflect", tagged_reflect_node) + builder.add_edge(START, "model") + builder.add_conditional_edges("model", router) + builder.add_edge("reflect", "model") + graph = builder.compile() + + # Run the graph: + return graph.invoke({"messages": [HumanMessage(content="How many 'r's in Strawberry?")], "attempts": 0}) + + +#: List of example functions to run in tests along the full (split-run enabled) expected monitor events. +_run_suites: list[tuple[Callable, int]] = [ + (_run_simple_chain, 4), + (_run_simple_agent, 9), + (_run_langgraph_graph, 9), +] + + +#: Dummy environment variables for testing. +_dummy_environment_variables = { + "MLRUN_TRACER_CLIENT_V3IO_STREAM_PATH": "dummy_stream_path", + "MLRUN_TRACER_CLIENT_V3IO_CONTAINER": "dummy_container", + "MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_NAME": "dummy_model_name", + "MLRUN_TRACER_CLIENT_MODEL_ENDPOINT_UID": "dummy_model_endpoint_uid", + "MLRUN_TRACER_CLIENT_SERVING_FUNCTION": "dummy_serving_function", + "MLRUN_TRACER_MONITOR_DEBUG": "true", + "MLRUN_TRACER_MONITOR_DEBUG_TARGET_LIST": "true", + "MLRUN_TRACER_MONITOR_SPLIT_RUNS": "true", +} + + +@pytest.fixture() +def auto_mode_settings(monkeypatch): + """ + Sets the environment variables to enable mlrun monitoring in 'auto' mode. + """ + # Set environment variables for the duration of the test: + monkeypatch.setenv(mlrun_monitoring_env_var, "1") + for key, value in _dummy_environment_variables.items(): + monkeypatch.setenv(key, value) + + # Reset the singleton tracer to ensure fresh initialization: + MLRunTracer._singleton_tracer = None + MLRunTracer._initialized = False + + yield + + # Reset the singleton tracer after the test: + MLRunTracer._singleton_tracer = None + MLRunTracer._initialized = False + + +@pytest.fixture +def manual_mode_settings(): + """ + Sets the mandatory client settings and debug flag for the tests. + """ + settings = MLRunTracerSettings( + client=MLRunTracerClientSettings( + v3io_stream_path="dummy_stream_path", + v3io_container="dummy_container", + model_endpoint_name="dummy_model_name", + model_endpoint_uid="dummy_model_endpoint_uid", + serving_function="dummy_serving_function", + ), + monitor=MLRunTracerMonitorSettings( + debug=True, + debug_target_list=[], + split_runs=True, # Easier to test with split runs (filters can filter per run instead of inner events) + ), + ) + + yield settings + + +def test_settings_init_via_env_vars(monkeypatch): + """ + Test that settings are correctly initialized from environment variables. + """ + #: First, ensure that without env vars, validation fails due to missing required fields: + with pytest.raises(ValidationError): + MLRunTracerSettings() + + # Now, set the environment variables for the client settings and debug flag: + for key, value in _dummy_environment_variables.items(): + monkeypatch.setenv(key, value) + + # Ensure that settings are now correctly initialized from env vars: + settings = MLRunTracerSettings() + assert settings.client.v3io_stream_path == "dummy_stream_path" + assert settings.client.v3io_container == "dummy_container" + assert settings.client.model_endpoint_name == "dummy_model_name" + assert settings.client.model_endpoint_uid == "dummy_model_endpoint_uid" + assert settings.client.serving_function == "dummy_serving_function" + assert settings.monitor.debug is True + + +@pytest.mark.parametrize( + "test_suite", [ + # Valid case: only v3io settings provided + ( + { + "v3io_stream_path": "dummy_stream_path", + "v3io_container": "dummy_container", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + True, + ), + # Invalid case: partial v3io settings provided + ( + { + "v3io_stream_path": "dummy_stream_path", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + False, + ), + # Valid case: only kafka settings provided + ( + { + "kafka_broker": "dummy_bootstrap_servers", + "kafka_topic": "dummy_topic", + # TODO: Add more mandatory kafka settings + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + True, + ), + # Invalid case: partial kafka settings provided + ( + { + "kafka_broker": "dummy_bootstrap_servers", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + False, + ), + # Invalid case: both v3io and kafka settings provided + ( + { + "v3io_stream_path": "dummy_stream_path", + "v3io_container": "dummy_container", + "kafka_broker": "dummy_bootstrap_servers", + "kafka_topic": "dummy_topic", + # TODO: Add more mandatory kafka settings + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + False, + ), + # Invalid case: both v3io and kafka settings provided (partial) + ( + { + "v3io_container": "dummy_container", + "kafka_broker": "dummy_bootstrap_servers", + "model_endpoint_name": "dummy_model_name", + "model_endpoint_uid": "dummy_model_endpoint_uid", + "serving_function": "dummy_serving_function", + }, + False, + ), + ] +) +def test_settings_v3io_kafka_combination(test_suite: tuple[dict[str, str], bool]): + """ + Test that settings validation enforces mutual exclusivity between v3io and kafka configurations. + + :param test_suite: A tuple containing environment variable overrides and a flag indicating + whether validation should pass. + """ + settings, should_pass = test_suite + + if should_pass: + MLRunTracerClientSettings(**settings) + else: + with pytest.raises(ValidationError): + MLRunTracerClientSettings(**settings) + + +def test_auto_mode_singleton_thread_safety(auto_mode_settings): + """ + Test that MLRunTracer singleton initialization is thread-safe in 'auto' mode. + + :param auto_mode_settings: Fixture to set up 'auto' mode environment and settings. + """ + # Initialize a list to hold tracer instances created in different threads: + tracer_instances = [] + + # Function to initialize the tracer in a thread: + def _init_tracer(): + tracer = MLRunTracer() + return tracer + + # Use ThreadPoolExecutor to simulate concurrent tracer initialization: + num_threads = 50 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(_init_tracer) for _ in range(num_threads)] + tracer_instances = [f.result() for f in futures] + + # Check if every single reference in the list is the exact same object: + unique_instances = set(tracer._uid for tracer in tracer_instances) + + assert len(tracer_instances) == num_threads, "Not all threads returned a tracer instance. Test cannot proceed." + assert len(unique_instances) == 1, ( + f"Thread-safety failure! {len(unique_instances)} different instances were created under high concurrency." + ) + assert tracer_instances[0] is MLRunTracer(), "The global access point should return the same singleton." + + +def test_manual_mode_multi_instances(manual_mode_settings: MLRunTracerSettings): + """ + Test that MLRunTracer allows multiple instances in 'manual' mode. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + # Initialize a list to hold tracer instances created in different iterations: + tracer_instances = [] + + # Create multiple tracer instances: + num_instances = 50 + for _ in range(num_instances): + tracer = MLRunTracer(settings=manual_mode_settings) + tracer_instances.append(tracer) + + # Check if every single reference in the list is a different object: + unique_instances = set(tracer._uid for tracer in tracer_instances) + + assert len(tracer_instances) == num_instances, "Not all instances were created. Test cannot proceed." + assert len(unique_instances) == num_instances, ( + f"Manual mode failure! {len(unique_instances)} unique instances were created instead of {num_instances}." + ) + + +@pytest.mark.parametrize("run_suites", _run_suites) +def test_auto_mode(auto_mode_settings, run_suites: tuple[Callable, int]): + """ + Test that MLRunTracer in 'auto' mode captures debug target list after running a LangChain / LangGraph example code. + + :param auto_mode_settings: Fixture to set up 'auto' mode environment and settings. + + :param run_suites: The function to run with the expected monitored events. + """ + run_func, expected_events = run_suites + + tracer = MLRunTracer() + assert len(tracer.settings.monitor.debug_target_list) == 0 + + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == expected_events + + +@pytest.mark.parametrize("run_suites", _run_suites) +def test_manual_mode(manual_mode_settings: MLRunTracerSettings, run_suites: tuple[Callable, int]): + """ + Test that MLRunTracer in 'auto' mode captures debug target list after running a LangChain / LangGraph example code. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events. + """ + run_func, expected_events = run_suites + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == expected_events + + +def test_labeling(manual_mode_settings: MLRunTracerSettings): + """ + Test that MLRunTracer in 'auto' mode captures debug target list after running a LangChain / LangGraph example code. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + for i, (run_func, expected_events) in enumerate(_run_suites): + label = f"label_{i}" + manual_mode_settings.monitor.label = label + manual_mode_settings.monitor.debug_target_list.clear() + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == expected_events + for event in tracer.settings.monitor.debug_target_list: + assert event["label"] == label + + +@pytest.mark.parametrize( + "run_suites", [ + run_suite + (filtered_events,) + for run_suite, filtered_events in zip(_run_suites, [1, 2, 1]) + ] +) +def test_monitor_settings_tags_filter( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int, int], +): + """ + Test the `tags_filter` setting of MLRunTracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events and filtered events. + """ + run_func, expected_events, filtered_events = run_suites + + manual_mode_settings.monitor.tags_filter = [_dummy_tag] + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == filtered_events + for event in tracer.settings.monitor.debug_target_list: + assert not set(manual_mode_settings.monitor.tags_filter).isdisjoint(event["input_data"]["input_data"]["tags"]) + + +@pytest.mark.parametrize( + "run_suites", [ + run_suite + (filtered_events,) + for run_suite, filtered_events in zip(_run_suites, [1, 3, 4]) + ] +) +def test_monitor_settings_name_filter( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int, int], +): + """ + Test the `names_filter` setting of MLRunTracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events and filtered events. + """ + run_func, expected_events, filtered_events = run_suites + + manual_mode_settings.monitor.names_filter = ["StrOutputParser", "get_weather", "model", "router"] + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == filtered_events + for event in tracer.settings.monitor.debug_target_list: + assert event["input_data"]["input_data"]["run_name"] in manual_mode_settings.monitor.names_filter + + +@pytest.mark.parametrize( + "run_suites", [ + run_suite + (filtered_events,) + for run_suite, filtered_events in zip(_run_suites, [2, 7, 9]) + ] +) +@pytest.mark.parametrize("split_runs", [True, False]) +def test_monitor_settings_run_type_filter( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int, int], + split_runs: bool +): + """ + Test the `run_types_filter` setting of MLRunTracer. Will also test with split runs enabled and disabled - meaning + that when disabled, if a parent run is filtered, all its child runs are also filtered by default. In the test we + made sure that the root run is always passing the filter (hence the equal one). + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events and filtered events. + :param split_runs: Whether to enable split runs in the monitor settings. + """ + run_func, expected_events, filtered_events = run_suites + filtered_events = filtered_events if split_runs else 1 + + manual_mode_settings.monitor.run_types_filter = ["llm", "chain"] + manual_mode_settings.monitor.split_runs = split_runs + + def recursive_check_run_types(run: dict): + assert run["input_data"]["run_type"] in manual_mode_settings.monitor.run_types_filter + if "child_runs" in run["output_data"]: + for child_run in run["output_data"]["child_runs"]: + recursive_check_run_types(child_run) + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == filtered_events + + for event in tracer.settings.monitor.debug_target_list: + event_run = { + "input_data": event["input_data"]["input_data"], + "output_data": event["output_data"]["output_data"], + } + recursive_check_run_types(run=event_run) + +@pytest.mark.parametrize("run_suites", _run_suites) +@pytest.mark.parametrize("split_runs", [True, False]) +def test_monitor_settings_full_filter( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int], + split_runs: bool +): + """ + Test that a complete filter (not allowing any events to pass) won't fail the tracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events. + :param split_runs: Whether to enable split runs in the monitor settings. + """ + run_func, _ = run_suites + + manual_mode_settings.monitor.run_types_filter = ["dummy_run_type"] + manual_mode_settings.monitor.split_runs = split_runs + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == 0 + + +@pytest.mark.parametrize("run_suites", _run_suites) +@pytest.mark.parametrize("split_runs", [True, False]) +@pytest.mark.parametrize("root_run_only", [True, False]) +def test_monitor_settings_split_runs_and_root_run_only( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int], + split_runs: bool, + root_run_only: bool, +): + """ + Test the `split_runs` setting of MLRunTracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events. + :param split_runs: Whether to enable split runs in the monitor settings. + :param root_run_only: Whether to enable `root_run_only` in the monitor settings. + """ + run_func, expected_events = run_suites + + manual_mode_settings.monitor.split_runs = split_runs + manual_mode_settings.monitor.root_run_only = root_run_only + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + for run_iteration in range(1, 3): + print(run_func()) + if root_run_only: + assert len(tracer.settings.monitor.debug_target_list) == 1 * run_iteration + assert "child_runs" not in tracer.settings.monitor.debug_target_list[-1]["output_data"]["output_data"] + elif split_runs: + assert len(tracer.settings.monitor.debug_target_list) == expected_events * run_iteration + assert "child_runs" not in tracer.settings.monitor.debug_target_list[-1]["output_data"]["output_data"] + else: # split_runs disabled + assert len(tracer.settings.monitor.debug_target_list) == 1 * run_iteration + assert len(tracer.settings.monitor.debug_target_list[-1]["output_data"]["output_data"]["child_runs"]) != 0 + + +class _CustomRunSummarizerSettings(BaseSettings): + """ + Settings for the custom summarizer function. + """ + dummy_value: int = 21 + + model_config = SettingsConfigDict(env_prefix="TEST_CUSTOM_SUMMARIZER_SETTINGS_") + + +def _custom_run_summarizer(run: Run, settings: _CustomRunSummarizerSettings = None): + """ + A custom summarizer function for testing. + + :param run: The LangChain / LangGraph run to summarize. + :param settings: Optional settings for the summarizer. + """ + inputs = { + "run_id": run.id, + "input": run.inputs, + "from_settings": settings.dummy_value if settings else 0, + } + + def count_llm_calls(r: Run) -> int: + if not r.child_runs: + return 1 if r.run_type == "llm" else 0 + return sum(count_llm_calls(child) for child in r.child_runs) + + def count_tool_calls(r: Run) -> int: + if not r.child_runs: + return 1 if r.run_type == "tool" else 0 + return sum(count_tool_calls(child) for child in r.child_runs) + + outputs = { + "llm_calls": count_llm_calls(run), + "tool_calls": count_tool_calls(run), + "output": run.outputs + } + + yield inputs, outputs + + +@pytest.mark.parametrize("run_suites", _run_suites) +@pytest.mark.parametrize("run_summarizer_function", [ + _custom_run_summarizer, + "test_langchain_mlrun._custom_run_summarizer", +]) +@pytest.mark.parametrize("run_summarizer_settings", [ + _CustomRunSummarizerSettings(dummy_value=12), + "test_langchain_mlrun._CustomRunSummarizerSettings", + None, +]) +def test_monitor_settings_custom_run_summarizer( + manual_mode_settings: MLRunTracerSettings, + run_suites: tuple[Callable, int], + run_summarizer_function: Callable | str, + run_summarizer_settings: BaseSettings | str | None, +): + """ + Test the custom run summarizer that can be passed to MLRunTracer. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param run_suites: The function to run with the expected monitored events. + :param run_summarizer_function: The custom summarizer function or its import path. + :param run_summarizer_settings: The settings for the custom summarizer or its import path. + """ + run_func, _ = run_suites + manual_mode_settings.monitor.run_summarizer_function = run_summarizer_function + manual_mode_settings.monitor.run_summarizer_settings = run_summarizer_settings + dummy_value_for_settings_from_env = 26 + os.environ["TEST_CUSTOM_SUMMARIZER_SETTINGS_DUMMY_VALUE"] = str(dummy_value_for_settings_from_env) + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + print(run_func()) + assert len(tracer.settings.monitor.debug_target_list) == 1 + + event = tracer.settings.monitor.debug_target_list[0] + if run_summarizer_settings: + if isinstance(run_summarizer_settings, str): + assert event["input_data"]["input_data"]["from_settings"] == dummy_value_for_settings_from_env + else: + assert event["input_data"]["input_data"]["from_settings"] == run_summarizer_settings.dummy_value + else: + assert event["input_data"]["input_data"]["from_settings"] == 0 + + +def test_monitor_settings_include_errors_field_presence(manual_mode_settings: MLRunTracerSettings): + """ + Test that when `include_errors` is True, the error field is present in outputs. + When `include_errors` is False, the error field is not added to outputs. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + # Run with include_errors=True (default) and verify error field is present: + manual_mode_settings.monitor.include_errors = True + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + assert len(tracer.settings.monitor.debug_target_list) > 0 + + for event in tracer.settings.monitor.debug_target_list: + output_data = event["output_data"]["output_data"] + assert "error" in output_data, "error field should be present when include_errors is True" + + # Now run with include_errors=False and verify error field is excluded: + manual_mode_settings.monitor.include_errors = False + manual_mode_settings.monitor.debug_target_list.clear() + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + assert len(tracer.settings.monitor.debug_target_list) > 0 + + for event in tracer.settings.monitor.debug_target_list: + output_data = event["output_data"]["output_data"] + assert "error" not in output_data, "error field should be excluded when include_errors is False" + + +def test_monitor_settings_include_full_run(manual_mode_settings: MLRunTracerSettings): + """ + Test that when `include_full_run` is True, the complete serialized run is included in outputs. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + manual_mode_settings.monitor.include_full_run = True + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + + assert len(tracer.settings.monitor.debug_target_list) > 0 + + for event in tracer.settings.monitor.debug_target_list: + output_data = event["output_data"]["output_data"] + assert "full_run" in output_data, "full_run should be included in outputs when include_full_run is True" + # Verify the full_run contains expected run structure: + assert "inputs" in output_data["full_run"] + assert "outputs" in output_data["full_run"] + + +def test_monitor_settings_include_metadata(manual_mode_settings: MLRunTracerSettings): + """ + Test that when `include_metadata` is False, metadata is excluded from inputs. + + Note: The fake models used in tests don't produce runs with metadata, so we can only + verify the "exclude" behavior. The code only adds metadata if the run actually contains it. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + # Run with include_metadata=False and verify metadata is excluded: + manual_mode_settings.monitor.include_metadata = False + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + assert len(tracer.settings.monitor.debug_target_list) > 0 + + # Check that metadata is not present in inputs: + for event in tracer.settings.monitor.debug_target_list: + input_data = event["input_data"]["input_data"] + assert "metadata" not in input_data, "metadata should be excluded when include_metadata is False" + + +def test_monitor_settings_include_latency(manual_mode_settings: MLRunTracerSettings): + """ + Test that when `include_latency` is False, latency is excluded from outputs. + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + """ + manual_mode_settings.monitor.include_latency = False + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + assert len(tracer.settings.monitor.debug_target_list) > 0 + + for event in tracer.settings.monitor.debug_target_list: + assert "latency" not in event["output_data"]["output_data"], \ + "latency should be excluded when include_latency is False" + + +def test_import_from_module_path_errors(): + """ + Test that `_import_from_module_path` raises appropriate errors for invalid paths. + """ + # Test ValueError for path without a dot: + with pytest.raises(ValueError) as exc_info: + MLRunTracer._import_from_module_path("no_dot_path") + assert "must have at least one '.'" in str(exc_info.value) + + # Test ImportError for non-existent module: + with pytest.raises(ImportError) as exc_info: + MLRunTracer._import_from_module_path("nonexistent_module_xyz.SomeClass") + assert "Could not import" in str(exc_info.value) + + # Test AttributeError for non-existent attribute in existing module: + with pytest.raises(AttributeError) as exc_info: + MLRunTracer._import_from_module_path("os.nonexistent_attribute_xyz") + assert "Could not import" in str(exc_info.value) + + +#: Sample structured runs for testing LangChainMonitoringApp methods. +_sample_structured_runs = [ + { + "label": "test_label", + "child_level": 0, + "input_data": { + "run_name": "RunnableSequence", + "run_type": "chain", + "tags": ["tag1"], + "inputs": {"topic": "MLRun"}, + "start_timestamp": "2024-01-01T10:00:00+00:00", + }, + "output_data": { + "outputs": {"result": "test output"}, + "end_timestamp": "2024-01-01T10:00:01+00:00", + "error": None, + "child_runs": [ + { + "input_data": { + "run_name": "FakeListChatModel", + "run_type": "llm", + "tags": ["tag2"], + "inputs": {"prompt": "test"}, + "start_timestamp": "2024-01-01T10:00:00.100+00:00", + }, + "output_data": { + "outputs": { + "generations": [[{ + "message": { + "kwargs": { + "usage_metadata": { + "input_tokens": 10, + "output_tokens": 20, + } + } + } + }]] + }, + "end_timestamp": "2024-01-01T10:00:00.500+00:00", + "error": None, + }, + }, + ], + }, + }, + { + "label": "test_label", + "child_level": 0, + "input_data": { + "run_name": "SimpleAgent", + "run_type": "chain", + "tags": ["tag1"], + "inputs": {"query": "test query"}, + "start_timestamp": "2024-01-01T10:00:02+00:00", + }, + "output_data": { + "outputs": {"result": "agent output"}, + "end_timestamp": "2024-01-01T10:00:04+00:00", + "error": "SomeError: something went wrong", + }, + }, +] + + +def test_langchain_monitoring_app_iterate_structured_runs(): + """ + Test that `iterate_structured_runs` yields all runs including nested child runs. + """ + # Iterate over all runs: + all_runs = list(LangChainMonitoringApp.iterate_structured_runs(_sample_structured_runs)) + + # Should yield parent runs and child runs: + # - First sample: 1 parent + 1 child = 2 runs + # - Second sample: 1 parent = 1 run + # Total: 3 runs + assert len(all_runs) == 3 + + # Verify run names are as expected: + run_names = [r["input_data"]["run_name"] for r in all_runs] + assert "RunnableSequence" in run_names + assert "FakeListChatModel" in run_names + assert "SimpleAgent" in run_names + + +def test_langchain_monitoring_app_count_run_names(): + """ + Test that `count_run_names` correctly counts occurrences of each run name. + """ + counts = LangChainMonitoringApp.count_run_names(_sample_structured_runs) + + assert counts["RunnableSequence"] == 1 + assert counts["FakeListChatModel"] == 1 + assert counts["SimpleAgent"] == 1 + + +def test_langchain_monitoring_app_count_token_usage(): + """ + Test that `count_token_usage` correctly calculates total tokens from LLM runs. + """ + token_usage = LangChainMonitoringApp.count_token_usage(_sample_structured_runs) + + assert token_usage["total_input_tokens"] == 10 + assert token_usage["total_output_tokens"] == 20 + assert token_usage["combined_total"] == 30 + + +def test_langchain_monitoring_app_calculate_success_rate(): + """ + Test that `calculate_success_rate` returns the correct percentage of successful runs. + """ + success_rate = LangChainMonitoringApp.calculate_success_rate(_sample_structured_runs) + + # First run has no error, second run has error: + # Success rate should be 1/2 = 0.5 + assert success_rate == 0.5 + + # Test with empty list: + empty_rate = LangChainMonitoringApp.calculate_success_rate([]) + assert empty_rate == 0.0 + + # Test with all successful runs: + successful_runs = [_sample_structured_runs[0]] # Only the first run which has no error + all_success_rate = LangChainMonitoringApp.calculate_success_rate(successful_runs) + assert all_success_rate == 1.0 + + +def test_langchain_monitoring_app_calculate_average_latency(): + """ + Test that `calculate_average_latency` returns the correct average latency across root runs. + """ + # Calculate average latency: + avg_latency = LangChainMonitoringApp.calculate_average_latency(_sample_structured_runs) + + # First run: 10:00:00 to 10:00:01 = 1000ms + # Second run: 10:00:02 to 10:00:04 = 2000ms + # Average: (1000 + 2000) / 2 = 1500ms + assert avg_latency == 1500.0 + + # Test with empty list: + empty_latency = LangChainMonitoringApp.calculate_average_latency([]) + assert empty_latency == 0.0 + + +def test_langchain_monitoring_app_calculate_average_latency_skips_child_runs(): + """ + Test that `calculate_average_latency` skips child runs (only calculates for root runs). + """ + # Create a sample with a child run that has child_level > 0: + runs_with_child = [ + { + "label": "test", + "child_level": 0, + "input_data": {"start_timestamp": "2024-01-01T10:00:00+00:00"}, + "output_data": {"end_timestamp": "2024-01-01T10:00:01+00:00"}, + }, + { + "label": "test", + "child_level": 1, # This is a child run, should be skipped + "input_data": {"start_timestamp": "2024-01-01T10:00:00+00:00"}, + "output_data": {"end_timestamp": "2024-01-01T10:00:10+00:00"}, # 10 seconds - would skew average + }, + ] + + # Calculate average latency: + avg_latency = LangChainMonitoringApp.calculate_average_latency(runs_with_child) + + # Should only consider the root run (1000ms), not the child run: + assert avg_latency == 1000.0 + + +def test_debug_mode_stdout(manual_mode_settings: MLRunTracerSettings, capsys): + """ + Test that debug mode prints to stdout when `debug_target_list` is not set (is False). + + :param manual_mode_settings: Fixture to set up 'manual' mode environment and settings. + :param capsys: Pytest fixture to capture stdout/stderr. + """ + # Set debug mode with debug_target_list=False (should print to stdout): + manual_mode_settings.monitor.debug = True + manual_mode_settings.monitor.debug_target_list = False + + with mlrun_monitoring(settings=manual_mode_settings) as tracer: + _run_simple_chain() + + # Capture stdout: + captured = capsys.readouterr() + + # Verify that JSON output was printed to stdout: + assert "event_id" in captured.out, "Event should be printed to stdout when debug_target_list is False" + assert "input_data" in captured.out + assert "output_data" in captured.out