diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed1626c..151ee0f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: - name: install poetry run: | curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python; - echo "::add-path::$HOME/.poetry/bin/" + echo "$HOME/.poetry/bin/" >> $GITHUB_PATH - name: install deps run: poetry install -v - name: flake8 diff --git a/dploy_kickstart/annotations.py b/dploy_kickstart/annotations.py index de94e84..8da1e41 100644 --- a/dploy_kickstart/annotations.py +++ b/dploy_kickstart/annotations.py @@ -20,9 +20,7 @@ def __init__(self, callble: typing.Callable) -> None: self.request_method = "post" self.accepts_json = True self.returns_json = True - self.response_mime_type = ( - "application/json" # functionality deprecated, to be removed ` - ) + self.response_mime_type = None self.request_content_type = ( "application/json" # functionality deprecated, to be removed ` ) @@ -69,7 +67,6 @@ def evaluate_comment_args(self) -> None: self.endpoint = True self.endpoint_path = p - # functionality deprecated, to be removed if c[0] == "response_mime_type": self.response_mime_type = c[1].lower() diff --git a/dploy_kickstart/errors.py b/dploy_kickstart/errors.py index 2795473..91391aa 100644 --- a/dploy_kickstart/errors.py +++ b/dploy_kickstart/errors.py @@ -16,7 +16,6 @@ class ScriptImportError(ServerException): status_code = 500 def __init__(self, message: str): - super().__init__(self) self.message = message def to_dict(self) -> dict: @@ -27,7 +26,6 @@ class UnsupportedEntrypoint(ServerException): status_code = 500 def __init__(self, entrypoint: str): - super().__init__(self) self.message = f"entrypoint '{entrypoint}' not supported" def to_dict(self) -> dict: diff --git a/dploy_kickstart/server.py b/dploy_kickstart/server.py index ca5c495..5298f6e 100644 --- a/dploy_kickstart/server.py +++ b/dploy_kickstart/server.py @@ -40,10 +40,7 @@ def append_entrypoint(app: Flask, entrypoint: str, location: str) -> Flask: po.path_spec(openapi_spec, f) app.add_url_rule( - "/openapi.yaml", - "/openapi.yaml", - openapi_spec.to_yaml, - methods=["GET"], + "/openapi.yaml", "/openapi.yaml", openapi_spec.to_yaml, methods=["GET"], ) return app diff --git a/dploy_kickstart/transformers.py b/dploy_kickstart/transformers.py index a041de1..36aaa18 100644 --- a/dploy_kickstart/transformers.py +++ b/dploy_kickstart/transformers.py @@ -1,25 +1,73 @@ """Utilities to transform requests and responses.""" - +import io import typing from flask import jsonify, Response, Request import dploy_kickstart.annotations as da -def bytes_resp(func_result: typing.Any) -> Response: - return Response(func_result, mimetype="application/octet-stream") +def bytes_resp(func_result: typing.Any, mimetype=None) -> Response: + if mimetype is None: + return Response(func_result, mimetype="application/octet-stream") + else: + return Response(func_result, mimetype=mimetype) -def bytes_io_resp(func_result: typing.Any) -> Response: - return Response(func_result.getvalue(), mimetype="application/octet-stream") +def bytes_io_resp(func_result: typing.Any, mimetype=None) -> Response: + if mimetype is None: + return Response(func_result.getvalue(), mimetype="application/octet-stream") + else: + return Response(func_result.getvalue(), mimetype=mimetype) -def default_req(f: da.AnnotatedCallable, req: Request) -> typing.Any: - return f(req.data) +def pil_image_resp(func_result: typing.Any, mimetype=None) -> Response: + # create file-object in memory + file_object = io.BytesIO() + img_format = func_result.format + + # write image to a file-object + # Don't change quality and subsampling since save decrease the quality by default + func_result.save(file_object, img_format, quality=100, subsampling=0) + auto_mimetype = f"image/{img_format.lower()}" + + # move to beginning of file so `send_file()` it will read from start + file_object.seek(0) + + if mimetype is None: + return Response(file_object, mimetype=auto_mimetype) + else: + return Response(file_object, mimetype=mimetype) + + +def np_tolist_resp(func_result: typing.Any, mimetype=None) -> Response: + response = jsonify(func_result.tolist()) + if mimetype is None: + return response + else: + response.mimetype = mimetype + return response + + +def np_item_resp(func_result: typing.Any, mimetype=None) -> Response: + response = jsonify(func_result.item()) + if mimetype is None: + return response + else: + response.mimetype = mimetype + return response -def json_resp(func_result: typing.Any) -> Response: +def json_resp(func_result: typing.Any, mimetype=None) -> Response: """Transform json response.""" - return jsonify(func_result) + response = jsonify(func_result) + if mimetype is None: + return response + else: + response.mimetype = mimetype + return response + + +def default_req(f: da.AnnotatedCallable, req: Request) -> typing.Any: + return f(req.data) def json_req(f: da.AnnotatedCallable, req: Request) -> typing.Any: @@ -39,5 +87,22 @@ def json_req(f: da.AnnotatedCallable, req: Request) -> typing.Any: "list": json_resp, "dict": json_resp, "bytes": bytes_resp, + # io.BytesIO return type "BytesIO": bytes_io_resp, + # Pillow Image Return dtype + "Image": pil_image_resp, + # Numpy Return dtypes + "ndarray": np_tolist_resp, + "matrix": np_tolist_resp, + "int8": np_item_resp, + "uint8": np_item_resp, + "int16": np_item_resp, + "uint16": np_item_resp, + "int32": np_item_resp, + "uint32": np_item_resp, + "int64": np_item_resp, + "uint64": np_item_resp, + "float16": np_item_resp, + "float32": np_item_resp, + "float64": np_item_resp, } diff --git a/dploy_kickstart/wrapper.py b/dploy_kickstart/wrapper.py index e836cf5..c393e8c 100644 --- a/dploy_kickstart/wrapper.py +++ b/dploy_kickstart/wrapper.py @@ -114,7 +114,9 @@ def exposed_func() -> typing.Callable: # determine whether or not to process response before sending it back to caller try: - return pt.MIME_TYPE_RES_MAPPER[res.__class__.__name__](res) + return pt.MIME_TYPE_RES_MAPPER[res.__class__.__name__]( + res, f.response_mime_type + ) except Exception: raise pe.UserApplicationError( message=f"error in executing '{f.__name__()}' method, the return type " diff --git a/poetry.lock b/poetry.lock index 2664a42..89e2883 100644 --- a/poetry.lock +++ b/poetry.lock @@ -363,6 +363,14 @@ traitlets = ">=4.1" [package.extras] test = ["testpath", "pytest", "pytest-cov"] +[[package]] +category = "dev" +description = "NumPy is the fundamental package for array computing with Python." +name = "numpy" +optional = false +python-versions = ">=3.6" +version = "1.19.4" + [[package]] category = "dev" description = "Core utilities for Python packages" @@ -675,7 +683,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "53d7a38c129a42e2bbc3d1c6bdf5d3324357bdccbb52df804587345d962a6e3b" +content-hash = "2ef9447096b18b6ee1beeb58643e2a36f21e585cbe5bce924dcb53abd43183d5" lock-version = "1.0" python-versions = "^3.7" @@ -864,6 +872,42 @@ nbformat = [ {file = "nbformat-5.0.6-py3-none-any.whl", hash = "sha256:276343c78a9660ab2a63c28cc33da5f7c58c092b3f3a40b6017ae2ce6689320d"}, {file = "nbformat-5.0.6.tar.gz", hash = "sha256:049af048ed76b95c3c44043620c17e56bc001329e07f83fec4f177f0e3d7b757"}, ] +numpy = [ + {file = "numpy-1.19.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e9b30d4bd69498fc0c3fe9db5f62fffbb06b8eb9321f92cc970f2969be5e3949"}, + {file = "numpy-1.19.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:fedbd128668ead37f33917820b704784aff695e0019309ad446a6d0b065b57e4"}, + {file = "numpy-1.19.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:8ece138c3a16db8c1ad38f52eb32be6086cc72f403150a79336eb2045723a1ad"}, + {file = "numpy-1.19.4-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:64324f64f90a9e4ef732be0928be853eee378fd6a01be21a0a8469c4f2682c83"}, + {file = "numpy-1.19.4-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:ad6f2ff5b1989a4899bf89800a671d71b1612e5ff40866d1f4d8bcf48d4e5764"}, + {file = "numpy-1.19.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:d6c7bb82883680e168b55b49c70af29b84b84abb161cbac2800e8fcb6f2109b6"}, + {file = "numpy-1.19.4-cp36-cp36m-win32.whl", hash = "sha256:13d166f77d6dc02c0a73c1101dd87fdf01339febec1030bd810dcd53fff3b0f1"}, + {file = "numpy-1.19.4-cp36-cp36m-win_amd64.whl", hash = "sha256:448ebb1b3bf64c0267d6b09a7cba26b5ae61b6d2dbabff7c91b660c7eccf2bdb"}, + {file = "numpy-1.19.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:27d3f3b9e3406579a8af3a9f262f5339005dd25e0ecf3cf1559ff8a49ed5cbf2"}, + {file = "numpy-1.19.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:16c1b388cc31a9baa06d91a19366fb99ddbe1c7b205293ed072211ee5bac1ed2"}, + {file = "numpy-1.19.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e5b6ed0f0b42317050c88022349d994fe72bfe35f5908617512cd8c8ef9da2a9"}, + {file = "numpy-1.19.4-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:18bed2bcb39e3f758296584337966e68d2d5ba6aab7e038688ad53c8f889f757"}, + {file = "numpy-1.19.4-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:fe45becb4c2f72a0907c1d0246ea6449fe7a9e2293bb0e11c4e9a32bb0930a15"}, + {file = "numpy-1.19.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:6d7593a705d662be5bfe24111af14763016765f43cb6923ed86223f965f52387"}, + {file = "numpy-1.19.4-cp37-cp37m-win32.whl", hash = "sha256:6ae6c680f3ebf1cf7ad1d7748868b39d9f900836df774c453c11c5440bc15b36"}, + {file = "numpy-1.19.4-cp37-cp37m-win_amd64.whl", hash = "sha256:9eeb7d1d04b117ac0d38719915ae169aa6b61fca227b0b7d198d43728f0c879c"}, + {file = "numpy-1.19.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cb1017eec5257e9ac6209ac172058c430e834d5d2bc21961dceeb79d111e5909"}, + {file = "numpy-1.19.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:edb01671b3caae1ca00881686003d16c2209e07b7ef8b7639f1867852b948f7c"}, + {file = "numpy-1.19.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f29454410db6ef8126c83bd3c968d143304633d45dc57b51252afbd79d700893"}, + {file = "numpy-1.19.4-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:ec149b90019852266fec2341ce1db513b843e496d5a8e8cdb5ced1923a92faab"}, + {file = "numpy-1.19.4-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:1aeef46a13e51931c0b1cf8ae1168b4a55ecd282e6688fdb0a948cc5a1d5afb9"}, + {file = "numpy-1.19.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:08308c38e44cc926bdfce99498b21eec1f848d24c302519e64203a8da99a97db"}, + {file = "numpy-1.19.4-cp38-cp38-win32.whl", hash = "sha256:5734bdc0342aba9dfc6f04920988140fb41234db42381cf7ccba64169f9fe7ac"}, + {file = "numpy-1.19.4-cp38-cp38-win_amd64.whl", hash = "sha256:09c12096d843b90eafd01ea1b3307e78ddd47a55855ad402b157b6c4862197ce"}, + {file = "numpy-1.19.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e452dc66e08a4ce642a961f134814258a082832c78c90351b75c41ad16f79f63"}, + {file = "numpy-1.19.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:a5d897c14513590a85774180be713f692df6fa8ecf6483e561a6d47309566f37"}, + {file = "numpy-1.19.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a09f98011236a419ee3f49cedc9ef27d7a1651df07810ae430a6b06576e0b414"}, + {file = "numpy-1.19.4-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:50e86c076611212ca62e5a59f518edafe0c0730f7d9195fec718da1a5c2bb1fc"}, + {file = "numpy-1.19.4-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:f0d3929fe88ee1c155129ecd82f981b8856c5d97bcb0d5f23e9b4242e79d1de3"}, + {file = "numpy-1.19.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:c42c4b73121caf0ed6cd795512c9c09c52a7287b04d105d112068c1736d7c753"}, + {file = "numpy-1.19.4-cp39-cp39-win32.whl", hash = "sha256:8cac8790a6b1ddf88640a9267ee67b1aee7a57dfa2d2dd33999d080bc8ee3a0f"}, + {file = "numpy-1.19.4-cp39-cp39-win_amd64.whl", hash = "sha256:4377e10b874e653fe96985c05feed2225c912e328c8a26541f7fc600fb9c637b"}, + {file = "numpy-1.19.4-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:2a2740aa9733d2e5b2dfb33639d98a64c3b0f24765fed86b0fd2aec07f6a0a08"}, + {file = "numpy-1.19.4.zip", hash = "sha256:141ec3a3300ab89c7f2b0775289954d193cc8edb621ea05f99db9cb181530512"}, +] packaging = [ {file = "packaging-20.3-py2.py3-none-any.whl", hash = "sha256:82f77b9bee21c1bafbf35a84905d604d5d1223801d639cf3ed140bd651c08752"}, {file = "packaging-20.3.tar.gz", hash = "sha256:3c292b474fda1671ec57d46d739d072bfd495a4f51ad01a055121d81e952b7a3"}, diff --git a/pyproject.toml b/pyproject.toml index 1d5d60a..721f39a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ pytest-cov = "^2.8.1" requests = "^2.23.0" flynt = "^0.46.1" pillow = "^8.0.1" +numpy = "^1.19.4" [tool.black] line-length = 88 diff --git a/tests/assets/cropped_golf.png b/tests/assets/cropped_golf.png new file mode 100644 index 0000000..c4b4598 Binary files /dev/null and b/tests/assets/cropped_golf.png differ diff --git a/tests/assets/deps_tests/fake_requirements.txt b/tests/assets/deps_tests/fake_requirements.txt new file mode 100644 index 0000000..a0a5db1 --- /dev/null +++ b/tests/assets/deps_tests/fake_requirements.txt @@ -0,0 +1 @@ +some_random_package_kickstart==4.0.0 \ No newline at end of file diff --git a/tests/assets/deps_tests/my_pkg/setup.py b/tests/assets/deps_tests/my_pkg/setup.py index e0ab863..40ac4b2 100644 --- a/tests/assets/deps_tests/my_pkg/setup.py +++ b/tests/assets/deps_tests/my_pkg/setup.py @@ -3,10 +3,5 @@ from distutils.core import setup setup( - name="dummy", - version="1.0", - packages=[], - install_requires=[ - "dummy_test==0.1.3", - ], + name="dummy", version="1.0", packages=[], install_requires=["dummy_test==0.1.3",], ) diff --git a/tests/assets/deps_tests/my_pkg/stp.py b/tests/assets/deps_tests/my_pkg/stp.py new file mode 100644 index 0000000..40ac4b2 --- /dev/null +++ b/tests/assets/deps_tests/my_pkg/stp.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python + +from distutils.core import setup + +setup( + name="dummy", version="1.0", packages=[], install_requires=["dummy_test==0.1.3",], +) diff --git a/tests/assets/golf.png b/tests/assets/golf.png new file mode 100644 index 0000000..39cf309 Binary files /dev/null and b/tests/assets/golf.png differ diff --git a/tests/assets/server_default.py b/tests/assets/server_default.py index d913118..5150ad8 100644 --- a/tests/assets/server_default.py +++ b/tests/assets/server_default.py @@ -1,3 +1,4 @@ +import numpy as np from PIL import Image import io @@ -97,3 +98,34 @@ def f7(some_string): def f8(raw_image): image = Image.open(io.BytesIO(raw_image)) return image + + +# @dploy endpoint f9 +def f9(raw_image): + image = Image.open(io.BytesIO(raw_image)) + + # Size of the image in pixels (size of original image) + # (This is not mandatory) + width, height = image.size + + # Setting the points for cropped image + left = 5 + top = height / 4 + right = 164 + bottom = 3 * height / 4 + + # Cropped image of above dimension + # (It will not change orginal image) + im1 = image.crop((left, top, right, bottom)) + im1.format = image.format # original image extension + return im1 + + +# @dploy endpoint f10 +def f10(input): + return np.uint(input) + + +# @dploy endpoint f11 +def f11(input): + return np.array(input) diff --git a/tests/test_callable_annotation.py b/tests/test_callable_annotation.py index fd2c3ac..eb690d3 100644 --- a/tests/test_callable_annotation.py +++ b/tests/test_callable_annotation.py @@ -73,6 +73,12 @@ def t11(): return img1 +# @dploy endpoint get_model_name +# @dploy response_method get +def t12(): + return "kickstart" + + @pytest.mark.parametrize( "callable,endpoint,endpoint_path,has_args,output, error", [ @@ -87,6 +93,7 @@ def t11(): (t9, True, "/get_image2/", True, img1, True), (t10, True, "/get_image3/", True, img1, True), (t11, True, "/get_image4/", True, img1, True), + (t12, True, "/get_model_name/", True, "kickstart", False), ], ) def test_callable_annotation( @@ -111,17 +118,10 @@ def test_callable_annotation( @pytest.mark.parametrize( "py_file,expected", [ - ( - "c1.py", - [["endpoint", "predict"], ["endpoint", "train2"]], - ), + ("c1.py", [["endpoint", "predict"], ["endpoint", "train2"]],), ( "nb_with_comments.ipynb", - [ - ["endpoint", "predict"], - ["endpoint", "train"], - ["trigger", "train"], - ], + [["endpoint", "predict"], ["endpoint", "train"], ["trigger", "train"],], ), ], ) @@ -180,9 +180,7 @@ def test_annotated_scripts(py_file, expected): # # irrelevant other stuff """, - [ - ["arg", "foo bar", "arg2", "bar the foos"], - ], + [["arg", "foo bar", "arg2", "bar the foos"],], ], [ """ diff --git a/tests/test_dep_install.py b/tests/test_dep_install.py index 1f5e542..8be6ff1 100644 --- a/tests/test_dep_install.py +++ b/tests/test_dep_install.py @@ -10,6 +10,7 @@ @pytest.mark.parametrize( "req_file, error_expected", [ + ("fake_requirements.txt", True), ("req1.txt", True), ("non_existing.txt", True), ("requirements.txt", False), @@ -28,6 +29,7 @@ def test_req_install(req_file, error_expected): "setup_py, error_expected", [ ("my_pkg/setup.py", False), + ("my_pkg/stp.py", True), ("non_existing_pkg/setup.py", True), ], ) diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..5f77c76 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,43 @@ +import os +import dploy_kickstart.errors as pe +import dploy_kickstart.wrapper as pw +import dploy_kickstart.server as ps +import pytest + + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +@pytest.mark.parametrize( + "package_name, expected_err_message, error_expected", + [ + ( + "import some_random_package", + "{'message': 'Cannot import some_random_package'}", + True, + ), + ("import os", None, False), + ], +) +def test_script_import_error(package_name, expected_err_message, error_expected): + try: + exec(package_name) + except Exception as e: + error_message = pe.ScriptImportError(f"Cannot {package_name}",).to_dict() + assert str(error_message) == str(expected_err_message) + + +@pytest.mark.parametrize( + "entrypoint, expected_err_message, error_expected", + [ + ("golf.png", "{'message': \"entrypoint 'golf.png' not supported\"}", True), + ("c1.py", None, False), + ], +) +def test_unsupported_entrypoint_error(entrypoint, expected_err_message, error_expected): + try: + p = os.path.join(THIS_DIR, "assets") + _ = pw.import_entrypoint(entrypoint, p) + except Exception as e: + assert isinstance(e, pe.UnsupportedEntrypoint) + assert str(e.to_dict()) == str(expected_err_message) diff --git a/tests/test_openapi_gen.py b/tests/test_openapi_gen.py index 11274e4..822fc35 100644 --- a/tests/test_openapi_gen.py +++ b/tests/test_openapi_gen.py @@ -20,11 +20,7 @@ def t2(): @pytest.mark.parametrize( - "callable", - [ - t1, - t2, - ], + "callable", [t1, t2,], ) def test_openapi_generation(callable): ca = pa.AnnotatedCallable(callable) diff --git a/tests/test_server.py b/tests/test_server.py index 7b717a1..6b6660c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,7 +1,7 @@ import os import logging import re - +import numpy as np import pytest import dploy_kickstart.server as ps @@ -13,6 +13,10 @@ logging.basicConfig(level=os.environ.get("LOGLEVEL", os.getenv("LOGLEVEL", "INFO"))) PNG_IMG = open(os.path.join(THIS_DIR, "assets", "test.png"), "rb").read() JPG_IMG = open(os.path.join(THIS_DIR, "assets", "test.jpg"), "rb").read() +GOLF_IMG = open(os.path.join(THIS_DIR, "assets", "golf.png"), "rb").read() +CROPPED_GOLF_IMG = open( + os.path.join(THIS_DIR, "assets", "cropped_golf.png"), "rb" +).read() def test_client(): @@ -146,6 +150,42 @@ def test_client(): True, 500, ), + # PIL IMAGE RETURN + ( + "server_default.py", + "post", + "/f9/", + GOLF_IMG, + CROPPED_GOLF_IMG, + "image/png", + "image/png", + False, + 200, + ), + # Numpy Data Types + ( + "server_default.py", + "post", + "/f10/", + 61, + 61, + "application/json", + "application/json", + False, + 200, + ), + # Numpy Arrays + ( + "server_default.py", + "post", + "/f11/", + [61, 61], + [61, 61], + "application/json", + "application/json", + False, + 200, + ), ( "server_t1.py", "post", @@ -346,30 +386,10 @@ def test_server_logs( @pytest.mark.parametrize( "path, status_code, method, error", [ - ( - "/healthz/", - 200, - "get", - False, - ), - ( - "/healthz", - 200, - "get", - False, - ), - ( - "/healthz/", - 200, - "post", - True, - ), - ( - "/health/", - 200, - "get", - True, - ), + ("/healthz/", 200, "get", False,), + ("/healthz", 200, "get", False,), + ("/healthz/", 200, "post", True,), + ("/health/", 200, "get", True,), ], ) @pytest.mark.usefixtures("restore_wd") diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 2068568..ba47f46 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -1,8 +1,15 @@ import pytest -from flask import Response - +import os import dploy_kickstart.transformers as dt import dploy_kickstart.server as ds +import numpy as np +from io import BytesIO +from flask import Response +from PIL import Image + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +BIN_IMG = open(os.path.join(THIS_DIR, "assets", "test.png"), "rb").read() +PIL_IMG = Image.open(BytesIO(BIN_IMG)) def test_client(): @@ -12,11 +19,185 @@ def test_client(): @pytest.mark.parametrize( - "i, o", - [("bla", '"bla"\n'), ({"foo": "bar"}, '{"foo":"bar"}\n')], + "i, o, mimetype, expected_mimetype, error_expected", + [ + ("bla", '"bla"\n', None, "application/json", False), + ({"foo": "bar"}, '{"foo":"bar"}\n', None, "application/json", False), + ( + {"foo1": "bar1"}, + '{"foo1":"bar1"}\n', + "application/json", + "application/json", + False, + ), + ({"foo2": "bar2"}, '{"foo2":"bar2"}\n', "text/data", "text/data", False), + ({"foo2": "bar2"}, '{"foo2":"bar2"}\n', "application/json", "text/data", True), + ({"foo2": "bar2"}, '{"foo2":"bar2"}\n', None, "text/data", True), + ({"foo2": "bar2"}, '{"foo2":"bar2"}\n', None, None, True), + ], +) +def test_json_resp(i, o, mimetype, expected_mimetype, error_expected): + with test_client().application.test_request_context(): + r = dt.json_resp(i, mimetype=mimetype) + try: + assert r.get_data().decode() == o + assert r.mimetype == expected_mimetype + assert isinstance(r, Response) + except Exception as e: + assert error_expected + + +@pytest.mark.parametrize( + "i, o, mimetype, expected_mimetype, error_expected", + [ + (BytesIO(BIN_IMG), BIN_IMG, None, "application/octet-stream", False), + (BytesIO(BIN_IMG), b"", None, "application/octet-stream", True), + ( + BytesIO(BIN_IMG), + BIN_IMG, + "application/octet-stream", + "application/octet-stream", + False, + ), + (BytesIO(BIN_IMG), BIN_IMG, "image/png", "application/octet-stream", True), + (BytesIO(BIN_IMG), BIN_IMG, "image/png", "image/png", False), + (BytesIO(BIN_IMG), BIN_IMG, "image/png", "image/jpg", True), + ], +) +def test_bytes_io_resp(i, o, mimetype, expected_mimetype, error_expected): + with test_client().application.test_request_context(): + r = dt.bytes_io_resp(i, mimetype=mimetype) + try: + assert r.get_data() == o + assert r.mimetype == expected_mimetype + assert isinstance(r, Response) + except Exception as e: + assert error_expected + + +@pytest.mark.parametrize( + "i, o, mimetype, expected_mimetype, error_expected", + [ + (BIN_IMG, BIN_IMG, None, "application/octet-stream", False), + (BIN_IMG, b"", None, "application/octet-stream", True), + ( + BIN_IMG, + BIN_IMG, + "application/octet-stream", + "application/octet-stream", + False, + ), + (BIN_IMG, BIN_IMG, "image/png", "application/octet-stream", True), + (BIN_IMG, BIN_IMG, "image/png", "image/png", False), + (BIN_IMG, BIN_IMG, "image/png", "image/jpg", True), + ], +) +def test_bytes_resp(i, o, mimetype, expected_mimetype, error_expected): + with test_client().application.test_request_context(): + r = dt.bytes_resp(i, mimetype=mimetype) + try: + assert r.get_data() == o + assert r.mimetype == expected_mimetype + assert isinstance(r, Response) + except Exception as e: + assert error_expected + + +@pytest.mark.parametrize( + "i, mimetype, expected_mimetype, error_expected", + [ + (PIL_IMG, None, "image/png", False), + (PIL_IMG, "application/octet-stream", "application/octet-stream", False), + (PIL_IMG, "image/png", "application/octet-stream", True), + (PIL_IMG, "image/png", "image/png", False), + (PIL_IMG, "image/jpeg", "image/jpeg", False), + (PIL_IMG, "image/png", "image/jpg", True), + ], +) +def test_pil_img_resp(i, mimetype, expected_mimetype, error_expected): + with test_client().application.test_request_context(): + r = dt.pil_image_resp(i, mimetype=mimetype) + o = Image.open(BytesIO(r.get_data())) + try: + assert list(i.getdata()) == list(o.getdata()) + assert r.mimetype == expected_mimetype + assert isinstance(r, Response) + except Exception as e: + assert error_expected + + +@pytest.mark.parametrize( + "i, o, mimetype, expected_mimetype, error_expected", + [ + # Tests for ndarray + (np.array([1, 2, 3]), [1, 2, 3], None, "application/json", False), + (np.array([1, 2, 3]), [222, 2222, 222], None, "application/json", True), + (np.array([1, 2, 3]), [1, 2, 3], "application/json", "application/json", False), + (np.array([1, 2, 3]), [1, 2, 3], "application/json", "data/text", True), + (np.array([1, 2, 3]), b"[1,2,3]\n", "data/text", "data/text", False), + (np.array([1, 2, 3]), b"[1,2,3]\n", None, "data/text", True), + # Tests for matrix + (np.matrix([1, 2, 3]), [[1, 2, 3]], None, "application/json", False), + (np.matrix([1, 2, 3]), [[222, 2222, 222]], None, "application/json", True), + ( + np.matrix([1, 2, 3]), + [[1, 2, 3]], + "application/json", + "application/json", + False, + ), + (np.matrix([1, 2, 3]), [[1, 2, 3]], "application/json", "data/text", True), + (np.matrix([1, 2, 3]), b"[[1,2,3]]\n", "data/text", "data/text", False), + (np.matrix([1, 2, 3]), b"[[1,2,3]]\n", None, "data/text", True), + ], +) +def test_numpy_list_resp(i, o, mimetype, expected_mimetype, error_expected): + with test_client().application.test_request_context(): + r = dt.np_tolist_resp(i, mimetype=mimetype) + if mimetype is None or mimetype == "application/json": + data = r.json + else: + data = r.get_data() + try: + assert data == o + assert r.mimetype == expected_mimetype + assert isinstance(r, Response) + except Exception as e: + assert error_expected + + +@pytest.mark.parametrize( + "i, o, mimetype, expected_mimetype, error_expected", + [ + # Check all numpy dtypes + (np.int8(61), 61, None, "application/json", False), + (np.uint8(61), 61, None, "application/json", False), + (np.int16(61), 61, None, "application/json", False), + (np.uint16(61), 61, None, "application/json", False), + (np.int32(61), 61, None, "application/json", False), + (np.uint32(61), 61, None, "application/json", False), + (np.int64(61), 61, None, "application/json", False), + (np.uint64(61), 61, None, "application/json", False), + (np.float16(61), 61, None, "application/json", False), + (np.float32(61), 61, None, "application/json", False), + (np.float64(61), 61, None, "application/json", False), + # Some problematic examples + (np.float64(61), np.float64(61), None, "application/json", True), + (np.float64(61), 61, "text/data", "application/json", True), + # Different Mimetypes + (np.float64(61), b"61.0\n", "text/data", "text/data", False), + ], ) -def test_json_resp(i, o): +def test_numpy_item_resp(i, o, mimetype, expected_mimetype, error_expected): with test_client().application.test_request_context(): - r = dt.json_resp(i) - assert r.get_data().decode() == o - assert isinstance(r, Response) + r = dt.np_item_resp(i, mimetype=mimetype) + if mimetype is None or mimetype == "application/json": + data = r.json + else: + data = r.get_data() + try: + assert data == o + assert r.mimetype == expected_mimetype + assert isinstance(r, Response) + except Exception as e: + assert error_expected