-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathserver.py
More file actions
53 lines (40 loc) · 1.43 KB
/
server.py
File metadata and controls
53 lines (40 loc) · 1.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from concurrent import futures
import click
import grpc
import signal
from time import time
import os
from utils import check_checkpoint_config
from rpc import service_pb2, service_pb2_grpc
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # noqa
from transformer.model import Model # noqa
class Transliterator(service_pb2_grpc.TransServicer):
def __init__(self, model) -> None:
super().__init__()
self.model = model
def infer(self, request, context):
start = time()
output = self.model.infer(request.word, request.to)
end = time() - start
output = ' '.join(output)
return service_pb2.Output(output=output, time=end)
@click.command()
@click.option('--checkpoint', default='./checkpoints', help='Path to checkpoints to restore model.')
def serve(checkpoint):
config = os.path.join(checkpoint, 'config.json')
config = check_checkpoint_config(config)
model = Model(config, checkpoint)
model.restore_checkpoint()
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
service_pb2_grpc.add_TransServicer_to_server(Transliterator(model), server)
server.add_insecure_port('[::]:50051')
server.start()
print('Server Started...')
def on_done(signum, frame):
print()
print('Stopping Server.')
server.stop(None)
signal.signal(signal.SIGINT, on_done)
server.wait_for_termination()
if __name__ == '__main__':
serve()