diff --git a/Gopkg.lock b/Gopkg.lock index 52fcc173bda..fd5cc56e7ec 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -33,14 +33,6 @@ pruneopts = "T" revision = "de5bf2ad457846296e2031421a34e2568e304e35" -[[projects]] - digest = "1:607c4f1f646bfe5ec1a4eac4a505608f280829550ed546a243698a525d3c5fe8" - name = "github.com/cpuguy83/go-md2man" - packages = ["md2man"] - pruneopts = "T" - revision = "20f5889cbdc3c73dbd2862796665e7c465ade7d1" - version = "v1.0.8" - [[projects]] digest = "1:9f42202ac457c462ad8bb9642806d275af9ab4850cf0b1960b9c6f083d4a309a" name = "github.com/davecgh/go-spew" @@ -60,14 +52,6 @@ revision = "85d198d05a92d31823b852b4a5928114912e8949" version = "v2.9.0" -[[projects]] - digest = "1:7fc160b460a6fc506b37fcca68332464c3f2cd57b6e3f111f26c5bbfd2d5518e" - name = "github.com/fsnotify/fsnotify" - packages = ["."] - pruneopts = "T" - revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9" - version = "v1.4.7" - [[projects]] digest = "1:2cd7915ab26ede7d95b8749e6b1f933f1c6d5398030684e6505940a10f31cfda" name = "github.com/ghodss/yaml" @@ -276,25 +260,6 @@ revision = "20f1fb78b0740ba8c3cb143a61e86ba5c8669768" version = "v0.5.0" -[[projects]] - digest = "1:071bcbf82c289fba4d3f63c876bf4f0ba7eda625cd60795e0a03ccbf949e517a" - name = "github.com/hashicorp/hcl" - packages = [ - ".", - "hcl/ast", - "hcl/parser", - "hcl/printer", - "hcl/scanner", - "hcl/strconv", - "hcl/token", - "json/parser", - "json/scanner", - "json/token", - ] - pruneopts = "T" - revision = "8cb6e5b959231cc1119e43259c4a608f9c51a241" - version = "v1.0.0" - [[projects]] digest = "1:8f20c8dd713564fa97299fbcb77d729c6de9c33f3222812a76e6ecfaef80fd61" name = "github.com/hpcloud/tail" @@ -354,20 +319,14 @@ name = "github.com/kubeflow/tf-operator" packages = [ "pkg/apis/common/v1beta1", + "pkg/apis/common/v1beta2", "pkg/apis/tensorflow/v1beta1", + "pkg/apis/tensorflow/v1beta2", ] pruneopts = "T" revision = "c2849477dffdeacc2ebc11de66f826a6ce5cf690" version = "v0.4.0" -[[projects]] - digest = "1:53e8c5c79716437e601696140e8b1801aae4204f4ec54a504333702a49572c4f" - name = "github.com/magiconair/properties" - packages = ["."] - pruneopts = "T" - revision = "c2353362d570a7bfa228149c62842019201cfb71" - version = "v1.8.0" - [[projects]] branch = "master" digest = "1:4be65cb3a11626a0d89fc72b34f62a8768040512d45feb086184ac30fdfbef65" @@ -396,14 +355,6 @@ pruneopts = "T" revision = "81af80346b1a01caae0cbc27fd3c1ba5b11e189f" -[[projects]] - digest = "1:53bc4cd4914cd7cd52139990d5170d6dc99067ae31c56530621b18b35fc30318" - name = "github.com/mitchellh/mapstructure" - packages = ["."] - pruneopts = "T" - revision = "3536a929edddb9a5b34bd6861dc4a9647cb459fe" - version = "v1.1.2" - [[projects]] digest = "1:33422d238f147d247752996a26574ac48dcf472976eda7f5134015f06bf16563" name = "github.com/modern-go/concurrent" @@ -478,14 +429,6 @@ revision = "adf5a7427709b9deb95d29d3fa8a2bf9cfd388f1" version = "v1.2" -[[projects]] - digest = "1:ccf9949c9c53e85dcb7e2905fc620571422567040925381e6baa62f0b7b850fe" - name = "github.com/pelletier/go-toml" - packages = ["."] - pruneopts = "T" - revision = "c01d1270ff3e442a8a57cddc1c92dc1138598194" - version = "v1.2.0" - [[projects]] branch = "master" digest = "1:0c29d499ffc3b9f33e7136444575527d0c3a9463a89b3cbeda0523b737f910b3" @@ -530,14 +473,6 @@ revision = "2ea7272aa4c7de5c8a568dff504404e8dc49945d" version = "v1.2.1" -[[projects]] - digest = "1:40e527269f1feb16b3069bfe80ff05a462d190eacfe07eb0a59fa25c381db7af" - name = "github.com/russross/blackfriday" - packages = ["."] - pruneopts = "T" - revision = "05f3235734ad95d0016f6a23902f06461fcf567a" - version = "v1.5.2" - [[projects]] digest = "1:2c31831b97353515f0bb77a3e7f2df69e50c4173116aad4250e22b1c664a1484" name = "github.com/spf13/afero" @@ -549,33 +484,14 @@ revision = "f4711e4db9e9a1d3887343acb72b2bbfc2f686f5" version = "v1.2.1" -[[projects]] - digest = "1:08d65904057412fc0270fc4812a1c90c594186819243160dc779a402d4b6d0bc" - name = "github.com/spf13/cast" - packages = ["."] - pruneopts = "T" - revision = "8c9545af88b134710ab1cd196795e7f2388358d7" - version = "v1.3.0" - [[projects]] digest = "1:8be8b3743fc9795ec21bbd3e0fc28ff6234018e1a269b0a7064184be95ac13e0" name = "github.com/spf13/cobra" - packages = [ - ".", - "doc", - ] + packages = ["."] pruneopts = "T" revision = "ef82de70bb3f60c65fb8eebacbb2d122ef517385" version = "v0.0.3" -[[projects]] - digest = "1:68ea4e23713989dc20b1bded5d9da2c5f9be14ff9885beef481848edd18c26cb" - name = "github.com/spf13/jwalterweatherman" - packages = ["."] - pruneopts = "T" - revision = "4a4406e478ca629068e7768fc33f3f044173c0a6" - version = "v1.0.0" - [[projects]] digest = "1:0f775ea7a72e30d5574267692aaa9ff265aafd15214a7ae7db26bc77f2ca04dc" name = "github.com/spf13/pflag" @@ -584,14 +500,6 @@ revision = "298182f68c66c05229eb03ac171abe6e309ee79a" version = "v1.0.3" -[[projects]] - digest = "1:fefab651e5663d271bda04e3fac63e52d642b454ef88dc1b78deca05f267e94e" - name = "github.com/spf13/viper" - packages = ["."] - pruneopts = "T" - revision = "6d33b5a963d922d182c91e8a1c88d81fd150cfd4" - version = "v1.3.1" - [[projects]] digest = "1:365b8ecb35a5faf5aa0ee8d798548fc9cd4200cb95d77a5b0b285ac881bae499" name = "go.uber.org/atomic" @@ -900,6 +808,9 @@ packages = [ "pkg/apis/apiextensions", "pkg/apis/apiextensions/v1beta1", + "pkg/client/clientset/clientset", + "pkg/client/clientset/clientset/scheme", + "pkg/client/clientset/clientset/typed/apiextensions/v1beta1", ] pruneopts = "T" revision = "408db4a50408e2149acbd657bceb2480c13cb0a4" @@ -909,6 +820,7 @@ digest = "1:e2035d60919705a125feeb8b13926383aff8817634b2c4550743c19ca21f3b08" name = "k8s.io/apimachinery" packages = [ + "pkg/api/equality", "pkg/api/errors", "pkg/api/meta", "pkg/api/resource", @@ -939,6 +851,7 @@ "pkg/util/json", "pkg/util/mergepatch", "pkg/util/net", + "pkg/util/rand", "pkg/util/runtime", "pkg/util/sets", "pkg/util/strategicpatch", @@ -1088,6 +1001,8 @@ "pkg/client/config", "pkg/controller", "pkg/controller/controllerutil", + "pkg/envtest", + "pkg/envtest/printer", "pkg/event", "pkg/handler", "pkg/internal/controller", @@ -1104,8 +1019,14 @@ "pkg/runtime/signals", "pkg/source", "pkg/source/internal", + "pkg/webhook", "pkg/webhook/admission", + "pkg/webhook/admission/builder", "pkg/webhook/admission/types", + "pkg/webhook/internal/cert", + "pkg/webhook/internal/cert/generator", + "pkg/webhook/internal/cert/writer", + "pkg/webhook/internal/cert/writer/atomic", "pkg/webhook/types", ] pruneopts = "T" @@ -1159,12 +1080,10 @@ "github.com/kubeflow/pytorch-operator/pkg/apis/pytorch/v1beta1", "github.com/kubeflow/tf-operator/pkg/apis/common/v1beta1", "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1beta1", + "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1beta2", "github.com/onsi/ginkgo", "github.com/onsi/gomega", "github.com/pressly/chi", - "github.com/spf13/cobra", - "github.com/spf13/cobra/doc", - "github.com/spf13/viper", "golang.org/x/net/context", "google.golang.org/genproto/googleapis/api/annotations", "google.golang.org/grpc", @@ -1174,36 +1093,51 @@ "google.golang.org/grpc/status", "gopkg.in/DATA-DOG/go-sqlmock.v1", "gopkg.in/yaml.v2", + "k8s.io/api/admissionregistration/v1beta1", + "k8s.io/api/apps/v1", "k8s.io/api/batch/v1", "k8s.io/api/batch/v1beta1", "k8s.io/api/core/v1", "k8s.io/api/extensions/v1beta1", "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1beta1", + "k8s.io/apimachinery/pkg/api/equality", "k8s.io/apimachinery/pkg/api/errors", "k8s.io/apimachinery/pkg/api/meta", "k8s.io/apimachinery/pkg/apis/meta/v1", + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured", "k8s.io/apimachinery/pkg/labels", "k8s.io/apimachinery/pkg/runtime", "k8s.io/apimachinery/pkg/runtime/schema", "k8s.io/apimachinery/pkg/runtime/serializer", + "k8s.io/apimachinery/pkg/selection", "k8s.io/apimachinery/pkg/types", + "k8s.io/apimachinery/pkg/util/rand", "k8s.io/apimachinery/pkg/util/uuid", "k8s.io/apimachinery/pkg/util/yaml", "k8s.io/client-go/kubernetes", "k8s.io/client-go/kubernetes/scheme", "k8s.io/client-go/plugin/pkg/client/auth/gcp", "k8s.io/client-go/rest", + "k8s.io/client-go/tools/clientcmd", + "k8s.io/client-go/tools/record", "k8s.io/code-generator/cmd/deepcopy-gen", "sigs.k8s.io/controller-runtime/pkg/client", "sigs.k8s.io/controller-runtime/pkg/client/config", "sigs.k8s.io/controller-runtime/pkg/controller", "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil", + "sigs.k8s.io/controller-runtime/pkg/envtest", "sigs.k8s.io/controller-runtime/pkg/handler", "sigs.k8s.io/controller-runtime/pkg/manager", "sigs.k8s.io/controller-runtime/pkg/reconcile", + "sigs.k8s.io/controller-runtime/pkg/runtime/inject", + "sigs.k8s.io/controller-runtime/pkg/runtime/log", "sigs.k8s.io/controller-runtime/pkg/runtime/scheme", "sigs.k8s.io/controller-runtime/pkg/runtime/signals", "sigs.k8s.io/controller-runtime/pkg/source", + "sigs.k8s.io/controller-runtime/pkg/webhook", + "sigs.k8s.io/controller-runtime/pkg/webhook/admission", + "sigs.k8s.io/controller-runtime/pkg/webhook/admission/builder", + "sigs.k8s.io/controller-runtime/pkg/webhook/admission/types", "sigs.k8s.io/controller-tools/cmd/controller-gen", "sigs.k8s.io/testing_frameworks/integration", ] diff --git a/cmd/suggestion/v1alpha3/darts/Dockerfile b/cmd/suggestion/v1alpha3/darts/Dockerfile new file mode 100644 index 00000000000..15fdc77ca9e --- /dev/null +++ b/cmd/suggestion/v1alpha3/darts/Dockerfile @@ -0,0 +1,12 @@ +FROM python:3 + +ADD . /usr/src/app/github.com/kubeflow/katib +WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/v1alpha3/darts +RUN pip install --no-cache-dir -r ./requirements.txt \ + && curl -O https://repo.anaconda.com/archive/Anaconda3-5.2.0-Linux-x86_64.sh \ + && sh Anaconda3-5.2.0-Linux-x86_64.sh -b \ + && . ~/.bashrc \ + && conda install -y pytorch-cpu torchvision-cpu -c pytorch +ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/api/suggestion/v1alpha3/python:/usr/src/app/github.com/kubeflow/katib/cmd/suggestion/v1alpha3/darts/cnn + +ENTRYPOINT ["python", "service.py"] diff --git a/cmd/suggestion/v1alpha3/darts/client.py b/cmd/suggestion/v1alpha3/darts/client.py new file mode 100644 index 00000000000..7dc5a240fc8 --- /dev/null +++ b/cmd/suggestion/v1alpha3/darts/client.py @@ -0,0 +1,23 @@ +# sample client for GRPC getSuggestions +import grpc +import time +from concurrent import futures + +from pkg.api.suggestion.v1alpha3.python import oneshot_pb2_grpc +from pkg.api.suggestion.v1alpha3.python import oneshot_pb2 + +ADDRESS = '127.0.0.1:6789' + +def main(): + channel = grpc.insecure_channel(ADDRESS) + stub = oneshot_pb2_grpc.OneshotSuggestionStub(channel) + request = oneshot_pb2.GetOneshotSuggestionRequest() + reply = stub.GetSuggestions(request) + cnt = 0 + for reply_item in reply: + cnt += 1 + print(reply_item) + print(cnt) # how many chunks are received + +if __name__ == "__main__": + main() diff --git a/cmd/suggestion/v1alpha3/darts/cnn/genotypes.py b/cmd/suggestion/v1alpha3/darts/cnn/genotypes.py new file mode 100644 index 00000000000..a7089f41442 --- /dev/null +++ b/cmd/suggestion/v1alpha3/darts/cnn/genotypes.py @@ -0,0 +1,122 @@ +from collections import namedtuple + +Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') + +PRIMITIVES = [ + 'none', + 'max_pool_3x3', + 'avg_pool_3x3', + 'skip_connect', + 'sep_conv_3x3', + 'sep_conv_5x5', + 'dil_conv_3x3', + 'dil_conv_5x5' +] + +NASNet = Genotype( + normal=[ + ('sep_conv_5x5', 1), + ('sep_conv_3x3', 0), + ('sep_conv_5x5', 0), + ('sep_conv_3x3', 0), + ('avg_pool_3x3', 1), + ('skip_connect', 0), + ('avg_pool_3x3', 0), + ('avg_pool_3x3', 0), + ('sep_conv_3x3', 1), + ('skip_connect', 1), + ], + normal_concat=[2, 3, 4, 5, 6], + reduce=[ + ('sep_conv_5x5', 1), + ('sep_conv_7x7', 0), + ('max_pool_3x3', 1), + ('sep_conv_7x7', 0), + ('avg_pool_3x3', 1), + ('sep_conv_5x5', 0), + ('skip_connect', 3), + ('avg_pool_3x3', 2), + ('sep_conv_3x3', 2), + ('max_pool_3x3', 1), + ], + reduce_concat=[4, 5, 6], +) + +AmoebaNet = Genotype( + normal=[ + ('avg_pool_3x3', 0), + ('max_pool_3x3', 1), + ('sep_conv_3x3', 0), + ('sep_conv_5x5', 2), + ('sep_conv_3x3', 0), + ('avg_pool_3x3', 3), + ('sep_conv_3x3', 1), + ('skip_connect', 1), + ('skip_connect', 0), + ('avg_pool_3x3', 1), + ], + normal_concat=[4, 5, 6], + reduce=[ + ('avg_pool_3x3', 0), + ('sep_conv_3x3', 1), + ('max_pool_3x3', 0), + ('sep_conv_7x7', 2), + ('sep_conv_7x7', 0), + ('avg_pool_3x3', 1), + ('max_pool_3x3', 0), + ('max_pool_3x3', 1), + ('conv_7x1_1x7', 0), + ('sep_conv_3x3', 5), + ], + reduce_concat=[3, 4, 6] +) + +DARTS_V1 = Genotype( + normal=[ + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 0), + ('skip_connect', 0), + ('sep_conv_3x3', 1), + ('skip_connect', 0), + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 0), + ('skip_connect', 2) + ], + normal_concat=[2, 3, 4, 5], + reduce=[ + ('max_pool_3x3', 0), + ('max_pool_3x3', 1), + ('skip_connect', 2), + ('max_pool_3x3', 0), + ('max_pool_3x3', 0), + ('skip_connect', 2), + ('skip_connect', 2), + ('avg_pool_3x3', 0) + ], + reduce_concat=[2, 3, 4, 5]) + +DARTS_V2 = Genotype( + normal=[ + ('sep_conv_3x3', 0), + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 0), + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 1), + ('skip_connect', 0), + ('skip_connect', 0), + ('dil_conv_3x3', 2) + ], + normal_concat=[2, 3, 4, 5], + reduce=[ + ('max_pool_3x3', 0), + ('max_pool_3x3', 1), + ('skip_connect', 2), + ('max_pool_3x3', 1), + ('max_pool_3x3', 0), + ('skip_connect', 2), + ('skip_connect', 2), + ('max_pool_3x3', 1) + ], + reduce_concat=[2, 3, 4, 5]) + +DARTS = DARTS_V2 diff --git a/cmd/suggestion/v1alpha3/darts/cnn/model.py b/cmd/suggestion/v1alpha3/darts/cnn/model.py new file mode 100644 index 00000000000..2c1ccd6f551 --- /dev/null +++ b/cmd/suggestion/v1alpha3/darts/cnn/model.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +from operations import * +from torch.autograd import Variable +from utils import drop_path + + +class Cell(nn.Module): + + def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): + super(Cell, self).__init__() + print(C_prev_prev, C_prev, C) + + if reduction_prev: + self.preprocess0 = FactorizedReduce(C_prev_prev, C) + else: + self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) + self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) + + if reduction: + op_names, indices = zip(*genotype.reduce) + concat = genotype.reduce_concat + else: + op_names, indices = zip(*genotype.normal) + concat = genotype.normal_concat + self._compile(C, op_names, indices, concat, reduction) + + def _compile(self, C, op_names, indices, concat, reduction): + assert len(op_names) == len(indices) + self._steps = len(op_names) // 2 + self._concat = concat + self.multiplier = len(concat) + + self._ops = nn.ModuleList() + for name, index in zip(op_names, indices): + stride = 2 if reduction and index < 2 else 1 + op = OPS[name](C, stride, True) + self._ops += [op] + self._indices = indices + + def forward(self, s0, s1, drop_prob): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) + + states = [s0, s1] + for i in range(self._steps): + h1 = states[self._indices[2*i]] + h2 = states[self._indices[2*i+1]] + op1 = self._ops[2*i] + op2 = self._ops[2*i+1] + h1 = op1(h1) + h2 = op2(h2) + if self.training and drop_prob > 0.: + if not isinstance(op1, Identity): + h1 = drop_path(h1, drop_prob) + if not isinstance(op2, Identity): + h2 = drop_path(h2, drop_prob) + s = h1 + h2 + states += [s] + return torch.cat([states[i] for i in self._concat], dim=1) + + +class AuxiliaryHeadCIFAR(nn.Module): + + def __init__(self, C, num_classes): + """assuming input size 8x8""" + super(AuxiliaryHeadCIFAR, self).__init__() + self.features = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 + nn.Conv2d(C, 128, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, 2, bias=False), + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.classifier = nn.Linear(768, num_classes) + + def forward(self, x): + x = self.features(x) + x = self.classifier(x.view(x.size(0),-1)) + return x + + +class NetworkCIFAR(nn.Module): + + def __init__(self, C, num_classes, layers, auxiliary, genotype): + super(NetworkCIFAR, self).__init__() + self._layers = layers + self._auxiliary = auxiliary + + self.drop_path_prob = 0 # add for onnx first + + stem_multiplier = 3 + C_curr = stem_multiplier*C + self.stem = nn.Sequential( + nn.Conv2d(3, C_curr, 3, padding=1, bias=False), + nn.BatchNorm2d(C_curr) + ) + + C_prev_prev, C_prev, C_curr = C_curr, C_curr, C + self.cells = nn.ModuleList() + reduction_prev = False + for i in range(layers): + if i in [layers//3, 2*layers//3]: + C_curr *= 2 + reduction = True + else: + reduction = False + cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + reduction_prev = reduction + self.cells += [cell] + C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr + if i == 2*layers//3: + C_to_auxiliary = C_prev + + if auxiliary: + self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + + def forward(self, input): + logits_aux = None + s0 = s1 = self.stem(input) + for i, cell in enumerate(self.cells): + s0, s1 = s1, cell(s0, s1, self.drop_path_prob) + if i == 2*self._layers//3: + if self._auxiliary and self.training: + logits_aux = self.auxiliary_head(s1) + out = self.global_pooling(s1) + logits = self.classifier(out.view(out.size(0),-1)) + # return logits, logits_aux + # used for onnx exporting, mask the second output when fine-tuning + return logits diff --git a/cmd/suggestion/v1alpha3/darts/cnn/operations.py b/cmd/suggestion/v1alpha3/darts/cnn/operations.py new file mode 100644 index 00000000000..b0c62c57511 --- /dev/null +++ b/cmd/suggestion/v1alpha3/darts/cnn/operations.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn + +OPS = { + 'none' : lambda C, stride, affine: Zero(stride), + 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), + 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), + 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), + 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), + 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), + 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), + 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), + 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), + 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), + nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), + nn.BatchNorm2d(C, affine=affine) + ), +} + +class ReLUConvBN(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super(ReLUConvBN, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.op(x) + +class DilConv(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): + super(DilConv, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), + nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=affine), + ) + + def forward(self, x): + return self.op(x) + + +class SepConv(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super(SepConv, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), + nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_in, affine=affine), + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), + nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=affine), + ) + + def forward(self, x): + return self.op(x) + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class Zero(nn.Module): + + def __init__(self, stride): + super(Zero, self).__init__() + self.stride = stride + + def forward(self, x): + if self.stride == 1: + return x.mul(0.) + return x[:,:,::self.stride,::self.stride].mul(0.) + + +class FactorizedReduce(nn.Module): + + def __init__(self, C_in, C_out, affine=True): + super(FactorizedReduce, self).__init__() + assert C_out % 2 == 0 + self.relu = nn.ReLU(inplace=False) + self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + x = self.relu(x) + out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) + out = self.bn(out) + return out + diff --git a/cmd/suggestion/v1alpha3/darts/cnn/train.py b/cmd/suggestion/v1alpha3/darts/cnn/train.py new file mode 100644 index 00000000000..fbc43429429 --- /dev/null +++ b/cmd/suggestion/v1alpha3/darts/cnn/train.py @@ -0,0 +1,59 @@ +import os +import sys +import time +import glob +import numpy as np +import torch +import utils +import logging +import argparse +import torch.nn as nn +import genotypes +import torch.utils +import torchvision.datasets as dset + +from torch.autograd import Variable +from model import NetworkCIFAR as Network + +import faulthandler +import torch.onnx + + +CIFAR_CLASSES = 10 + +class DartsTrain(): + # TODO(AnchovYu): modify API + def __init__(self, init_channels=36, layers=8, auxiliary=False, save='EXP', seed=0, arch='DARTS'): + self.init_channels_ = init_channels + self.layers_ = layers + self.auxiliary_ = auxiliary + self.save_ = 'eval-{}-{}'.format(save, time.strftime("%Y%m%d-%H%M%S")) + self.seed_ = seed + self.arch_ = arch + + utils.create_exp_dir(self.save_, scripts_to_save=glob.glob('*.py')) + log_format = '%(asctime)s %(message)s' + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + fh = logging.FileHandler(os.path.join(self.save_, 'log.txt')) + fh.setFormatter(logging.Formatter(log_format)) + logging.getLogger().addHandler(fh) + + def export_darts(self, onnx_export_path='./model/darts.proto'): + faulthandler.enable() + + np.random.seed(self.seed_) + torch.manual_seed(self.seed_) + + # load trained model, without trained weights.. + genotype = eval("genotypes.%s" % self.arch_) + model = Network(self.init_channels_, CIFAR_CLASSES, self.layers_, self.auxiliary_, genotype) + + # export to onnx + dummy_input = Variable(torch.randn(96, 3, 32, 32)) + torch.onnx.export(model, dummy_input, onnx_export_path, verbose=True) + + + + + diff --git a/cmd/suggestion/v1alpha3/darts/cnn/utils.py b/cmd/suggestion/v1alpha3/darts/cnn/utils.py new file mode 100644 index 00000000000..38d05a20668 --- /dev/null +++ b/cmd/suggestion/v1alpha3/darts/cnn/utils.py @@ -0,0 +1,121 @@ +import os +import numpy as np +import torch +import shutil +import torchvision.transforms as transforms +from torch.autograd import Variable + + +class AvgrageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +def accuracy(output, target, topk=(1,)): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0/batch_size)) + return res + + +class Cutout(object): + def __init__(self, length): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + return img + + +def _data_transforms_cifar10(args): + CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] + CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] + + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(CIFAR_MEAN, CIFAR_STD), + ]) + if args.cutout: + train_transform.transforms.append(Cutout(args.cutout_length)) + + valid_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(CIFAR_MEAN, CIFAR_STD), + ]) + return train_transform, valid_transform + + +def count_parameters_in_MB(model): + return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 + + +def save_checkpoint(state, is_best, save): + filename = os.path.join(save, 'checkpoint.pth.tar') + torch.save(state, filename) + if is_best: + best_filename = os.path.join(save, 'model_best.pth.tar') + shutil.copyfile(filename, best_filename) + + +def save(model, model_path): + torch.save(model.state_dict(), model_path) + + +def load(model, model_path): + model.load_state_dict(torch.load(model_path)) + + +def drop_path(x, drop_prob): + if drop_prob > 0.: + keep_prob = 1.-drop_prob + mask = Variable(torch.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) + x.div_(keep_prob) + x.mul_(mask) + return x + + +def create_exp_dir(path, scripts_to_save=None): + if not os.path.exists(path): + os.mkdir(path) + print('Experiment dir : {}'.format(path)) + + if scripts_to_save is not None: + os.mkdir(os.path.join(path, 'scripts')) + for script in scripts_to_save: + dst_file = os.path.join(path, 'scripts', os.path.basename(script)) + shutil.copyfile(script, dst_file) + diff --git a/cmd/suggestion/v1alpha3/darts/model/darts.proto b/cmd/suggestion/v1alpha3/darts/model/darts.proto new file mode 100644 index 00000000000..47f8dac6a3f Binary files /dev/null and b/cmd/suggestion/v1alpha3/darts/model/darts.proto differ diff --git a/cmd/suggestion/v1alpha3/darts/requirements.txt b/cmd/suggestion/v1alpha3/darts/requirements.txt new file mode 100644 index 00000000000..8d2c9d4bda7 --- /dev/null +++ b/cmd/suggestion/v1alpha3/darts/requirements.txt @@ -0,0 +1,9 @@ +grpcio +duecredit +cloudpickle==0.5.6 +numpy>=1.13.3 +scikit-learn>=0.19.0 +scipy>=0.19.1 +forestci +protobuf +googleapis-common-protos diff --git a/cmd/suggestion/v1alpha3/darts/service.py b/cmd/suggestion/v1alpha3/darts/service.py new file mode 100644 index 00000000000..4ba98746224 --- /dev/null +++ b/cmd/suggestion/v1alpha3/darts/service.py @@ -0,0 +1,53 @@ +# sample suggestion server for darts, +# which should written by user and run in suggestion resource + +import grpc +import time +from concurrent import futures + +from pkg.api.suggestion.v1alpha3.python import oneshot_pb2_grpc +from pkg.api.suggestion.v1alpha3.python import oneshot_pb2 +from pkg.suggestion.v1alpha1.types import DEFAULT_PORT +from cnn import train + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 +CHUNK_SIZE = 64 * 1024 # 64 K +LISTEN_ADDRESS = '127.0.0.1:6789' + + +class DartsService(oneshot_pb2_grpc.OneshotSuggestionServicer): + def __init__(self): + self.train_ = train.DartsTrain() + + def GetSuggestions(self, request, context): + # do the training (fake) and export darts + self.train_.export_darts() + + with open("./model/darts.proto", "rb") as f: + content = f.read(CHUNK_SIZE) + yield oneshot_pb2.GetOneshotSuggestionReply(onnx_model=content) + + while content != b"": + # Do stuff with byte. + content = f.read(CHUNK_SIZE) + yield oneshot_pb2.GetOneshotSuggestionReply(onnx_model=content) + + +def serve(): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + oneshot_pb2_grpc.add_OneshotSuggestionServicer_to_server(DartsService(), server) + server.add_insecure_port(LISTEN_ADDRESS) + print("Listening...") + server.start() + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + serve() + + + + diff --git a/manifests/v1alpha3/sample/suggestion.yaml b/manifests/v1alpha3/sample/suggestion.yaml new file mode 100644 index 00000000000..38078c73673 --- /dev/null +++ b/manifests/v1alpha3/sample/suggestion.yaml @@ -0,0 +1,21 @@ +apiVersion: kubeflow.org/v1alpha2 +kind: Suggestion +metadata: + labels: + controller-tools.k8s.io: "1.0" + name: sample +spec: + replicas: 1 + type: HyperParameter + needHistory: false + template: + metadata: + name: suggestion-random + spec: + containers: + - name: suggestion-random + image: katib/suggestion-random:latest + + + + diff --git a/pkg/api/suggestion/v1alpha3/oneshot.pb.go b/pkg/api/suggestion/v1alpha3/oneshot.pb.go new file mode 100644 index 00000000000..b447b0dc6e3 --- /dev/null +++ b/pkg/api/suggestion/v1alpha3/oneshot.pb.go @@ -0,0 +1,181 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: oneshot.proto + +/* +Package api_v1_alpha3 is a generated protocol buffer package. + +It is generated from these files: + oneshot.proto + +It has these top-level messages: + GetOneshotSuggestionRequest + GetOneshotSuggestionReply +*/ +package api_v1_alpha3 + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import _ "google.golang.org/genproto/googleapis/api/annotations" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type GetOneshotSuggestionRequest struct { +} + +func (m *GetOneshotSuggestionRequest) Reset() { *m = GetOneshotSuggestionRequest{} } +func (m *GetOneshotSuggestionRequest) String() string { return proto.CompactTextString(m) } +func (*GetOneshotSuggestionRequest) ProtoMessage() {} +func (*GetOneshotSuggestionRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +type GetOneshotSuggestionReply struct { + OnnxModel []byte `protobuf:"bytes,1,opt,name=onnx_model,json=onnxModel,proto3" json:"onnx_model,omitempty"` +} + +func (m *GetOneshotSuggestionReply) Reset() { *m = GetOneshotSuggestionReply{} } +func (m *GetOneshotSuggestionReply) String() string { return proto.CompactTextString(m) } +func (*GetOneshotSuggestionReply) ProtoMessage() {} +func (*GetOneshotSuggestionReply) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func (m *GetOneshotSuggestionReply) GetOnnxModel() []byte { + if m != nil { + return m.OnnxModel + } + return nil +} + +func init() { + proto.RegisterType((*GetOneshotSuggestionRequest)(nil), "api.v1.alpha3.GetOneshotSuggestionRequest") + proto.RegisterType((*GetOneshotSuggestionReply)(nil), "api.v1.alpha3.GetOneshotSuggestionReply") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for OneshotSuggestion service + +type OneshotSuggestionClient interface { + GetSuggestions(ctx context.Context, in *GetOneshotSuggestionRequest, opts ...grpc.CallOption) (OneshotSuggestion_GetSuggestionsClient, error) +} + +type oneshotSuggestionClient struct { + cc *grpc.ClientConn +} + +func NewOneshotSuggestionClient(cc *grpc.ClientConn) OneshotSuggestionClient { + return &oneshotSuggestionClient{cc} +} + +func (c *oneshotSuggestionClient) GetSuggestions(ctx context.Context, in *GetOneshotSuggestionRequest, opts ...grpc.CallOption) (OneshotSuggestion_GetSuggestionsClient, error) { + stream, err := grpc.NewClientStream(ctx, &_OneshotSuggestion_serviceDesc.Streams[0], c.cc, "/api.v1.alpha3.OneshotSuggestion/GetSuggestions", opts...) + if err != nil { + return nil, err + } + x := &oneshotSuggestionGetSuggestionsClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type OneshotSuggestion_GetSuggestionsClient interface { + Recv() (*GetOneshotSuggestionReply, error) + grpc.ClientStream +} + +type oneshotSuggestionGetSuggestionsClient struct { + grpc.ClientStream +} + +func (x *oneshotSuggestionGetSuggestionsClient) Recv() (*GetOneshotSuggestionReply, error) { + m := new(GetOneshotSuggestionReply) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for OneshotSuggestion service + +type OneshotSuggestionServer interface { + GetSuggestions(*GetOneshotSuggestionRequest, OneshotSuggestion_GetSuggestionsServer) error +} + +func RegisterOneshotSuggestionServer(s *grpc.Server, srv OneshotSuggestionServer) { + s.RegisterService(&_OneshotSuggestion_serviceDesc, srv) +} + +func _OneshotSuggestion_GetSuggestions_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(GetOneshotSuggestionRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(OneshotSuggestionServer).GetSuggestions(m, &oneshotSuggestionGetSuggestionsServer{stream}) +} + +type OneshotSuggestion_GetSuggestionsServer interface { + Send(*GetOneshotSuggestionReply) error + grpc.ServerStream +} + +type oneshotSuggestionGetSuggestionsServer struct { + grpc.ServerStream +} + +func (x *oneshotSuggestionGetSuggestionsServer) Send(m *GetOneshotSuggestionReply) error { + return x.ServerStream.SendMsg(m) +} + +var _OneshotSuggestion_serviceDesc = grpc.ServiceDesc{ + ServiceName: "api.v1.alpha3.OneshotSuggestion", + HandlerType: (*OneshotSuggestionServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "GetSuggestions", + Handler: _OneshotSuggestion_GetSuggestions_Handler, + ServerStreams: true, + }, + }, + Metadata: "oneshot.proto", +} + +func init() { proto.RegisterFile("oneshot.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 180 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcd, 0xcf, 0x4b, 0x2d, + 0xce, 0xc8, 0x2f, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x4d, 0x2c, 0xc8, 0xd4, 0x2b, + 0x33, 0xd4, 0x4b, 0xcc, 0x29, 0xc8, 0x48, 0x34, 0x96, 0x92, 0x49, 0xcf, 0xcf, 0x4f, 0xcf, 0x49, + 0xd5, 0x4f, 0x2c, 0xc8, 0xd4, 0x4f, 0xcc, 0xcb, 0xcb, 0x2f, 0x49, 0x2c, 0xc9, 0xcc, 0xcf, 0x2b, + 0x86, 0x28, 0x56, 0x92, 0xe5, 0x92, 0x76, 0x4f, 0x2d, 0xf1, 0x87, 0x18, 0x10, 0x5c, 0x9a, 0x9e, + 0x9e, 0x5a, 0x0c, 0x92, 0x0e, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x51, 0xb2, 0xe2, 0x92, 0xc4, + 0x2e, 0x5d, 0x90, 0x53, 0x29, 0x24, 0xcb, 0xc5, 0x95, 0x9f, 0x97, 0x57, 0x11, 0x9f, 0x9b, 0x9f, + 0x92, 0x9a, 0x23, 0xc1, 0xa8, 0xc0, 0xa8, 0xc1, 0x13, 0xc4, 0x09, 0x12, 0xf1, 0x05, 0x09, 0x18, + 0xd5, 0x72, 0x09, 0x62, 0x68, 0x14, 0xca, 0xe0, 0xe2, 0x73, 0x4f, 0x45, 0x12, 0x28, 0x16, 0xd2, + 0xd2, 0x43, 0x71, 0xaf, 0x1e, 0x1e, 0xe7, 0x48, 0x69, 0x10, 0xa5, 0xb6, 0x20, 0xa7, 0xd2, 0x80, + 0x31, 0x89, 0x0d, 0xec, 0x41, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, 0x90, 0x33, 0x0a, 0xd1, + 0x1e, 0x01, 0x00, 0x00, +} diff --git a/pkg/api/suggestion/v1alpha3/oneshot.proto b/pkg/api/suggestion/v1alpha3/oneshot.proto new file mode 100644 index 00000000000..5cabfb3efa4 --- /dev/null +++ b/pkg/api/suggestion/v1alpha3/oneshot.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package api.v1.alpha3; + +import "google/api/annotations.proto"; + +service OneshotSuggestion { + rpc GetSuggestions(GetOneshotSuggestionRequest) returns (stream GetOneshotSuggestionReply); +} + +message GetOneshotSuggestionRequest { +} + +message GetOneshotSuggestionReply { + bytes onnx_model = 1; +} + + diff --git a/pkg/api/suggestion/v1alpha3/oneshot.swagger.json b/pkg/api/suggestion/v1alpha3/oneshot.swagger.json new file mode 100644 index 00000000000..07982684329 --- /dev/null +++ b/pkg/api/suggestion/v1alpha3/oneshot.swagger.json @@ -0,0 +1,29 @@ +{ + "swagger": "2.0", + "info": { + "title": "oneshot.proto", + "version": "version not set" + }, + "schemes": [ + "http", + "https" + ], + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "paths": {}, + "definitions": { + "alpha3GetOneshotSuggestionReply": { + "type": "object", + "properties": { + "onnx_model": { + "type": "string", + "format": "byte" + } + } + } + } +} diff --git a/pkg/api/suggestion/v1alpha3/python/oneshot_pb2.py b/pkg/api/suggestion/v1alpha3/python/oneshot_pb2.py new file mode 100644 index 00000000000..2d07136e110 --- /dev/null +++ b/pkg/api/suggestion/v1alpha3/python/oneshot_pb2.py @@ -0,0 +1,245 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: oneshot.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='oneshot.proto', + package='api.v1.alpha3', + syntax='proto3', + serialized_pb=_b('\n\roneshot.proto\x12\rapi.v1.alpha3\x1a\x1cgoogle/api/annotations.proto\"\x1d\n\x1bGetOneshotSuggestionRequest\"/\n\x19GetOneshotSuggestionReply\x12\x12\n\nonnx_model\x18\x01 \x01(\x0c\x32}\n\x11OneshotSuggestion\x12h\n\x0eGetSuggestions\x12*.api.v1.alpha3.GetOneshotSuggestionRequest\x1a(.api.v1.alpha3.GetOneshotSuggestionReply0\x01\x62\x06proto3') + , + dependencies=[google_dot_api_dot_annotations__pb2.DESCRIPTOR,]) + + + + +_GETONESHOTSUGGESTIONREQUEST = _descriptor.Descriptor( + name='GetOneshotSuggestionRequest', + full_name='api.v1.alpha3.GetOneshotSuggestionRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=62, + serialized_end=91, +) + + +_GETONESHOTSUGGESTIONREPLY = _descriptor.Descriptor( + name='GetOneshotSuggestionReply', + full_name='api.v1.alpha3.GetOneshotSuggestionReply', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='onnx_model', full_name='api.v1.alpha3.GetOneshotSuggestionReply.onnx_model', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=93, + serialized_end=140, +) + +DESCRIPTOR.message_types_by_name['GetOneshotSuggestionRequest'] = _GETONESHOTSUGGESTIONREQUEST +DESCRIPTOR.message_types_by_name['GetOneshotSuggestionReply'] = _GETONESHOTSUGGESTIONREPLY +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +GetOneshotSuggestionRequest = _reflection.GeneratedProtocolMessageType('GetOneshotSuggestionRequest', (_message.Message,), dict( + DESCRIPTOR = _GETONESHOTSUGGESTIONREQUEST, + __module__ = 'oneshot_pb2' + # @@protoc_insertion_point(class_scope:api.v1.alpha3.GetOneshotSuggestionRequest) + )) +_sym_db.RegisterMessage(GetOneshotSuggestionRequest) + +GetOneshotSuggestionReply = _reflection.GeneratedProtocolMessageType('GetOneshotSuggestionReply', (_message.Message,), dict( + DESCRIPTOR = _GETONESHOTSUGGESTIONREPLY, + __module__ = 'oneshot_pb2' + # @@protoc_insertion_point(class_scope:api.v1.alpha3.GetOneshotSuggestionReply) + )) +_sym_db.RegisterMessage(GetOneshotSuggestionReply) + + + +_ONESHOTSUGGESTION = _descriptor.ServiceDescriptor( + name='OneshotSuggestion', + full_name='api.v1.alpha3.OneshotSuggestion', + file=DESCRIPTOR, + index=0, + options=None, + serialized_start=142, + serialized_end=267, + methods=[ + _descriptor.MethodDescriptor( + name='GetSuggestions', + full_name='api.v1.alpha3.OneshotSuggestion.GetSuggestions', + index=0, + containing_service=None, + input_type=_GETONESHOTSUGGESTIONREQUEST, + output_type=_GETONESHOTSUGGESTIONREPLY, + options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_ONESHOTSUGGESTION) + +DESCRIPTOR.services_by_name['OneshotSuggestion'] = _ONESHOTSUGGESTION + +try: + # THESE ELEMENTS WILL BE DEPRECATED. + # Please use the generated *_pb2_grpc.py files instead. + import grpc + from grpc.beta import implementations as beta_implementations + from grpc.beta import interfaces as beta_interfaces + from grpc.framework.common import cardinality + from grpc.framework.interfaces.face import utilities as face_utilities + + + class OneshotSuggestionStub(object): + # missing associated documentation comment in .proto file + pass + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GetSuggestions = channel.unary_stream( + '/api.v1.alpha3.OneshotSuggestion/GetSuggestions', + request_serializer=GetOneshotSuggestionRequest.SerializeToString, + response_deserializer=GetOneshotSuggestionReply.FromString, + ) + + + class OneshotSuggestionServicer(object): + # missing associated documentation comment in .proto file + pass + + def GetSuggestions(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + + def add_OneshotSuggestionServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GetSuggestions': grpc.unary_stream_rpc_method_handler( + servicer.GetSuggestions, + request_deserializer=GetOneshotSuggestionRequest.FromString, + response_serializer=GetOneshotSuggestionReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'api.v1.alpha3.OneshotSuggestion', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + class BetaOneshotSuggestionServicer(object): + """The Beta API is deprecated for 0.15.0 and later. + + It is recommended to use the GA API (classes and functions in this + file not marked beta) for all further purposes. This class was generated + only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0.""" + # missing associated documentation comment in .proto file + pass + def GetSuggestions(self, request, context): + # missing associated documentation comment in .proto file + pass + context.code(beta_interfaces.StatusCode.UNIMPLEMENTED) + + + class BetaOneshotSuggestionStub(object): + """The Beta API is deprecated for 0.15.0 and later. + + It is recommended to use the GA API (classes and functions in this + file not marked beta) for all further purposes. This class was generated + only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0.""" + # missing associated documentation comment in .proto file + pass + def GetSuggestions(self, request, timeout, metadata=None, with_call=False, protocol_options=None): + # missing associated documentation comment in .proto file + pass + raise NotImplementedError() + + + def beta_create_OneshotSuggestion_server(servicer, pool=None, pool_size=None, default_timeout=None, maximum_timeout=None): + """The Beta API is deprecated for 0.15.0 and later. + + It is recommended to use the GA API (classes and functions in this + file not marked beta) for all further purposes. This function was + generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0""" + request_deserializers = { + ('api.v1.alpha3.OneshotSuggestion', 'GetSuggestions'): GetOneshotSuggestionRequest.FromString, + } + response_serializers = { + ('api.v1.alpha3.OneshotSuggestion', 'GetSuggestions'): GetOneshotSuggestionReply.SerializeToString, + } + method_implementations = { + ('api.v1.alpha3.OneshotSuggestion', 'GetSuggestions'): face_utilities.unary_stream_inline(servicer.GetSuggestions), + } + server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout) + return beta_implementations.server(method_implementations, options=server_options) + + + def beta_create_OneshotSuggestion_stub(channel, host=None, metadata_transformer=None, pool=None, pool_size=None): + """The Beta API is deprecated for 0.15.0 and later. + + It is recommended to use the GA API (classes and functions in this + file not marked beta) for all further purposes. This function was + generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0""" + request_serializers = { + ('api.v1.alpha3.OneshotSuggestion', 'GetSuggestions'): GetOneshotSuggestionRequest.SerializeToString, + } + response_deserializers = { + ('api.v1.alpha3.OneshotSuggestion', 'GetSuggestions'): GetOneshotSuggestionReply.FromString, + } + cardinalities = { + 'GetSuggestions': cardinality.Cardinality.UNARY_STREAM, + } + stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size) + return beta_implementations.dynamic_stub(channel, 'api.v1.alpha3.OneshotSuggestion', cardinalities, options=stub_options) +except ImportError: + pass +# @@protoc_insertion_point(module_scope) diff --git a/pkg/api/suggestion/v1alpha3/python/oneshot_pb2_grpc.py b/pkg/api/suggestion/v1alpha3/python/oneshot_pb2_grpc.py new file mode 100644 index 00000000000..9f0ca68538a --- /dev/null +++ b/pkg/api/suggestion/v1alpha3/python/oneshot_pb2_grpc.py @@ -0,0 +1,46 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +import oneshot_pb2 as oneshot__pb2 + + +class OneshotSuggestionStub(object): + # missing associated documentation comment in .proto file + pass + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GetSuggestions = channel.unary_stream( + '/api.v1.alpha3.OneshotSuggestion/GetSuggestions', + request_serializer=oneshot__pb2.GetOneshotSuggestionRequest.SerializeToString, + response_deserializer=oneshot__pb2.GetOneshotSuggestionReply.FromString, + ) + + +class OneshotSuggestionServicer(object): + # missing associated documentation comment in .proto file + pass + + def GetSuggestions(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_OneshotSuggestionServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GetSuggestions': grpc.unary_stream_rpc_method_handler( + servicer.GetSuggestions, + request_deserializer=oneshot__pb2.GetOneshotSuggestionRequest.FromString, + response_serializer=oneshot__pb2.GetOneshotSuggestionReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'api.v1.alpha3.OneshotSuggestion', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/pkg/controller/v1alpha3/recorder/recorder.go b/pkg/controller/v1alpha3/recorder/recorder.go index 4d762746503..44e6995eeec 100644 --- a/pkg/controller/v1alpha3/recorder/recorder.go +++ b/pkg/controller/v1alpha3/recorder/recorder.go @@ -12,6 +12,7 @@ import ( type Recorder interface { record.EventRecorder ReportChange(o runtime.Object, operator, typ string) + ReportError(o runtime.Object, reason, message string) } // GeneralRecorder is the general-purpose implementation of Recorder. @@ -25,6 +26,10 @@ func (g *GeneralRecorder) ReportChange(o runtime.Object, operator, typ string) { g.EventRecorder.Event(o, corev1.EventTypeNormal, "SuccessfulChange", msg) } +func (g *GeneralRecorder) ReportError(o runtime.Object, reason, message string) { + g.EventRecorder.Event(o, corev1.EventTypeWarning, reason, message) +} + // New returns a new Recorder. func New(r record.EventRecorder) Recorder { return &GeneralRecorder{ diff --git a/pkg/controller/v1alpha3/suggestion/clientset/deployment.go b/pkg/controller/v1alpha3/suggestion/clientset/deployment.go new file mode 100644 index 00000000000..aa641ab988e --- /dev/null +++ b/pkg/controller/v1alpha3/suggestion/clientset/deployment.go @@ -0,0 +1,61 @@ +package clientset + +import ( + "context" + + suggestionv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/suggestion/v1alpha2" + "github.com/kubeflow/katib/pkg/controller/v1alpha3/recorder" + "github.com/kubeflow/katib/pkg/controller/v1alpha3/util" + appsv1 "k8s.io/api/apps/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + logf "sigs.k8s.io/controller-runtime/pkg/runtime/log" +) + +const ( + clientsetName = "suggestion-deployment-clientset" +) + +var log = logf.Log.WithName(clientsetName) + +type DeploymentClient struct { + client.Client + recorder.Recorder +} + +func New(c client.Client, r recorder.Recorder) DeploymentClient { + return DeploymentClient{ + Client: c, + Recorder: r, + } +} + +func (dc *DeploymentClient) CreateOrUpdateDeployment(suggestion *suggestionv1alpha2.Suggestion, desired *appsv1.Deployment) error { + found := &appsv1.Deployment{} + err := dc.Get(context.TODO(), types.NamespacedName{ + Name: desired.Name, + Namespace: desired.Namespace, + }, found) + if err != nil && errors.IsNotFound(err) { + log.Info("Creating Deployment", "namespace", desired.Namespace, "name", desired.Name) + if err = dc.Create(context.TODO(), desired); err != nil { + return err + } + dc.ReportChange(suggestion, util.FlagCreate, util.TypeDeployment) + } else if err != nil { + return err + } + + // TODO(anchovYu): check and update + // if !reflect.DeepEqual(desired.Spec, found.Spec) { + // // found.Spec = desired.Spec + // log.Info("Updating Deployment", "namespace", desired.Namespace, "name", desired.Name) + // if err = dc.Update(context.TODO(), desired); err != nil { + // return err + // } + // dc.ReportChange(suggestion, util.FlagUpdate, util.TypeDeployment) + // } + desired.Status = found.Status + return nil +} diff --git a/pkg/controller/v1alpha3/suggestion/condition.go b/pkg/controller/v1alpha3/suggestion/condition.go new file mode 100644 index 00000000000..ed6c06ef9fd --- /dev/null +++ b/pkg/controller/v1alpha3/suggestion/condition.go @@ -0,0 +1,55 @@ +package suggestion + +import ( + suggestionv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/suggestion/v1alpha2" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func createOrUpdateCondition(suggestionStatus *suggestionv1alpha2.SuggestionStatus, + conditionType suggestionv1alpha2.SuggestionConditionType, + conditionStatus corev1.ConditionStatus) { + createOrUpdateConditionWithReason(suggestionStatus, conditionType, conditionStatus, "", "") +} + +func createOrUpdateConditionWithReason(suggestionStatus *suggestionv1alpha2.SuggestionStatus, + conditionType suggestionv1alpha2.SuggestionConditionType, + conditionStatus corev1.ConditionStatus, + reason, message string) { + conditions := suggestionStatus.Conditions + for i, cond := range conditions { + if cond.Type == conditionType { + updateCondition(&suggestionStatus.Conditions[i], conditionStatus, reason, message) + return + } + } + c := createCondition(conditionType, conditionStatus, reason, message) + suggestionStatus.Conditions = append(suggestionStatus.Conditions, c) + +} + +func createCondition(conditionType suggestionv1alpha2.SuggestionConditionType, + status corev1.ConditionStatus, + reason string, + message string) suggestionv1alpha2.SuggestionCondition { + return suggestionv1alpha2.SuggestionCondition{ + Type: conditionType, + Status: status, + Reason: reason, + Message: message, + LastUpdateTime: metav1.Now(), + LastTransitionTime: metav1.Now(), + } +} + +func updateCondition(condition *suggestionv1alpha2.SuggestionCondition, + status corev1.ConditionStatus, reason string, message string) { + if condition.Status != status { + condition.LastTransitionTime = metav1.Now() + } + condition.Status = status + condition.Reason = reason + condition.Message = message + condition.LastUpdateTime = metav1.Now() + +} diff --git a/pkg/controller/v1alpha3/suggestion/event.go b/pkg/controller/v1alpha3/suggestion/event.go new file mode 100644 index 00000000000..a1bf8628afa --- /dev/null +++ b/pkg/controller/v1alpha3/suggestion/event.go @@ -0,0 +1,20 @@ +package suggestion + +import ( + "fmt" + + suggestionv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/suggestion/v1alpha2" + corev1 "k8s.io/api/core/v1" +) + +func (r *ReconcileSuggestion) reportChange(s *suggestionv1alpha2.Suggestion, operator, typ string) { + msg := fmt.Sprintf("%s the %s", operator, typ) + log.Info(msg) + r.Recorder.ReportChange(s, operator, typ) +} + +func (r *ReconcileSuggestion) reportError(s *suggestionv1alpha2.Suggestion, err error, reason, msg string) { + log.Error(err, msg) + r.Recorder.ReportError(s, reason, err.Error()) + createOrUpdateConditionWithReason(&s.Status, suggestionv1alpha2.SuggestionDeploymentAvailable, corev1.ConditionFalse, reason, msg) +} diff --git a/pkg/controller/v1alpha3/suggestion/handle.go b/pkg/controller/v1alpha3/suggestion/handle.go new file mode 100644 index 00000000000..93a79e965d7 --- /dev/null +++ b/pkg/controller/v1alpha3/suggestion/handle.go @@ -0,0 +1,52 @@ +package suggestion + +import ( + suggestionsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/suggestion/v1alpha2" + "github.com/kubeflow/katib/pkg/controller/v1alpha3/util" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +func (r *ReconcileSuggestion) handle(suggestion *suggestionsv1alpha2.Suggestion) (result reconcile.Result, err error) { + oldStatus := suggestion.Status.DeepCopy() + result = reconcile.Result{} + logger := log.WithName(types.NamespacedName{ + Namespace: suggestion.Namespace, + Name: suggestion.Name, + }.String()) + + defer func() { + err = r.updateStatusIfChanged(oldStatus, suggestion) + }() + + desired, err := getDesiredDeployment(suggestion) + if err != nil { + r.reportError(suggestion, err, util.FailReason, "Failed to get desired deployment") + return result, err + } + logger.V(0).Info("OK: Get desired deployment of suggestion") + + if err = controllerutil.SetControllerReference(suggestion, desired, r.scheme); err != nil { + r.reportError(suggestion, err, util.FailReason, "Failed to set controller reference") + return result, err + } + logger.V(0).Info("OK: Set controller reference between deployment and suggestion") + + // if suggestion spec changes, create or update deployment + // desired deployment status is updated + if err = r.CreateOrUpdateDeployment(suggestion, desired); err != nil { + r.reportError(suggestion, err, util.FailReason, "Failed to create or update deployment of suggestion") + return result, err + } + logger.V(0).Info("OK: Create or update deployment") + + // if deployment changes, sync status of suggestion + if err = r.syncStatus(&desired.Status, suggestion); err != nil { + r.reportError(suggestion, err, util.FailReason, "Failed to sync status of suggestion") + return result, err + } + logger.V(0).Info("OK: Sync status of suggestion") + + return result, err +} diff --git a/pkg/controller/v1alpha3/suggestion/status.go b/pkg/controller/v1alpha3/suggestion/status.go new file mode 100644 index 00000000000..92aa35ed007 --- /dev/null +++ b/pkg/controller/v1alpha3/suggestion/status.go @@ -0,0 +1,33 @@ +package suggestion + +import ( + "context" + + appsv1 "k8s.io/api/apps/v1" + "k8s.io/apimachinery/pkg/api/equality" + + suggestionsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/suggestion/v1alpha2" + suggestionv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/suggestion/v1alpha2" +) + +func (r *ReconcileSuggestion) syncStatus(deployStatus *appsv1.DeploymentStatus, suggestion *suggestionv1alpha2.Suggestion) error { + for _, cond := range deployStatus.Conditions { + if cond.Type == appsv1.DeploymentAvailable { + createOrUpdateCondition(&suggestion.Status, suggestionv1alpha2.SuggestionDeploymentAvailable, cond.Status) + } else if cond.Type == appsv1.DeploymentProgressing { + createOrUpdateCondition(&suggestion.Status, suggestionv1alpha2.SuggestionDeploymentProgressing, cond.Status) + } else if cond.Type == appsv1.DeploymentReplicaFailure { + createOrUpdateCondition(&suggestion.Status, suggestionv1alpha2.SuggestionDeploymentReplicaFailure, cond.Status) + } + } + return nil +} + +func (r *ReconcileSuggestion) updateStatusIfChanged(oldStatus *suggestionv1alpha2.SuggestionStatus, + suggestion *suggestionsv1alpha2.Suggestion) error { + if !equality.Semantic.DeepEqual(oldStatus, &suggestion.Status) { + return r.Status().Update(context.TODO(), suggestion) + } + return nil + +} diff --git a/pkg/controller/v1alpha3/suggestion/suggestion_controller.go b/pkg/controller/v1alpha3/suggestion/suggestion_controller.go index 813de264ded..1d65f40a46d 100644 --- a/pkg/controller/v1alpha3/suggestion/suggestion_controller.go +++ b/pkg/controller/v1alpha3/suggestion/suggestion_controller.go @@ -18,17 +18,12 @@ package suggestion import ( "context" - "reflect" appsv1 "k8s.io/api/apps/v1" - corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" - "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/reconcile" @@ -36,14 +31,15 @@ import ( "sigs.k8s.io/controller-runtime/pkg/source" suggestionsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/suggestion/v1alpha2" + "github.com/kubeflow/katib/pkg/controller/v1alpha3/recorder" + "github.com/kubeflow/katib/pkg/controller/v1alpha3/suggestion/clientset" ) -var log = logf.Log.WithName("controller") +const ( + controllerName = "suggestion-controller" +) -/** -* USER ACTION REQUIRED: This is a scaffold file intended for the user to modify with their own Controller -* business logic. Delete these comments after modifying this file.* - */ +var log = logf.Log.WithName(controllerName) // Add creates a new Suggestion Controller and adds it to the Manager with default RBAC. The Manager will set fields on the Controller // and Start it when the Manager is Started. @@ -53,13 +49,19 @@ func Add(mgr manager.Manager) error { // newReconciler returns a new reconcile.Reconciler func newReconciler(mgr manager.Manager) reconcile.Reconciler { - return &ReconcileSuggestion{Client: mgr.GetClient(), scheme: mgr.GetScheme()} + r := &ReconcileSuggestion{ + Client: mgr.GetClient(), + Recorder: recorder.New(mgr.GetRecorder(controllerName)), + scheme: mgr.GetScheme(), + } + r.DeploymentClient = clientset.New(r.Client, r.Recorder) + return r } // add adds a new Controller to mgr with r as the reconcile.Reconciler func add(mgr manager.Manager, r reconcile.Reconciler) error { // Create a new controller - c, err := controller.New("suggestion-controller", mgr, controller.Options{Reconciler: r}) + c, err := controller.New(controllerName, mgr, controller.Options{Reconciler: r}) if err != nil { return err } @@ -70,8 +72,7 @@ func add(mgr manager.Manager, r reconcile.Reconciler) error { return err } - // TODO(user): Modify this to be the types you create - // Uncomment watch a Deployment created by Suggestion - change this for objects you create + // Watch for changes to Deployment created by Suggestion err = c.Watch(&source.Kind{Type: &appsv1.Deployment{}}, &handler.EnqueueRequestForOwner{ IsController: true, OwnerType: &suggestionsv1alpha2.Suggestion{}, @@ -88,13 +89,13 @@ var _ reconcile.Reconciler = &ReconcileSuggestion{} // ReconcileSuggestion reconciles a Suggestion object type ReconcileSuggestion struct { client.Client + recorder.Recorder scheme *runtime.Scheme + clientset.DeploymentClient } // Reconcile reads that state of the cluster for a Suggestion object and makes changes based on the state read // and what is in the Suggestion.Spec -// TODO(user): Modify this Reconcile function to implement your Controller logic. The scaffolding writes -// a Deployment as an example // Automatically generate RBAC rules to allow the Controller to read and write Deployments // +kubebuilder:rbac:groups=apps,resources=deployments,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=apps,resources=deployments/status,verbs=get;update;patch @@ -114,55 +115,5 @@ func (r *ReconcileSuggestion) Reconcile(request reconcile.Request) (reconcile.Re return reconcile.Result{}, err } - // TODO(user): Change this to be the object type created by your controller - // Define the desired Deployment object - deploy := &appsv1.Deployment{ - ObjectMeta: metav1.ObjectMeta{ - Name: instance.Name + "-deployment", - Namespace: instance.Namespace, - }, - Spec: appsv1.DeploymentSpec{ - Selector: &metav1.LabelSelector{ - MatchLabels: map[string]string{"deployment": instance.Name + "-deployment"}, - }, - Template: corev1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{"deployment": instance.Name + "-deployment"}}, - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "nginx", - Image: "nginx", - }, - }, - }, - }, - }, - } - if err := controllerutil.SetControllerReference(instance, deploy, r.scheme); err != nil { - return reconcile.Result{}, err - } - - // TODO(user): Change this for the object type created by your controller - // Check if the Deployment already exists - found := &appsv1.Deployment{} - err = r.Get(context.TODO(), types.NamespacedName{Name: deploy.Name, Namespace: deploy.Namespace}, found) - if err != nil && errors.IsNotFound(err) { - log.Info("Creating Deployment", "namespace", deploy.Namespace, "name", deploy.Name) - err = r.Create(context.TODO(), deploy) - return reconcile.Result{}, err - } else if err != nil { - return reconcile.Result{}, err - } - - // TODO(user): Change this for the object type created by your controller - // Update the found object and write the result back if there are any changes - if !reflect.DeepEqual(deploy.Spec, found.Spec) { - found.Spec = deploy.Spec - log.Info("Updating Deployment", "namespace", deploy.Namespace, "name", deploy.Name) - err = r.Update(context.TODO(), found) - if err != nil { - return reconcile.Result{}, err - } - } - return reconcile.Result{}, nil + return r.handle(instance) } diff --git a/pkg/controller/v1alpha3/suggestion/util.go b/pkg/controller/v1alpha3/suggestion/util.go new file mode 100644 index 00000000000..d8651695542 --- /dev/null +++ b/pkg/controller/v1alpha3/suggestion/util.go @@ -0,0 +1,30 @@ +package suggestion + +import ( + suggestionsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/suggestion/v1alpha2" + appsv1 "k8s.io/api/apps/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func getDesiredDeployment(instance *suggestionsv1alpha2.Suggestion) (*appsv1.Deployment, error) { + deploy := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: instance.Name + "-deployment", + Namespace: instance.Namespace, + }, + Spec: appsv1.DeploymentSpec{ + Replicas: instance.Spec.Replicas, + // do we need that? + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{"deployment": instance.Name + "-deployment"}, + }, + Template: instance.Spec.Template, + }, + } + if deploy.Spec.Template.ObjectMeta.Labels == nil { + deploy.Spec.Template.ObjectMeta.Labels = make(map[string]string) + } + deploy.Spec.Template.ObjectMeta.Labels["deployment"] = instance.Name + "-deployment" + + return deploy, nil +} diff --git a/pkg/controller/v1alpha3/util/util.go b/pkg/controller/v1alpha3/util/util.go index 84ab10d0711..246bf8ce196 100644 --- a/pkg/controller/v1alpha3/util/util.go +++ b/pkg/controller/v1alpha3/util/util.go @@ -8,6 +8,7 @@ const ( TypeTrial = "Trial" TypeTFJob = "TFJob" TypePyTorchJob = "PyTorchJob" + TypeDeployment = "Deployment" LabelExperiment = "experiment-name" LabelTrial = "trial-name" diff --git a/pkg/ui/ui.go b/pkg/ui/ui.go index 239ec288990..acfe83c7abd 100644 --- a/pkg/ui/ui.go +++ b/pkg/ui/ui.go @@ -375,7 +375,7 @@ func (k *KatibUIHandler) MetricsCollectorTemplate(w http.ResponseWriter, r *http MetricsCollectorTemplate map[string]string } mtv := MetricsCollectorTemplateView{ - IDList: &IDList{}, + IDList: &IDList{}, MetricsCollectorTemplate: mt, } if err != nil {