diff --git a/libs/executors/garf/executors/entrypoints/server.py b/libs/executors/garf/executors/entrypoints/server.py index 9119008..e869dc0 100644 --- a/libs/executors/garf/executors/entrypoints/server.py +++ b/libs/executors/garf/executors/entrypoints/server.py @@ -17,6 +17,7 @@ from typing import Any, Optional, Union import fastapi +import garf.core import garf.executors import pydantic import typer @@ -116,6 +117,20 @@ def model_post_init(self, __context__) -> None: self.title = str(self.query_path) +class ApiExecutorBatchRequest(pydantic.BaseModel): + """Request for executing multiple queries. + + Attributes: + source: Type of API to interact with. + batch: Mapping between query_title and its text. + context: Execution context. + """ + + source: str + batch: dict[str, str] + context: garf.executors.execution_context.ExecutionContext + + class ApiExecutorResponse(pydantic.BaseModel): """Response after executing a query. @@ -126,6 +141,23 @@ class ApiExecutorResponse(pydantic.BaseModel): results: list[Union[str, Any]] +@app.exception_handler(garf.core.exceptions.GarfError) +async def error_handlier( + request: fastapi.Request, exc: garf.core.exceptions.GarfError +): + error_mapping = {} + + status_code = error_mapping.get(type(exc), 400) + + return fastapi.responses.JSONResponse( + status_code=status_code, + content={ + 'error_type': type(exc).__name__, + 'detail': str(exc), + }, + ) + + @app.get('/api/version') async def version() -> str: return garf.executors.__version__ @@ -157,15 +189,13 @@ def execute( @app.post('/api/execute:batch') def execute_batch( - request: ApiExecutorRequest, + request: ApiExecutorBatchRequest, dependencies: Annotated[GarfDependencies, fastapi.Depends(GarfDependencies)], ) -> ApiExecutorResponse: query_executor = setup.setup_executor( request.source, request.context.fetcher_parameters ) - reader_client = reader.FileReader() - batch = {query: reader_client.read(query) for query in request.query_path} - results = query_executor.execute_batch(batch, request.context) + results = query_executor.execute_batch(request.batch, request.context) return ApiExecutorResponse(results=results) diff --git a/libs/executors/garf/executors/entrypoints/typer_cli.py b/libs/executors/garf/executors/entrypoints/typer_cli.py index da60f18..b3c9840 100644 --- a/libs/executors/garf/executors/entrypoints/typer_cli.py +++ b/libs/executors/garf/executors/entrypoints/typer_cli.py @@ -21,6 +21,7 @@ from typing import Optional import garf.executors +import requests import rich import typer from garf.executors import exceptions, setup @@ -35,6 +36,9 @@ from garf.io import reader, writer from opentelemetry import trace from opentelemetry.instrumentation.logging import LoggingInstrumentor +from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, +) from typing_extensions import Annotated LoggingInstrumentor().instrument(set_logging_format=False) @@ -137,6 +141,7 @@ def execute( enable_cache: EnableCache = False, cache_ttl_seconds: CacheTTL = 3600, simulate: Simulate = False, + server_url: str | None = None, ) -> None: """Runs queries.""" span = trace.get_current_span() @@ -195,28 +200,68 @@ def execute( raise exceptions.GarfExecutorError( f'No execution context found for source {source} in {config}' ) - query_executor = setup.setup_executor( - source=source, - fetcher_parameters=context.fetcher_parameters, - enable_cache=enable_cache, - cache_ttl_seconds=cache_ttl_seconds, - simulate=simulate, - writers=context.writer, - writer_parameters=context.writer_parameters, - ) + parallel_queries = parallel_threshold > 1 with tracer.start_as_current_span('read_queries'): batch = {query: reader_client.read(query) for query in found_queries} - parallel_queries = parallel_threshold > 1 - if parallel_queries and len(found_queries) > 1: - garf_logger.info('Running queries in parallel') - query_executor.execute_batch(batch, context, parallel_threshold) + if server_url: + rest_context = context.model_dump() + if rest_context.get('writer') == ['console']: + del rest_context['writer'] + headers = {} + TraceContextTextMapPropagator().inject(headers) + if parallel_queries and len(batch) > 1: + endpoint = f'{server_url}/api/execute:batch' + request = { + 'batch': batch, + 'source': source, + 'context': rest_context, + } + try: + response = requests.post(url=endpoint, json=request, headers=headers) + response.raise_for_status() + except requests.exceptions.HTTPError as e: + raise exceptions.GarfExecutorError( + f'Server error: {e.response.status_code} - {e.response.json()}' + ) from e + + with tracer.start_as_current_span('parse_results batch'): + typer.secho(response.json().get('results')) + else: + endpoint = f'{server_url}/api/execute' + for title, text in batch.items(): + request = { + 'source': source, + 'title': title, + 'query': text, + 'context': rest_context, + } + try: + response = requests.post(url=endpoint, json=request, headers=headers) + response.raise_for_status() + except requests.exceptions.HTTPError as e: + raise exceptions.GarfExecutorError( + f'Server error: {e.response.status_code} - {e.response.json()}' + ) from e + with tracer.start_as_current_span(f'parse_results {title}'): + typer.secho(response.json().get('results')) else: - if len(found_queries) > 1: - garf_logger.info('Running queries sequentially') - for query in found_queries: - query_executor.execute( - query=reader_client.read(query), title=query, context=context - ) + query_executor = setup.setup_executor( + source=source, + fetcher_parameters=context.fetcher_parameters, + enable_cache=enable_cache, + cache_ttl_seconds=cache_ttl_seconds, + simulate=simulate, + writers=context.writer, + writer_parameters=context.writer_parameters, + ) + if parallel_queries and len(batch) > 1: + garf_logger.info('Running queries in parallel') + query_executor.execute_batch(batch, context, parallel_threshold) + else: + if len(batch) > 1: + garf_logger.info('Running queries sequentially') + for title, text in batch.items(): + query_executor.execute(query=text, title=title, context=context) @workflow_app.command( diff --git a/libs/executors/tests/end-to-end/test_server.py b/libs/executors/tests/end-to-end/test_server.py index 2385727..81c2495 100644 --- a/libs/executors/tests/end-to-end/test_server.py +++ b/libs/executors/tests/end-to-end/test_server.py @@ -81,10 +81,10 @@ def test_batch_fake_source_from_query_path(self, tmp_path): fake_data = _SCRIPT_PATH / 'test.json' request = { 'source': 'fake', - 'query_path': [ - str(query_path1), - str(query_path2), - ], + 'batch': { + 'query1': self.query, + 'query2': self.query, + }, 'context': { 'fetcher_parameters': { 'data_location': str(fake_data),