-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathinline_prediction_executor.py
More file actions
91 lines (73 loc) · 2.96 KB
/
inline_prediction_executor.py
File metadata and controls
91 lines (73 loc) · 2.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A thin shell to fit a predictor function into the PredictionExecutor interface.
This is a convenience wrapper to allow a predictor function to be used directly
as a PredictionExecutor.
Intended usage in launching a server:
predictor = Predictor()
executor = InlinePredictionExecutor(predictor.predict)
server_gunicorn.PredictionApplication(executor, ...).run()
"""
from collections.abc import Callable
from typing import Any
from typing_extensions import override
from serving.serving_framework import model_runner
from serving.serving_framework import server_gunicorn
class InlinePredictionExecutor(server_gunicorn.PredictionExecutor):
"""Provides prediction request execution using an inline function.
Provides a little framing to simplify the use of predictor functions in the
server worker process.
If a function call with no setup is insufficient, overriding the start and
predict methods here can be used to provide more complex behavior. Inheritance
directly from PredictionExecutor is also an option.
"""
def __init__(
self,
predictor: Callable[
[dict[str, Any], model_runner.ModelRunner], dict[str, Any]
],
model_runner_source: Callable[[], model_runner.ModelRunner]
):
self._predictor = predictor
self._model_runner = None
self._model_runner_source = model_runner_source
@override
def start(self) -> None:
"""Starts the executor.
Called after the Gunicorn worker process has started. Performs any setup
which needs to be done post-fork.
"""
# Safer to instantiate the RPC stub post-fork.
self._model_runner = self._model_runner_source()
def predict(self, input_json: dict[str, Any]) -> dict[str, Any]:
"""Executes the given request payload."""
if self._model_runner is None:
raise RuntimeError(
"Model runner is not initialized. Please call start() first."
)
return self._predictor(input_json, self._model_runner)
@override
def execute(self, input_json: dict[str, Any]) -> dict[str, Any]:
"""Executes the given request payload.
Args:
input_json: The full json prediction request payload.
Returns:
The json response to the prediction request.
Raises:
RuntimeError: Prediction failed in an unhandled way.
"""
try:
return self.predict(input_json)
except Exception as e:
raise RuntimeError("Unhandled exception from predictor.") from e