diff --git a/.gitignore b/.gitignore index b6e4761..9b5e917 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +# Misc / user +.history diff --git a/Dockerfile b/Dockerfile index 888a8da..950601a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,3 +1,5 @@ +# This is a potassium-standard dockerfile, compatible with Banana + # Must use a Cuda version 11+ FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime @@ -11,18 +13,14 @@ RUN pip3 install --upgrade pip ADD requirements.txt requirements.txt RUN pip3 install -r requirements.txt -# We add the banana boilerplate here -ADD server.py . # Add your model weight files # (in this case we have a python script) ADD download.py . RUN python3 download.py - -# Add your custom app code, init() and inference() -ADD app.py . +ADD . . EXPOSE 8000 -CMD python3 -u server.py +CMD python3 -u app.py \ No newline at end of file diff --git a/README.md b/README.md index 89068bd..28815aa 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,14 @@ - -# 🍌 Banana Serverless - -This repo gives a framework to serve ML models in production using simple HTTP servers. - -# Quickstart -**[Follow the quickstart guide in Banana's documentation to use this repo](https://docs.banana.dev/banana-docs/quickstart).** - -*(choose "GitHub Repository" deployment method)* - -
- -# Helpful Links -Understand the 🍌 [Serverless framework](https://docs.banana.dev/banana-docs/core-concepts/inference-server/serverless-framework) and functionality of each file within it. - -Generalize this framework to [deploy anything on Banana](https://docs.banana.dev/banana-docs/resources/how-to-serve-anything-on-banana). - -
- -## Use Banana for scale. +# My Potassium App +This is a Potassium HTTP server, created with `banana init` CLI + +### Testing +Start a local dev server with `banana dev` + +### Deployment +1. Create empty repo on [Github](https://github.com) +2. Push this repo to github +``` +git remote add origin https://github.com/{username}/{repo-name}.git +``` +3. [Log into Banana](https://app.banana.dev/onboard) +4. Select this repo to build and deploy! \ No newline at end of file diff --git a/app.py b/app.py index 7f6b061..e21af78 100644 --- a/app.py +++ b/app.py @@ -1,26 +1,37 @@ -from transformers import pipeline -import torch +from potassium import Potassium, Request, Response -# Init is ran on server startup -# Load your model to GPU as a global variable here using the variable name "model" +from sentence_transformers import SentenceTransformer +from sklearn.preprocessing import normalize + +app = Potassium("my_app") + +# @app.init runs at startup, and loads models into the app's context +@app.init def init(): - global model - - device = 0 if torch.cuda.is_available() else -1 - model = pipeline('fill-mask', model='bert-base-uncased', device=device) - -# Inference is ran for every server call -# Reference your preloaded global model variable here. -def inference(model_inputs:dict) -> dict: - global model - - # Parse out your arguments - prompt = model_inputs.get('prompt', None) - if prompt == None: - return {'message': "No prompt provided"} - + model = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2") + + context = { + "model": model + } + + return context + +# @app.handler runs for every call +@app.handler() +def handler(context: dict, request: Request) -> Response: + prompt = request.json.get("prompt") + model = context.get("model") # Run the model - result = model(prompt) + sentence_embeddings = model.encode(prompt) + normalized_embeddings = normalize(sentence_embeddings) + + # Convert the output array to a list + output = normalized_embeddings.tolist() + + return Response( + json = {"data": output}, + status=200 + ) - # Return the results as a dictionary - return result +if __name__ == "__main__": + app.serve() \ No newline at end of file diff --git a/banana_config.json b/banana_config.json new file mode 100644 index 0000000..cd9edbe --- /dev/null +++ b/banana_config.json @@ -0,0 +1,24 @@ +{ + "name": "", + "category": "", + "example_input": { + "prompt": "Hello I am a [MASK] model." + }, + "example_output": { + "outputs":[ + { + "score":0.13177461922168732, + "token":4827, + "token_str":"fashion", + "sequence":"hello i am a fashion model." + }, + { + "score":0.1120428815484047, + "token":2535, + "token_str":"role", + "sequence":"hello i am a role model." + } + ] + }, + "version": "1" +} \ No newline at end of file diff --git a/download.py b/download.py index 9f2956d..d693872 100644 --- a/download.py +++ b/download.py @@ -3,11 +3,11 @@ # In this example: A Huggingface BERT model -from transformers import pipeline +from sentence_transformers import SentenceTransformer def download_model(): # do a dry run of loading the huggingface model, which will download weights - pipeline('fill-mask', model='bert-base-uncased') + SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2') if __name__ == "__main__": download_model() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f9cbeac..10f94c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -sanic==22.6.2 -transformers +potassium +sentence-transformers==2.2.2 accelerate diff --git a/server.py b/server.py deleted file mode 100644 index d66cf1a..0000000 --- a/server.py +++ /dev/null @@ -1,42 +0,0 @@ -# Do not edit if deploying to Banana Serverless -# This file is boilerplate for the http server, and follows a strict interface. - -# Instead, edit the init() and inference() functions in app.py - -from sanic import Sanic, response -import subprocess -import app as user_src - -# We do the model load-to-GPU step on server startup -# so the model object is available globally for reuse -user_src.init() - -# Create the http server app -server = Sanic("my_app") - -# Healthchecks verify that the environment is correct on Banana Serverless -@server.route('/healthcheck', methods=["GET"]) -def healthcheck(request): - # dependency free way to check if GPU is visible - gpu = False - out = subprocess.run("nvidia-smi", shell=True) - if out.returncode == 0: # success state on shell command - gpu = True - - return response.json({"state": "healthy", "gpu": gpu}) - -# Inference POST handler at '/' is called for every http call from Banana -@server.route('/', methods=["POST"]) -def inference(request): - try: - model_inputs = response.json.loads(request.json) - except: - model_inputs = request.json - - output = user_src.inference(model_inputs) - - return response.json(output) - - -if __name__ == '__main__': - server.run(host='0.0.0.0', port=8000, workers=1) diff --git a/test.py b/test.py deleted file mode 100644 index 3c88413..0000000 --- a/test.py +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to verify your http server acts as expected -# Run it with `python3 test.py`` - -import requests - -model_inputs = {'prompt': 'Hello I am a [MASK] model.'} - -res = requests.post('http://localhost:8000/', json = model_inputs) - -print(res.json()) \ No newline at end of file