diff --git a/.gitignore b/.gitignore
index b6e4761..9b5e917 100644
--- a/.gitignore
+++ b/.gitignore
@@ -127,3 +127,6 @@ dmypy.json
# Pyre type checker
.pyre/
+
+# Misc / user
+.history
diff --git a/Dockerfile b/Dockerfile
index 888a8da..950601a 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,3 +1,5 @@
+# This is a potassium-standard dockerfile, compatible with Banana
+
# Must use a Cuda version 11+
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
@@ -11,18 +13,14 @@ RUN pip3 install --upgrade pip
ADD requirements.txt requirements.txt
RUN pip3 install -r requirements.txt
-# We add the banana boilerplate here
-ADD server.py .
# Add your model weight files
# (in this case we have a python script)
ADD download.py .
RUN python3 download.py
-
-# Add your custom app code, init() and inference()
-ADD app.py .
+ADD . .
EXPOSE 8000
-CMD python3 -u server.py
+CMD python3 -u app.py
\ No newline at end of file
diff --git a/README.md b/README.md
index 89068bd..28815aa 100644
--- a/README.md
+++ b/README.md
@@ -1,20 +1,14 @@
-
-# 🍌 Banana Serverless
-
-This repo gives a framework to serve ML models in production using simple HTTP servers.
-
-# Quickstart
-**[Follow the quickstart guide in Banana's documentation to use this repo](https://docs.banana.dev/banana-docs/quickstart).**
-
-*(choose "GitHub Repository" deployment method)*
-
-
-
-# Helpful Links
-Understand the 🍌 [Serverless framework](https://docs.banana.dev/banana-docs/core-concepts/inference-server/serverless-framework) and functionality of each file within it.
-
-Generalize this framework to [deploy anything on Banana](https://docs.banana.dev/banana-docs/resources/how-to-serve-anything-on-banana).
-
-
-
-## Use Banana for scale.
+# My Potassium App
+This is a Potassium HTTP server, created with `banana init` CLI
+
+### Testing
+Start a local dev server with `banana dev`
+
+### Deployment
+1. Create empty repo on [Github](https://github.com)
+2. Push this repo to github
+```
+git remote add origin https://github.com/{username}/{repo-name}.git
+```
+3. [Log into Banana](https://app.banana.dev/onboard)
+4. Select this repo to build and deploy!
\ No newline at end of file
diff --git a/app.py b/app.py
index 7f6b061..e21af78 100644
--- a/app.py
+++ b/app.py
@@ -1,26 +1,37 @@
-from transformers import pipeline
-import torch
+from potassium import Potassium, Request, Response
-# Init is ran on server startup
-# Load your model to GPU as a global variable here using the variable name "model"
+from sentence_transformers import SentenceTransformer
+from sklearn.preprocessing import normalize
+
+app = Potassium("my_app")
+
+# @app.init runs at startup, and loads models into the app's context
+@app.init
def init():
- global model
-
- device = 0 if torch.cuda.is_available() else -1
- model = pipeline('fill-mask', model='bert-base-uncased', device=device)
-
-# Inference is ran for every server call
-# Reference your preloaded global model variable here.
-def inference(model_inputs:dict) -> dict:
- global model
-
- # Parse out your arguments
- prompt = model_inputs.get('prompt', None)
- if prompt == None:
- return {'message': "No prompt provided"}
-
+ model = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2")
+
+ context = {
+ "model": model
+ }
+
+ return context
+
+# @app.handler runs for every call
+@app.handler()
+def handler(context: dict, request: Request) -> Response:
+ prompt = request.json.get("prompt")
+ model = context.get("model")
# Run the model
- result = model(prompt)
+ sentence_embeddings = model.encode(prompt)
+ normalized_embeddings = normalize(sentence_embeddings)
+
+ # Convert the output array to a list
+ output = normalized_embeddings.tolist()
+
+ return Response(
+ json = {"data": output},
+ status=200
+ )
- # Return the results as a dictionary
- return result
+if __name__ == "__main__":
+ app.serve()
\ No newline at end of file
diff --git a/banana_config.json b/banana_config.json
new file mode 100644
index 0000000..cd9edbe
--- /dev/null
+++ b/banana_config.json
@@ -0,0 +1,24 @@
+{
+ "name": "",
+ "category": "",
+ "example_input": {
+ "prompt": "Hello I am a [MASK] model."
+ },
+ "example_output": {
+ "outputs":[
+ {
+ "score":0.13177461922168732,
+ "token":4827,
+ "token_str":"fashion",
+ "sequence":"hello i am a fashion model."
+ },
+ {
+ "score":0.1120428815484047,
+ "token":2535,
+ "token_str":"role",
+ "sequence":"hello i am a role model."
+ }
+ ]
+ },
+ "version": "1"
+}
\ No newline at end of file
diff --git a/download.py b/download.py
index 9f2956d..d693872 100644
--- a/download.py
+++ b/download.py
@@ -3,11 +3,11 @@
# In this example: A Huggingface BERT model
-from transformers import pipeline
+from sentence_transformers import SentenceTransformer
def download_model():
# do a dry run of loading the huggingface model, which will download weights
- pipeline('fill-mask', model='bert-base-uncased')
+ SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2')
if __name__ == "__main__":
download_model()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index f9cbeac..10f94c7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,3 @@
-sanic==22.6.2
-transformers
+potassium
+sentence-transformers==2.2.2
accelerate
diff --git a/server.py b/server.py
deleted file mode 100644
index d66cf1a..0000000
--- a/server.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Do not edit if deploying to Banana Serverless
-# This file is boilerplate for the http server, and follows a strict interface.
-
-# Instead, edit the init() and inference() functions in app.py
-
-from sanic import Sanic, response
-import subprocess
-import app as user_src
-
-# We do the model load-to-GPU step on server startup
-# so the model object is available globally for reuse
-user_src.init()
-
-# Create the http server app
-server = Sanic("my_app")
-
-# Healthchecks verify that the environment is correct on Banana Serverless
-@server.route('/healthcheck', methods=["GET"])
-def healthcheck(request):
- # dependency free way to check if GPU is visible
- gpu = False
- out = subprocess.run("nvidia-smi", shell=True)
- if out.returncode == 0: # success state on shell command
- gpu = True
-
- return response.json({"state": "healthy", "gpu": gpu})
-
-# Inference POST handler at '/' is called for every http call from Banana
-@server.route('/', methods=["POST"])
-def inference(request):
- try:
- model_inputs = response.json.loads(request.json)
- except:
- model_inputs = request.json
-
- output = user_src.inference(model_inputs)
-
- return response.json(output)
-
-
-if __name__ == '__main__':
- server.run(host='0.0.0.0', port=8000, workers=1)
diff --git a/test.py b/test.py
deleted file mode 100644
index 3c88413..0000000
--- a/test.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# This file is used to verify your http server acts as expected
-# Run it with `python3 test.py``
-
-import requests
-
-model_inputs = {'prompt': 'Hello I am a [MASK] model.'}
-
-res = requests.post('http://localhost:8000/', json = model_inputs)
-
-print(res.json())
\ No newline at end of file