diff --git a/docker-compose.yml b/docker-compose.yml index 11870127c..4c56b62d4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,7 +29,8 @@ services: - "9200:9200" - "9600:9600" volumes: - - ${OPENSEARCH_DATA_PATH:-./opensearch-data}:/usr/share/opensearch/data:U,z + # If OPENSEARCH_DATA_PATH is set, use host path; otherwise use named volume + - ${OPENSEARCH_DATA_PATH:-opensearch-data}:/usr/share/opensearch/data dashboards: image: opensearchproject/opensearch-dashboards:3.0.0 @@ -68,6 +69,7 @@ services: - OPENSEARCH_USERNAME=admin - OPENSEARCH_PASSWORD=${OPENSEARCH_PASSWORD} - OPENAI_API_KEY=${OPENAI_API_KEY} + - OPENAI_API_BASE=${OPENAI_API_BASE:-None} - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} - WATSONX_API_KEY=${WATSONX_API_KEY} - WATSONX_ENDPOINT=${WATSONX_ENDPOINT} @@ -81,12 +83,15 @@ services: - AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID} - AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY} - OPENSEARCH_INDEX_NAME=${OPENSEARCH_INDEX_NAME:-documents} + - LOG_LEVEL=${LOG_LEVEL} volumes: - ${OPENRAG_DOCUMENTS_PATH:-./openrag-documents}:/app/openrag-documents:Z - ${OPENRAG_KEYS_PATH:-./keys}:/app/keys:U,z - ${OPENRAG_FLOWS_PATH:-./flows}:/app/flows:U,z - ${OPENRAG_CONFIG_PATH:-./config}:/app/config:Z - ${OPENRAG_DATA_PATH:-./data}:/app/data:Z + ports: + - "8000:8000" openrag-frontend: image: langflowai/openrag-frontend:${OPENRAG_VERSION:-latest} @@ -117,6 +122,7 @@ services: - LANGFUSE_HOST=${LANGFUSE_HOST:-} - LANGFLOW_DEACTIVATE_TRACING - OPENAI_API_KEY=${OPENAI_API_KEY:-None} + - OPENAI_API_BASE=${OPENAI_API_BASE:-None} - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY:-None} - WATSONX_API_KEY=${WATSONX_API_KEY:-None} - WATSONX_ENDPOINT=${WATSONX_ENDPOINT:-None} @@ -145,8 +151,8 @@ services: - MIMETYPE=None - FILESIZE=0 - SELECTED_EMBEDDING_MODEL=${SELECTED_EMBEDDING_MODEL:-} - - LANGFLOW_VARIABLES_TO_GET_FROM_ENVIRONMENT=JWT,OPENRAG-QUERY-FILTER,OPENSEARCH_PASSWORD,OPENSEARCH_URL,DOCLING_SERVE_URL,OWNER,OWNER_NAME,OWNER_EMAIL,CONNECTOR_TYPE,DOCUMENT_ID,SOURCE_URL,ALLOWED_USERS,ALLOWED_GROUPS,FILENAME,MIMETYPE,FILESIZE,SELECTED_EMBEDDING_MODEL,OPENAI_API_KEY,ANTHROPIC_API_KEY,WATSONX_API_KEY,WATSONX_ENDPOINT,WATSONX_PROJECT_ID,OLLAMA_BASE_URL,OPENSEARCH_INDEX_NAME - - LANGFLOW_LOG_LEVEL=DEBUG + - LANGFLOW_VARIABLES_TO_GET_FROM_ENVIRONMENT=JWT,OPENRAG-QUERY-FILTER,OPENSEARCH_PASSWORD,OPENSEARCH_URL,DOCLING_SERVE_URL,OWNER,OWNER_NAME,OWNER_EMAIL,CONNECTOR_TYPE,DOCUMENT_ID,SOURCE_URL,ALLOWED_USERS,ALLOWED_GROUPS,FILENAME,MIMETYPE,FILESIZE,SELECTED_EMBEDDING_MODEL,OPENAI_API_KEY,OPENAI_API_BASE,ANTHROPIC_API_KEY,WATSONX_API_KEY,WATSONX_ENDPOINT,WATSONX_PROJECT_ID,OLLAMA_BASE_URL,OPENSEARCH_INDEX_NAME + - LANGFLOW_LOG_LEVEL=${LOG_LEVEL} - LANGFLOW_WORKERS=${LANGFLOW_WORKERS:-1} - LANGFLOW_AUTO_LOGIN=${LANGFLOW_AUTO_LOGIN} - LANGFLOW_SUPERUSER=${LANGFLOW_SUPERUSER} @@ -155,3 +161,6 @@ services: - LANGFLOW_ENABLE_SUPERUSER_CLI=${LANGFLOW_ENABLE_SUPERUSER_CLI} # - DEFAULT_FOLDER_NAME=OpenRAG - HIDE_GETTING_STARTED_PROGRESS=true + +volumes: + opensearch-data: \ No newline at end of file diff --git a/flows/components/split_text.py b/flows/components/split_text.py new file mode 100644 index 000000000..ef1100885 --- /dev/null +++ b/flows/components/split_text.py @@ -0,0 +1,515 @@ +import copy +import re +from typing import Iterable + +from langchain_text_splitters import CharacterTextSplitter + +from lfx.custom.custom_component.component import Component +from lfx.io import DropdownInput, HandleInput, IntInput, MessageTextInput, Output +from lfx.schema.data import Data +from lfx.schema.dataframe import DataFrame +from lfx.schema.message import Message +from lfx.utils.util import unescape_string +from lfx.log import logger + +from langchain_core.documents import Document + + +class SplitTextComponent(Component): + display_name: str = "Split Text" + description: str = "Split text into chunks based on specified criteria." + documentation: str = "https://docs.langflow.org/components-processing#split-text" + icon = "scissors-line-dashed" + name = "SplitText" + + inputs = [ + HandleInput( + name="data_inputs", + display_name="Input", + info="The data with texts to split in chunks.", + input_types=["Data", "DataFrame", "Message"], + required=True, + ), + IntInput( + name="chunk_overlap", + display_name="Chunk Overlap", + info="Number of characters to overlap between chunks.", + value=200, + ), + IntInput( + name="chunk_size", + display_name="Chunk Size", + info=( + "The maximum length of each chunk. Text is first split by separator, " + "then chunks are merged up to this size. " + "Individual splits larger than this won't be further divided." + ), + value=1000, + ), + MessageTextInput( + name="separator", + display_name="Separator", + info=( + "The character to split on. Use \\n for newline. " + "Examples: \\n\\n for paragraphs, \\n for lines, . for sentences" + ), + value="\n", + ), + MessageTextInput( + name="text_key", + display_name="Text Key", + info="The key to use for the text column.", + value="text", + advanced=True, + ), + DropdownInput( + name="keep_separator", + display_name="Keep Separator", + info="Whether to keep the separator in the output chunks and where to place it.", + options=["False", "True", "Start", "End"], + value="False", + advanced=True, + ), + DropdownInput( + name="splitter_type", + display_name="Splitter Type", + info="Which text splitter to use to chunk the documents.", + options=["CharacterTextSplitter", "TableAwareTextSplitter", "LineBasedTextSplitter"], + value="CharacterTextSplitter", + advanced=True, + ), + MessageTextInput( + name="model_id", + display_name="Model ID", + info="The name of the model that will be used for computing embeddings.", + value="ibm-granite/granite-embedding-30m-english", + advanced=True, + ), + DropdownInput( + name="use_document_title", + display_name="Use Document Title", + info="Whether to use the document title as a prefix in each chunk.", + options=["False", "True"], + value="False", + advanced=True, + ), + ] + + outputs = [ + Output(display_name="Chunks", name="dataframe", method="split_text"), + ] + + def _docs_to_data(self, docs) -> list[Data]: + return [Data(text=doc.page_content, data=doc.metadata) for doc in docs] + + def _fix_separator(self, separator: str) -> str: + """Fix common separator issues and convert to proper format.""" + if separator == "/n": + return "\n" + if separator == "/t": + return "\t" + return separator + + @staticmethod + def to_bool(val): + if isinstance(val, str): + if val.lower() == "false": + return False + elif val.lower() == "true": + return True + elif isinstance(val, bool): + return val + raise RuntimeError(f"Cannot convert value {val} to a boolean value. Expected 'True' or 'False'.") + + def split_text_base(self): + separator = self._fix_separator(self.separator) + separator = unescape_string(separator) + + if isinstance(self.data_inputs, DataFrame): + if not len(self.data_inputs): + msg = "DataFrame is empty" + raise TypeError(msg) + + self.data_inputs.text_key = self.text_key + try: + documents = self.data_inputs.to_lc_documents() + except Exception as e: + msg = f"Error converting DataFrame to documents: {e}" + raise TypeError(msg) from e + elif isinstance(self.data_inputs, Message): + self.data_inputs = [self.data_inputs.to_data()] + return self.split_text_base() + else: + if not self.data_inputs: + msg = "No data inputs provided" + raise TypeError(msg) + + documents = [] + if isinstance(self.data_inputs, Data): + self.data_inputs.text_key = self.text_key + documents = [self.data_inputs.to_lc_document()] + else: + try: + documents = [input_.to_lc_document() for input_ in self.data_inputs if isinstance(input_, Data)] + if not documents: + msg = f"No valid Data inputs found in {type(self.data_inputs)}" + raise TypeError(msg) + except AttributeError as e: + msg = f"Invalid input type in collection: {e}" + raise TypeError(msg) from e + try: + if self.splitter_type == "CharacterTextSplitter": + # Convert string 'False'/'True' to boolean + keep_sep = self.to_bool(self.keep_separator) + logger.debug("SPLIT: Creating a CharacterTextSplitter..") + splitter = CharacterTextSplitter( + chunk_overlap=self.chunk_overlap, + chunk_size=self.chunk_size, + separator=separator, + keep_separator=keep_sep, + ) + elif self.splitter_type == "LineBasedTextSplitter": + use_document_title = self.to_bool(self.use_document_title) + splitter = LineBasedTextSplitter( + chunk_size=self.chunk_size, + model_id=self.model_id, + use_document_title=use_document_title, + ) + elif self.splitter_type == "TableAwareTextSplitter": + logger.debug(f"SPLIT: Creating a TableAwareTextSplitter with chunk_size={self.chunk_size} and model_id '{self.model_id}'.") + splitter = TableAwareTextSplitter( + chunk_size=self.chunk_size, + model_id=self.model_id + ) + else: + raise RuntimeError(f"Unknown splitter type value '{self.splitter_type}'.") + return splitter.split_documents(documents) + except Exception as e: + msg = f"Error splitting text: {e}" + raise TypeError(msg) from e + + def split_text(self) -> DataFrame: + return DataFrame(self._docs_to_data(self.split_text_base())) + +class LineBasedTextSplitter: + def __init__( + self, + chunk_size: int, + model_id: str, + prefix: str = "", + use_document_title: bool = False, + ): + self._chunk_size = chunk_size + self.use_tiktoken = False + if model_id in ["text-embedding-3-small", "text-embedding-3-large"]: + self.use_tiktoken = True + logger.debug(f"SPLIT: Initializing LineBasedTextSplitter, chunk_size = {chunk_size}, model_id = '{model_id}', use_tiktoken = {self.use_tiktoken}, use_document_title={use_document_title}.") + if self.use_tiktoken: + import tiktoken + # The tokenizer for text-embedding-3-small, text-embedding-3-large + self._tokenizer = tiktoken.get_encoding("cl100k_base") + else: + from transformers import AutoTokenizer + self._tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=model_id, + ) + self._prefix = "" + self._prefix_len = 0 + self.set_prefix(prefix) + self.use_document_title = use_document_title + + def set_prefix(self, prefix): + logger.debug(f"SPLIT: setting prefix to '{prefix}'..") + prefix_len = len(self.tokenize(prefix)) + if prefix_len >= self._chunk_size: + raise RuntimeError( + f"Chunks prefix: {prefix} is too long for chunk size {self._chunk_size}" + ) + else: + self._prefix = prefix + self._prefix_len = prefix_len + + def tokenize(self, text: str) -> list[int]: + if self.use_tiktoken: + return self._tokenizer.encode(text) + else: + return self._tokenizer.encode(text, add_special_tokens=False) + + def decode_tokens(self, tokens: list[int]): + return self._tokenizer.decode(tokens) + + def split_documents(self, documents: Iterable[Document]) -> list[Document]: + """Given Documents, chunk the text to smaller pieces and return them as list of Documents""" + + chunks = [] + for document in documents: + chunks.extend(self._chunk_document(document)) + return chunks + + def _chunk_document(self, document: Document): + document_text = document.page_content + document_metadata = document.metadata + chunks = [] + chunk_seq_num = 0 + + first_character_index = document_metadata.get("start_index", 0) + if self.use_document_title: + file_name = document_metadata.get("filename", "unknown-file-name") + logger.debug(f"SPLIT: Chunking document with file name '{file_name}'..") + document_title = get_title(file_name) + logger.debug(f"SPLIT: Found title '{document_title}'..") + self.set_prefix(document_title) + + current = self._prefix + current_len = self._prefix_len + + new_line_token_count = len(self.tokenize("\n")) + lines = document_text.split("\n") + for line in lines: + line_tokens = self.tokenize(line) + + while ( + len(line_tokens) > self._chunk_size - current_len + ): # line cannot fit into current + num_available_tokens_in_chunk = ( + self._chunk_size - current_len + if len(line_tokens) + self._prefix_len > self._chunk_size + else 0 + ) # if whole line can fit into a new chunk, do not add anything to current chunk, + # otherwise, split the line between current and next chunks. + + if num_available_tokens_in_chunk > 0: + # split line + if current: + current += "\n" + current_len += new_line_token_count + current += self.decode_tokens( + line_tokens[:num_available_tokens_in_chunk] + ) + current_len += num_available_tokens_in_chunk + + # add current chunk + chunks.append( + self._new_chunk( + current, chunk_seq_num, first_character_index, document_metadata + ) + ) + + # new current chunk + first_character_index += len(current) + chunk_seq_num += 1 + current = self._prefix + current_len = self._prefix_len + line_tokens = line_tokens[num_available_tokens_in_chunk:] + + # rest of line fits into current + if len(line_tokens) > 0: + if current: + current += "\n" + current_len += new_line_token_count + current += self.decode_tokens(line_tokens) + current_len += len(line_tokens) + + # final chunk + chunks.append( + self._new_chunk( + current, chunk_seq_num, first_character_index, document_metadata + ) + ) + + return chunks + + @staticmethod + def _new_chunk( + text: str, seq_no: int, start_index: int, doc_metadata: dict + ) -> Document: + chunk_metadata = copy.deepcopy(doc_metadata) + chunk_metadata["sequence_number"] = seq_no + chunk_metadata["start_index"] = start_index + return Document(page_content=text, metadata=chunk_metadata) + + +class TableAwareTextSplitter: + + def __init__(self, chunk_size: int, model_id: str): + self.chunk_size = chunk_size + self.model_id = model_id + + def split_documents(self, documents: Iterable[Document]) -> list[Document]: + """Given Documents, chunk the text to smaller pieces and return them as list of Documents""" + + chunks = [] + for document in documents: + chunks.extend(self._chunk_document(document)) + return chunks + + def _chunk_document(self, document: Document) -> list[Document]: + segments = self._get_segments(document) + + chunks = [] + for segment in segments: + prefix = self.get_prefix(segment) + line_splitter = LineBasedTextSplitter( + chunk_size=self.chunk_size, + model_id=self.model_id, + prefix=prefix + ) + + chunks.extend(line_splitter.split_documents([segment])) + + return chunks + + # fix me: does not indicate sub headers + def _get_segments(self, doc): + segments = [] + doc_metadata = doc.metadata + segments_count = 0 + start_index = doc.metadata.get("start_index", 0) + current_segment = Document( + page_content="", + metadata={"type": "text", "seq_no": segments_count, "start_index": start_index} + | doc_metadata, + ) + separator_found = False + lines = doc.page_content.split("\n") + for line in lines: + + if self._is_table_line(line): + if current_segment.metadata["type"] != "table": # first table line + segments.append(current_segment) + segments_count += 1 + start_index += len(current_segment.page_content) + current_segment = Document( + page_content="", + metadata={ + "type": "table", + "caption": self.get_caption(current_segment), + "header": self.condense_table_row(line), + "seq_no": segments_count, + "start_index": start_index, + } + | doc_metadata, + ) + separator_found = False + elif self._is_table_seperator(line): + + separator_found = True + current_segment.metadata[ + "header" + ] += "\n" + self.condense_separator(line) + elif not separator_found: + + current_segment.metadata[ + "header" + ] += "\n" + self.condense_table_row(line) + else: + current_segment.page_content += "\n" + line + + else: # text line + if current_segment.metadata["type"] == "table": + segments.append(current_segment) + segments_count += 1 + start_index += len(current_segment.page_content) + current_segment = Document( + page_content="", + metadata={ + "type": "text", + "seq_no": segments_count, + "start_index": start_index, + } + | doc_metadata, + ) + current_segment.page_content += "\n" + line + + # last segment + segments.append(current_segment) + return [c for c in segments if len(c.page_content.strip()) > 0] + + @staticmethod + def get_prefix(segment: Document) -> str: + if segment.metadata["type"] == "text": + return "" + elif segment.metadata["type"] == "table": + result = segment.metadata["caption"] + if result: + result += "\n" + result += segment.metadata["header"] + return result + else: + raise RuntimeError(f"Internal error: unknown segment type '{segment['type']}' for segment {segment}.") + + # returns last sentence before table + @staticmethod + def get_caption(prev_segment) -> str: + last_sentence = prev_segment.page_content.strip().split("\n")[-1].split(".")[-1] + return last_sentence + + # each line starting with | is included in table + @staticmethod + def _is_table_line(line: str): + return line.startswith("|") + + @staticmethod + def _is_table_seperator(line: str): + cells = [c.strip() for c in line.strip().split("|")] + return all( + re.match(r"[-]+", cell.strip()) for cell in cells if len(cell.strip()) > 0 + ) + + @staticmethod + def condense_separator(line: str): + numCells = len(line.strip().split("|")) - 2 + return "| --- " * numCells + "|" + + @staticmethod + def condense_table_row(line: str) -> str: + if sum([t.isalnum() for t in line]) == 0: + return "" + cells = [c.strip() for c in line.strip().split("|")] + + return " | ".join(cells).strip() + +def get_title(file_name: str) -> str: + file_name_to_title = { + "docling.pdf": "Docling Technical Report" + } + file_name_to_title.update(filename_to_output) + return file_name_to_title.get(file_name, "") + + +filename_to_output = { + "Alaska-2017.pdf": """This document is the 2017 annual report (Form 10-K) of Alaska Air Group, Inc., filed with the United States Securities and Exchange Commission (SEC). The report covers the fiscal year ended December 31, 2017. Important entities mentioned include: + +* Alaska Air Group, Inc. (the company) +* United States Securities and Exchange Commission (SEC) +* New York Stock Exchange (where the company's common stock is registered) + +Important dates mentioned include: + +* December 31, 2017 (end of the fiscal year) +* January 31, 2018 (date of share outstanding total) +* June 30, 2017 (date used to calculate aggregate market value of shares held by nonaffiliates)""", + "Alaska-2018.pdf": """This document is the 2018 annual report (Form 10-K) of Alaska Air Group, Inc., filed with the United States Securities and Exchange Commission (SEC). The report covers the fiscal year ended December 31, 2018. Important entities mentioned include: + +* Alaska Air Group, Inc. (the company) +* United States Securities and Exchange Commission (SEC) +* New York Stock Exchange (where the company's common stock is listed) + +Important dates mentioned include: + +* December 31, 2018 (end of the fiscal year) +* January 31, 2019 (date of share outstanding total) +* June 30, 2018 (date used to calculate aggregate market value of shares held by nonaffiliates)""", + "AmericanAirlines-2017.pdf": "This document is the 2017 annual report (Form 10-K) of American Airlines Group Inc., filed with the United States Securities and Exchange Commission (SEC).", + "AmericanAirlines-2018.pdf": "The document \"AmericanAirlines-2018.pdf\" is the 2018 Annual Report on Form 10-K for American Airlines Group Inc.", + "AmericanAirlines-2019.pdf": "This document is the 2020 Annual Report on Form 10-K for American Airlines Group Inc., filed for the year ending 2019.", + "Delta-2017.pdf": "This document is the 2017 annual report (Form 10-K) of Delta Air Lines, Inc. for the fiscal year ended December 31, 2017.", + "Delta-2018.pdf": "This document is the 2018 annual report (Form 10-K) of Delta Air Lines, Inc. for the fiscal year ended December 31, 2018.", + "Delta-2019.pdf": "This document is the 2019 annual report (Form 10-K) of Delta Air Lines, Inc. for the fiscal year ended December 31, 2019.", + "Southwest-2017.pdf": "This document is the 2017 Annual Report to Shareholders of Southwest Airlines Co.", + "Southwest-2018.pdf": "This document is the 2018 Annual Report to Shareholders of Southwest Airlines Co.", + "Southwest-2019.pdf": "This document is the 2019 Annual Report to Shareholders of Southwest Airlines Co.", + "United-2017.pdf": "This document is the 2017 annual report (Form 10-K) of United Continental Holdings, Inc. and United Airlines, Inc.", + "United-2018.pdf": "This document is the 2018 annual report (Form 10-K) of United Continental Holdings, Inc. and United Airlines, Inc.", + "United-2019.pdf": "This document is the 2019 annual report (Form 10-K) of United Airlines Holdings, Inc. and United Airlines, Inc." +} diff --git a/flows/ingestion_flow.json b/flows/ingestion_flow.json index c5c2fc375..4555f40a3 100644 --- a/flows/ingestion_flow.json +++ b/flows/ingestion_flow.json @@ -5564,7 +5564,7 @@ "_type": "Component", "api_base": { "_input_type": "MessageTextInput", - "advanced": true, + "advanced": false, "display_name": "OpenAI API Base URL", "dynamic": false, "info": "Base URL for the API. Leave empty for default.", @@ -5573,7 +5573,7 @@ ], "list": false, "list_add_label": "Add More", - "load_from_db": false, + "load_from_db": true, "name": "api_base", "override_skip": false, "placeholder": "", @@ -5585,7 +5585,7 @@ "trace_as_metadata": true, "track_in_telemetry": false, "type": "str", - "value": "" + "value": "OPENAI_API_BASE" }, "api_key": { "_input_type": "SecretStrInput", diff --git a/flows/openrag_agent.json b/flows/openrag_agent.json index b2f41b1b6..1e6309032 100644 --- a/flows/openrag_agent.json +++ b/flows/openrag_agent.json @@ -1808,14 +1808,14 @@ }, "openai_api_base": { "_input_type": "StrInput", - "advanced": true, + "advanced": false, "display_name": "OpenAI API Base", "dynamic": false, "info": "The base URL of the OpenAI API. Defaults to https://api.openai.com/v1. You can change this to use other APIs like JinaChat, LocalAI and Prem.", "input_types": [], "list": false, "list_add_label": "Add More", - "load_from_db": false, + "load_from_db": true, "name": "openai_api_base", "override_skip": false, "placeholder": "", @@ -1826,7 +1826,7 @@ "trace_as_metadata": true, "track_in_telemetry": false, "type": "str", - "value": "" + "value": "OPENAI_API_BASE" }, "output_schema": { "_input_type": "TableInput", @@ -2344,7 +2344,7 @@ ], "list": false, "list_add_label": "Add More", - "load_from_db": false, + "load_from_db": true, "name": "api_base", "override_skip": false, "placeholder": "", @@ -2356,7 +2356,7 @@ "trace_as_metadata": true, "track_in_telemetry": false, "type": "str", - "value": "" + "value": "OPENAI_API_BASE" }, "api_key": { "_input_type": "SecretStrInput", diff --git a/scripts/update_split_text_component.py b/scripts/update_split_text_component.py new file mode 100644 index 000000000..ec3c3b541 --- /dev/null +++ b/scripts/update_split_text_component.py @@ -0,0 +1,52 @@ +# !/usr/bin/env python3 +import subprocess +import sys + +flows_dir = "../flows" +flow_file = flows_dir + "/ingestion_flow.json" +code_file = flows_dir + "/components/split_text.py" +display_name = "Split Text" + + +def main(read: bool): + if read: + read_component() + else: + write_component() + +def read_component(): + metadata_module = None # "mypkg.flow_meta" # OPTIONAL + match_index = None # OPTIONAL + output = code_file # OPTIONAL + + # Build the command + cmd = [sys.executable, "extract_flow_component.py", "--flow-file", flow_file] + + if display_name: + cmd += ["--display-name", display_name] + if metadata_module: + cmd += ["--metadata-module", metadata_module] + if match_index is not None: + cmd += ["--match-index", str(match_index)] + if output: + cmd += ["--output", output] + + # Run the command + print("Running:", " ".join(cmd)) + subprocess.run(cmd) + +def write_component(): + # Build the command + cmd = [sys.executable, "update_flow_components.py", + "--code-file", code_file, + "--display-name", display_name, + "--flows-dir", flows_dir + ] + + # Run the command + print("Running:", " ".join(cmd)) + subprocess.run(cmd) + + +if __name__ == "__main__": + main(read=False) diff --git a/src/api/models.py b/src/api/models.py index 118126029..8d4ceb281 100644 --- a/src/api/models.py +++ b/src/api/models.py @@ -1,3 +1,5 @@ +import os + from starlette.responses import JSONResponse from utils.logging_config import get_logger from config.settings import get_openrag_config @@ -10,9 +12,11 @@ async def get_openai_models(request, models_service, session_manager): try: # Get API key from request body api_key = None + api_base = None try: body = await request.json() api_key = body.get("api_key") if body else None + api_base = body.get("api_base") if body else None except Exception: # Body might be empty or invalid JSON, continue to fallback pass @@ -36,7 +40,24 @@ async def get_openai_models(request, models_service, session_manager): status_code=400, ) - models = await models_service.get_openai_models(api_key=api_key) + if not api_base: + try: + config = get_openrag_config() + api_base = config.providers.openai.endpoint + logger.info( + f"Retrieved OpenAI API base from config: {'yes' if api_base else 'no'}" + ) + except Exception as e: + logger.error(f"Failed to get config: {e}") + if not api_base: + return JSONResponse( + { + "error": "OpenAI API base is required either in request body or in configuration" + }, + status_code=400, + ) + + models = await models_service.get_openai_models(api_key=api_key, api_base=api_base) return JSONResponse(models) except Exception as e: logger.error(f"Failed to get OpenAI models: {str(e)}") diff --git a/src/api/provider_validation.py b/src/api/provider_validation.py index 813826a18..458190193 100644 --- a/src/api/provider_validation.py +++ b/src/api/provider_validation.py @@ -1,6 +1,8 @@ """Provider validation utilities for testing API keys and models during onboarding.""" import json +import os + import httpx from utils.container_utils import transform_localhost_url from utils.logging_config import get_logger @@ -107,7 +109,6 @@ def _extract_error_details(response: httpx.Response) -> str: return parsed return response_text - async def validate_provider_setup( provider: str, api_key: str = None, @@ -125,7 +126,7 @@ async def validate_provider_setup( api_key: API key for the provider (optional for ollama) embedding_model: Embedding model to test llm_model: LLM model to test - endpoint: Provider endpoint (required for ollama and watsonx) + endpoint: Provider endpoint (required for ollama and watsonx, optional for openai) project_id: Project ID (required for watsonx) test_completion: If True, performs full validation with completion/embedding tests (consumes credits). If False, performs lightweight validation (no credits consumed). Default: False. @@ -138,11 +139,14 @@ async def validate_provider_setup( try: logger.info(f"Starting validation for provider: {provider_lower} (test_completion={test_completion})") + if provider == "openai" and not endpoint: + endpoint = os.environ.get("OPENAI_API_BASE", "https://api.openai.com") + if test_completion: # Full validation with completion/embedding tests (consumes credits) if embedding_model: # Test embedding - await test_embedding( + await _test_embedding( provider=provider_lower, api_key=api_key, embedding_model=embedding_model, @@ -151,7 +155,7 @@ async def validate_provider_setup( ) elif llm_model: # Test completion with tool calling - await test_completion_with_tools( + await _test_completion_with_tools( provider=provider_lower, api_key=api_key, llm_model=llm_model, @@ -160,7 +164,7 @@ async def validate_provider_setup( ) else: # Lightweight validation (no credits consumed) - await test_lightweight_health( + await _test_lightweight_health( provider=provider_lower, api_key=api_key, endpoint=endpoint, @@ -175,7 +179,7 @@ async def validate_provider_setup( raise -async def test_lightweight_health( +async def _test_lightweight_health( provider: str, api_key: str = None, endpoint: str = None, @@ -184,7 +188,7 @@ async def test_lightweight_health( """Test provider health with lightweight check (no credits consumed).""" if provider == "openai": - await _test_openai_lightweight_health(api_key) + await _test_openai_lightweight_health(api_key, endpoint) elif provider == "watsonx": await _test_watsonx_lightweight_health(api_key, endpoint, project_id) elif provider == "ollama": @@ -195,7 +199,7 @@ async def test_lightweight_health( raise ValueError(f"Unknown provider: {provider}") -async def test_completion_with_tools( +async def _test_completion_with_tools( provider: str, api_key: str = None, llm_model: str = None, @@ -205,7 +209,7 @@ async def test_completion_with_tools( """Test completion with tool calling for the provider.""" if provider == "openai": - await _test_openai_completion_with_tools(api_key, llm_model) + await _test_openai_completion_with_tools(api_key, llm_model, endpoint) elif provider == "watsonx": await _test_watsonx_completion_with_tools(api_key, llm_model, endpoint, project_id) elif provider == "ollama": @@ -216,7 +220,7 @@ async def test_completion_with_tools( raise ValueError(f"Unknown provider: {provider}") -async def test_embedding( +async def _test_embedding( provider: str, api_key: str = None, embedding_model: str = None, @@ -226,7 +230,7 @@ async def test_embedding( """Test embedding generation for the provider.""" if provider == "openai": - await _test_openai_embedding(api_key, embedding_model) + await _test_openai_embedding(api_key, embedding_model, endpoint) elif provider == "watsonx": await _test_watsonx_embedding(api_key, embedding_model, endpoint, project_id) elif provider == "ollama": @@ -236,7 +240,7 @@ async def test_embedding( # OpenAI validation functions -async def _test_openai_lightweight_health(api_key: str) -> None: +async def _test_openai_lightweight_health(api_key: str, endpoint: str) -> None: """Test OpenAI API key validity with lightweight check. Only checks if the API key is valid without consuming credits. @@ -248,10 +252,12 @@ async def _test_openai_lightweight_health(api_key: str) -> None: "Content-Type": "application/json", } + url = f"{endpoint}/v1/models" + logger.info("Testing openai lightweight health", url=url) async with httpx.AsyncClient() as client: # Use /v1/models endpoint which validates the key without consuming credits response = await client.get( - "https://api.openai.com/v1/models", + url=url, headers=headers, timeout=10.0, # Short timeout for lightweight check ) @@ -271,7 +277,7 @@ async def _test_openai_lightweight_health(api_key: str) -> None: raise -async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> None: +async def _test_openai_completion_with_tools(api_key: str, llm_model: str, endpoint: str) -> None: """Test OpenAI completion with tool calling.""" try: headers = { @@ -309,8 +315,10 @@ async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> No async with httpx.AsyncClient() as client: # Try with max_tokens first payload = {**base_payload, "max_tokens": 50} + url = f"{endpoint}/v1/chat/completions" + logger.info("Test openai completion tools", url=url) response = await client.post( - "https://api.openai.com/v1/chat/completions", + url=url, headers=headers, json=payload, timeout=30.0, @@ -320,8 +328,9 @@ async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> No if response.status_code != 200: logger.info("max_tokens parameter failed, trying max_completion_tokens instead") payload = {**base_payload, "max_completion_tokens": 50} + logger.info("Test openai completion tools", url=url) response = await client.post( - "https://api.openai.com/v1/chat/completions", + url=url, headers=headers, json=payload, timeout=30.0, @@ -342,7 +351,7 @@ async def _test_openai_completion_with_tools(api_key: str, llm_model: str) -> No raise -async def _test_openai_embedding(api_key: str, embedding_model: str) -> None: +async def _test_openai_embedding(api_key: str, embedding_model: str, endpoint: str) -> None: """Test OpenAI embedding generation.""" try: headers = { @@ -356,8 +365,9 @@ async def _test_openai_embedding(api_key: str, embedding_model: str) -> None: } async with httpx.AsyncClient() as client: + url = f"{endpoint}/v1/embeddings" response = await client.post( - "https://api.openai.com/v1/embeddings", + url=url, headers=headers, json=payload, timeout=30.0, diff --git a/src/api/settings.py b/src/api/settings.py index dedba0f7c..21e35b35c 100644 --- a/src/api/settings.py +++ b/src/api/settings.py @@ -2,6 +2,7 @@ import platform import time from starlette.responses import JSONResponse + from utils.container_utils import transform_localhost_url from utils.logging_config import get_logger from utils.telemetry import TelemetryClient, Category, MessageId @@ -191,6 +192,48 @@ async def get_settings(request, session_manager): {"error": f"Failed to retrieve settings: {str(e)}"}, status_code=500 ) +async def validate_enum_str( + body: dict, + field_name: str, + allowed_values: list[str], +) -> JSONResponse | None: + """ + Validate that body[field_name], if present, is: + - a string, + - non-empty after stripping whitespace, + - and one of the allowed_values. + + """ + # If not provided, no validation needed + if field_name not in body: + return None + + value = body[field_name] + + # Must be a string + if not isinstance(value, str): + return JSONResponse( + {"error": f"{field_name} must be a string"}, + status_code=400, + ) + + # Must be non-empty after trimming + trimmed = value.strip() + if not trimmed: + return JSONResponse( + {"error": f"{field_name} must be a non-empty string"}, + status_code=400, + ) + + # Must be among allowed values + if trimmed not in allowed_values: + allowed_str = ", ".join(allowed_values) + return JSONResponse( + {"error": f"{field_name} must be one of: {allowed_str}"}, + status_code=400, + ) + + return None async def update_settings(request, session_manager): """Update application settings""" @@ -217,6 +260,8 @@ async def update_settings(request, session_manager): "system_prompt", "chunk_size", "chunk_overlap", + "splitter_type", + "use_document_title", "table_structure", "ocr", "picture_descriptions", @@ -480,6 +525,19 @@ async def update_settings(request, session_manager): Category.SETTINGS_OPERATIONS, MessageId.ORB_SETTINGS_EMBED_MODEL ) + + # Also update the ingest flow with a new model id + try: + flows_service = _get_flows_service() + await flows_service.update_ingest_flow_model_id_in_text_splitter(model_id=new_embedding_model) + logger.info( + f"Successfully updated ingest flow model id in text splitter to {new_embedding_model}" + ) + except Exception as e: + logger.error(f"Failed to update ingest flow model id in text splitter: {str(e)}") + # Don't fail the entire settings update if flow update fails + # The config will still be saved + logger.info(f"Embedding model changed from {old_model} to {new_embedding_model}") if "embedding_provider" in body: @@ -537,6 +595,48 @@ async def update_settings(request, session_manager): except Exception as e: logger.error(f"Failed to update docling settings in flow: {str(e)}") + if "splitter_type" in body: + new_splitter_type = body["splitter_type"] + current_config.knowledge.splitter_type = new_splitter_type + config_updated = True + await TelemetryClient.send_event( + Category.SETTINGS_OPERATIONS, + MessageId.ORB_SETTINGS_CHUNK_UPDATED + ) + + # Also update the ingest flow with the new splitter type + try: + flows_service = _get_flows_service() + await flows_service.update_ingest_flow_splitter_type(new_splitter_type) + logger.info( + f"Successfully updated ingest flow splitter type to {new_splitter_type}" + ) + except Exception as e: + logger.error(f"Failed to update ingest flow splitter type: {str(e)}") + # Don't fail the entire settings update if flow update fails + # The config will still be saved + + if "use_document_title" in body: + new_use_document_title = body["use_document_title"] + current_config.knowledge.use_document_title = new_use_document_title + config_updated = True + await TelemetryClient.send_event( + Category.SETTINGS_OPERATIONS, + MessageId.ORB_SETTINGS_CHUNK_UPDATED + ) + + # Also update the ingest flow with the new splitter type + try: + flows_service = _get_flows_service() + await flows_service.update_ingest_flow_use_document_title(new_use_document_title) + logger.info( + f"Successfully updated ingest flow use_document_title to {new_use_document_title}" + ) + except Exception as e: + logger.error(f"Failed to update ingest flow use_document_title do: {str(e)}") + # Don't fail the entire settings update if flow update fails + # The config will still be saved + if "chunk_size" in body: current_config.knowledge.chunk_size = body["chunk_size"] config_updated = True @@ -731,6 +831,7 @@ async def onboarding(request, flows_service, session_manager=None): "llm_model", "embedding_provider", "embedding_model", + "delete_existing_index", # Provider-specific fields "openai_api_key", "anthropic_api_key", @@ -743,9 +844,11 @@ async def onboarding(request, flows_service, session_manager=None): # Check for invalid fields invalid_fields = set(body.keys()) - allowed_fields if invalid_fields: + error_message = f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}" + logger.error(error_message) return JSONResponse( { - "error": f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}" + "error": error_message }, status_code=400, ) @@ -1012,10 +1115,26 @@ async def onboarding(request, flows_service, session_manager=None): # Import here to avoid circular imports from main import init_index + # Handle delete_existing_index + delete_existing_index = False + if "delete_existing_index" in body: + delete_existing_index = body["delete_existing_index"] + if not isinstance(delete_existing_index, bool): + return JSONResponse( + {"error": "delete_existing_index must be a boolean value"}, status_code=400 + ) + if delete_existing_index: + await TelemetryClient.send_event( + Category.ONBOARDING, + MessageId.ORB_ONBOARD_DELETE_EXISTING_INDEX + ) + logger.info("Delete existing index requested during onboarding") + logger.info( - "Initializing OpenSearch index after onboarding configuration" + f"Initializing OpenSearch index after onboarding configuration (delete_existing_index={delete_existing_index})", ) - await init_index() + await init_index(delete_existing=delete_existing_index) + logger.info("OpenSearch index initialization completed successfully") except Exception as e: logger.error( diff --git a/src/config/config_manager.py b/src/config/config_manager.py index 8b3bb38b8..6556aa05f 100644 --- a/src/config/config_manager.py +++ b/src/config/config_manager.py @@ -14,6 +14,7 @@ class OpenAIConfig: """OpenAI provider configuration.""" api_key: str = "" + endpoint: str = "" configured: bool = False @@ -71,6 +72,7 @@ class KnowledgeConfig: embedding_provider: str = "openai" # Which provider to use for embeddings chunk_size: int = 1000 chunk_overlap: int = 200 + splitter_type: str = "CharacterTextSplitter" table_structure: bool = True ocr: bool = False picture_descriptions: bool = False @@ -224,6 +226,8 @@ def _load_env_overrides( # OpenAI provider settings if os.getenv("OPENAI_API_KEY"): config_data["providers"]["openai"]["api_key"] = os.getenv("OPENAI_API_KEY") + if os.getenv("OPENAI_API_BASE"): + config_data["providers"]["openai"]["endpoint"] = os.getenv("OPENAI_API_BASE") # Anthropic provider settings if os.getenv("ANTHROPIC_API_KEY"): diff --git a/src/main.py b/src/main.py index aa8217f7c..3dc397b50 100644 --- a/src/main.py +++ b/src/main.py @@ -154,6 +154,7 @@ async def configure_alerting_security(): async def _ensure_opensearch_index(): """Ensure OpenSearch index exists when using traditional connector service.""" + import config.settings as settings try: index_name = get_index_name() # Check if index already exists @@ -183,8 +184,9 @@ async def _ensure_opensearch_index(): # The service can still function, document operations might fail later -async def init_index(): +async def init_index(delete_existing: bool = False): """Initialize OpenSearch index and security roles""" + import config.settings as settings await wait_for_opensearch() # Get the configured embedding model from user configuration @@ -201,9 +203,17 @@ async def init_index(): endpoint=getattr(embedding_provider_config, "endpoint", None) ) - # Create documents index index_name = get_index_name() - if not await clients.opensearch.indices.exists(index=index_name): + index_exists = await clients.opensearch.indices.exists(index=index_name) + if index_exists and delete_existing: + # Asked to delete the existing index .. + logger.info(f"Deleting index '{index_name}'...") + resp = await clients.opensearch.indices.delete(index=index_name) + logger.info(f"Deleted index '{index_name}', response: {resp}") + index_exists = False + + # Create documents index + if not index_exists: await clients.opensearch.indices.create( index=index_name, body=dynamic_index_body ) @@ -603,6 +613,7 @@ async def startup_tasks(services): async def initialize_services(): + import config.settings as settings """Initialize all services and their dependencies""" await TelemetryClient.send_event(Category.SERVICE_INITIALIZATION, MessageId.ORB_SVC_INIT_START) # Generate JWT keys if they don't exist diff --git a/src/services/flows_service.py b/src/services/flows_service.py index e97ac2d3a..1fb4ff806 100644 --- a/src/services/flows_service.py +++ b/src/services/flows_service.py @@ -953,6 +953,39 @@ async def update_ingest_flow_chunk_overlap(self, chunk_overlap: int): node_display_name="Split Text", ) + async def update_ingest_flow_splitter_type(self, splitter_type: str): + """Helper function to update splitter type in the ingest flow""" + if not LANGFLOW_INGEST_FLOW_ID: + raise ValueError("LANGFLOW_INGEST_FLOW_ID is not configured") + await self._update_flow_field( + LANGFLOW_INGEST_FLOW_ID, + "splitter_type", + splitter_type, + node_display_name="Split Text", + ) + + async def update_ingest_flow_use_document_title(self, use_document_title: bool): + """Helper function to update splitter type in the ingest flow""" + if not LANGFLOW_INGEST_FLOW_ID: + raise ValueError("LANGFLOW_INGEST_FLOW_ID is not configured") + await self._update_flow_field( + LANGFLOW_INGEST_FLOW_ID, + "use_document_title", + str(use_document_title), + node_display_name="Split Text", + ) + + async def update_ingest_flow_model_id_in_text_splitter(self, model_id: str): + """Helper function to update splitter type in the ingest flow""" + if not LANGFLOW_INGEST_FLOW_ID: + raise ValueError("LANGFLOW_INGEST_FLOW_ID is not configured") + await self._update_flow_field( + LANGFLOW_INGEST_FLOW_ID, + "model_id", + model_id, + node_display_name="Split Text", + ) + async def update_ingest_flow_embedding_model(self, embedding_model: str, provider: str): """Helper function to update embedding model in the ingest flow""" if not LANGFLOW_INGEST_FLOW_ID: @@ -1393,11 +1426,17 @@ async def _update_component_fields( template["api_key"]["advanced"] = False updated = True if provider == "openai" and "api_base" in template: - template["api_base"]["value"] = "" - template["api_base"]["load_from_db"] = False + template["api_base"]["value"] = "OPENAI_API_BASE" + template["api_base"]["load_from_db"] = True template["api_base"]["show"] = True template["api_base"]["advanced"] = False updated = True + if provider == "openai" and "openai_api_base" in template: + template["openai_api_base"]["value"] = "OPENAI_API_BASE" + template["openai_api_base"]["load_from_db"] = True + template["openai_api_base"]["show"] = True + template["openai_api_base"]["advanced"] = False + updated = True if provider == "anthropic" and "api_key" in template: template["api_key"]["value"] = "ANTHROPIC_API_KEY" diff --git a/src/services/models_service.py b/src/services/models_service.py index cd08b7085..d99e91e33 100644 --- a/src/services/models_service.py +++ b/src/services/models_service.py @@ -1,3 +1,5 @@ +import os + import httpx from typing import Dict, List from utils.container_utils import transform_localhost_url @@ -28,6 +30,7 @@ class ModelsService: "o3-pro", "o4-mini", "o4-mini-high", + "claude-opus-4-5-20251101", ] ANTHROPIC_MODELS = [ @@ -43,7 +46,7 @@ class ModelsService: def __init__(self): self.session_manager = None - async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str]]]: + async def get_openai_models(self, api_key: str, api_base: str) -> Dict[str, List[Dict[str, str]]]: """Fetch available models from OpenAI API with lightweight validation""" try: headers = { @@ -54,8 +57,10 @@ async def get_openai_models(self, api_key: str) -> Dict[str, List[Dict[str, str] async with httpx.AsyncClient() as client: # Lightweight validation: just check if API key is valid # This doesn't consume credits, only validates the key + url = f"{api_base}/v1/models" + logger.debug("Getting openai models.", url=url) response = await client.get( - "https://api.openai.com/v1/models", headers=headers, timeout=10.0 + url, headers=headers, timeout=10.0 ) if response.status_code == 200: diff --git a/src/utils/langflow_headers.py b/src/utils/langflow_headers.py index e3447e611..bfc5f1271 100644 --- a/src/utils/langflow_headers.py +++ b/src/utils/langflow_headers.py @@ -14,7 +14,10 @@ def add_provider_credentials_to_headers(headers: Dict[str, str], config) -> None # Add OpenAI credentials if config.providers.openai.api_key: headers["X-LANGFLOW-GLOBAL-VAR-OPENAI_API_KEY"] = str(config.providers.openai.api_key) - + + if config.providers.openai.endpoint: + headers["X-LANGFLOW-GLOBAL-VAR-OPENAI_API_BASE"] = str(config.providers.openai.endpoint) + # Add Anthropic credentials if config.providers.anthropic.api_key: headers["X-LANGFLOW-GLOBAL-VAR-ANTHROPIC_API_KEY"] = str(config.providers.anthropic.api_key) @@ -47,6 +50,9 @@ def build_mcp_global_vars_from_config(config) -> Dict[str, str]: if config.providers.openai.api_key: global_vars["OPENAI_API_KEY"] = config.providers.openai.api_key + if config.providers.openai.endpoint: + global_vars["OPENAI_API_BASE"] = config.providers.openai.endpoint + # Add Anthropic credentials if config.providers.anthropic.api_key: global_vars["ANTHROPIC_API_KEY"] = config.providers.anthropic.api_key diff --git a/src/utils/telemetry/message_id.py b/src/utils/telemetry/message_id.py index a5b17656e..e7d96c0be 100644 --- a/src/utils/telemetry/message_id.py +++ b/src/utils/telemetry/message_id.py @@ -199,6 +199,8 @@ class MessageId: ORB_ONBOARD_EMBED_MODEL = "ORB_ONBOARD_EMBED_MODEL" # Message: Sample data ingestion requested ORB_ONBOARD_SAMPLE_DATA = "ORB_ONBOARD_SAMPLE_DATA" + # Message: Delete existing index requested + ORB_ONBOARD_DELETE_EXISTING_INDEX = "ORB_ONBOARD_DELETE_EXISTING_INDEX" # Message: Configuration marked as edited ORB_ONBOARD_CONFIG_EDITED = "ORB_ONBOARD_CONFIG_EDITED" # Message: Onboarding rolled back due to all files failing