diff --git a/app.py b/app.py index 45f5902..6f4bc3e 100644 --- a/app.py +++ b/app.py @@ -26,7 +26,7 @@ ) from generate import CLI from offline import EpisodeCache -from security import get_api_key +from security import get_api_key, verify_credentials cache = EpisodeCache() batch_builder = SampleBatchBuilder() # or MultiAgentSampleBatchBuilder @@ -95,7 +95,9 @@ async def startup_event(): if config.observations: # Note: for basic auth, use "logged_in: bool = Depends(verify_credentials)" as parameter @app.post("/predict/", response_model=Action, tags=["Predictions"]) - async def predict(payload: Observation, api_key: APIKey = Depends(get_api_key)): + async def predict( + payload: Observation, logged_in: bool = Depends(verify_credentials) + ): lists = [ [getattr(payload, obs)] if not isinstance(getattr(payload, obs), List) @@ -114,7 +116,7 @@ async def predict_deterministic( @app.post("/distribution/", tags=["Predictions"]) async def distribution( - payload: Observation, api_key: APIKey = Depends(get_api_key) + payload: Observation, logged_in: bool = Depends(verify_credentials) ): return _distribution(payload) @@ -187,7 +189,7 @@ async def clients(api_key: APIKey = Depends(get_api_key)): @app.get("/schema", tags=["Clients"]) -async def server_schema(api_key: APIKey = Depends(get_api_key)): +async def server_schema(logged_in: bool = Depends(verify_credentials)): with open(config.PATHMIND_SCHEMA, "r") as schema_file: schema_str = schema_file.read() schema = yaml.safe_load(schema_str) diff --git a/frontend.py b/frontend.py index cdbe052..b15a503 100644 --- a/frontend.py +++ b/frontend.py @@ -13,8 +13,9 @@ def run_the_app(): st.sidebar.markdown("# Server authentication") auth = None - user = st.sidebar.text_input("User name") - password = st.sidebar.text_input("Password") + user = st.sidebar.text_input("User name", "admin") + password = st.sidebar.text_input("Password", "admin") + token = st.sidebar.text_input("Token", "1234567asdfgh") if user and password: auth = (user, password) @@ -32,17 +33,20 @@ def run_the_app(): st.markdown(f"## Action: {response.get('actions')[0]}") st.markdown(f"## Probability: {int(100 *response.get('probability'))}%") - # compute_action_distro = st.checkbox(label="What's the variance of my actions?", value=False) - # if compute_action_distro and obs: - # distro_dict: dict = distro(obs, auth, url).json() - # - # import matplotlib.pyplot as plt - # import numpy as np - # arr = np.asarray(list(distro_dict.values())) - # x_range = np.arange(len(distro_dict)) - # plt.bar(x_range, arr) - # plt.xticks(x_range, list(distro_dict.keys())) - # st.pyplot(plt) + compute_action_distro = st.checkbox( + label="What's the variance of my actions?", value=False + ) + if compute_action_distro and obs: + distro_dict: dict = distro(obs, auth, url).json() + + import matplotlib.pyplot as plt + import numpy as np + + arr = np.asarray(list(distro_dict.values())) + x_range = np.arange(len(distro_dict)) + plt.bar(x_range, arr) + plt.xticks(x_range, list(distro_dict.keys())) + st.pyplot(plt) def server_schema(auth, url): @@ -65,6 +69,7 @@ def generate_frontend_from_observations(schema: dict): """ properties = schema.get("observations") result = {} + print(f"schema: {schema}") for key, values in properties.items(): prop_type = values.get("type") # example = values.get("example") diff --git a/requirements-dev.txt b/requirements-dev.txt index 2db9db1..3996987 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,4 @@ pre-commit streamlit~=0.72.0 requests~=2.25.0 pytest +matplotlib