Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions libs/executors/garf/executors/entrypoints/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Optional, Union

import fastapi
import garf.core
import garf.executors
import pydantic
import typer
Expand Down Expand Up @@ -116,6 +117,20 @@ def model_post_init(self, __context__) -> None:
self.title = str(self.query_path)


class ApiExecutorBatchRequest(pydantic.BaseModel):
"""Request for executing multiple queries.

Attributes:
source: Type of API to interact with.
batch: Mapping between query_title and its text.
context: Execution context.
"""

source: str
batch: dict[str, str]
context: garf.executors.execution_context.ExecutionContext


class ApiExecutorResponse(pydantic.BaseModel):
"""Response after executing a query.

Expand All @@ -126,6 +141,23 @@ class ApiExecutorResponse(pydantic.BaseModel):
results: list[Union[str, Any]]


@app.exception_handler(garf.core.exceptions.GarfError)
async def error_handlier(
request: fastapi.Request, exc: garf.core.exceptions.GarfError
):
error_mapping = {}

status_code = error_mapping.get(type(exc), 400)

return fastapi.responses.JSONResponse(
status_code=status_code,
content={
'error_type': type(exc).__name__,
'detail': str(exc),
},
)


@app.get('/api/version')
async def version() -> str:
return garf.executors.__version__
Expand Down Expand Up @@ -157,15 +189,13 @@ def execute(

@app.post('/api/execute:batch')
def execute_batch(
request: ApiExecutorRequest,
request: ApiExecutorBatchRequest,
dependencies: Annotated[GarfDependencies, fastapi.Depends(GarfDependencies)],
) -> ApiExecutorResponse:
query_executor = setup.setup_executor(
request.source, request.context.fetcher_parameters
)
reader_client = reader.FileReader()
batch = {query: reader_client.read(query) for query in request.query_path}
results = query_executor.execute_batch(batch, request.context)
results = query_executor.execute_batch(request.batch, request.context)
return ApiExecutorResponse(results=results)


Expand Down
83 changes: 64 additions & 19 deletions libs/executors/garf/executors/entrypoints/typer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Optional

import garf.executors
import requests
import rich
import typer
from garf.executors import exceptions, setup
Expand All @@ -35,6 +36,9 @@
from garf.io import reader, writer
from opentelemetry import trace
from opentelemetry.instrumentation.logging import LoggingInstrumentor
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
from typing_extensions import Annotated

LoggingInstrumentor().instrument(set_logging_format=False)
Expand Down Expand Up @@ -137,6 +141,7 @@ def execute(
enable_cache: EnableCache = False,
cache_ttl_seconds: CacheTTL = 3600,
simulate: Simulate = False,
server_url: str | None = None,
) -> None:
"""Runs queries."""
span = trace.get_current_span()
Expand Down Expand Up @@ -195,28 +200,68 @@ def execute(
raise exceptions.GarfExecutorError(
f'No execution context found for source {source} in {config}'
)
query_executor = setup.setup_executor(
source=source,
fetcher_parameters=context.fetcher_parameters,
enable_cache=enable_cache,
cache_ttl_seconds=cache_ttl_seconds,
simulate=simulate,
writers=context.writer,
writer_parameters=context.writer_parameters,
)
parallel_queries = parallel_threshold > 1
with tracer.start_as_current_span('read_queries'):
batch = {query: reader_client.read(query) for query in found_queries}
parallel_queries = parallel_threshold > 1
if parallel_queries and len(found_queries) > 1:
garf_logger.info('Running queries in parallel')
query_executor.execute_batch(batch, context, parallel_threshold)
if server_url:
rest_context = context.model_dump()
if rest_context.get('writer') == ['console']:
del rest_context['writer']
headers = {}
TraceContextTextMapPropagator().inject(headers)
if parallel_queries and len(batch) > 1:
endpoint = f'{server_url}/api/execute:batch'
request = {
'batch': batch,
'source': source,
'context': rest_context,
}
try:
response = requests.post(url=endpoint, json=request, headers=headers)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise exceptions.GarfExecutorError(
f'Server error: {e.response.status_code} - {e.response.json()}'
) from e

with tracer.start_as_current_span('parse_results batch'):
typer.secho(response.json().get('results'))
else:
endpoint = f'{server_url}/api/execute'
for title, text in batch.items():
request = {
'source': source,
'title': title,
'query': text,
'context': rest_context,
}
try:
response = requests.post(url=endpoint, json=request, headers=headers)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise exceptions.GarfExecutorError(
f'Server error: {e.response.status_code} - {e.response.json()}'
) from e
with tracer.start_as_current_span(f'parse_results {title}'):
typer.secho(response.json().get('results'))
else:
if len(found_queries) > 1:
garf_logger.info('Running queries sequentially')
for query in found_queries:
query_executor.execute(
query=reader_client.read(query), title=query, context=context
)
query_executor = setup.setup_executor(
source=source,
fetcher_parameters=context.fetcher_parameters,
enable_cache=enable_cache,
cache_ttl_seconds=cache_ttl_seconds,
simulate=simulate,
writers=context.writer,
writer_parameters=context.writer_parameters,
)
if parallel_queries and len(batch) > 1:
garf_logger.info('Running queries in parallel')
query_executor.execute_batch(batch, context, parallel_threshold)
else:
if len(batch) > 1:
garf_logger.info('Running queries sequentially')
for title, text in batch.items():
query_executor.execute(query=text, title=title, context=context)


@workflow_app.command(
Expand Down
8 changes: 4 additions & 4 deletions libs/executors/tests/end-to-end/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def test_batch_fake_source_from_query_path(self, tmp_path):
fake_data = _SCRIPT_PATH / 'test.json'
request = {
'source': 'fake',
'query_path': [
str(query_path1),
str(query_path2),
],
'batch': {
'query1': self.query,
'query2': self.query,
},
'context': {
'fetcher_parameters': {
'data_location': str(fake_data),
Expand Down
Loading