diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fc32e3b..f42d56c 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -39,7 +39,8 @@ jobs: run: | pre-commit run --all-files - name: Test with pytest + # config.py won't import if schema.yaml doesn't exist yet run: | - cp examples/mouse_and_cheese/schema.yaml . - python generate.py copy_server_files examples/mouse_and_cheese - pytest tests/test_app.py + cp examples/lpoc/schema.yaml . + python generate.py copy_server_files "examples/mouse_and_cheese" + pytest tests/ diff --git a/README.md b/README.md index 65c3710..f07d50b 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,8 @@ parameters: discrete: True tuple: False api_key: "1234567asdfgh" + project_id: 1284 + model_id: 2817 ``` With this configuration of the policy server, the user will only have access to one predictive endpoint, namely @@ -117,6 +119,12 @@ With this configuration of the policy server, the user will only have access to have to have precisely the right cardinality and ordering. In other words, this endpoint strips all validation that the other endpoints have, but is quicker to set up due to not having to specify the structure of observations. +Note that if `discrete` is `True` is for models with all discrete actions, while setting this flag to `False` means +that the policy emits continuous actions. + +Providing `project_id` and `model_id` is optional. If you provide both, the `/docs` and `/redoc` endpoints will have +feature a link to go back to the respective Pathmind experiment this policy came from. + ### Starting the app Once you have both `saved_model.zip` and `schema.yaml` ready, you can start the policy server like this: diff --git a/app.py b/app.py index 685862f..67d11c4 100644 --- a/app.py +++ b/app.py @@ -10,22 +10,43 @@ from fastapi.responses import FileResponse from fastapi.security.api_key import APIKey from ray import serve -from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder -from ray.rllib.offline.json_writer import JsonWriter import config from api import Action, Observation, RawObservation +from docs import get_redoc_html, get_swagger_ui_html from generate import CLI -from offline import EpisodeCache from security import get_api_key -cache = EpisodeCache() -batch_builder = SampleBatchBuilder() # or MultiAgentSampleBatchBuilder -writer = JsonWriter(config.EXPERIENCE_LOCATION) - url_path = config.parameters.get("url_path") -app = FastAPI(root_path=f"/{url_path}") if url_path else FastAPI() +app = ( + FastAPI(root_path=f"/{url_path}", docs_url=None, redoc_url=None) + if url_path + else FastAPI(docs_url=None, redoc_url=None) +) + + +@app.get("/docs", include_in_schema=False) +def overridden_swagger(): + return get_swagger_ui_html( + openapi_url="/openapi.json", + title="Pathmind Policy Server", + swagger_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com", + project_id=config.project_id, + model_id=config.model_id, + ) + + +@app.get("/redoc", include_in_schema=False) +def overridden_redoc(): + return get_redoc_html( + openapi_url="/openapi.json", + title="Pathmind Policy Server", + redoc_favicon_url="https://www.google.com/s2/favicons?domain_url=pathmind.com", + project_id=config.project_id, + model_id=config.model_id, + ) + tags_metadata = [ { diff --git a/config.py b/config.py index fd2bdd8..e51f4e4 100644 --- a/config.py +++ b/config.py @@ -43,6 +43,9 @@ def base_path(local_file): parameters = schema.get("parameters") action_type = int if parameters.get("discrete") else float +model_id = parameters.get("model_id", None) +project_id = parameters.get("project_id", None) + payload_data = {} # If the schema includes `max_items` set the constraints for the array if observations: diff --git a/docs.py b/docs.py new file mode 100644 index 0000000..7597baa --- /dev/null +++ b/docs.py @@ -0,0 +1,208 @@ +import json +from typing import Optional + +from fastapi.encoders import jsonable_encoder +from starlette.responses import HTMLResponse + + +def get_swagger_ui_html( + *, + openapi_url: str, + title: str, + swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js", + swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css", + swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png", + oauth2_redirect_url: Optional[str] = None, + init_oauth: Optional[dict] = None, + project_id: Optional[int] = None, + model_id: Optional[int] = None, +) -> HTMLResponse: + + html = f""" + + +
+ + +