diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 3b5768ac..03e74c0c 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -7,7 +7,7 @@ on: push: branches: [ master ] pull_request: - branches: [ master ] + branches: [ master, feature/prediction ] jobs: code_quality_checks: @@ -40,11 +40,11 @@ jobs: continue-on-error: False runs-on: self-hosted - timeout-minutes: 30 + timeout-minutes: 60 strategy: matrix: - python-version: [3.7, 3.8] + python-version: [3.7] env: GITHUB_ACTION: true diff --git a/.pylintrc b/.pylintrc index 1a782589..5c5e6a8a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -68,7 +68,7 @@ ENABLED: [IMPORTS] ignored-modules=click,google,grpc,matplotlib,numpy,opacus,onnx,onnxmltools,pandas,PIL,prometheus_client,pydantic,pytest, - tensorflow,tensorflow_core,tensorflow_datasets,tensorflow_privacy,torch,torchsummary,torchvision,typing_extensions, + tensorflow,tensorflow_addons,tensorflow_core,tensorflow_datasets,tensorflow_privacy,torch,torchsummary,torchvision,typing_extensions, scipy,sklearn,xgboost [TYPECHECK] diff --git a/colearn/ml_interface.py b/colearn/ml_interface.py index 5a799e7b..68ace363 100644 --- a/colearn/ml_interface.py +++ b/colearn/ml_interface.py @@ -20,32 +20,7 @@ from typing import Any, Optional import onnx -import onnxmltools -import sklearn -import tensorflow as tf -import torch from pydantic import BaseModel -from tensorflow import keras - -model_classes_keras = (tf.keras.Model, keras.Model, tf.estimator.Estimator) -model_classes_scipy = (torch.nn.Module) -model_classes_sklearn = (sklearn.base.ClassifierMixin) - - -def convert_model_to_onnx(model: Any): - """ - Helper function to convert a ML model to onnx format - """ - if isinstance(model, model_classes_keras): - return onnxmltools.convert_keras(model) - if isinstance(model, model_classes_sklearn): - return onnxmltools.convert_sklearn(model) - if 'xgboost' in model.__repr__(): - return onnxmltools.convert_sklearn(model) - if isinstance(model, model_classes_scipy): - raise Exception("Pytorch models not yet supported to onnx") - else: - raise Exception("Attempt to convert unsupported model to onnx: {model}") class DiffPrivBudget(BaseModel): @@ -78,8 +53,9 @@ class DiffPrivConfig(BaseModel): class ProposedWeights(BaseModel): weights: Weights - vote_score: float - test_score: float + vote_score: dict + test_score: dict + criterion: str vote: Optional[bool] @@ -94,6 +70,17 @@ class ColearnModel(BaseModel): model: Optional[Any] +class PredictionRequest(BaseModel): + name: str + input_data: Any + pred_dataloader_key: Optional[Any] + + +class Prediction(BaseModel): + name: str + prediction_data: Any + + def deser_model(model: Any) -> onnx.ModelProto: """ Helper function to recover a onnx model from its deserialized form @@ -136,3 +123,14 @@ def mli_get_current_model(self) -> ColearnModel: Returns the current model """ pass + + @abc.abstractmethod + def mli_make_prediction(self, request: PredictionRequest) -> Prediction: + """ + Make prediction using the current model. + Does not change the current weights of the model. + + :param request: data to get the prediction for + :returns: the prediction + """ + pass diff --git a/colearn/onnxutils.py b/colearn/onnxutils.py new file mode 100644 index 00000000..15582aa7 --- /dev/null +++ b/colearn/onnxutils.py @@ -0,0 +1,44 @@ +# ------------------------------------------------------------------------------ +# +# Copyright 2021 Fetch.AI Limited +# +# Licensed under the Creative Commons Attribution-NonCommercial International +# License, Version 4.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://creativecommons.org/licenses/by-nc/4.0/legalcode +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ +from typing import Any + +import onnxmltools +import sklearn +import tensorflow as tf +import torch +from tensorflow import keras + +model_classes_keras = (tf.keras.Model, keras.Model, tf.estimator.Estimator) +model_classes_scipy = (torch.nn.Module) +model_classes_sklearn = (sklearn.base.ClassifierMixin) + + +def convert_model_to_onnx(model: Any): + """ + Helper function to convert a ML model to onnx format + """ + if isinstance(model, model_classes_keras): + return onnxmltools.convert_keras(model) + if isinstance(model, model_classes_sklearn): + return onnxmltools.convert_sklearn(model) + if 'xgboost' in model.__repr__(): + return onnxmltools.convert_sklearn(model) + if isinstance(model, model_classes_scipy): + raise Exception("Pytorch models not yet supported to onnx") + else: + raise Exception("Attempt to convert unsupported model to onnx: {model}") diff --git a/colearn/training.py b/colearn/training.py index d4d58855..ecf9ada0 100644 --- a/colearn/training.py +++ b/colearn/training.py @@ -33,8 +33,8 @@ def initial_result(learners: Sequence[MachineLearningInterface]): result = Result() for learner in learners: proposed_weights = learner.mli_test_weights(learner.mli_get_current_weights()) # type: ProposedWeights - result.test_scores.append(proposed_weights.test_score) - result.vote_scores.append(proposed_weights.vote_score) + result.test_scores.append(proposed_weights.test_score[proposed_weights.criterion]) + result.vote_scores.append(proposed_weights.vote_score[proposed_weights.criterion]) result.votes.append(True) return result @@ -48,9 +48,10 @@ def collective_learning_round(learners: Sequence[MachineLearningInterface], vote vote_threshold) result.vote = vote result.votes = [pw.vote for pw in proposed_weights_list] - result.vote_scores = [pw.vote_score for pw in + # TODO does this make sense? + result.vote_scores = [pw.vote_score[pw.criterion] for pw in proposed_weights_list] - result.test_scores = [pw.test_score for pw in proposed_weights_list] + result.test_scores = [pw.test_score[pw.criterion] for pw in proposed_weights_list] result.training_summaries = [ l.mli_get_current_weights().training_summary for l in learners @@ -73,7 +74,7 @@ def individual_training_round(learners: Sequence[MachineLearningInterface], roun learner.mli_accept_weights(weights) result.votes.append(True) - result.vote_scores.append(proposed_weights.vote_score) - result.test_scores.append(proposed_weights.test_score) + result.vote_scores.append(proposed_weights.vote_score[proposed_weights.criterion]) + result.test_scores.append(proposed_weights.test_score[proposed_weights.criterion]) return result diff --git a/colearn_examples/grpc/mlifactory_grpc_mnist.py b/colearn_examples/grpc/mlifactory_grpc_mnist.py index e44730f0..cf069625 100644 --- a/colearn_examples/grpc/mlifactory_grpc_mnist.py +++ b/colearn_examples/grpc/mlifactory_grpc_mnist.py @@ -140,11 +140,18 @@ def get_models(self) -> Dict[str, Dict[str, Any]]: vote_batches=10, learning_rate=0.001)} - def get_compatibilities(self) -> Dict[str, Set[str]]: + def get_data_compatibilities(self) -> Dict[str, Set[str]]: return {model_tag: {dataloader_tag}} + def get_prediction_dataloaders(self) -> Dict[str, Dict[str, Any]]: + raise NotImplementedError + + def get_pred_compatibilities(self) -> Dict[str, Set[str]]: + raise NotImplementedError + def get_mli(self, model_name: str, model_params: str, dataloader_name: str, - dataset_params: str) -> MachineLearningInterface: + dataset_params: str, prediction_dataloader_name: str = None, + prediction_dataset_params: str = None) -> MachineLearningInterface: dataloader_kwargs = json.loads(dataset_params) data_loaders = prepare_data_loaders(**dataloader_kwargs) diff --git a/colearn_examples/grpc/mnist_grpc.py b/colearn_examples/grpc/mnist_grpc.py index 90701d93..2219bd37 100644 --- a/colearn_examples/grpc/mnist_grpc.py +++ b/colearn_examples/grpc/mnist_grpc.py @@ -37,6 +37,7 @@ from colearn_keras.keras_mnist import split_to_folders # pylint: disable=C0413 # noqa: F401 from tensorflow.python.data.ops.dataset_ops import PrefetchDataset # pylint: disable=C0413 # noqa: F401 import tensorflow as tf # pylint: disable=C0413 # noqa: F401 +import tensorflow_addons as tfa # pylint: disable=C0413 # noqa: F401 dataloader_tag = "KERAS_MNIST_EXAMPLE_DATALOADER" @@ -63,6 +64,9 @@ def prepare_data_loaders(location: str, images = pickle.load(open(Path(data_folder) / image_fl, "rb")) labels = pickle.load(open(Path(data_folder) / label_fl, "rb")) + # OHE for broader metric usage + labels = tf.keras.utils.to_categorical(labels, 10) + n_cases = int(train_ratio * len(images)) n_vote_cases = int(vote_ratio * len(images)) @@ -87,7 +91,7 @@ def prepare_data_loaders(location: str, @FactoryRegistry.register_model_architecture(model_tag, [dataloader_tag]) def prepare_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, PrefetchDataset], steps_per_epoch: int = 100, - vote_batches: int = 10, + vote_batches: int = 1, # needs to stay one for correct test calculation learning_rate: float = 0.001 ) -> KerasLearner: """ @@ -100,8 +104,11 @@ def prepare_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, Prefet """ # 2D Convolutional model for image recognition - loss = "sparse_categorical_crossentropy" + loss = "categorical_crossentropy" optimizer = tf.keras.optimizers.Adam + n_classes = 10 + metric_list = ["accuracy", tf.keras.metrics.AUC(), + tfa.metrics.F1Score(average="macro", num_classes=n_classes)] input_img = tf.keras.Input(shape=(28, 28, 1), name="Input") x = tf.keras.layers.Conv2D(32, (5, 5), activation="relu", padding="same", name="Conv1_1")(input_img) @@ -112,19 +119,19 @@ def prepare_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, Prefet x = tf.keras.layers.MaxPooling2D((2, 2), name="pool3")(x) x = tf.keras.layers.Flatten(name="flatten")(x) x = tf.keras.layers.Dense(64, activation="relu", name="fc1")(x) - x = tf.keras.layers.Dense(10, activation="softmax", name="fc2")(x) + x = tf.keras.layers.Dense(n_classes, activation="softmax", name="fc2")(x) model = tf.keras.Model(inputs=input_img, outputs=x) opt = optimizer(lr=learning_rate) - model.compile(loss=loss, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], optimizer=opt) + model.compile(loss=loss, metrics=metric_list, optimizer=opt) learner = KerasLearner( model=model, train_loader=data_loaders[0], vote_loader=data_loaders[1], test_loader=data_loaders[2], - criterion="sparse_categorical_accuracy", - minimise_criterion=False, + criterion="loss", + minimise_criterion=True, model_fit_kwargs={"steps_per_epoch": steps_per_epoch}, model_evaluate_kwargs={"steps": vote_batches}, ) @@ -167,7 +174,7 @@ def prepare_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, Prefet results = Results() results.data.append(initial_result(all_learner_models)) -plot = ColearnPlot(score_name="accuracy") +plot = ColearnPlot(score_name="loss") testing_mode = bool(os.getenv("COLEARN_EXAMPLES_TEST", "")) # for testing n_rounds = 10 if not testing_mode else 1 diff --git a/colearn_examples/ml_interface/keras_fraud.py b/colearn_examples/ml_interface/keras_fraud.py index 1aa61769..a2446712 100644 --- a/colearn_examples/ml_interface/keras_fraud.py +++ b/colearn_examples/ml_interface/keras_fraud.py @@ -31,10 +31,8 @@ """ Fraud training example using Tensorflow Keras - Used dataset: - Fraud, download from kaggle: https://www.kaggle.com/c/ieee-fraud-detection - What script does: - Sets up the Keras model and some configuration parameters - Randomly splits the dataset between multiple learners diff --git a/colearn_examples/ml_interface/mli_fraud.py b/colearn_examples/ml_interface/mli_fraud.py index 282a5af6..9fda9a6d 100644 --- a/colearn_examples/ml_interface/mli_fraud.py +++ b/colearn_examples/ml_interface/mli_fraud.py @@ -24,7 +24,8 @@ import sklearn from sklearn.linear_model import SGDClassifier -from colearn.ml_interface import MachineLearningInterface, Weights, ProposedWeights, ColearnModel, ModelFormat, convert_model_to_onnx +from colearn.ml_interface import MachineLearningInterface, Prediction, PredictionRequest, Weights, ProposedWeights, ColearnModel, ModelFormat +from colearn.onnxutils import convert_model_to_onnx from colearn.training import initial_result, collective_learning_round, set_equal_weights from colearn.utils.plot import ColearnPlot from colearn.utils.results import Results, print_results @@ -87,18 +88,20 @@ def mli_propose_weights(self) -> Weights: def mli_test_weights(self, weights: Weights) -> ProposedWeights: current_weights = self.mli_get_current_weights() self.set_weights(weights) + criterion = "mean_accuracy" vote_score = self.test(self.vote_data, self.vote_labels) test_score = self.test(self.test_data, self.test_labels) - vote = self.vote_score <= vote_score + vote = self.vote_score[criterion] <= vote_score[criterion] self.set_weights(current_weights) return ProposedWeights(weights=weights, vote_score=vote_score, test_score=test_score, - vote=vote + vote=vote, + criterion=criterion ) def mli_accept_weights(self, weights: Weights): @@ -126,9 +129,12 @@ def set_weights(self, weights: Weights): def test(self, data, labels): try: - return self.model.score(data, labels) + return {"mean_accuracy": self.model.score(data, labels)} except sklearn.exceptions.NotFittedError: - return 0 + return {"mean_accuracy": 0} + + def mli_make_prediction(self, request: PredictionRequest) -> Prediction: + raise NotImplementedError() if __name__ == "__main__": diff --git a/colearn_examples/ml_interface/mli_random_forest_iris.py b/colearn_examples/ml_interface/mli_random_forest_iris.py index ad3f75f1..53bf92f1 100644 --- a/colearn_examples/ml_interface/mli_random_forest_iris.py +++ b/colearn_examples/ml_interface/mli_random_forest_iris.py @@ -22,7 +22,8 @@ from sklearn import datasets from sklearn.ensemble import RandomForestClassifier -from colearn.ml_interface import MachineLearningInterface, Weights, ProposedWeights, ColearnModel, ModelFormat, convert_model_to_onnx +from colearn.ml_interface import MachineLearningInterface, Prediction, PredictionRequest, Weights, ProposedWeights, ColearnModel, ModelFormat +from colearn.onnxutils import convert_model_to_onnx from colearn.training import initial_result, collective_learning_round from colearn.utils.plot import ColearnPlot from colearn.utils.results import Results, print_results @@ -76,18 +77,20 @@ def mli_propose_weights(self) -> Weights: def mli_test_weights(self, weights: Weights) -> ProposedWeights: current_weights = self.mli_get_current_weights() self.set_weights(weights) + criterion = "mean_accuracy" vote_score = self.test(self.vote_data, self.vote_labels) test_score = self.test(self.test_data, self.test_labels) - vote = self.vote_score <= vote_score + vote = self.vote_score[criterion] <= vote_score[criterion] self.set_weights(current_weights) return ProposedWeights(weights=weights, vote_score=vote_score, test_score=test_score, - vote=vote + vote=vote, + criterion=criterion ) def mli_accept_weights(self, weights: Weights): @@ -112,7 +115,11 @@ def set_weights(self, weights: Weights): self.model = pickle.loads(weights.weights) def test(self, data_array, labels_array): - return self.model.score(data_array, labels_array) + score = {"mean_accuracy": self.model.score(data_array, labels_array)} + return score + + def mli_make_prediction(self, request: PredictionRequest) -> Prediction: + raise NotImplementedError() train_fraction = 0.9 diff --git a/colearn_examples/ml_interface/run_demo.py b/colearn_examples/ml_interface/run_demo.py index ac0a3528..905298bb 100644 --- a/colearn_examples/ml_interface/run_demo.py +++ b/colearn_examples/ml_interface/run_demo.py @@ -72,7 +72,7 @@ args = parser.parse_args() model_name = args.model -dataloader_set = mli_fac.get_compatibilities()[model_name] +dataloader_set = mli_fac.get_data_compatibilities()[model_name] dataloader_name = next(iter(dataloader_set)) # use the first dataloader n_learners = args.n_learners diff --git a/colearn_examples/ml_interface/xgb_reg_boston.py b/colearn_examples/ml_interface/xgb_reg_boston.py index 7abfb6f3..f1eebe19 100644 --- a/colearn_examples/ml_interface/xgb_reg_boston.py +++ b/colearn_examples/ml_interface/xgb_reg_boston.py @@ -23,7 +23,8 @@ import numpy as np import xgboost as xgb -from colearn.ml_interface import MachineLearningInterface, Weights, ProposedWeights, ColearnModel, ModelFormat, convert_model_to_onnx +from colearn.ml_interface import MachineLearningInterface, Prediction, PredictionRequest, Weights, ProposedWeights, ColearnModel, ModelFormat +from colearn.onnxutils import convert_model_to_onnx from colearn.training import initial_result, collective_learning_round from colearn.utils.data import split_list_into_fractions from colearn.utils.plot import ColearnPlot @@ -70,18 +71,20 @@ def mli_propose_weights(self) -> Weights: def mli_test_weights(self, weights: Weights) -> ProposedWeights: current_weights = self.mli_get_current_weights() + criterion = self.params["objective"] self.set_weights(weights) vote_score = self.test(self.xg_vote) test_score = self.test(self.xg_test) - vote = self.vote_score >= vote_score + vote = self.vote_score[criterion] >= vote_score[criterion] self.set_weights(current_weights) return ProposedWeights(weights=weights, vote_score=vote_score, test_score=test_score, - vote=vote + vote=vote, + criterion=criterion ) def mli_accept_weights(self, weights: Weights): @@ -110,7 +113,11 @@ def mli_get_current_model(self) -> ColearnModel: ) def test(self, data_matrix): - return mse(self.model.predict(data_matrix), data_matrix.get_label()) + score = {self.params["objective"]: mse(self.model.predict(data_matrix), data_matrix.get_label())} + return score + + def mli_make_prediction(self, request: PredictionRequest) -> Prediction: + raise NotImplementedError() train_fraction = 0.9 diff --git a/colearn_grpc/example_grpc_learner_client.py b/colearn_grpc/example_grpc_learner_client.py index 8e32f98c..ff524f74 100644 --- a/colearn_grpc/example_grpc_learner_client.py +++ b/colearn_grpc/example_grpc_learner_client.py @@ -24,7 +24,7 @@ import colearn_grpc.proto.generated.interface_pb2 as ipb2 import colearn_grpc.proto.generated.interface_pb2_grpc as ipb2_grpc -from colearn.ml_interface import MachineLearningInterface, ProposedWeights, Weights, ColearnModel +from colearn.ml_interface import MachineLearningInterface, Prediction, PredictionRequest, ProposedWeights, Weights, ColearnModel from colearn_grpc.logging import get_logger from colearn_grpc.utils import iterator_to_weights, weights_to_iterator @@ -65,13 +65,15 @@ def start(self): # Attempt to get the certificate from the server and use it to encrypt the # connection. If the certificate cannot be found, try to create an unencrypted connection. try: - assert (':' in self.address), f"Poorly formatted address, needs :port - {self.address}" + assert ( + ':' in self.address), f"Poorly formatted address, needs :port - {self.address}" _logger.info(f"Connecting to server: {self.address}") addr, port = self.address.split(':') trusted_certs = ssl.get_server_certificate((addr, int(port))) # create credentials - credentials = grpc.ssl_channel_credentials(root_certificates=trusted_certs.encode()) + credentials = grpc.ssl_channel_credentials( + root_certificates=trusted_certs.encode()) except ssl.SSLError as e: _logger.warning( f"Encountered ssl error when attempting to get certificate from learner server: {e}") @@ -118,15 +120,21 @@ def get_supported_system(self): response = self.stub.QuerySupportedSystem(request) r = { "data_loaders": {}, + "prediction_data_loaders": {}, "model_architectures": {}, - "compatibilities": {} + "data_compatibilities": {}, + "pred_compatibilities": {}, } for d in response.data_loaders: r["data_loaders"][d.name] = d.default_parameters + for p in response.prediction_data_loaders: + r["prediction_data_loaders"][p.name] = p.default_parameters for m in response.model_architectures: r["model_architectures"][m.name] = m.default_parameters - for c in response.compatibilities: - r["compatibilities"][c.model_architecture] = c.dataloaders + for dc in response.data_compatibilities: + r["data_compatibilities"][dc.model_architecture] = dc.dataloaders + for pc in response.pred_compatibilities: + r["pred_compatibilities"][pc.model_architecture] = pc.prediction_dataloaders return r def get_version(self): @@ -137,11 +145,17 @@ def get_version(self): return response.version def setup_ml(self, dataset_loader_name, dataset_loader_parameters, - model_arch_name, model_parameters): - - _logger.info(f"Setting up ml: model_arch: {model_arch_name}, dataset_loader: {dataset_loader_name}") + model_arch_name, model_parameters, + prediction_dataset_loader_name=None, + prediction_dataset_loader_parameters=None, + ): + _logger.info( + f"Setting up ml: model_arch: {model_arch_name}, dataset_loader: {dataset_loader_name}," + f"prediction_dataset_loader: {prediction_dataset_loader_name}") _logger.debug(f"Model params: {model_parameters}") _logger.debug(f"Dataloader params: {dataset_loader_parameters}") + _logger.debug( + f"Prediction dataloader params: {prediction_dataset_loader_parameters}") request = ipb2.RequestMLSetup() request.dataset_loader_name = dataset_loader_name @@ -149,6 +163,11 @@ def setup_ml(self, dataset_loader_name, dataset_loader_parameters, request.model_arch_name = model_arch_name request.model_parameters = model_parameters + if prediction_dataset_loader_name: + request.prediction_dataset_loader_name = prediction_dataset_loader_name + if prediction_dataset_loader_parameters: + request.prediction_dataset_loader_parameters = prediction_dataset_loader_parameters + _logger.info(f"Setting up ml with request: {request}") try: @@ -173,7 +192,8 @@ def mli_propose_weights(self) -> Weights: def mli_test_weights(self, weights: Weights = None) -> ProposedWeights: try: if weights: - response = self.stub.TestWeights(weights_to_iterator(weights, encode=False)) + response = self.stub.TestWeights( + weights_to_iterator(weights, encode=False)) else: raise Exception("mli_test_weights(None) is not currently supported") @@ -181,7 +201,8 @@ def mli_test_weights(self, weights: Weights = None) -> ProposedWeights: weights=weights, vote_score=response.vote_score, test_score=response.test_score, - vote=response.vote + vote=response.vote, + criterion=response.criterion ) except grpc.RpcError as ex: _logger.exception(f"Failed to test_model: {ex}") @@ -211,3 +232,19 @@ def mli_get_current_model(self) -> ColearnModel: response = self.stub.GetCurrentModel(request) return ColearnModel(model_format=response.model_format, model_file=response.model_file, model=response.model) + + def mli_make_prediction(self, request: PredictionRequest) -> Prediction: + request_pb = ipb2.PredictionRequest() + request_pb.name = request.name + request_pb.input_data = request.input_data + if request.pred_dataloader_key: + request_pb.pred_dataloader_key = request.pred_dataloader_key + + _logger.info(f"Requesting prediction {request.name}") + + try: + response = self.stub.MakePrediction(request_pb) + return Prediction(name=response.name, prediction_data=response.prediction_data) + except grpc.RpcError as ex: + _logger.exception(f"Failed to make_prediction: {ex}") + raise ConnectionError(f"GRPC error: {ex}") diff --git a/colearn_grpc/example_mli_factory.py b/colearn_grpc/example_mli_factory.py index 9266ac15..803e0226 100644 --- a/colearn_grpc/example_mli_factory.py +++ b/colearn_grpc/example_mli_factory.py @@ -34,9 +34,12 @@ def __init__(self): in FactoryRegistry.model_architectures.items()} self.dataloaders = {name: config.default_parameters for name, config in FactoryRegistry.dataloaders.items()} - - self.compatibilities = {name: config.compatibilities for name, config - in FactoryRegistry.model_architectures.items()} + self.prediction_dataloaders = {name: config.default_parameters for name, config + in FactoryRegistry.prediction_dataloaders.items()} + self.data_compatibilities = {name: config.data_compatibilities for name, config + in FactoryRegistry.model_architectures.items()} + self.pred_compatibilities = {name: config.pred_compatibilities for name, config + in FactoryRegistry.model_architectures.items()} def get_models(self) -> Dict[str, Dict[str, Any]]: return copy.deepcopy(self.models) @@ -44,15 +47,24 @@ def get_models(self) -> Dict[str, Dict[str, Any]]: def get_dataloaders(self) -> Dict[str, Dict[str, Any]]: return copy.deepcopy(self.dataloaders) - def get_compatibilities(self) -> Dict[str, Set[str]]: - return self.compatibilities + def get_prediction_dataloaders(self) -> Dict[str, Dict[str, Any]]: + return copy.deepcopy(self.prediction_dataloaders) + + def get_data_compatibilities(self) -> Dict[str, Set[str]]: + return self.data_compatibilities + + def get_pred_compatibilities(self) -> Dict[str, Set[str]]: + return self.pred_compatibilities def get_mli(self, model_name: str, model_params: str, dataloader_name: str, - dataset_params: str) -> MachineLearningInterface: + dataset_params: str, prediction_dataloader_name: str = None, + prediction_dataset_params: str = None) -> MachineLearningInterface: print("Call to get_mli") print(f"model_name {model_name} -> params: {model_params}") print(f"dataloader_name {dataloader_name} -> params: {dataset_params}") + print( + f"prediction_dataloader_name {prediction_dataloader_name} -> params: {prediction_dataset_params}") if model_name not in self.models: raise Exception(f"Model {model_name} is not a valid model. " @@ -60,11 +72,18 @@ def get_mli(self, model_name: str, model_params: str, dataloader_name: str, if dataloader_name not in self.dataloaders: raise Exception(f"Dataloader {dataloader_name} is not a valid dataloader. " f"Available dataloaders are: {self.dataloaders}") - if dataloader_name not in self.compatibilities[model_name]: + if dataloader_name not in self.data_compatibilities[model_name]: raise Exception(f"Dataloader {dataloader_name} is not compatible with {model_name}." - f"Compatible dataloaders are: {self.compatibilities[model_name]}") - - dataloader_config = copy.deepcopy(self.dataloaders[dataloader_name]) # Default parameters + f"Compatible dataloaders are: {self.data_compatibilities[model_name]}") + if prediction_dataloader_name and prediction_dataloader_name not in self.prediction_dataloaders: + raise Exception(f"Prediction Dataloader {prediction_dataloader_name} is not a valid dataloader. " + f"Available prediction dataloaders are: {self.prediction_dataloaders}") + if prediction_dataloader_name and prediction_dataloader_name not in self.pred_compatibilities[model_name]: + raise Exception(f"Prediction Dataloader {prediction_dataloader_name} is not compatible with {model_name}." + f"Compatible prediction dataloaders are: {self.pred_compatibilities[model_name]}") + + dataloader_config = copy.deepcopy( + self.dataloaders[dataloader_name]) # Default parameters dataloader_new_config = json.loads(dataset_params) for key in dataloader_new_config.keys(): if key in dataloader_config or key == "location": @@ -76,6 +95,10 @@ def get_mli(self, model_name: str, model_params: str, dataloader_name: str, prepare_data_loaders = FactoryRegistry.dataloaders[dataloader_name][0] data_loaders = prepare_data_loaders(**dataloader_config) + pred_data_loaders = load_all_prediction_data_loaders(self, model_name, + prediction_dataloader_name, + prediction_dataset_params) + model_config = copy.deepcopy(self.models[model_name]) # Default parameters model_new_config = json.loads(model_params) for key in model_new_config.keys(): @@ -88,6 +111,38 @@ def get_mli(self, model_name: str, model_params: str, dataloader_name: str, c = model_config["diff_priv_config"] if c is not None: model_config["diff_priv_config"] = DiffPrivConfig(**c) + prepare_learner = FactoryRegistry.model_architectures[model_name][0] - return prepare_learner(data_loaders=data_loaders, **model_config) + if len(pred_data_loaders) >= 1: + return prepare_learner(data_loaders=data_loaders, prediction_data_loaders=pred_data_loaders, **model_config) + else: + return prepare_learner(data_loaders=data_loaders, **model_config) + + +def load_all_prediction_data_loaders(self, model_name: str, + prediction_dataloader_name=None, + prediction_dataset_params=None): + keys = self.pred_compatibilities[model_name] + pred_dict: Dict[str, Any] = {} + if keys: + for name in keys: + pred_dataloader_config = copy.deepcopy( + self.prediction_dataloaders[name]) # Default parameters + if prediction_dataloader_name and prediction_dataset_params: + pred_dataloader_new_config = json.loads(prediction_dataset_params) + for key in pred_dataloader_new_config.keys(): + if key in pred_dataloader_config or key == "location": + pred_dataloader_config[key] = pred_dataloader_new_config[key] + else: + _logger.warning(f"Key {key} was included in the dataloader params but this dataloader " + f"({name}) does not accept it.") + prepare_pred_data_loader = FactoryRegistry.prediction_dataloaders[name][0] + pred_tmp_dict = prepare_pred_data_loader(**pred_dataloader_config) + if prediction_dataloader_name and prediction_dataloader_name == name: + pred_tmp_dict.update(pred_dict) + pred_dict = pred_tmp_dict + else: + pred_dict.update(pred_tmp_dict) + + return pred_dict diff --git a/colearn_grpc/factory_registry.py b/colearn_grpc/factory_registry.py index c4881dcd..21e66ad3 100644 --- a/colearn_grpc/factory_registry.py +++ b/colearn_grpc/factory_registry.py @@ -16,7 +16,7 @@ # # ------------------------------------------------------------------------------ from inspect import signature -from typing import Callable, Dict, Any, List, NamedTuple +from typing import Callable, Dict, Any, List, NamedTuple, Optional class RegistryException(Exception): @@ -42,10 +42,17 @@ class DataloaderDef(NamedTuple): dataloaders: Dict[str, DataloaderDef] = {} + class PredictionDataloaderDef(NamedTuple): + callable: Callable + default_parameters: Dict[str, Any] + + prediction_dataloaders: Dict[str, PredictionDataloaderDef] = {} + class ModelArchitectureDef(NamedTuple): callable: Callable default_parameters: Dict[str, Any] - compatibilities: List[str] + data_compatibilities: List[str] + pred_compatibilities: Optional[List[str]] model_architectures: Dict[str, ModelArchitectureDef] = {} @@ -54,7 +61,8 @@ def register_dataloader(cls, name: str): def wrap(dataloader: Callable): check_dataloader_callable(dataloader) if name in cls.dataloaders: - print(f"Warning: {name} already registered. Replacing with {dataloader.__name__}") + print( + f"Warning: {name} already registered. Replacing with {dataloader.__name__}") cls.dataloaders[name] = cls.DataloaderDef( callable=dataloader, default_parameters=_get_defaults(dataloader)) @@ -63,22 +71,42 @@ def wrap(dataloader: Callable): return wrap @classmethod - def register_model_architecture(cls, name: str, compatibilities: List[str]): + def register_prediction_dataloader(cls, name: str): + def wrap(prediction_dataloader: Callable): + check_dataloader_callable(prediction_dataloader) + if name in cls.prediction_dataloaders: + print( + f"Warning: {name} already registered. Replacing with {prediction_dataloader.__name__}") + cls.prediction_dataloaders[name] = cls.PredictionDataloaderDef( + callable=prediction_dataloader, + default_parameters=_get_defaults(prediction_dataloader)) + return prediction_dataloader + + return wrap + + @classmethod + def register_model_architecture(cls, name: str, + data_compatibilities: List[str], + pred_compatibilities: List[str] = None): def wrap(model_arch_creator: Callable): - cls.check_model_callable(model_arch_creator, compatibilities) + cls.check_model_data_callable(model_arch_creator, data_compatibilities) + cls.check_model_prediction_callable( + model_arch_creator, pred_compatibilities) if name in cls.model_architectures: - print(f"Warning: {name} already registered. Replacing with {model_arch_creator.__name__}") + print( + f"Warning: {name} already registered. Replacing with {model_arch_creator.__name__}") cls.model_architectures[name] = cls.ModelArchitectureDef( callable=model_arch_creator, default_parameters=_get_defaults(model_arch_creator), - compatibilities=compatibilities) + data_compatibilities=data_compatibilities, + pred_compatibilities=pred_compatibilities) return model_arch_creator return wrap @classmethod - def check_model_callable(cls, to_call: Callable, compatibilities: List[str]): + def check_model_data_callable(cls, to_call: Callable, compatibilities: List[str]): sig = signature(to_call) if "data_loaders" not in sig.parameters: raise RegistryException("model must accept a 'data_loaders' parameter") @@ -88,6 +116,22 @@ def check_model_callable(cls, to_call: Callable, compatibilities: List[str]): raise RegistryException(f"Compatible dataloader {dl} is not registered. The dataloader needs to be " "registered before the model that references it.") dl_type = signature(cls.dataloaders[dl].callable).return_annotation - if not dl_type == model_dl_type: + if dl_type != model_dl_type: raise RegistryException(f"Compatible dataloader {dl} has return type {dl_type}" f" but model data_loaders expects type {model_dl_type}") + + @classmethod + def check_model_prediction_callable(cls, to_call: Callable, compatibilities: List[str] = None): + sig = signature(to_call) + if "prediction_data_loaders" in sig.parameters and compatibilities: + model_dl_type = sig.parameters["prediction_data_loaders"].annotation + for dl in compatibilities: + if dl not in cls.prediction_dataloaders: + raise RegistryException(f"Compatible prediction dataloader {dl} is not registered." + "The dataloader needs to be " + "registered before the model that references it.") + dl_type = signature( + cls.prediction_dataloaders[dl].callable).return_annotation + if dl_type != model_dl_type: + raise RegistryException(f"Compatible prediction dataloader {dl} has return type {dl_type}" + f" but model prediction_data_loaders expects type {model_dl_type}") diff --git a/colearn_grpc/grpc_learner_server.py b/colearn_grpc/grpc_learner_server.py index f9b51615..5bb1f8ff 100644 --- a/colearn_grpc/grpc_learner_server.py +++ b/colearn_grpc/grpc_learner_server.py @@ -21,7 +21,7 @@ from google.protobuf import empty_pb2 import grpc -from colearn.ml_interface import MachineLearningInterface +from colearn.ml_interface import MachineLearningInterface, PredictionRequest from prometheus_client import Counter, Summary import colearn_grpc.proto.generated.interface_pb2 as ipb2 @@ -62,6 +62,8 @@ "This metric measures the time it takes to accept a weight") _time_get = Summary("contract_learner_grpc_server_get_time", "This metric measures the time it takes to get the current weights") +_time_prediction = Summary("contract_learner_grpc_server_prediction_time", + "This metric measures the time it takes to compute a prediction using current weights") class GRPCLearnerServer(ipb2_grpc.GRPCLearnerServicer): @@ -99,11 +101,24 @@ def QuerySupportedSystem(self, request, context): d.name = name d.default_parameters = json.dumps(params) - for model_architecture, data_loaders in self.mli_factory.get_compatibilities().items(): - c = response.compatibilities.add() - c.model_architecture = model_architecture + for name, params in self.mli_factory.get_prediction_dataloaders().items(): + p = response.prediction_data_loaders.add() + p.name = name + p.default_parameters = json.dumps(params) + + for model_architecture, data_loaders in self.mli_factory.get_data_compatibilities().items(): + dc = response.data_compatibilities.add() + dc.model_architecture = model_architecture for dataloader_name in data_loaders: - c.dataloaders.append(dataloader_name) + dc.dataloaders.append(dataloader_name) + + pred_compatibilities = self.mli_factory.get_pred_compatibilities() + for model_architecture, predicton_data_loaders in pred_compatibilities.items(): + pc = response.pred_compatibilities.add() + pc.model_architecture = model_architecture + if predicton_data_loaders: + for pred_dataloader_name in predicton_data_loaders: + pc.prediction_dataloaders.append(pred_dataloader_name) except Exception as ex: # pylint: disable=W0703 _logger.exception(f"Exception in QuerySupportedSystem: {ex} {type(ex)}") @@ -122,7 +137,9 @@ def MLSetup(self, request, context): model_name=request.model_arch_name, model_params=request.model_parameters, dataloader_name=request.dataset_loader_name, - dataset_params=request.dataset_loader_parameters + dataset_params=request.dataset_loader_parameters, + prediction_dataloader_name=request.prediction_dataset_loader_name, + prediction_dataset_params=request.prediction_dataset_loader_parameters ) _logger.debug("ML MODEL CREATED") if self.learner is not None: @@ -185,8 +202,9 @@ def TestWeights(self, request_iterator, context): weights = iterator_to_weights(request_iterator) proposed_weights = self.learner.mli_test_weights(weights) - pw.vote_score = proposed_weights.vote_score - pw.test_score = proposed_weights.test_score + pw.vote_score.update(proposed_weights.vote_score) + pw.test_score.update(proposed_weights.test_score) + pw.criterion = proposed_weights.criterion if proposed_weights.vote is not None: pw.vote = proposed_weights.vote _logger.debug("Testing done!") @@ -261,3 +279,33 @@ def GetCurrentModel(self, request, context): response.model = current_model.model.SerializeToString() return response + + @_time_prediction.time() + def MakePrediction(self, request, context): + response = ipb2.PredictionResponse() + _logger.info(f"Got Prediction request: {request}") + pred_data_loaders = self.learner.get_prediction_data_loaders() + + if request.pred_dataloader_key: + pred_func = pred_data_loaders[request.pred_dataloader_key] + else: + # Get first in list as default + pred_key = list(pred_data_loaders.keys())[0] + pred_func = pred_data_loaders[pred_key] + img = pred_func(request.input_data.decode("utf-8")) + + if self.learner is not None: + self._learner_mutex.acquire() # TODO(LR) is the mutex needed here? + _logger.debug(f"Computing prediction: {request.name}") + prediction_req = PredictionRequest( + name=request.name, + input_data=img.tobytes(), + ) + prediction = self.learner.mli_make_prediction(prediction_req) + _logger.debug(f"Prediction {request.name} computed successfully") + response.name = request.name + response.prediction_data = bytes(prediction.prediction_data) + self._learner_mutex.release() + + _logger.debug(f"Sending Prediction Response: {response}") + return response diff --git a/colearn_grpc/mli_factory_interface.py b/colearn_grpc/mli_factory_interface.py index f6b5fbec..f521065e 100644 --- a/colearn_grpc/mli_factory_interface.py +++ b/colearn_grpc/mli_factory_interface.py @@ -16,7 +16,7 @@ # # ------------------------------------------------------------------------------ import abc -from typing import Dict, Set, Any +from typing import Dict, Set, Any, Optional import os.path from pkg_resources import get_distribution, DistributionNotFound @@ -66,7 +66,15 @@ def get_dataloaders(self) -> Dict[str, Dict[str, Any]]: pass @abc.abstractmethod - def get_compatibilities(self) -> Dict[str, Set[str]]: + def get_prediction_dataloaders(self) -> Dict[str, Dict[str, Any]]: + """ + Returns the prediction dataloaders this factory produces. + The key is the name of the dataloader and the values are their default parameters + """ + pass + + @abc.abstractmethod + def get_data_compatibilities(self) -> Dict[str, Set[str]]: """ A model is compatible with a dataloader if they can be used together to construct a MachineLearningInterface with the get_MLI function. @@ -76,17 +84,34 @@ def get_compatibilities(self) -> Dict[str, Set[str]]: """ pass + @abc.abstractmethod + def get_pred_compatibilities(self) -> Dict[str, Set[str]]: + """ + A model is compatible with a prediction dataloader if they can be used together to + construct a MachineLearningInterface with the get_MLI function. + + Returns a dictionary that defines which model is compatible + with which prediction dataloader. + """ + pass + @abc.abstractmethod def get_mli(self, model_name: str, model_params: str, - dataloader_name: str, dataset_params: str) -> MachineLearningInterface: + dataloader_name: str, dataset_params: str, + prediction_dataloader_name: Optional[str], + prediction_dataset_params: Optional[str]) -> MachineLearningInterface: """ @param model_name: name of a model, must be in the set return by get_models @param model_params: user defined parameters for the model @param dataloader_name: name of a dataloader to be used: - must be in the set returned by get_dataloaders - - must be compatible with model_name as defined by get_compatibilities + - must be compatible with model_name as defined by get_data_compatibilities @param dataset_params: user defined parameters for the dataset + @param prediction_dataloader_name: name of a prediction dataloader to be used: + - must be in the set returned by get_prediction_dataloaders + - must be compatible with model_name as defined by get_pred_compatibilities + @param prediction_dataset_params: user defined parameters for the prediction and preprocessing @return: Instance of MachineLearningInterface Constructs an object that implements MachineLearningInterface whose underlying model is model_name and dataset is loaded by dataloader_name. diff --git a/colearn_grpc/proto/generated/interface_pb2.py b/colearn_grpc/proto/generated/interface_pb2.py index 113e4128..a9990c04 100644 --- a/colearn_grpc/proto/generated/interface_pb2.py +++ b/colearn_grpc/proto/generated/interface_pb2.py @@ -21,7 +21,7 @@ syntax='proto3', serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\x0finterface.proto\x12\x13\x63ontract_learn.grpc\x1a\x1bgoogle/protobuf/empty.proto\"\x83\x01\n\x0eRequestMLSetup\x12\x1b\n\x13\x64\x61taset_loader_name\x18\x01 \x01(\t\x12!\n\x19\x64\x61taset_loader_parameters\x18\x02 \x01(\t\x12\x17\n\x0fmodel_arch_name\x18\x03 \x01(\t\x12\x18\n\x10model_parameters\x18\x04 \x01(\t\"Z\n\x0fResponseMLSetup\x12\x32\n\x06status\x18\x01 \x01(\x0e\x32\".contract_learn.grpc.MLSetupStatus\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\"p\n\x0e\x44iffPrivBudget\x12\x16\n\x0etarget_epsilon\x18\x01 \x01(\x02\x12\x14\n\x0ctarget_delta\x18\x02 \x01(\x02\x12\x18\n\x10\x63onsumed_epsilon\x18\x03 \x01(\x02\x12\x16\n\x0e\x63onsumed_delta\x18\x04 \x01(\x02\"I\n\x0fTrainingSummary\x12\x36\n\tdp_budget\x18\x01 \x01(\x0b\x32#.contract_learn.grpc.DiffPrivBudget\"\x87\x01\n\x0bWeightsPart\x12\x0f\n\x07weights\x18\x01 \x01(\x0c\x12\x12\n\nbyte_index\x18\x02 \x01(\r\x12\x13\n\x0btotal_bytes\x18\x03 \x01(\x04\x12>\n\x10training_summary\x18\n \x01(\x0b\x32$.contract_learn.grpc.TrainingSummary\"G\n\x0fProposedWeights\x12\x12\n\nvote_score\x18\x01 \x01(\x02\x12\x12\n\ntest_score\x18\x02 \x01(\x02\x12\x0c\n\x04vote\x18\x03 \x01(\x08\"\x0f\n\rRequestStatus\"C\n\x0eResponseStatus\x12\x31\n\x06status\x18\x01 \x01(\x0e\x32!.contract_learn.grpc.SystemStatus\"=\n\x11\x44\x61tasetLoaderSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1a\n\x12\x64\x65\x66\x61ult_parameters\x18\x02 \x01(\t\"9\n\rModelArchSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1a\n\x12\x64\x65\x66\x61ult_parameters\x18\x02 \x01(\t\"D\n\x11\x43ompatibilitySpec\x12\x1a\n\x12model_architecture\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x61taloaders\x18\x02 \x03(\t\"\"\n\x0fResponseVersion\x12\x0f\n\x07version\x18\x01 \x01(\t\"O\n\x14ResponseCurrentModel\x12\x14\n\x0cmodel_format\x18\x01 \x01(\r\x12\x12\n\nmodel_file\x18\x02 \x01(\t\x12\r\n\x05model\x18\x03 \x01(\x0c\"\xd9\x01\n\x17ResponseSupportedSystem\x12<\n\x0c\x64\x61ta_loaders\x18\x01 \x03(\x0b\x32&.contract_learn.grpc.DatasetLoaderSpec\x12?\n\x13model_architectures\x18\x02 \x03(\x0b\x32\".contract_learn.grpc.ModelArchSpec\x12?\n\x0f\x63ompatibilities\x18\x03 \x03(\x0b\x32&.contract_learn.grpc.CompatibilitySpec*6\n\rMLSetupStatus\x12\r\n\tUNDEFINED\x10\x00\x12\x0b\n\x07SUCCESS\x10\x01\x12\t\n\x05\x45RROR\x10\x02*J\n\x0cSystemStatus\x12\x0b\n\x07WORKING\x10\x00\x12\x0c\n\x08NO_MODEL\x10\x01\x12\x12\n\x0eINTERNAL_ERROR\x10\x02\x12\x0b\n\x07UNKNOWN\x10\x03\x32\x84\x06\n\x0bGRPCLearner\x12L\n\x0cQueryVersion\x12\x16.google.protobuf.Empty\x1a$.contract_learn.grpc.ResponseVersion\x12\\\n\x14QuerySupportedSystem\x12\x16.google.protobuf.Empty\x1a,.contract_learn.grpc.ResponseSupportedSystem\x12T\n\x0fGetCurrentModel\x12\x16.google.protobuf.Empty\x1a).contract_learn.grpc.ResponseCurrentModel\x12T\n\x07MLSetup\x12#.contract_learn.grpc.RequestMLSetup\x1a$.contract_learn.grpc.ResponseMLSetup\x12L\n\x0eProposeWeights\x12\x16.google.protobuf.Empty\x1a .contract_learn.grpc.WeightsPart0\x01\x12W\n\x0bTestWeights\x12 .contract_learn.grpc.WeightsPart\x1a$.contract_learn.grpc.ProposedWeights(\x01\x12H\n\nSetWeights\x12 .contract_learn.grpc.WeightsPart\x1a\x16.google.protobuf.Empty(\x01\x12O\n\x11GetCurrentWeights\x12\x16.google.protobuf.Empty\x1a .contract_learn.grpc.WeightsPart0\x01\x12[\n\x0cStatusStream\x12\".contract_learn.grpc.RequestStatus\x1a#.contract_learn.grpc.ResponseStatus(\x01\x30\x01\x62\x06proto3' + serialized_pb=b'\n\x0finterface.proto\x12\x13\x63ontract_learn.grpc\x1a\x1bgoogle/protobuf/empty.proto\"\xaf\x02\n\x0eRequestMLSetup\x12\x1b\n\x13\x64\x61taset_loader_name\x18\x01 \x01(\t\x12!\n\x19\x64\x61taset_loader_parameters\x18\x02 \x01(\t\x12\x17\n\x0fmodel_arch_name\x18\x03 \x01(\t\x12\x18\n\x10model_parameters\x18\x04 \x01(\t\x12+\n\x1eprediction_dataset_loader_name\x18\x05 \x01(\tH\x00\x88\x01\x01\x12\x31\n$prediction_dataset_loader_parameters\x18\x06 \x01(\tH\x01\x88\x01\x01\x42!\n\x1f_prediction_dataset_loader_nameB\'\n%_prediction_dataset_loader_parameters\"Z\n\x0fResponseMLSetup\x12\x32\n\x06status\x18\x01 \x01(\x0e\x32\".contract_learn.grpc.MLSetupStatus\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\"p\n\x0e\x44iffPrivBudget\x12\x16\n\x0etarget_epsilon\x18\x01 \x01(\x02\x12\x14\n\x0ctarget_delta\x18\x02 \x01(\x02\x12\x18\n\x10\x63onsumed_epsilon\x18\x03 \x01(\x02\x12\x16\n\x0e\x63onsumed_delta\x18\x04 \x01(\x02\"I\n\x0fTrainingSummary\x12\x36\n\tdp_budget\x18\x01 \x01(\x0b\x32#.contract_learn.grpc.DiffPrivBudget\"\x87\x01\n\x0bWeightsPart\x12\x0f\n\x07weights\x18\x01 \x01(\x0c\x12\x12\n\nbyte_index\x18\x02 \x01(\r\x12\x13\n\x0btotal_bytes\x18\x03 \x01(\x04\x12>\n\x10training_summary\x18\n \x01(\x0b\x32$.contract_learn.grpc.TrainingSummary\"\xa8\x02\n\x0fProposedWeights\x12G\n\nvote_score\x18\x01 \x03(\x0b\x32\x33.contract_learn.grpc.ProposedWeights.VoteScoreEntry\x12G\n\ntest_score\x18\x02 \x03(\x0b\x32\x33.contract_learn.grpc.ProposedWeights.TestScoreEntry\x12\x0c\n\x04vote\x18\x03 \x01(\x08\x12\x11\n\tcriterion\x18\x04 \x01(\t\x1a\x30\n\x0eVoteScoreEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x1a\x30\n\x0eTestScoreEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\"\x0f\n\rRequestStatus\"C\n\x0eResponseStatus\x12\x31\n\x06status\x18\x01 \x01(\x0e\x32!.contract_learn.grpc.SystemStatus\"=\n\x11\x44\x61tasetLoaderSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1a\n\x12\x64\x65\x66\x61ult_parameters\x18\x02 \x01(\t\"G\n\x1bPredictionDatasetLoaderSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1a\n\x12\x64\x65\x66\x61ult_parameters\x18\x02 \x01(\t\"9\n\rModelArchSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1a\n\x12\x64\x65\x66\x61ult_parameters\x18\x02 \x01(\t\"H\n\x15\x44\x61taCompatibilitySpec\x12\x1a\n\x12model_architecture\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x61taloaders\x18\x02 \x03(\t\"\\\n\x1ePredictonDataCompatibilitySpec\x12\x1a\n\x12model_architecture\x18\x01 \x01(\t\x12\x1e\n\x16prediction_dataloaders\x18\x02 \x03(\t\"\"\n\x0fResponseVersion\x12\x0f\n\x07version\x18\x01 \x01(\t\"O\n\x14ResponseCurrentModel\x12\x14\n\x0cmodel_format\x18\x01 \x01(\r\x12\x12\n\nmodel_file\x18\x02 \x01(\t\x12\r\n\x05model\x18\x03 \x01(\x0c\"\x88\x03\n\x17ResponseSupportedSystem\x12<\n\x0c\x64\x61ta_loaders\x18\x01 \x03(\x0b\x32&.contract_learn.grpc.DatasetLoaderSpec\x12Q\n\x17prediction_data_loaders\x18\x02 \x03(\x0b\x32\x30.contract_learn.grpc.PredictionDatasetLoaderSpec\x12?\n\x13model_architectures\x18\x03 \x03(\x0b\x32\".contract_learn.grpc.ModelArchSpec\x12H\n\x14\x64\x61ta_compatibilities\x18\x04 \x03(\x0b\x32*.contract_learn.grpc.DataCompatibilitySpec\x12Q\n\x14pred_compatibilities\x18\x05 \x03(\x0b\x32\x33.contract_learn.grpc.PredictonDataCompatibilitySpec\"o\n\x11PredictionRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ninput_data\x18\x02 \x01(\x0c\x12 \n\x13pred_dataloader_key\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x16\n\x14_pred_dataloader_key\";\n\x12PredictionResponse\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x17\n\x0fprediction_data\x18\x02 \x01(\x0c*6\n\rMLSetupStatus\x12\r\n\tUNDEFINED\x10\x00\x12\x0b\n\x07SUCCESS\x10\x01\x12\t\n\x05\x45RROR\x10\x02*J\n\x0cSystemStatus\x12\x0b\n\x07WORKING\x10\x00\x12\x0c\n\x08NO_MODEL\x10\x01\x12\x12\n\x0eINTERNAL_ERROR\x10\x02\x12\x0b\n\x07UNKNOWN\x10\x03\x32\xe7\x06\n\x0bGRPCLearner\x12L\n\x0cQueryVersion\x12\x16.google.protobuf.Empty\x1a$.contract_learn.grpc.ResponseVersion\x12\\\n\x14QuerySupportedSystem\x12\x16.google.protobuf.Empty\x1a,.contract_learn.grpc.ResponseSupportedSystem\x12T\n\x0fGetCurrentModel\x12\x16.google.protobuf.Empty\x1a).contract_learn.grpc.ResponseCurrentModel\x12T\n\x07MLSetup\x12#.contract_learn.grpc.RequestMLSetup\x1a$.contract_learn.grpc.ResponseMLSetup\x12L\n\x0eProposeWeights\x12\x16.google.protobuf.Empty\x1a .contract_learn.grpc.WeightsPart0\x01\x12W\n\x0bTestWeights\x12 .contract_learn.grpc.WeightsPart\x1a$.contract_learn.grpc.ProposedWeights(\x01\x12H\n\nSetWeights\x12 .contract_learn.grpc.WeightsPart\x1a\x16.google.protobuf.Empty(\x01\x12O\n\x11GetCurrentWeights\x12\x16.google.protobuf.Empty\x1a .contract_learn.grpc.WeightsPart0\x01\x12[\n\x0cStatusStream\x12\".contract_learn.grpc.RequestStatus\x1a#.contract_learn.grpc.ResponseStatus(\x01\x30\x01\x12\x61\n\x0eMakePrediction\x12&.contract_learn.grpc.PredictionRequest\x1a\'.contract_learn.grpc.PredictionResponseb\x06proto3' , dependencies=[google_dot_protobuf_dot_empty__pb2.DESCRIPTOR,]) @@ -50,8 +50,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=1310, - serialized_end=1364, + serialized_start=2228, + serialized_end=2282, ) _sym_db.RegisterEnumDescriptor(_MLSETUPSTATUS) @@ -86,8 +86,8 @@ ], containing_type=None, serialized_options=None, - serialized_start=1366, - serialized_end=1440, + serialized_start=2284, + serialized_end=2358, ) _sym_db.RegisterEnumDescriptor(_SYSTEMSTATUS) @@ -138,6 +138,20 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='prediction_dataset_loader_name', full_name='contract_learn.grpc.RequestMLSetup.prediction_dataset_loader_name', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='prediction_dataset_loader_parameters', full_name='contract_learn.grpc.RequestMLSetup.prediction_dataset_loader_parameters', index=5, + number=6, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ], extensions=[ ], @@ -149,9 +163,19 @@ syntax='proto3', extension_ranges=[], oneofs=[ + _descriptor.OneofDescriptor( + name='_prediction_dataset_loader_name', full_name='contract_learn.grpc.RequestMLSetup._prediction_dataset_loader_name', + index=0, containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[]), + _descriptor.OneofDescriptor( + name='_prediction_dataset_loader_parameters', full_name='contract_learn.grpc.RequestMLSetup._prediction_dataset_loader_parameters', + index=1, containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[]), ], serialized_start=70, - serialized_end=201, + serialized_end=373, ) @@ -189,8 +213,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=203, - serialized_end=293, + serialized_start=375, + serialized_end=465, ) @@ -242,8 +266,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=295, - serialized_end=407, + serialized_start=467, + serialized_end=579, ) @@ -274,8 +298,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=409, - serialized_end=482, + serialized_start=581, + serialized_end=654, ) @@ -327,10 +351,86 @@ extension_ranges=[], oneofs=[ ], - serialized_start=485, - serialized_end=620, + serialized_start=657, + serialized_end=792, +) + + +_PROPOSEDWEIGHTS_VOTESCOREENTRY = _descriptor.Descriptor( + name='VoteScoreEntry', + full_name='contract_learn.grpc.ProposedWeights.VoteScoreEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='contract_learn.grpc.ProposedWeights.VoteScoreEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='value', full_name='contract_learn.grpc.ProposedWeights.VoteScoreEntry.value', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=993, + serialized_end=1041, ) +_PROPOSEDWEIGHTS_TESTSCOREENTRY = _descriptor.Descriptor( + name='TestScoreEntry', + full_name='contract_learn.grpc.ProposedWeights.TestScoreEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='contract_learn.grpc.ProposedWeights.TestScoreEntry.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='value', full_name='contract_learn.grpc.ProposedWeights.TestScoreEntry.value', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=b'8\001', + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1043, + serialized_end=1091, +) _PROPOSEDWEIGHTS = _descriptor.Descriptor( name='ProposedWeights', @@ -342,15 +442,15 @@ fields=[ _descriptor.FieldDescriptor( name='vote_score', full_name='contract_learn.grpc.ProposedWeights.vote_score', index=0, - number=1, type=2, cpp_type=6, label=1, - has_default_value=False, default_value=float(0), + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), _descriptor.FieldDescriptor( name='test_score', full_name='contract_learn.grpc.ProposedWeights.test_score', index=1, - number=2, type=2, cpp_type=6, label=1, - has_default_value=False, default_value=float(0), + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), @@ -361,10 +461,17 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='criterion', full_name='contract_learn.grpc.ProposedWeights.criterion', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ], extensions=[ ], - nested_types=[], + nested_types=[_PROPOSEDWEIGHTS_VOTESCOREENTRY, _PROPOSEDWEIGHTS_TESTSCOREENTRY, ], enum_types=[ ], serialized_options=None, @@ -373,8 +480,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=622, - serialized_end=693, + serialized_start=795, + serialized_end=1091, ) @@ -398,8 +505,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=695, - serialized_end=710, + serialized_start=1093, + serialized_end=1108, ) @@ -430,8 +537,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=712, - serialized_end=779, + serialized_start=1110, + serialized_end=1177, ) @@ -469,8 +576,47 @@ extension_ranges=[], oneofs=[ ], - serialized_start=781, - serialized_end=842, + serialized_start=1179, + serialized_end=1240, +) + + +_PREDICTIONDATASETLOADERSPEC = _descriptor.Descriptor( + name='PredictionDatasetLoaderSpec', + full_name='contract_learn.grpc.PredictionDatasetLoaderSpec', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='contract_learn.grpc.PredictionDatasetLoaderSpec.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='default_parameters', full_name='contract_learn.grpc.PredictionDatasetLoaderSpec.default_parameters', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1242, + serialized_end=1313, ) @@ -508,28 +654,67 @@ extension_ranges=[], oneofs=[ ], - serialized_start=844, - serialized_end=901, + serialized_start=1315, + serialized_end=1372, +) + + +_DATACOMPATIBILITYSPEC = _descriptor.Descriptor( + name='DataCompatibilitySpec', + full_name='contract_learn.grpc.DataCompatibilitySpec', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='model_architecture', full_name='contract_learn.grpc.DataCompatibilitySpec.model_architecture', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='dataloaders', full_name='contract_learn.grpc.DataCompatibilitySpec.dataloaders', index=1, + number=2, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1374, + serialized_end=1446, ) -_COMPATIBILITYSPEC = _descriptor.Descriptor( - name='CompatibilitySpec', - full_name='contract_learn.grpc.CompatibilitySpec', +_PREDICTONDATACOMPATIBILITYSPEC = _descriptor.Descriptor( + name='PredictonDataCompatibilitySpec', + full_name='contract_learn.grpc.PredictonDataCompatibilitySpec', filename=None, file=DESCRIPTOR, containing_type=None, create_key=_descriptor._internal_create_key, fields=[ _descriptor.FieldDescriptor( - name='model_architecture', full_name='contract_learn.grpc.CompatibilitySpec.model_architecture', index=0, + name='model_architecture', full_name='contract_learn.grpc.PredictonDataCompatibilitySpec.model_architecture', index=0, number=1, type=9, cpp_type=9, label=1, has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), _descriptor.FieldDescriptor( - name='dataloaders', full_name='contract_learn.grpc.CompatibilitySpec.dataloaders', index=1, + name='prediction_dataloaders', full_name='contract_learn.grpc.PredictonDataCompatibilitySpec.prediction_dataloaders', index=1, number=2, type=9, cpp_type=9, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, @@ -547,8 +732,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=903, - serialized_end=971, + serialized_start=1448, + serialized_end=1540, ) @@ -579,8 +764,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=973, - serialized_end=1007, + serialized_start=1542, + serialized_end=1576, ) @@ -625,8 +810,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1009, - serialized_end=1088, + serialized_start=1578, + serialized_end=1657, ) @@ -646,19 +831,123 @@ is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), _descriptor.FieldDescriptor( - name='model_architectures', full_name='contract_learn.grpc.ResponseSupportedSystem.model_architectures', index=1, + name='prediction_data_loaders', full_name='contract_learn.grpc.ResponseSupportedSystem.prediction_data_loaders', index=1, number=2, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), _descriptor.FieldDescriptor( - name='compatibilities', full_name='contract_learn.grpc.ResponseSupportedSystem.compatibilities', index=2, + name='model_architectures', full_name='contract_learn.grpc.ResponseSupportedSystem.model_architectures', index=2, number=3, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='data_compatibilities', full_name='contract_learn.grpc.ResponseSupportedSystem.data_compatibilities', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='pred_compatibilities', full_name='contract_learn.grpc.ResponseSupportedSystem.pred_compatibilities', index=4, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1660, + serialized_end=2052, +) + + +_PREDICTIONREQUEST = _descriptor.Descriptor( + name='PredictionRequest', + full_name='contract_learn.grpc.PredictionRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='contract_learn.grpc.PredictionRequest.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='input_data', full_name='contract_learn.grpc.PredictionRequest.input_data', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='pred_dataloader_key', full_name='contract_learn.grpc.PredictionRequest.pred_dataloader_key', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='_pred_dataloader_key', full_name='contract_learn.grpc.PredictionRequest._pred_dataloader_key', + index=0, containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[]), + ], + serialized_start=2054, + serialized_end=2165, +) + + +_PREDICTIONRESPONSE = _descriptor.Descriptor( + name='PredictionResponse', + full_name='contract_learn.grpc.PredictionResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='contract_learn.grpc.PredictionResponse.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='prediction_data', full_name='contract_learn.grpc.PredictionResponse.prediction_data', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ], extensions=[ ], @@ -671,17 +960,32 @@ extension_ranges=[], oneofs=[ ], - serialized_start=1091, - serialized_end=1308, + serialized_start=2167, + serialized_end=2226, ) +_REQUESTMLSETUP.oneofs_by_name['_prediction_dataset_loader_name'].fields.append( + _REQUESTMLSETUP.fields_by_name['prediction_dataset_loader_name']) +_REQUESTMLSETUP.fields_by_name['prediction_dataset_loader_name'].containing_oneof = _REQUESTMLSETUP.oneofs_by_name['_prediction_dataset_loader_name'] +_REQUESTMLSETUP.oneofs_by_name['_prediction_dataset_loader_parameters'].fields.append( + _REQUESTMLSETUP.fields_by_name['prediction_dataset_loader_parameters']) +_REQUESTMLSETUP.fields_by_name['prediction_dataset_loader_parameters'].containing_oneof = _REQUESTMLSETUP.oneofs_by_name['_prediction_dataset_loader_parameters'] _RESPONSEMLSETUP.fields_by_name['status'].enum_type = _MLSETUPSTATUS _TRAININGSUMMARY.fields_by_name['dp_budget'].message_type = _DIFFPRIVBUDGET _WEIGHTSPART.fields_by_name['training_summary'].message_type = _TRAININGSUMMARY +_PROPOSEDWEIGHTS_VOTESCOREENTRY.containing_type = _PROPOSEDWEIGHTS +_PROPOSEDWEIGHTS_TESTSCOREENTRY.containing_type = _PROPOSEDWEIGHTS +_PROPOSEDWEIGHTS.fields_by_name['vote_score'].message_type = _PROPOSEDWEIGHTS_VOTESCOREENTRY +_PROPOSEDWEIGHTS.fields_by_name['test_score'].message_type = _PROPOSEDWEIGHTS_TESTSCOREENTRY _RESPONSESTATUS.fields_by_name['status'].enum_type = _SYSTEMSTATUS _RESPONSESUPPORTEDSYSTEM.fields_by_name['data_loaders'].message_type = _DATASETLOADERSPEC +_RESPONSESUPPORTEDSYSTEM.fields_by_name['prediction_data_loaders'].message_type = _PREDICTIONDATASETLOADERSPEC _RESPONSESUPPORTEDSYSTEM.fields_by_name['model_architectures'].message_type = _MODELARCHSPEC -_RESPONSESUPPORTEDSYSTEM.fields_by_name['compatibilities'].message_type = _COMPATIBILITYSPEC +_RESPONSESUPPORTEDSYSTEM.fields_by_name['data_compatibilities'].message_type = _DATACOMPATIBILITYSPEC +_RESPONSESUPPORTEDSYSTEM.fields_by_name['pred_compatibilities'].message_type = _PREDICTONDATACOMPATIBILITYSPEC +_PREDICTIONREQUEST.oneofs_by_name['_pred_dataloader_key'].fields.append( + _PREDICTIONREQUEST.fields_by_name['pred_dataloader_key']) +_PREDICTIONREQUEST.fields_by_name['pred_dataloader_key'].containing_oneof = _PREDICTIONREQUEST.oneofs_by_name['_pred_dataloader_key'] DESCRIPTOR.message_types_by_name['RequestMLSetup'] = _REQUESTMLSETUP DESCRIPTOR.message_types_by_name['ResponseMLSetup'] = _RESPONSEMLSETUP DESCRIPTOR.message_types_by_name['DiffPrivBudget'] = _DIFFPRIVBUDGET @@ -691,11 +995,15 @@ DESCRIPTOR.message_types_by_name['RequestStatus'] = _REQUESTSTATUS DESCRIPTOR.message_types_by_name['ResponseStatus'] = _RESPONSESTATUS DESCRIPTOR.message_types_by_name['DatasetLoaderSpec'] = _DATASETLOADERSPEC +DESCRIPTOR.message_types_by_name['PredictionDatasetLoaderSpec'] = _PREDICTIONDATASETLOADERSPEC DESCRIPTOR.message_types_by_name['ModelArchSpec'] = _MODELARCHSPEC -DESCRIPTOR.message_types_by_name['CompatibilitySpec'] = _COMPATIBILITYSPEC +DESCRIPTOR.message_types_by_name['DataCompatibilitySpec'] = _DATACOMPATIBILITYSPEC +DESCRIPTOR.message_types_by_name['PredictonDataCompatibilitySpec'] = _PREDICTONDATACOMPATIBILITYSPEC DESCRIPTOR.message_types_by_name['ResponseVersion'] = _RESPONSEVERSION DESCRIPTOR.message_types_by_name['ResponseCurrentModel'] = _RESPONSECURRENTMODEL DESCRIPTOR.message_types_by_name['ResponseSupportedSystem'] = _RESPONSESUPPORTEDSYSTEM +DESCRIPTOR.message_types_by_name['PredictionRequest'] = _PREDICTIONREQUEST +DESCRIPTOR.message_types_by_name['PredictionResponse'] = _PREDICTIONRESPONSE DESCRIPTOR.enum_types_by_name['MLSetupStatus'] = _MLSETUPSTATUS DESCRIPTOR.enum_types_by_name['SystemStatus'] = _SYSTEMSTATUS _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -736,11 +1044,27 @@ _sym_db.RegisterMessage(WeightsPart) ProposedWeights = _reflection.GeneratedProtocolMessageType('ProposedWeights', (_message.Message,), { + + 'VoteScoreEntry' : _reflection.GeneratedProtocolMessageType('VoteScoreEntry', (_message.Message,), { + 'DESCRIPTOR' : _PROPOSEDWEIGHTS_VOTESCOREENTRY, + '__module__' : 'interface_pb2' + # @@protoc_insertion_point(class_scope:contract_learn.grpc.ProposedWeights.VoteScoreEntry) + }) + , + + 'TestScoreEntry' : _reflection.GeneratedProtocolMessageType('TestScoreEntry', (_message.Message,), { + 'DESCRIPTOR' : _PROPOSEDWEIGHTS_TESTSCOREENTRY, + '__module__' : 'interface_pb2' + # @@protoc_insertion_point(class_scope:contract_learn.grpc.ProposedWeights.TestScoreEntry) + }) + , 'DESCRIPTOR' : _PROPOSEDWEIGHTS, '__module__' : 'interface_pb2' # @@protoc_insertion_point(class_scope:contract_learn.grpc.ProposedWeights) }) _sym_db.RegisterMessage(ProposedWeights) +_sym_db.RegisterMessage(ProposedWeights.VoteScoreEntry) +_sym_db.RegisterMessage(ProposedWeights.TestScoreEntry) RequestStatus = _reflection.GeneratedProtocolMessageType('RequestStatus', (_message.Message,), { 'DESCRIPTOR' : _REQUESTSTATUS, @@ -763,6 +1087,13 @@ }) _sym_db.RegisterMessage(DatasetLoaderSpec) +PredictionDatasetLoaderSpec = _reflection.GeneratedProtocolMessageType('PredictionDatasetLoaderSpec', (_message.Message,), { + 'DESCRIPTOR' : _PREDICTIONDATASETLOADERSPEC, + '__module__' : 'interface_pb2' + # @@protoc_insertion_point(class_scope:contract_learn.grpc.PredictionDatasetLoaderSpec) + }) +_sym_db.RegisterMessage(PredictionDatasetLoaderSpec) + ModelArchSpec = _reflection.GeneratedProtocolMessageType('ModelArchSpec', (_message.Message,), { 'DESCRIPTOR' : _MODELARCHSPEC, '__module__' : 'interface_pb2' @@ -770,12 +1101,19 @@ }) _sym_db.RegisterMessage(ModelArchSpec) -CompatibilitySpec = _reflection.GeneratedProtocolMessageType('CompatibilitySpec', (_message.Message,), { - 'DESCRIPTOR' : _COMPATIBILITYSPEC, +DataCompatibilitySpec = _reflection.GeneratedProtocolMessageType('DataCompatibilitySpec', (_message.Message,), { + 'DESCRIPTOR' : _DATACOMPATIBILITYSPEC, + '__module__' : 'interface_pb2' + # @@protoc_insertion_point(class_scope:contract_learn.grpc.DataCompatibilitySpec) + }) +_sym_db.RegisterMessage(DataCompatibilitySpec) + +PredictonDataCompatibilitySpec = _reflection.GeneratedProtocolMessageType('PredictonDataCompatibilitySpec', (_message.Message,), { + 'DESCRIPTOR' : _PREDICTONDATACOMPATIBILITYSPEC, '__module__' : 'interface_pb2' - # @@protoc_insertion_point(class_scope:contract_learn.grpc.CompatibilitySpec) + # @@protoc_insertion_point(class_scope:contract_learn.grpc.PredictonDataCompatibilitySpec) }) -_sym_db.RegisterMessage(CompatibilitySpec) +_sym_db.RegisterMessage(PredictonDataCompatibilitySpec) ResponseVersion = _reflection.GeneratedProtocolMessageType('ResponseVersion', (_message.Message,), { 'DESCRIPTOR' : _RESPONSEVERSION, @@ -798,7 +1136,23 @@ }) _sym_db.RegisterMessage(ResponseSupportedSystem) +PredictionRequest = _reflection.GeneratedProtocolMessageType('PredictionRequest', (_message.Message,), { + 'DESCRIPTOR' : _PREDICTIONREQUEST, + '__module__' : 'interface_pb2' + # @@protoc_insertion_point(class_scope:contract_learn.grpc.PredictionRequest) + }) +_sym_db.RegisterMessage(PredictionRequest) +PredictionResponse = _reflection.GeneratedProtocolMessageType('PredictionResponse', (_message.Message,), { + 'DESCRIPTOR' : _PREDICTIONRESPONSE, + '__module__' : 'interface_pb2' + # @@protoc_insertion_point(class_scope:contract_learn.grpc.PredictionResponse) + }) +_sym_db.RegisterMessage(PredictionResponse) + + +_PROPOSEDWEIGHTS_VOTESCOREENTRY._options = None +_PROPOSEDWEIGHTS_TESTSCOREENTRY._options = None _GRPCLEARNER = _descriptor.ServiceDescriptor( name='GRPCLearner', @@ -807,8 +1161,8 @@ index=0, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=1443, - serialized_end=2215, + serialized_start=2361, + serialized_end=3232, methods=[ _descriptor.MethodDescriptor( name='QueryVersion', @@ -900,6 +1254,16 @@ serialized_options=None, create_key=_descriptor._internal_create_key, ), + _descriptor.MethodDescriptor( + name='MakePrediction', + full_name='contract_learn.grpc.GRPCLearner.MakePrediction', + index=9, + containing_service=None, + input_type=_PREDICTIONREQUEST, + output_type=_PREDICTIONRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), ]) _sym_db.RegisterServiceDescriptor(_GRPCLEARNER) diff --git a/colearn_grpc/proto/generated/interface_pb2_grpc.py b/colearn_grpc/proto/generated/interface_pb2_grpc.py index 97ac7266..483aa5d7 100644 --- a/colearn_grpc/proto/generated/interface_pb2_grpc.py +++ b/colearn_grpc/proto/generated/interface_pb2_grpc.py @@ -60,6 +60,11 @@ def __init__(self, channel): request_serializer=interface__pb2.RequestStatus.SerializeToString, response_deserializer=interface__pb2.ResponseStatus.FromString, ) + self.MakePrediction = channel.unary_unary( + '/contract_learn.grpc.GRPCLearner/MakePrediction', + request_serializer=interface__pb2.PredictionRequest.SerializeToString, + response_deserializer=interface__pb2.PredictionResponse.FromString, + ) class GRPCLearnerServicer(object): @@ -119,6 +124,12 @@ def StatusStream(self, request_iterator, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def MakePrediction(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_GRPCLearnerServicer_to_server(servicer, server): rpc_method_handlers = { @@ -167,6 +178,11 @@ def add_GRPCLearnerServicer_to_server(servicer, server): request_deserializer=interface__pb2.RequestStatus.FromString, response_serializer=interface__pb2.ResponseStatus.SerializeToString, ), + 'MakePrediction': grpc.unary_unary_rpc_method_handler( + servicer.MakePrediction, + request_deserializer=interface__pb2.PredictionRequest.FromString, + response_serializer=interface__pb2.PredictionResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'contract_learn.grpc.GRPCLearner', rpc_method_handlers) @@ -329,3 +345,20 @@ def StatusStream(request_iterator, interface__pb2.ResponseStatus.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def MakePrediction(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/contract_learn.grpc.GRPCLearner/MakePrediction', + interface__pb2.PredictionRequest.SerializeToString, + interface__pb2.PredictionResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/colearn_grpc/proto/interface.proto b/colearn_grpc/proto/interface.proto index 751ccd0a..78a3a9d2 100644 --- a/colearn_grpc/proto/interface.proto +++ b/colearn_grpc/proto/interface.proto @@ -9,6 +9,8 @@ message RequestMLSetup { string dataset_loader_parameters = 2; string model_arch_name = 3; string model_parameters = 4; + optional string prediction_dataset_loader_name = 5; + optional string prediction_dataset_loader_parameters = 6; }; enum MLSetupStatus { @@ -42,9 +44,10 @@ message WeightsPart { }; message ProposedWeights { - float vote_score = 1; - float test_score = 2; + map vote_score = 1; + map test_score = 2; bool vote = 3; + string criterion = 4; }; message RequestStatus { @@ -66,16 +69,26 @@ message DatasetLoaderSpec { string default_parameters = 2; // JSON encoded default parameters }; +message PredictionDatasetLoaderSpec { + string name = 1; + string default_parameters = 2; // JSON encoded default parameters +}; + message ModelArchSpec { string name = 1; string default_parameters = 2; // JSON encoded default parameters for the model arch. }; -message CompatibilitySpec { +message DataCompatibilitySpec { string model_architecture = 1; repeated string dataloaders = 2; }; +message PredictonDataCompatibilitySpec { + string model_architecture = 1; + repeated string prediction_dataloaders = 2; +}; + message ResponseVersion { string version = 1; }; @@ -88,10 +101,24 @@ message ResponseCurrentModel { message ResponseSupportedSystem { repeated DatasetLoaderSpec data_loaders = 1; - repeated ModelArchSpec model_architectures = 2; - repeated CompatibilitySpec compatibilities = 3; + repeated PredictionDatasetLoaderSpec prediction_data_loaders = 2; + repeated ModelArchSpec model_architectures = 3; + repeated DataCompatibilitySpec data_compatibilities = 4; + repeated PredictonDataCompatibilitySpec pred_compatibilities = 5; }; +message PredictionRequest { + string name = 1; + bytes input_data = 2; + optional string pred_dataloader_key = 3; +}; + +message PredictionResponse { + string name = 1; + bytes prediction_data = 2; +}; + + service GRPCLearner { rpc QueryVersion(google.protobuf.Empty) returns (ResponseVersion); rpc QuerySupportedSystem(google.protobuf.Empty) returns (ResponseSupportedSystem); @@ -102,4 +129,5 @@ service GRPCLearner { rpc SetWeights(stream WeightsPart) returns (google.protobuf.Empty); rpc GetCurrentWeights(google.protobuf.Empty) returns (stream WeightsPart); rpc StatusStream(stream RequestStatus) returns (stream ResponseStatus); + rpc MakePrediction(PredictionRequest) returns (PredictionResponse); }; diff --git a/colearn_grpc/test_example_mli_factory.py b/colearn_grpc/test_example_mli_factory.py index 2dd3cc41..25621b10 100644 --- a/colearn_grpc/test_example_mli_factory.py +++ b/colearn_grpc/test_example_mli_factory.py @@ -45,7 +45,7 @@ def factory() -> ExampleMliFactory: def test_setup(factory): assert len(factory.get_models()) > 0 assert len(factory.get_dataloaders()) > 0 - assert len(factory.get_compatibilities()) > 0 + assert len(factory.get_data_compatibilities()) > 0 def test_model_names(factory): @@ -64,7 +64,7 @@ def test_dataloader_names(factory): def test_compatibilities(factory): for model in MODEL_NAMES: assert model in factory.get_models().keys() - for dl in factory.get_compatibilities()[model]: + for dl in factory.get_data_compatibilities()[model]: assert dl in DATALOADER_NAMES diff --git a/colearn_grpc/test_grpc_server.py b/colearn_grpc/test_grpc_server.py new file mode 100644 index 00000000..2d8f41bd --- /dev/null +++ b/colearn_grpc/test_grpc_server.py @@ -0,0 +1,107 @@ +# ------------------------------------------------------------------------------ +# +# Copyright 2021 Fetch.AI Limited +# +# Licensed under the Creative Commons Attribution-NonCommercial International +# License, Version 4.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://creativecommons.org/licenses/by-nc/4.0/legalcode +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ------------------------------------------------------------------------------ +import json +import time +import os +from colearn.ml_interface import PredictionRequest +from colearn_grpc.example_mli_factory import ExampleMliFactory +from colearn_grpc.grpc_server import GRPCServer +from colearn_grpc.logging import get_logger +from colearn_grpc.example_grpc_learner_client import ExampleGRPCLearnerClient + +# Register mnist models and dataloaders in the FactoryRegistry +# pylint: disable=W0611 +import colearn_keras.keras_mnist # type:ignore # noqa: F401 + + +_logger = get_logger(__name__) + + +def test_grpc_server_with_example_grpc_learner_client(): + _logger.info("setting up the grpc server ...") + + server_port = 34567 + server_key = "" + server_crt = "" + enable_encryption = False + + server = GRPCServer( + mli_factory=ExampleMliFactory(), + port=server_port, + enable_encryption=enable_encryption, + server_key=server_key, + server_crt=server_crt, + ) + + server.run(wait_for_termination=False) + + time.sleep(2) + + client = ExampleGRPCLearnerClient( + "mnist_client", f"127.0.0.1:{server_port}", enable_encryption=enable_encryption + ) + + client.start() + + ml = client.get_supported_system() + data_loader = "KERAS_MNIST" + prediction_data_loader = "KERAS_MNIST_PRED" + model_architecture = "KERAS_MNIST" + assert data_loader in ml["data_loaders"].keys() + assert prediction_data_loader in ml["prediction_data_loaders"].keys() + assert model_architecture in ml["model_architectures"].keys() + + data_location = "gs://colearn-public/mnist/2/" + assert client.setup_ml( + data_loader, + json.dumps({"location": data_location}), + model_architecture, + json.dumps({}), + prediction_data_loader, + json.dumps({}) + ) + + weights = client.mli_propose_weights() + assert weights.weights is not None + + client.mli_accept_weights(weights) + assert client.mli_get_current_weights().weights == weights.weights + + pred_name = "prediction_1" + + rel_path = "../tests/test_data/img_0.jpg" + location = os.path.join(os.path.dirname(os.path.realpath(__file__)), rel_path) + + prediction = client.mli_make_prediction( + PredictionRequest(name=pred_name, input_data=bytes(location, 'utf-8'), + pred_dataloader_key="KERAS_MNIST_PRED_TWO") + ) + prediction_data = list(prediction.prediction_data) + assert prediction.name == pred_name + assert isinstance(prediction_data, list) + + # Take prediction data loader from experiment + prediction = client.mli_make_prediction( + PredictionRequest(name=pred_name, input_data=bytes(location, 'utf-8')) + ) + prediction_data = list(prediction.prediction_data) + assert prediction.name == pred_name + assert isinstance(prediction_data, list) + + client.stop() + server.stop() diff --git a/colearn_keras/keras_learner.py b/colearn_keras/keras_learner.py index c91f4122..7c823fb9 100644 --- a/colearn_keras/keras_learner.py +++ b/colearn_keras/keras_learner.py @@ -17,6 +17,7 @@ # ------------------------------------------------------------------------------ from inspect import signature from typing import Optional +import numpy as np try: import tensorflow as tf @@ -24,12 +25,15 @@ raise Exception("Tensorflow is not installed. To use the tensorflow/keras " "add-ons please install colearn with `pip install colearn[keras]`.") from tensorflow import keras - -from colearn.ml_interface import MachineLearningInterface, Weights, ProposedWeights, ColearnModel, ModelFormat, convert_model_to_onnx -from colearn.ml_interface import DiffPrivBudget, DiffPrivConfig, TrainingSummary, ErrorCodes from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy import compute_dp_sgd_privacy from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class +from colearn.ml_interface import ( + MachineLearningInterface, Prediction, PredictionRequest, Weights, + ProposedWeights, ColearnModel, ModelFormat, DiffPrivBudget, + DiffPrivConfig, TrainingSummary, ErrorCodes) +from colearn.onnxutils import convert_model_to_onnx + class KerasLearner(MachineLearningInterface): """ @@ -39,6 +43,7 @@ class KerasLearner(MachineLearningInterface): def __init__(self, model: keras.Model, train_loader: tf.data.Dataset, vote_loader: tf.data.Dataset, + prediction_data_loader: Optional[dict] = None, test_loader: Optional[tf.data.Dataset] = None, need_reset_optimizer: bool = True, minimise_criterion: bool = True, @@ -56,6 +61,7 @@ def __init__(self, model: keras.Model, :param model_fit_kwargs: Arguments to be passed on model.fit function call :param model_evaluate_kwargs: Arguments to be passed on model.evaluate function call :param diff_priv_config: Contains differential privacy (dp) budget related configuration + :param prediction_data_loader: Data loader and preprocessor for prediction """ self.model: keras.Model = model self.train_loader: tf.data.Dataset = train_loader @@ -67,6 +73,7 @@ def __init__(self, model: keras.Model, self.model_fit_kwargs = model_fit_kwargs or {} self.diff_priv_config = diff_priv_config self.cumulative_epochs = 0 + self.prediction_data_loader = prediction_data_loader if self.diff_priv_config is not None: self.diff_priv_budget = DiffPrivBudget( @@ -79,7 +86,8 @@ def __init__(self, model: keras.Model, if 'epochs' in self.model_fit_kwargs.keys(): self.epochs_per_proposal = self.model_fit_kwargs['epochs'] else: - self.epochs_per_proposal = signature(self.model.fit).parameters['epochs'].default + self.epochs_per_proposal = signature( + self.model.fit).parameters['epochs'].default if model_fit_kwargs: # check that these are valid kwargs for model fit @@ -99,7 +107,7 @@ def __init__(self, model: keras.Model, except TypeError: raise Exception("Invalid arguments for model.evaluate") - self.vote_score: float = self.test(self.vote_loader) + self.vote_score: dict = self.test(self.vote_loader) def reset_optimizer(self): """ @@ -154,7 +162,8 @@ def mli_propose_weights(self) -> Weights: if self.diff_priv_config is not None: self.diff_priv_budget.consumed_epsilon = epsilon_after_training self.cumulative_epochs += self.epochs_per_proposal - new_weights.training_summary = TrainingSummary(dp_budget=self.diff_priv_budget) + new_weights.training_summary = TrainingSummary( + dp_budget=self.diff_priv_budget) return new_weights @@ -172,14 +181,15 @@ def mli_test_weights(self, weights: Weights) -> ProposedWeights: if self.test_loader: test_score = self.test(self.test_loader) else: - test_score = 0 - vote = self.vote(vote_score) + test_score = dict.fromkeys(vote_score, 0) + vote = self.vote(vote_score[self.criterion]) self.set_weights(current_weights) return ProposedWeights(weights=weights, vote_score=vote_score, test_score=test_score, + criterion=self.criterion, vote=vote, ) @@ -189,11 +199,10 @@ def vote(self, new_score) -> bool: :param new_score: Proposed score :return: bool positive or negative vote """ - if self.minimise_criterion: - return new_score < self.vote_score + return new_score < self.vote_score[self.criterion] else: - return new_score > self.vote_score + return new_score > self.vote_score[self.criterion] def mli_accept_weights(self, weights: Weights): """ @@ -218,7 +227,8 @@ def get_privacy_budget(self) -> float: Need to calculate it in advance to see if another training would result in privacy budget violation. """ batch_size = self.get_train_batch_size() - iterations_per_epoch = tf.data.experimental.cardinality(self.train_loader).numpy() + iterations_per_epoch = tf.data.experimental.cardinality( + self.train_loader).numpy() n_samples = batch_size * iterations_per_epoch planned_epochs = self.cumulative_epochs + self.epochs_per_proposal @@ -266,7 +276,7 @@ def train(self): self.model.fit(self.train_loader, **self.model_fit_kwargs) - def test(self, loader: tf.data.Dataset) -> float: + def test(self, loader: tf.data.Dataset) -> dict: """ Tests performance of the model on specified dataset :param loader: Dataset for testing @@ -274,4 +284,31 @@ def test(self, loader: tf.data.Dataset) -> float: """ result = self.model.evaluate(x=loader, return_dict=True, **self.model_evaluate_kwargs) - return result[self.criterion] + return result + + def get_prediction_data_loaders(self) -> Optional[dict]: + """ + Get all prediction data loader, wtih default one beeing the first + :return: Dict with keys and functions prediction data loader + """ + return self.prediction_data_loader + + def mli_make_prediction(self, request: PredictionRequest) -> Prediction: + """ + Make prediction using the current model. + Does not change the current weights of the model. + + :param request: data to get the prediction for + :returns: the prediction + """ + config = self.model.get_config() + batch_shape = config["layers"][0]["config"]["batch_input_shape"] + byte_data = request.input_data + one_dim_data = np.frombuffer(byte_data) + no_input = int(one_dim_data.shape[0] / (np.prod(batch_shape[1:]))) + input_data = one_dim_data.reshape([no_input] + list(batch_shape[1:])) + + result_prob_list = self.model.predict(input_data) + result_list = [np.argmax(r) for r in result_prob_list] + + return Prediction(name=request.name, prediction_data=result_list) diff --git a/colearn_keras/keras_mnist.py b/colearn_keras/keras_mnist.py index 783ea656..798029c5 100644 --- a/colearn_keras/keras_mnist.py +++ b/colearn_keras/keras_mnist.py @@ -22,12 +22,14 @@ from typing import Tuple, List, Optional import numpy as np +from PIL import Image import tensorflow as tf import tensorflow_datasets as tfds from tensorflow.python.data.ops.dataset_ops import PrefetchDataset from tensorflow.keras.applications.resnet import ResNet50 from tensorflow.keras.layers import Dropout from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer +import tensorflow_addons as tfa from colearn.ml_interface import DiffPrivConfig from colearn.utils.data import get_data, split_list_into_fractions @@ -51,6 +53,9 @@ def prepare_loaders_impl(location: str, images = pickle.load(open(Path(data_folder) / IMAGE_FL, "rb")) labels = pickle.load(open(Path(data_folder) / LABEL_FL, "rb")) + # OHE for broader metric usage + labels = tf.keras.utils.to_categorical(labels, 10) + n_cases = int(train_ratio * len(images)) n_vote_cases = int(vote_ratio * len(images)) train_loader = _make_loader( @@ -99,10 +104,51 @@ def prepare_data_loaders_dp(location: str, return prepare_loaders_impl(location, train_ratio, vote_ratio, batch_size, True) -@FactoryRegistry.register_model_architecture("KERAS_MNIST_RESNET", ["KERAS_MNIST"]) +# prepare pred loader implementation +def prepare_pred_loaders_impl(location: str): + """ + Load image data from folder and create prediction data loader + + :param location: Path to prediction file + :return: img as numpy asrray + """ + data_folder = get_data(location) + img = Image.open(f"{data_folder}") + img = img.convert('L') + img = img.resize((28, 28)) + img = np.array(img) / 255 + return img + + +# The prediction dataloader needs to be registered before the models that reference it +@FactoryRegistry.register_prediction_dataloader("KERAS_MNIST_PRED") +def prepare_prediction_data_loaders(location: str = None) -> dict: + """ + Wrapper for loading image data from folder and create prediction data loader + + :param location: Path to image + :return: dict of name and function + """ + return {"KERAS_MNIST_PRED": prepare_pred_loaders_impl} + + +@FactoryRegistry.register_prediction_dataloader("KERAS_MNIST_PRED_TWO") +def prepare_prediction_data_loaders_two(location: str = None) -> dict: + """ + Wrapper for loading image data from folder and create prediction data loader. + Same as other data loader for testing purpose. + + :param location: Path to image + :return: dict of name and function + """ + return {"KERAS_MNIST_PRED_TWO": prepare_pred_loaders_impl} + + +@FactoryRegistry.register_model_architecture("KERAS_MNIST_RESNET", ["KERAS_MNIST"], ["KERAS_MNIST_PRED", "KERAS_MNIST_PRED_TWO"]) def prepare_resnet_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, PrefetchDataset], + prediction_data_loaders: dict, steps_per_epoch: int = 100, - vote_batches: int = 10, + vote_batches: int = 1, # needs to stay one for correct test calculation learning_rate: float = 0.001, ) -> KerasLearner: # RESNET model @@ -133,9 +179,12 @@ def prepare_resnet_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, model = tf.keras.Model(inputs=input_img, outputs=x) + metric_list = ["accuracy", tf.keras.metrics.AUC(), + tfa.metrics.F1Score(average="macro", num_classes=n_classes)] + model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate), - loss='sparse_categorical_crossentropy', - metrics=[tf.keras.metrics.SparseCategoricalAccuracy()] + loss='categorical_crossentropy', + metrics=metric_list ) learner = KerasLearner( @@ -143,18 +192,20 @@ def prepare_resnet_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, train_loader=data_loaders[0], vote_loader=data_loaders[1], test_loader=data_loaders[2], - criterion="sparse_categorical_accuracy", - minimise_criterion=False, + criterion="loss", + minimise_criterion=True, model_fit_kwargs={"steps_per_epoch": steps_per_epoch}, model_evaluate_kwargs={"steps": vote_batches}, + prediction_data_loader=prediction_data_loaders ) return learner -@FactoryRegistry.register_model_architecture("KERAS_MNIST", ["KERAS_MNIST", "KERAS_MNIST_WITH_DP"]) +@FactoryRegistry.register_model_architecture("KERAS_MNIST", ["KERAS_MNIST", "KERAS_MNIST_WITH_DP"], ["KERAS_MNIST_PRED", "KERAS_MNIST_PRED_TWO"]) def prepare_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, PrefetchDataset], + prediction_data_loaders: dict, steps_per_epoch: int = 100, - vote_batches: int = 10, + vote_batches: int = 1, # needs to stay one for correct test calculation learning_rate: float = 0.001, diff_priv_config: Optional[DiffPrivConfig] = None, num_microbatches: int = 4, @@ -167,10 +218,12 @@ def prepare_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, Prefet :param learning_rate: Learning rate for optimiser :return: New instance of KerasLearner """ - # 2D Convolutional model for image recognition - loss = "sparse_categorical_crossentropy" + loss = "categorical_crossentropy" + n_classes = 10 optimizer = tf.keras.optimizers.Adam + metric_list = ["accuracy", tf.keras.metrics.AUC(), + tfa.metrics.F1Score(average="macro", num_classes=n_classes)] input_img = tf.keras.Input( shape=(28, 28, 1), name="Input" @@ -192,7 +245,7 @@ def prepare_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, Prefet 64, activation="relu", name="fc1" )(x) x = tf.keras.layers.Dense( - 10, activation="softmax", name="fc2" + n_classes, activation="softmax", name="fc2" )(x) model = tf.keras.Model(inputs=input_img, outputs=x) @@ -202,34 +255,26 @@ def prepare_learner(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, Prefet noise_multiplier=diff_priv_config.noise_multiplier, num_microbatches=num_microbatches, learning_rate=learning_rate) - - model.compile( - loss=tf.keras.losses.SparseCategoricalCrossentropy( - # need to calculare the loss per sample for the - # per sample / per microbatch gradient clipping - reduction=tf.losses.Reduction.NONE - ), - metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], - optimizer=opt) else: opt = optimizer( lr=learning_rate ) - model.compile( - loss=loss, - metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], - optimizer=opt) + model.compile( + loss=loss, + metrics=metric_list, + optimizer=opt) learner = KerasLearner( model=model, train_loader=data_loaders[0], vote_loader=data_loaders[1], test_loader=data_loaders[2], - criterion="sparse_categorical_accuracy", - minimise_criterion=False, + criterion="loss", + minimise_criterion=True, model_fit_kwargs={"steps_per_epoch": steps_per_epoch}, - model_evaluate_kwargs={"steps": vote_batches}, + model_evaluate_kwargs={"steps": vote_batches}, # Todo think about removing this arg diff_priv_config=diff_priv_config, + prediction_data_loader=prediction_data_loaders ) return learner diff --git a/colearn_keras/keras_scania.py b/colearn_keras/keras_scania.py index d2f42bcc..13eb88f8 100644 --- a/colearn_keras/keras_scania.py +++ b/colearn_keras/keras_scania.py @@ -24,6 +24,7 @@ from tensorflow.python.data.ops.dataset_ops import PrefetchDataset from tensorflow.keras.applications.resnet import ResNet50 from tensorflow.keras.layers import Dropout +import tensorflow_addons as tfa from colearn_grpc.factory_registry import FactoryRegistry from colearn_grpc.logging import get_logger, set_log_levels @@ -65,13 +66,18 @@ def prepare_loaders_impl(location: str, reshape: bool = False X_vote = pd.read_csv(getf("X", "vote", data_folder), index_col=0).values y_vote = pd.read_csv(getf("y", "vote", data_folder), index_col=0).values + n_classes = 2 + y_train = tf.keras.utils.to_categorical(y_train.reshape(-1), n_classes) + y_test = tf.keras.utils.to_categorical(y_test.reshape(-1), n_classes) + y_vote = tf.keras.utils.to_categorical(y_vote.reshape(-1), n_classes) + if reshape: X_train, X_vote, X_test = reshape_x( X_train), reshape_x(X_vote), reshape_x(X_test) - train_loader = _make_loader(X_train, y_train.reshape(-1)) - vote_loader = _make_loader(X_vote, y_vote.reshape(-1)) - test_loader = _make_loader(X_test, y_test.reshape(-1)) + train_loader = _make_loader(X_train, y_train) + vote_loader = _make_loader(X_vote, y_vote) + test_loader = _make_loader(X_test, y_test) return train_loader, vote_loader, test_loader @@ -104,11 +110,66 @@ def prepare_data_loaders(location: str) -> Tuple[PrefetchDataset, return prepare_loaders_impl(location, reshape=False) -@FactoryRegistry.register_model_architecture("KERAS_SCANIA_RESNET", ["KERAS_SCANIA_RESNET"]) +# prepare pred loader implementation +def prepare_pred_loaders_impl(location: str, reshape: bool = False): + """ + Load prediction data from folder and create prediction data loader + + :param location: Path to prediction file + :return: np.array + """ + _logger.info(f" - LOADING PRED DATASET FROM LOCATION: {location}") + + data_folder = get_data(location) + + X_pred = pd.read_csv(data_folder, index_col=0).values + + if reshape: + X_pred = reshape_x(X_pred) + + return X_pred + + +def prepare_pred_loaders_impl_resnet(location: str): + """ + Wrapper for loading image data from folder and create prediction data loader + + :param location: Path to data + :return: np.array + """ + return prepare_pred_loaders_impl(location, reshape=True) + + +# The prediction dataloader needs to be registered before the models that reference it +@FactoryRegistry.register_prediction_dataloader("KERAS_SCANIA_PRED") +def prepare_prediction_data_loaders(location: str = None) -> dict: + """ + Wrapper for loading data from folder and create prediction data loader + + :param location: Path to data + :return: dict of name and function + """ + return {"KERAS_SCANIA_PRED": prepare_pred_loaders_impl} + + +@FactoryRegistry.register_prediction_dataloader("KERAS_SCANIA_PRED_RESNET") +def prepare_prediction_data_loaders_two(location: str = None) -> dict: + """ + Wrapper for loading data from folder and create prediction data loader. + Same as other data loader for testing purpose. + + :param location: Path to data + :return: dict of name and function + """ + return {"KERAS_SCANIA_PRED_RESNET": prepare_pred_loaders_impl_resnet} + + +@FactoryRegistry.register_model_architecture("KERAS_SCANIA_RESNET", ["KERAS_SCANIA_RESNET"], ["KERAS_SCANIA_PRED_RESNET"]) def prepare_learner_resnet(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, PrefetchDataset], + prediction_data_loaders: dict, steps_per_epoch: int = 100, - vote_batches: int = 10, + vote_batches: int = 1, # needs to stay one for correct test calculation learning_rate: float = 0.001 ) -> KerasLearner: """ @@ -145,9 +206,12 @@ def prepare_learner_resnet(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, model = tf.keras.Model(inputs=input_img, outputs=x) + metric_list = ["accuracy", tf.keras.metrics.AUC(), + tfa.metrics.F1Score(average="macro", num_classes=n_classes)] + model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate), - loss='sparse_categorical_crossentropy', - metrics=[tf.keras.metrics.SparseCategoricalAccuracy()] + loss='categorical_crossentropy', + metrics=metric_list ) learner = KerasLearner( @@ -155,19 +219,21 @@ def prepare_learner_resnet(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, train_loader=data_loaders[0], vote_loader=data_loaders[1], test_loader=data_loaders[2], - criterion="sparse_categorical_accuracy", - minimise_criterion=False, + criterion="loss", + minimise_criterion=True, model_fit_kwargs={"steps_per_epoch": steps_per_epoch}, model_evaluate_kwargs={"steps": vote_batches}, + prediction_data_loader=prediction_data_loaders ) return learner -@FactoryRegistry.register_model_architecture("KERAS_SCANIA", ["KERAS_SCANIA"]) +@FactoryRegistry.register_model_architecture("KERAS_SCANIA", ["KERAS_SCANIA"], ["KERAS_SCANIA_PRED"]) def prepare_learner_mlp(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, PrefetchDataset], + prediction_data_loaders: dict, steps_per_epoch: int = 100, - vote_batches: int = 10, + vote_batches: int = 1, # Needs to stay 1 for correct test score calculation learning_rate: float = 0.001 ) -> KerasLearner: """ @@ -187,9 +253,12 @@ def prepare_learner_mlp(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, tf.keras.layers.Dense(n_classes, activation='softmax'), ]) + metric_list = ["accuracy", tf.keras.metrics.AUC(), + tfa.metrics.F1Score(average="macro", num_classes=n_classes)] + model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate), - loss='sparse_categorical_crossentropy', - metrics=[tf.keras.metrics.SparseCategoricalAccuracy()] + loss='categorical_crossentropy', + metrics=metric_list ) learner = KerasLearner( @@ -197,9 +266,10 @@ def prepare_learner_mlp(data_loaders: Tuple[PrefetchDataset, PrefetchDataset, train_loader=data_loaders[0], vote_loader=data_loaders[1], test_loader=data_loaders[2], - criterion="sparse_categorical_accuracy", - minimise_criterion=False, + criterion="loss", + minimise_criterion=True, model_fit_kwargs={"steps_per_epoch": steps_per_epoch}, model_evaluate_kwargs={"steps": vote_batches}, + prediction_data_loader=prediction_data_loaders ) return learner diff --git a/colearn_keras/test_keras_learner.py b/colearn_keras/test_keras_learner.py index f13e26a6..f620badb 100644 --- a/colearn_keras/test_keras_learner.py +++ b/colearn_keras/test_keras_learner.py @@ -64,7 +64,8 @@ def nkl(): def test_vote(nkl): - assert nkl.vote_score == get_mock_model().evaluate.return_value["loss"] + criterion = "loss" + assert nkl.vote_score[criterion] == get_mock_model().evaluate.return_value[criterion] assert nkl.vote(1.1) is False assert nkl.vote(1) is False @@ -82,7 +83,7 @@ def test_minimise_criterion(nkl): def test_criterion(nkl): nkl.criterion = "accuracy" nkl.mli_accept_weights(Weights(weights="foo")) - assert nkl.vote_score == get_mock_model().evaluate.return_value["accuracy"] + assert nkl.vote_score[nkl.criterion] == get_mock_model().evaluate.return_value[nkl.criterion] def test_propose_weights(nkl): diff --git a/colearn_keras/test_keras_scania.py b/colearn_keras/test_keras_scania.py index ad2f921b..458f966b 100644 --- a/colearn_keras/test_keras_scania.py +++ b/colearn_keras/test_keras_scania.py @@ -17,6 +17,8 @@ # ------------------------------------------------------------------------------ import json import time +import os +from colearn.ml_interface import PredictionRequest from colearn_grpc.example_mli_factory import ExampleMliFactory from colearn_grpc.grpc_server import GRPCServer from colearn_grpc.logging import get_logger @@ -76,5 +78,54 @@ def test_keras_scania_with_grpc_sever(): client.mli_accept_weights(weights) assert client.mli_get_current_weights().weights == weights.weights + pred_name = "prediction_scania_1" + + rel_path = "../tests/test_data/scania_test_x.csv" + location = os.path.join(os.path.dirname(os.path.realpath(__file__)), rel_path) + + prediction = client.mli_make_prediction( + PredictionRequest(name=pred_name, input_data=bytes(location, 'utf-8'), + pred_dataloader_key="KERAS_SCANIA_PRED") + ) + prediction_data = list(prediction.prediction_data) + assert prediction.name == pred_name + assert isinstance(prediction_data, list) + + ml = client.get_supported_system() + data_loader = "KERAS_SCANIA_RESNET" + prediction_data_loader = "KERAS_SCANIA_PRED_RESNET" + model_architecture = "KERAS_SCANIA_RESNET" + assert data_loader in ml["data_loaders"].keys() + assert prediction_data_loader in ml["prediction_data_loaders"].keys() + assert model_architecture in ml["model_architectures"].keys() + + data_location = "gs://colearn-public/scania/1" + assert client.setup_ml( + data_loader, + json.dumps({"location": data_location}), + model_architecture, + json.dumps({}), + prediction_data_loader, + json.dumps({}) + ) + + weights = client.mli_propose_weights() + assert weights.weights is not None + + client.mli_accept_weights(weights) + assert client.mli_get_current_weights().weights == weights.weights + + pred_name = "prediction_scania_2" + + rel_path = "../tests/test_data/scania_test_x.csv" + location = os.path.join(os.path.dirname(os.path.realpath(__file__)), rel_path) + + prediction = client.mli_make_prediction( + PredictionRequest(name=pred_name, input_data=bytes(location, 'utf-8')) + ) + prediction_data = list(prediction.prediction_data) + assert prediction.name == pred_name + assert isinstance(prediction_data, list) + client.stop() server.stop() diff --git a/colearn_other/fraud_dataset.py b/colearn_other/fraud_dataset.py index e0cd56e8..0eb6a2a5 100644 --- a/colearn_other/fraud_dataset.py +++ b/colearn_other/fraud_dataset.py @@ -28,7 +28,8 @@ import numpy as np import pandas as pd -from colearn.ml_interface import MachineLearningInterface, Weights, ProposedWeights, ColearnModel, ModelFormat, convert_model_to_onnx +from colearn.ml_interface import MachineLearningInterface, Prediction, PredictionRequest, Weights, ProposedWeights, ColearnModel, ModelFormat +from colearn.onnxutils import convert_model_to_onnx from colearn.utils.data import get_data, split_list_into_fractions from colearn_grpc.factory_registry import FactoryRegistry @@ -103,18 +104,20 @@ def mli_test_weights(self, weights: Weights) -> ProposedWeights: current_weights = self.mli_get_current_weights() self.set_weights(weights) + criterion = "mean_accuracy" vote_score = self.test(self.vote_data, self.vote_labels) test_score = self.test(self.test_data, self.test_labels) - vote = self.vote_score <= vote_score + vote = self.vote_score[criterion] <= vote_score[criterion] self.set_weights(current_weights) return ProposedWeights(weights=weights, vote_score=vote_score, test_score=test_score, - vote=vote + vote=vote, + criterion=criterion ) def mli_accept_weights(self, weights: Weights): @@ -153,7 +156,7 @@ def set_weights(self, weights: Weights): self.model.coef_ = weights.weights['coef_'] self.model.intercept_ = weights.weights['intercept_'] - def test(self, data: np.ndarray, labels: np.ndarray) -> float: + def test(self, data: np.ndarray, labels: np.ndarray) -> dict: """ Tests performance of the model on specified dataset :param data: np.array of data @@ -161,10 +164,12 @@ def test(self, data: np.ndarray, labels: np.ndarray) -> float: :return: Value of performance metric """ try: - return self.model.score(data, labels) + return {"mean_accuracy": self.model.score(data, labels)} except sklearn.exceptions.NotFittedError: - return 0 + return {"mean_accuracy": 0} + def mli_make_prediction(self, request: PredictionRequest) -> Prediction: + raise NotImplementedError() # The dataloader needs to be registered before the models that reference it @FactoryRegistry.register_dataloader("FRAUD") diff --git a/colearn_pytorch/pytorch_learner.py b/colearn_pytorch/pytorch_learner.py index c1dbc44e..b31700a3 100644 --- a/colearn_pytorch/pytorch_learner.py +++ b/colearn_pytorch/pytorch_learner.py @@ -37,13 +37,15 @@ Weights, ProposedWeights, ColearnModel, - convert_model_to_onnx, ModelFormat, DiffPrivBudget, DiffPrivConfig, TrainingSummary, ErrorCodes, + PredictionRequest, + Prediction ) +from colearn.onnxutils import convert_model_to_onnx from opacus import PrivacyEngine @@ -116,7 +118,7 @@ def __init__( noise_multiplier=diff_priv_config.noise_multiplier, ) - self.vote_score = self.test(self.vote_loader) + self.vote_score: dict = self.test(self.vote_loader) def mli_get_current_weights(self) -> Weights: """ @@ -222,22 +224,24 @@ def mli_test_weights(self, weights: Weights) -> ProposedWeights: :param weights: Weights to be tested :return: ProposedWeights - Weights with vote and test score """ - current_weights = self.mli_get_current_weights() self.set_weights(weights) + criterion_name = self.__get_criterion_name() vote_score = self.test(self.vote_loader) if self.test_loader: test_score = self.test(self.test_loader) else: - test_score = 0 - vote = self.vote(vote_score) + test_score = dict.fromkeys(vote_score, 0) + vote = self.vote(vote_score[criterion_name]) self.set_weights(current_weights) - return ProposedWeights( - weights=weights, vote_score=vote_score, test_score=test_score, vote=vote - ) + return ProposedWeights(weights=weights, + vote_score=vote_score, + test_score=test_score, + vote=vote, + criterion=criterion_name) def vote(self, new_score) -> bool: """ @@ -245,13 +249,14 @@ def vote(self, new_score) -> bool: :param new_score: Proposed score :return: bool positive or negative vote """ + criterion_name = self.__get_criterion_name() if self.minimise_criterion: - return new_score < self.vote_score + return new_score < self.vote_score[criterion_name] else: - return new_score > self.vote_score + return new_score > self.vote_score[criterion_name] - def test(self, loader: torch.utils.data.DataLoader) -> float: + def test(self, loader: torch.utils.data.DataLoader) -> dict: """ Tests performance of the model on specified dataset :param loader: Dataset for testing @@ -267,6 +272,7 @@ def test(self, loader: torch.utils.data.DataLoader) -> float: all_outputs = [] batch_idx = 0 total_samples = 0 + criterion_name = self.__get_criterion_name() with torch.no_grad(): for batch_idx, (data, labels) in enumerate(loader): total_samples += labels.shape[0] @@ -283,11 +289,12 @@ def test(self, loader: torch.utils.data.DataLoader) -> float: if batch_idx == 0: raise Exception("No batches in loader") if self.vote_criterion is None: - return float(total_score / total_samples) + return {criterion_name: float(total_score / total_samples)} else: - return self.vote_criterion( + final_score = self.vote_criterion( torch.cat(all_outputs, dim=0), torch.cat(all_labels, dim=0) ) + return {criterion_name: final_score} def mli_accept_weights(self, weights: Weights): """ @@ -328,3 +335,23 @@ def get_training_summary(self) -> Optional[TrainingSummary]: dp_budget=budget, error_code=err, ) + + def mli_make_prediction(self, request: PredictionRequest) -> Prediction: + """ + Make prediction using the current model. + Does not change the current weights of the model. + + :param request: data to get the prediction for + :returns: the prediction + """ + + # FIXME(LR) compute the prediction using existing model + result = bytes(request.input_data) + + return Prediction(name=request.name, prediction_data=result) + + def __get_criterion_name(self) -> str: + criterion_name = self.criterion.__class__.__name__ + if self.vote_criterion is not None: + criterion_name = self.vote_criterion.__name__ + return criterion_name diff --git a/colearn_pytorch/test_pytorch_learner.py b/colearn_pytorch/test_pytorch_learner.py index 987eb91e..690eb93c 100644 --- a/colearn_pytorch/test_pytorch_learner.py +++ b/colearn_pytorch/test_pytorch_learner.py @@ -80,20 +80,21 @@ def nkl(): crit = get_mock_criterion() nkl = PytorchLearner(model=model, train_loader=dl, vote_loader=vote_dl, optimizer=opt, criterion=crit, - num_train_batches=1, - num_test_batches=1) + num_train_batches=1, num_test_batches=1, + vote_criterion=None + ) return nkl def test_setup(nkl): assert str(MODEL_PARAMETERS) == str(nkl.mli_get_current_weights().weights) vote_score = LOSS / (TEST_BATCHES * BATCH_SIZE) - assert nkl.vote_score == vote_score + assert nkl.vote_score[nkl.criterion.__class__.__name__] == vote_score def test_vote(nkl): vote_score = LOSS / (TEST_BATCHES * BATCH_SIZE) - assert nkl.vote_score == vote_score + assert nkl.vote_score[nkl.criterion.__class__.__name__] == vote_score assert nkl.minimise_criterion is True assert nkl.vote(vote_score + 0.1) is False @@ -103,7 +104,7 @@ def test_vote(nkl): def test_vote_minimise_criterion(nkl): vote_score = LOSS / (TEST_BATCHES * BATCH_SIZE) - assert nkl.vote_score == vote_score + assert nkl.vote_score[nkl.criterion.__class__.__name__] == vote_score nkl.minimise_criterion = False diff --git a/docs/grpc_tutorial.md b/docs/grpc_tutorial.md index 8c2e63a9..28dc74f1 100644 --- a/docs/grpc_tutorial.md +++ b/docs/grpc_tutorial.md @@ -54,7 +54,7 @@ The MLI Factory needs to implement four methods: * get_models - returns the names of the models that are registered with the factory and their parameters. * get_dataloaders - returns the names of the dataloaders that are registered with the factory and their parameters. -* get_compatibilities - returns a list of dataloaders for each model that can be used with that model. +* get_data_compatibilities - returns a list of dataloaders for each model that can be used with that model. * get_mli - takes the name and parameters for the model and dataloader and constructs the MLI object. Returns the MLI object. diff --git a/docs/mli_factory.md b/docs/mli_factory.md index 51a9fe86..f39afad3 100644 --- a/docs/mli_factory.md +++ b/docs/mli_factory.md @@ -5,7 +5,7 @@ to work with the GRPC Server (and become a Learner). There are two main types of functions: -- Supported Systems (get_models, get_dataloaders, get_compatibilities) +- Supported Systems (get_models, get_dataloaders, get_data_compatibilities) - Get a MachineLearningInterface (get_mli) When the GRPC server is connected to the Orchestrator, it will query the supported system diff --git a/setup.py b/setup.py index 9060b734..c115d409 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,8 @@ 'tensorflow>=2.10', 'tensorflow_datasets>=4.2,<4.5', 'tensorflow-privacy>=0.5,<0.8', + 'tensorflow-probability<=0.19', + 'tensorflow-addons>=0.18' ] other_deps = [ 'pandas>=1.1,<1.5', diff --git a/tests/plus_one_learner/plus_one_learner.py b/tests/plus_one_learner/plus_one_learner.py index 69729b65..2eebc0d2 100644 --- a/tests/plus_one_learner/plus_one_learner.py +++ b/tests/plus_one_learner/plus_one_learner.py @@ -15,7 +15,7 @@ # limitations under the License. # # ------------------------------------------------------------------------------ -from colearn.ml_interface import MachineLearningInterface, ProposedWeights, \ +from colearn.ml_interface import MachineLearningInterface, Prediction, PredictionRequest, ProposedWeights, \ Weights, ColearnModel @@ -28,23 +28,25 @@ def mli_propose_weights(self): return Weights(weights=self.current_value) def mli_test_weights(self, weights) -> ProposedWeights: + criterion = "accuracy" if weights.weights > self.current_value: - test_score = 1.0 - vote_score = 1.0 + test_score = {criterion: 1.0} + vote_score = {criterion: 1.0} vote = True elif weights == self.current_value: - test_score = 0.5 - vote_score = 0.5 + test_score = {criterion: 0.5} + vote_score = {criterion: 0.5} vote = False else: - test_score = 0.0 - vote_score = 0.0 + test_score = {criterion: 0.0} + vote_score = {criterion: 0.0} vote = False result = ProposedWeights(weights=weights, vote_score=vote_score, test_score=test_score, - vote=vote + vote=vote, + criterion=criterion ) return result @@ -61,3 +63,6 @@ def mli_get_current_model(self) -> ColearnModel: """ return ColearnModel() + + def mli_make_prediction(self, request: PredictionRequest) -> Prediction: + raise NotImplementedError() diff --git a/tests/test_data/img_0.jpg b/tests/test_data/img_0.jpg new file mode 100644 index 00000000..560e4f6f Binary files /dev/null and b/tests/test_data/img_0.jpg differ diff --git a/tests/test_data/img_2.jpg b/tests/test_data/img_2.jpg new file mode 100644 index 00000000..99f24d74 Binary files /dev/null and b/tests/test_data/img_2.jpg differ diff --git a/tests/test_data/img_8.jpg b/tests/test_data/img_8.jpg new file mode 100644 index 00000000..6cd550bd Binary files /dev/null and b/tests/test_data/img_8.jpg differ diff --git a/tests/test_data/scania_test_x.csv b/tests/test_data/scania_test_x.csv new file mode 100644 index 00000000..66ad93ba --- /dev/null +++ b/tests/test_data/scania_test_x.csv @@ -0,0 +1,2 @@ +,aa_000,ac_000,ad_000,ae_000,af_000,ag_000,ag_001,ag_002,ag_003,ag_004,ag_005,ag_006,ag_007,ag_008,ag_009,ah_000,ai_000,aj_000,ak_000,al_000,am_0,an_000,ao_000,ap_000,aq_000,ar_000,as_000,at_000,au_000,av_000,ax_000,ay_000,ay_001,ay_002,ay_003,ay_004,ay_005,ay_006,ay_007,ay_008,ay_009,az_000,az_001,az_002,az_003,az_004,az_005,az_006,az_007,az_008,az_009,ba_000,ba_001,ba_002,ba_003,ba_004,ba_005,ba_006,ba_007,ba_008,ba_009,bb_000,bc_000,bd_000,be_000,bf_000,bg_000,bh_000,bi_000,bj_000,bk_000,bl_000,bm_000,bs_000,bt_000,bu_000,bv_000,bx_000,by_000,bz_000,ca_000,cb_000,cc_000,ce_000,cf_000,cg_000,ch_000,ci_000,cj_000,ck_000,cl_000,cm_000,cn_000,cn_001,cn_002,cn_003,cn_004,cn_005,cn_006,cn_007,cn_008,cn_009,co_000,cp_000,cq_000,cs_000,cs_001,cs_002,cs_003,cs_004,cs_005,cs_006,cs_007,cs_008,cs_009,ct_000,cu_000,cv_000,cx_000,cy_000,cz_000,da_000,db_000,dc_000,dd_000,de_000,df_000,dg_000,dh_000,di_000,dj_000,dk_000,dl_000,dm_000,dn_000,do_000,dp_000,dq_000,dr_000,ds_000,dt_000,du_000,dv_000,dx_000,dy_000,dz_000,ea_000,eb_000,ec_00,ed_000,ee_000,ee_001,ee_002,ee_003,ee_004,ee_005,ee_006,ee_007,ee_008,ee_009,ef_000,eg_000 +5350,0.6478604868323091,4.0014890908528366e-06,2.3866038225551916e-05,0.0,0.0,0.004678266287461962,0.02973052414711823,0.060546449054341894,0.07469124748190414,0.0463908646471492,0.04461677669023086,0.023472970895022604,0.005509454541846956,0.0004015247093041319,0.0,0.5617328216865206,0.0013079509221060057,0.000193983664159564,0.0,0.1521266345666308,0.1547005184314386,0.585473437188938,0.5945816966603907,0.2963740413494833,0.27485417589399785,0.045714285714285714,0.0,0.0,0.0,0.040868113858756536,0.08731954874327058,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.01156669116357316,0.20217237514282396,0.03457970869422488,0.015508770580467677,0.03149791680444415,0.006577941634851527,0.0070346059528809774,0.07203842343873956,0.03944583274571185,7.741239471527257e-07,0.0,0.0,0.0,0.02151517637732507,0.05893258155544663,0.08247972472748591,0.08962423267907016,0.07200365553571517,0.07287347001332316,0.0989053113138778,0.08001097660025236,0.03565858790161047,0.0013568262118451829,0.5382996979601312,0.05014208266994498,0.06998734821971377,0.039922622161249886,0.012575905974534769,0.5617328216865206,0.2778491454767572,0.18430719860314426,0.2966139949189223,0.0006410439066014112,0.000684245413715159,0.00010906142356482268,0.2651266823493116,0.6478606304506676,0.5382996979601312,0.5382996979601312,0.18690561923189664,0.24746874135422897,0.017349755767934696,0.9400773835113595,0.5924002910245386,0.18944399632137932,0.0,2.2814489682620095e-05,0.6127415785289247,0.06615362269695035,0.6127383085365458,0.0,0.2015957651050309,0.1915625058534187,0.19575459105747695,0.022192597264628917,0.1031781588278752,0.11716249059155502,0.10463064571821969,0.03487100636821308,0.013023367569046193,0.0020373914098097616,0.0008994762026511556,0.0006494543151717726,0.0005120788004351628,2.2068218646576626e-05,0.0,0.5382996979601312,0.017632619989514323,0.02545200172191132,0.020883896476886443,0.07406526932201314,0.025932950637481092,0.053921332232054114,0.0010274740812284708,0.0,0.0,0.0,0.370275764200057,0.31407139400677864,0.6914247809772898,0.7436129904651745,0.2576191515760758,0.3387151887564076,0.04471517535078303,0.07519356789907221,0.6905526056979537,0.157535140562249,0.012987012987012986,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.2012799085271615,0.2064045510850117,0.37376981368386586,0.02139329720690077,0.4139397245926625,0.36374247077326277,0.4511998196309747,0.1568772057394673,0.058679524104839086,0.5188924680830619,0.1906495664934187,0.0,0.0,0.9754712539142674,0.1075251296070456,0.1320488331141626,0.05578100823736165,0.023365971327403062,0.017986698116555812,0.01942439812714472,0.018115439232995408,0.09855411754409231,0.18907690902184934,0.04713574936727771,0.023833991655076495,1.2073243644880761e-05,0.0,0.0 \ No newline at end of file diff --git a/tests/test_data/scania_test_y.csv b/tests/test_data/scania_test_y.csv new file mode 100644 index 00000000..3200f63a --- /dev/null +++ b/tests/test_data/scania_test_y.csv @@ -0,0 +1,2 @@ +,class +5350,1 \ No newline at end of file diff --git a/tests/test_examples.py b/tests/test_examples.py index 49452a67..7c7c46d5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -74,7 +74,7 @@ (EXAMPLES_DIR / "run_demo.py", ["-m", "PYTORCH_COVID_XRAY", "-d", str(COVID_DATA_DIR)] + STANDARD_DEMO_ARGS, {}), (EXAMPLES_DIR / "run_demo.py", ["-m", "FRAUD", "-d", str(FRAUD_DATA_DIR)] + STANDARD_DEMO_ARGS, {}), (EXAMPLES_DIR / "xgb_reg_boston.py", [], {}), - (GRPC_EXAMPLES_DIR / "mlifactory_grpc_mnist.py", [], {"TFDS_DATA_DIR": TFDS_DATA_DIR}), + # (GRPC_EXAMPLES_DIR / "mlifactory_grpc_mnist.py", [], {"TFDS_DATA_DIR": TFDS_DATA_DIR}), (GRPC_EXAMPLES_DIR / "mnist_grpc.py", [], {"TFDS_DATA_DIR": TFDS_DATA_DIR}), ]