diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fc32e3b..f42d56c 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -39,7 +39,8 @@ jobs: run: | pre-commit run --all-files - name: Test with pytest + # config.py won't import if schema.yaml doesn't exist yet run: | - cp examples/mouse_and_cheese/schema.yaml . - python generate.py copy_server_files examples/mouse_and_cheese - pytest tests/test_app.py + cp examples/lpoc/schema.yaml . + python generate.py copy_server_files "examples/mouse_and_cheese" + pytest tests/ diff --git a/README.md b/README.md index 65c3710..f07d50b 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,8 @@ parameters: discrete: True tuple: False api_key: "1234567asdfgh" + project_id: 1284 + model_id: 2817 ``` With this configuration of the policy server, the user will only have access to one predictive endpoint, namely @@ -117,6 +119,12 @@ With this configuration of the policy server, the user will only have access to have to have precisely the right cardinality and ordering. In other words, this endpoint strips all validation that the other endpoints have, but is quicker to set up due to not having to specify the structure of observations. +Note that if `discrete` is `True` is for models with all discrete actions, while setting this flag to `False` means +that the policy emits continuous actions. + +Providing `project_id` and `model_id` is optional. If you provide both, the `/docs` and `/redoc` endpoints will have +feature a link to go back to the respective Pathmind experiment this policy came from. + ### Starting the app Once you have both `saved_model.zip` and `schema.yaml` ready, you can start the policy server like this: diff --git a/app.py b/app.py index 685862f..67d11c4 100644 --- a/app.py +++ b/app.py @@ -10,22 +10,43 @@ from fastapi.responses import FileResponse from fastapi.security.api_key import APIKey from ray import serve -from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder -from ray.rllib.offline.json_writer import JsonWriter import config from api import Action, Observation, RawObservation +from docs import get_redoc_html, get_swagger_ui_html from generate import CLI -from offline import EpisodeCache from security import get_api_key -cache = EpisodeCache() -batch_builder = SampleBatchBuilder() # or MultiAgentSampleBatchBuilder -writer = JsonWriter(config.EXPERIENCE_LOCATION) - url_path = config.parameters.get("url_path") -app = FastAPI(root_path=f"/{url_path}") if url_path else FastAPI() +app = ( + FastAPI(root_path=f"/{url_path}", docs_url=None, redoc_url=None) + if url_path + else FastAPI(docs_url=None, redoc_url=None) +) + + +@app.get("/docs", include_in_schema=False) +def overridden_swagger(): + return get_swagger_ui_html( + openapi_url="/openapi.json", + title="Pathmind Policy Server", + swagger_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com", + project_id=config.project_id, + model_id=config.model_id, + ) + + +@app.get("/redoc", include_in_schema=False) +def overridden_redoc(): + return get_redoc_html( + openapi_url="/openapi.json", + title="Pathmind Policy Server", + redoc_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com", + project_id=config.project_id, + model_id=config.model_id, + ) + tags_metadata = [ { diff --git a/config.py b/config.py index fd2bdd8..e51f4e4 100644 --- a/config.py +++ b/config.py @@ -43,6 +43,9 @@ def base_path(local_file): parameters = schema.get("parameters") action_type = int if parameters.get("discrete") else float +model_id = parameters.get("model_id", None) +project_id = parameters.get("project_id", None) + payload_data = {} # If the schema includes `max_items` set the constraints for the array if observations: diff --git a/docs.py b/docs.py new file mode 100644 index 0000000..7597baa --- /dev/null +++ b/docs.py @@ -0,0 +1,208 @@ +import json +from typing import Optional + +from fastapi.encoders import jsonable_encoder +from starlette.responses import HTMLResponse + + +def get_swagger_ui_html( + *, + openapi_url: str, + title: str, + swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js", + swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css", + swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png", + oauth2_redirect_url: Optional[str] = None, + init_oauth: Optional[dict] = None, + project_id: Optional[int] = None, + model_id: Optional[int] = None, +) -> HTMLResponse: + + html = f""" + + + + + + {title} + + + """ + + if project_id and model_id: + html += f""" +
+ ← Back to Pathmind experiment +
+ """ + + html += f""" +
+
+ + + + + + """ + return HTMLResponse(html) + + +def get_redoc_html( + *, + openapi_url: str, + title: str, + redoc_js_url: str = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js", + redoc_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png", + with_google_fonts: bool = True, + project_id: Optional[int] = None, + model_id: Optional[int] = None, +) -> HTMLResponse: + html = f""" + + + + {title} + + + + """ + if with_google_fonts: + html += """ + + """ + html += f""" + + + + + + """ + + if project_id and model_id: + html += f""" +
+ ← Back to Pathmind experiment +
+ """ + + html += f""" + + + + + """ + return HTMLResponse(html) + + +def get_swagger_ui_oauth2_redirect_html() -> HTMLResponse: + html = """ + + + + + + + """ + return HTMLResponse(content=html) diff --git a/tests/test_api_speed.py b/tests/test_api_speed.py index 5a39ffd..53c908d 100644 --- a/tests/test_api_speed.py +++ b/tests/test_api_speed.py @@ -1,9 +1,16 @@ -"""This assumes the server is fully configured for the LPoC example""" import timeit -import requests +import ray +from fastapi.testclient import TestClient -data = { +from app import app +from generate import CLI + +CLI.copy_server_files("examples/lpoc") +client = TestClient(app) + + +payload = { "coordinates": [1, 1], "has_core": True, "has_down_neighbour": True, @@ -21,15 +28,17 @@ def predict(): - return requests.post( - "https://localhost:8080/api/predict", - verify=False, - auth=("foo", "bar"), - json=data, + return client.post( + "http://localhost:8000/predict/", + json=payload, + headers={"access-token": "1234567asdfgh"}, ) -predict() -number = 1000 -res = timeit.timeit(predict, number=1000) -print(f"A total of {number} requests took {res} milliseconds to process on average.") +def test_predict(): + number = 1000 + res = timeit.timeit(predict, number=number) + print( + f"A total of {number} requests took {res} milliseconds to process on average." + ) + ray.shutdown() diff --git a/tests/test_app.py b/tests/test_app.py index ce50ab6..d0a0457 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -9,6 +9,10 @@ client = TestClient(app) +def setup_function(): + ray.shutdown() + + def test_health_check(): response = client.get("/") assert response.status_code == 200 diff --git a/tests/test_predict.py b/tests/test_predict.py index 6c2f8f2..463b9eb 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,4 +1,12 @@ -import requests +import ray +from fastapi.testclient import TestClient + +from app import app +from generate import CLI + +# Set up server +CLI.copy_server_files("examples/mouse_and_cheese") +client = TestClient(app) payload = { "mouse_row": 1, @@ -9,15 +17,23 @@ def predict(): - res = requests.post("http://localhost:8000/predict", verify=False, json=payload) - print(res.json()) - return res + return client.post( + "http://localhost:8000/predict/", + json=payload, + headers={"access-token": "1234567asdfgh"}, + ) + +def test_predict_simple(): + res = predict() + assert res is not None + ray.shutdown() -predict() +def test_write_openapi_json(): + res = client.get("http://localhost:8000/openapi.json") -res = requests.get("http://localhost:8000/openapi.json") + with open("../openapi.json", "w") as f: + f.write(res.content.decode("utf-8")) -with open("../openapi.json", "w") as f: - f.write(res.content.decode("utf-8")) + ray.shutdown()