forked from Worldie-com/pen.el
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgpt-neo-server.py
More file actions
65 lines (51 loc) · 1.75 KB
/
gpt-neo-server.py
File metadata and controls
65 lines (51 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# https://github.com/samrawal/emacs-secondmate/blob/main/serve/server.py
from transformers import AutoTokenizer, AutoModelForCausalLM
from flask import Flask, request, jsonify
app = Flask(__name__)
device = "cuda" # "cpu" or "cuda" or "cuda:n" where n is specific GPU to use
modelname = "EleutherAI/gpt-neo-2.7B"
tokenizer = AutoTokenizer.from_pretrained(modelname)
model = AutoModelForCausalLM.from_pretrained(modelname)
model.to(device)
# Prepend to input to provide more context. Anecdotally, this helps.
# May not need for larger models.
prime = '''
def is_palendrome(s):
"""Check whether a string is a palindrome"""
for i in range(len(s) - 1, -1, -1):
if s[i + 1] == s[i]:
return False
return True
###
def is_even(i):
"""Check whether an integer is even"""
return i % 2
###
def square_root(i):
"""Return the square root of an integer"""
return math.sqrt(i)
###
'''
def inference(prompt, temperature, max_length):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device)
gen_tokens = model.generate(
input_ids,
do_sample=True,
temperature=temperature,
max_length=max_length,
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
return gen_text
def autocomplete(plaintext, to_prime=True, temperature=0.8, max_length=300):
prompt = prime + plaintext if to_prime else plaintext
generation = inference(prompt, temperature, max_length)
return generation[len(prompt) :].split("###")[0]
@app.route("/")
def arguments():
text = request.args.get("text", "")
generation = autocomplete(text)
out = {"generation": generation}
return jsonify(out)
if __name__ == "__main__":
app.run(host="0.0.0.0", port="9900")