Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 102 additions & 29 deletions api/routes/pipelex/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,135 @@
from __future__ import annotations

import asyncio
import traceback
from typing import Annotated
from typing import TYPE_CHECKING, Annotated, Any, cast

from fastapi import APIRouter, Depends, HTTPException, Request
from kajson import kajson
from mthds.pipeline import PipelineRequest, PipelineState
from pipelex import log
from pipelex.client.pipeline_request_factory import PipelineRequestFactory
from pipelex.client.pipeline_response_factory import PipelineResponseFactory
from pipelex.client.protocol import PipelineRequest, PipelineResponse, PipelineState
from pipelex.pipeline.execute import execute_pipeline
from pipelex.config import get_config
from pipelex.hub import get_pipe_router
from pipelex.pipeline.pipeline_response import PipelexPipelineExecuteResponse, PipelexPipelineStartResponse
from pipelex.pipeline.pipeline_run_setup import pipeline_run_setup
from pipelex.pipeline.runner import PipelexRunner
from typing_extensions import override

from api.routes.pipelex.utils import get_current_iso_timestamp

if TYPE_CHECKING:
from mthds.models.pipe_output import VariableMultiplicity
from mthds.models.pipeline_inputs import PipelineInputs
from mthds.models.working_memory import WorkingMemoryAbstract
from pipelex.core.memory.working_memory import WorkingMemory
from pipelex.core.pipes.pipe_output import PipeOutput


router = APIRouter(tags=["pipeline"])


class ApiRunner(PipelexRunner):
"""API runner that extends PipelexRunner with start_pipeline support."""

@override
async def start_pipeline(
self,
pipe_code: str | None = None,
mthds_content: str | None = None,
inputs: PipelineInputs | WorkingMemoryAbstract[Any] | None = None,
output_name: str | None = None,
output_multiplicity: VariableMultiplicity | None = None,
dynamic_output_concept_code: str | None = None,
) -> PipelexPipelineStartResponse:
"""Start a pipeline execution asynchronously without waiting for completion.

Setup (validation, library loading) runs synchronously so errors are
returned to the caller immediately. Only the pipe execution itself
runs in the background.
"""
created_at = get_current_iso_timestamp()
pipelex_inputs: PipelineInputs | WorkingMemory | None = cast("PipelineInputs | WorkingMemory | None", inputs)

execution_config = self.execution_config or get_config().pipelex.pipeline_execution_config
pipe_job, pipeline_run_id, _ = await pipeline_run_setup(
execution_config=execution_config,
library_id=self.library_id,
library_dirs=self.library_dirs,
pipe_code=pipe_code,
plx_content=mthds_content,
bundle_uri=self.bundle_uri,
inputs=pipelex_inputs,
output_name=output_name,
output_multiplicity=output_multiplicity,
dynamic_output_concept_code=dynamic_output_concept_code,
pipe_run_mode=self.pipe_run_mode,
search_domain_codes=self.search_domain_codes,
user_id=self.user_id,
)

task: asyncio.Task[PipeOutput] = asyncio.create_task(get_pipe_router().run(pipe_job))
self._running_tasks[pipeline_run_id] = task

return PipelexPipelineStartResponse(
pipeline_run_id=pipeline_run_id,
created_at=created_at,
pipeline_state=PipelineState.STARTED,
)


async def request_deserialization(request: Request) -> PipelineRequest:
"""Dependency that deserializes the request body using kajson"""
"""Dependency that deserializes the request body using kajson."""
body = await request.body()
body_str = body.decode("utf-8")
request_data = kajson.loads(body_str)
return PipelineRequestFactory.make_from_body(request_data)
return PipelineRequest.from_body(request_data)


@router.post("/pipeline/execute", response_model=PipelineResponse)
@router.post("/pipeline/execute", response_model=PipelexPipelineExecuteResponse)
async def execute(
pipeline_request: Annotated[PipelineRequest, Depends(request_deserialization)],
):
"""Executes a pipe with the given memory and waits for completion.

This endpoint can operate in two modes:
1. If 'plx_content' is provided: validates, loads, and executes pipes from the PLX content
2. If 'plx_content' is not provided: executes an already-loaded pipe

This is a blocking operation that doesn't return until the pipe execution is complete.
"""
) -> PipelexPipelineExecuteResponse:
"""Execute a pipeline and wait for completion."""
try:
created_at = get_current_iso_timestamp()
pipe_output = await execute_pipeline(
runner = ApiRunner()
return await runner.execute_pipeline(
pipe_code=pipeline_request.pipe_code,
plx_content=pipeline_request.plx_content,
mthds_content=pipeline_request.mthds_content,
inputs=pipeline_request.inputs,
output_name=pipeline_request.output_name,
output_multiplicity=pipeline_request.output_multiplicity,
dynamic_output_concept_code=pipeline_request.dynamic_output_concept_code,
)

return PipelineResponseFactory.make_from_pipe_output(
pipeline_run_id=pipe_output.pipeline_run_id,
pipeline_state=PipelineState.COMPLETED,
created_at=created_at,
finished_at=get_current_iso_timestamp(),
pipe_output=pipe_output,
)

except Exception as exc:
log.error("Pipeline execution error details:")
traceback.print_exc()
raise HTTPException(
status_code=500,
detail={
"error_type": type(exc).__name__,
"message": str(exc),
},
) from exc


@router.post("/pipeline/start", response_model=PipelexPipelineStartResponse)
async def start(
pipeline_request: Annotated[PipelineRequest, Depends(request_deserialization)],
) -> PipelexPipelineStartResponse:
"""Start a pipeline execution asynchronously without waiting for completion."""
try:
runner = ApiRunner()
return await runner.start_pipeline(
pipe_code=pipeline_request.pipe_code,
mthds_content=pipeline_request.mthds_content,
inputs=pipeline_request.inputs,
output_name=pipeline_request.output_name,
output_multiplicity=pipeline_request.output_multiplicity,
dynamic_output_concept_code=pipeline_request.dynamic_output_concept_code,
)
except Exception as exc:
log.error("Pipeline start error details:")
traceback.print_exc()
raise HTTPException(
status_code=500,
detail={
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ dependencies = [
]

[tool.uv.sources]
pipelex = { git = "https://github.com/Pipelex/pipelex.git", rev = "moad" }

pipelex = { path = "../pipelex", editable = true }

[build-system]
requires = ["hatchling"]
Expand Down
Loading
Loading