Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions dploy_kickstart/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import inspect
import typing
import re
import traceback
from flask import request

import dploy_kickstart.transformers as dt
import dploy_kickstart.errors as de

HEADER_MODEL_REPORT = "X-Dployai-Model-Report"


class AnnotatedCallable:
"""Wrap a callable and allow annotation (comments) extraction."""
Expand All @@ -26,6 +30,8 @@ def __init__(self, callble: typing.Callable) -> None:
self.response_mime_type = "application/json"
self.request_content_type = "application/json"
self.json_to_kwargs = False
self.report_return_value = False
self.custom_headers = dict()

if not callable(callble):
raise Exception("Trying to parse annotations on non-callable object")
Expand Down Expand Up @@ -92,6 +98,9 @@ def evaluate_comment_args(self) -> None:
if c[0] == "json_to_kwargs":
self.json_to_kwargs = True

if c[0] == "report_return_value":
self.report_return_value = True

def has_args(self) -> bool:
"""Return if callable has comment annotation arguments."""
return len(self.comment_args) > 0
Expand All @@ -100,6 +109,51 @@ def __call__(self, *args, **kwargs) -> typing.Any:
"""Allow calling of original callable."""
return self.callble(*args, **kwargs)

def http_call(self, request_transformers, response_transformers) -> typing.Callable:
"""Allow calling in context of a Flask request."""

def f() -> typing.Any:
# some sanity checking
if request.content_type.lower() != self.request_content_type:
raise de.UnsupportedMediaType(
"Function doesn't provide support"
" for 'Content-Type' {}, supported: {}".format(
request.content_type.lower(), self.request_content_type
)
)

# preprocess input for callable
try:
res = request_transformers[self.request_content_type](self, request)
except Exception:
raise de.UserApplicationError(
message=f"error in executing '{self.__name__}'",
traceback=traceback.format_exc(),
)

# determine whether or not to process
# response before sending it back to caller
# resp should be a flask Response object
resp = response_transformers[self.response_mime_type](res)

if self.report_return_value:
if "application/json" not in self.response_mime_type:
raise de.UserApplicationError(
message=f"requesting reporting of return value but"
" not using json response mimetype (requirements)"
)

self.custom_headers[HEADER_MODEL_REPORT] = (
resp.get_data(as_text=True).replace("\r", "").replace("\n", "")
)
resp.headers[HEADER_MODEL_REPORT] = self.custom_headers[
HEADER_MODEL_REPORT
]

return resp

return f

def __name__(self) -> str:
"""Return name of original callable."""
return self.callble.__name__
Expand Down
5 changes: 3 additions & 2 deletions dploy_kickstart/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dploy_kickstart.wrapper as pw
import dploy_kickstart.errors as pe
import dploy_kickstart.openapi as po
import dploy_kickstart.transformers as pt

log = logging.getLogger(__name__)

Expand All @@ -33,9 +34,9 @@ def append_entrypoint(
app.add_url_rule(
f.endpoint_path,
f.endpoint_path,
pw.func_wrapper(f),
f.http_call(pt.MIME_TYPE_REQ_MAPPER, pt.MIME_TYPE_RES_MAPPER),
methods=[f.request_method.upper()],
strict_slashes=False
strict_slashes=False,
)

# add info about endpoint to api spec
Expand Down
4 changes: 1 addition & 3 deletions dploy_kickstart/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import typing
from flask import jsonify, Response, Request

import dploy_kickstart.annotations as da


def json_resp(func_result: typing.Any) -> Response:
"""Transform json response."""
return jsonify(func_result)


def json_req(f: da.AnnotatedCallable, req: Request):
def json_req(f: typing.Callable, req: Request):
"""Preprocess application/json request."""
if f.json_to_kwargs:
return f(**req.json)
Expand Down
32 changes: 0 additions & 32 deletions dploy_kickstart/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
import atexit
import functools
import typing
import traceback

from flask import request
import dploy_kickstart.errors as pe
import dploy_kickstart.transformers as pt
import dploy_kickstart.annotations as pa

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,32 +93,3 @@ def import_entrypoint(entrypoint: str, location: str) -> typing.Generic:
raise pe.ScriptImportError(f"{msg}: {e}")

return mod


def func_wrapper(f: pa.AnnotatedCallable) -> typing.Callable:
"""Wrap functions with request logic."""

def exposed_func() -> typing.Callable:
# some sanity checking
if request.content_type.lower() != f.request_content_type:
raise pe.UnsupportedMediaType(
"Please provide a valid 'Content-Type' header, valid: {}".format(
f.request_content_type
)
)

# preprocess input for callable
try:
res = pt.MIME_TYPE_REQ_MAPPER[f.response_mime_type](f, request)
except Exception:
raise pe.UserApplicationError(
message=f"error in executing '{f.__name__}'",
traceback=traceback.format_exc(),
)

# determine whether or not to process response before sending it back to caller
wrapped_res = pt.MIME_TYPE_RES_MAPPER[request.content_type](res)

return wrapped_res

return exposed_func
6 changes: 6 additions & 0 deletions tests/assets/server_t1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@ def f1(x):
#' @dploy endpoint train
def f2(x):
return x


# @dploy endpoint performance
# @dploy report_return_value
def my_func4(x):
return {"_performance": x}
2 changes: 1 addition & 1 deletion tests/test_callable_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def t5():
def t6():
return t1()


# root path / endpoint
# @dploy endpoint
def t7():
return t1()



@pytest.mark.parametrize(
"callable,endpoint,endpoint_path,has_args,output, error",
[
Expand Down
49 changes: 49 additions & 0 deletions tests/test_modelreport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import logging
import pytest
import dploy_kickstart.server as ps
import dploy_kickstart.annotations as da
from .fixtures import restore_wd
import json

THIS_DIR = os.path.dirname(os.path.abspath(__file__))
logging.basicConfig(level=os.environ.get("LOGLEVEL", os.getenv("LOGLEVEL", "INFO")))


def test_client():
app = ps.generate_app()
app.config["TESTING"] = True
return app.test_client()


@pytest.mark.parametrize(
"entrypoint,method,path,payload,report_header_value",
[
(
"server_t1.py",
"post",
"/performance/",
{"val": 1},
{"_performance": {"val": 1}},
),
],
)
@pytest.mark.usefixtures("restore_wd")
def test_report_header(entrypoint, method, path, payload, report_header_value):
p = os.path.join(THIS_DIR, "assets")

try:
app = ps.generate_app()
app = ps.append_entrypoint(app, entrypoint, p)

except:
assert error
return

test_client = app.test_client()
r = getattr(test_client, method)(
path, json=payload, headers={"Content-Type": "application/json"}
)

assert r.status_code == 200
assert json.loads(r.headers.get(da.HEADER_MODEL_REPORT)) == report_header_value
2 changes: 1 addition & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_client():
(
"server_t1.py",
"post",
"/predict", # test without trailing slash
"/predict", # test without trailing slash
{"val": 1},
1,
"application/json",
Expand Down