diff --git a/Makefile b/Makefile index 8968c04..6ad0e96 100644 --- a/Makefile +++ b/Makefile @@ -72,3 +72,31 @@ pypi: run_locally: @uvicorn api.fast:app --reload --host 0.0.0.0 + +# ---------------------------------- +# GCLOUD TRAINING +# ---------------------------------- + + +BUCKET_NAME=taxifare_bucket_fast-drake-318911 +BUCKET_TRAINING_FOLDER = 'trainings' +REGION=europe-west1 +PYTHON_VERSION=3.7 +RUNTIME_VERSION=1.15 + +PACKAGE_NAME=cookit +FILENAME=trainer + +JOB_NAME=cookit_training_$(shell date +'%Y%m%d_%H%M%S') + + + +gcp_submit_training: + gcloud ai-platform jobs submit training ${JOB_NAME} \ + --job-dir gs://${BUCKET_NAME}/${BUCKET_TRAINING_FOLDER} \ + --package-path ${PACKAGE_NAME} \ + --module-name ${PACKAGE_NAME}.${FILENAME} \ + --python-version=${PYTHON_VERSION} \ + --runtime-version=${RUNTIME_VERSION} \ + --region ${REGION} \ + --stream-logs diff --git a/api/fast.py b/api/fast.py index 266f931..c092bdf 100644 --- a/api/fast.py +++ b/api/fast.py @@ -33,7 +33,7 @@ def index(): # response = requests.post('http://localhost:8000/predict', files=files, data=payload) # --> make sure to close the opened file again or del the files dictionary in this example case! @app.post("/predict") -async def predict(image: UploadFile = File(...), threshold: float = Form(0.5)): +async def predict(image: UploadFile = File(...), threshold: float = Form(0.1)): """ Executes prediction based on sent file and threshold image: a file in multipart/form-data format threshold: a float value (default: 0.25) to filter predicted classes diff --git a/class_labels.pkl b/class_labels.pkl new file mode 100644 index 0000000..a13aa62 --- /dev/null +++ b/class_labels.pkl @@ -0,0 +1,128 @@ +(dp0 +I1 +Vlabel +p1 +sI2 +VPumpkin +p2 +sI3 +VIce cream +p3 +sI4 +VSalad +p4 +sI5 +VBread +p5 +sI6 +VCoconut +p6 +sI7 +VGrape +p7 +sI8 +VMushroom +p8 +sI9 +VHoneycomb +p9 +sI10 +VFish +p10 +sI11 +VOyster +p11 +sI12 +VPomegranate +p12 +sI13 +VRadish +p13 +sI14 +VWatermelon +p14 +sI15 +VPasta +p15 +sI16 +VCabbage +p16 +sI17 +VStrawberry +p17 +sI18 +VApple +p18 +sI19 +VOrange +p19 +sI20 +VPotato +p20 +sI21 +VBanana +p21 +sI22 +VPear +p22 +sI23 +VShellfish +p23 +sI24 +VTomato +p24 +sI25 +VCheese +p25 +sI26 +VCarrot +p26 +sI27 +VShrimp +p27 +sI28 +VLemon +p28 +sI29 +VArtichoke +p29 +sI30 +VBroccoli +p30 +sI31 +VBell pepper +p31 +sI32 +VPineapple +p32 +sI33 +VLobster +p33 +sI34 +VMilk +p34 +sI35 +VMango +p35 +sI36 +VGrapefruit +p36 +sI37 +VCantaloupe +p37 +sI38 +VPeach +p38 +sI39 +VCream +p39 +sI40 +VZucchini +p40 +sI41 +VCucumber +p41 +sI42 +VWinter melon +p42 +s. \ No newline at end of file diff --git a/cookit/data.py b/cookit/data.py index a52e8c9..fe61aa5 100644 --- a/cookit/data.py +++ b/cookit/data.py @@ -1,8 +1,142 @@ +import csv +import json +import math +import pandas as pd +from google.cloud import storage +from cookit.utils import OIv4_INGREDIENTS_ONLY, TEST_FOOD_CLASSES, OIv4_MIN_SET +from cookit.params import BUCKET_NAME -def get_data(): - pass + +def download_file_from_bucket(path='labels.json'): + client = storage.Client() + bucket = client.bucket(BUCKET_NAME) + blob = bucket.blob(path) + blob.download_to_filename(path) + + +def upload_file_to_bucket(path): + client = storage.Client() + bucket = client.bucket(BUCKET_NAME) + blob = bucket.blob(path) + blob.upload_from_filename(path) + + +def get_oi_dataset_df(path='oi_food.csv', nrows=1000): + """method to get the training data (or a portion of it) from google cloud bucket""" + df = pd.read_csv(f"gs://{BUCKET_NAME}/{path}", nrows=nrows) + return df + + +# Should we perhaps keep non-food-related labels in the test set? +def convert_oi_metadata(labelfile_path, baseurl=f'gs://{BUCKET_NAME}/data', csv_path='tf_training.csv', + test_split=0.2, val_split=0.1, train_classes=OIv4_INGREDIENTS_ONLY): + """ Converts a json file in format fiftyone.types.FiftyOneImageDetectionDataset + to a format as it is expected by the tflite_model_maker.object_detector + + labelfile_path: path to json file [String] + baseurl: url to files referenced in the json file [String] + csv_path: path to output csv file [String] + test_split: share of images for test [float] + val_split: share of images for validation [float] + + Returns dict containing basic dataset information + """ + with open(labelfile_path) as json_file: + oi_data = json.load(json_file) + + classes = oi_data['classes'] + food_classes = [] + + total_images = len(oi_data['labels']) + val_count = math.floor(total_images* val_split) + test_count = math.floor(total_images * test_split) + + bbox_count, uuid_count = 0, 0 + test_bboxes, val_bboxes, train_bboxes = 0, 0, 0 + + with open(csv_path, "w") as csv_file: + writer = csv.writer(csv_file, delimiter=',', dialect='excel') + line = ['set', 'path', 'label', 'x_min', 'y_min', 'x_max', 'y_min', 'x_max', 'y_max', 'x_min', 'y_max'] + writer.writerow(line) + for uuid, labels in oi_data['labels'].items(): + uuid_count += 1 + for label_items in labels: + label = classes[label_items['label']] + if label in train_classes: + food_classes.append(label) + bbox_count += 1 + bb = label_items['bounding_box'] + if uuid_count <= val_count: + split = "TEST" + val_bboxes += 1 + elif uuid_count >= val_count and uuid_count < int(test_count + val_count): + split = "VALIDATION" + test_bboxes += 1 + elif uuid_count >= int(test_count + val_count): + split = "TRAINING" + train_bboxes += 1 + line = [split, f"{baseurl}/{uuid}.jpg", label, bb[0], bb[1], "", "", + min(bb[0] + bb[2], 1.0), min(bb[1] + bb[3], 1.0), "", ""] # min() should actually not be neccesarry, but never say never.... + writer.writerow(line) + + unique_food_classes = list(sorted(set(food_classes))) + print(f"Found {len(unique_food_classes)} food-related classes of total {len(classes)} classes.") + print(f"Found {bbox_count} bounding boxes. Train/Val/Test split: {train_bboxes} / {val_bboxes} / {test_bboxes}") + print(f"Total number of images in dataset: {total_images}") + + return { + 'food_classes': unique_food_classes, + 'bbox_count': bbox_count, + 'train_bbox_count': train_bboxes, + 'val_bbox_count': val_bboxes, + 'test_bbox_count': test_bboxes, + 'images_count': total_images + } + + +def get_random_slice(csv_path=f'gs://{BUCKET_NAME}/oi_food.csv', + out_csv_name='oi_food_sample.csv', + size=1000, + return_df=True): + df = pd.read_csv(csv_path) + sample = df.sample(size) + sample.to_csv(out_csv_name, index=False) + upload_file_to_bucket(out_csv_name) + print(f"Uploaded CSV file containing {size} samples to gs://{BUCKET_NAME}/{out_csv_name}") + if return_df: + return sample + +def get_random_slice_balanced(csv_file='oi_food_minimal.csv', + out_csv_file='oi_food_minimal_balanced.csv', + label_size=50, + gcloud_upload=False): + df = get_oi_dataset_df(csv_file, nrows=100_000) + out_df = pd.DataFrame(columns=df.columns) + for label in df.label.unique(): + sample = df[(df.set == 'TRAINING') & (df.label == label)].sample(int(label_size*0.7), replace=True) + out_df = pd.concat([out_df, sample]) + sample = df[(df.set == 'VALIDATION') & (df.label == label)].sample(int(label_size*0.2), replace=True) + out_df = pd.concat([out_df, sample]) + sample = df[(df.set == 'TEST') & (df.label == label)].sample(int(label_size*0.1), replace=True) + out_df = pd.concat([out_df, sample]) + out_df.to_csv(out_csv_file, index=False) + if gcloud_upload: + upload_file_to_bucket(out_csv_file) + return out_df + + + + +def create_dataset(json_path='labels.json', csv_path='oi_food.csv', classes=OIv4_MIN_SET): + download_file_from_bucket(json_path) + ds = convert_oi_metadata(json_path, csv_path=csv_path, train_classes=classes) + upload_file_to_bucket(csv_path) + print(f"Uploaded new dataset to gs://{BUCKET_NAME}/{csv_path}") if __name__ == '__main__': - print("Nothing to do here...") + create_dataset() + print("Print 5 random samples") + print(get_oi_dataset_df('oi_food.csv', 5)) + #sample = get_random_slice('oi_food.csv', 'labels_slice.csv', 1000) diff --git a/cookit/params.py b/cookit/params.py index e69de29..fb86ec1 100644 --- a/cookit/params.py +++ b/cookit/params.py @@ -0,0 +1,2 @@ +PROJECT_NAME = 'fast-drake-318911' +BUCKET_NAME = 'taxifare_bucket_fast-drake-318911' diff --git a/cookit/predict.py b/cookit/predict.py index 300a5a7..e4bbdaa 100644 --- a/cookit/predict.py +++ b/cookit/predict.py @@ -2,7 +2,6 @@ import tensorflow as tf import tensorflow_hub as hub from PIL import Image, ImageOps -from cookit.data import get_data from cookit.utils import OIv4_FOOD_CLASSES os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' diff --git a/cookit/predict_lite.py b/cookit/predict_lite.py index e32aa75..c93ed95 100644 --- a/cookit/predict_lite.py +++ b/cookit/predict_lite.py @@ -1,8 +1,7 @@ import tensorflow as tf import numpy as np import os - -from cookit.data import get_data +import pickle from PIL import Image from cookit.utils import OIv4_FOOD_CLASSES @@ -15,15 +14,12 @@ def __init__(self): """ A basic call for predictions. """ - self.model = self._get_model() + self.model = self._get_model('oi_food_balanced_400_lite4_ll.tflite') + + with open('class_labels.pkl', 'rb') as f: + self.classes_map = pickle.load(f) + #self.classes = ['Baked Goods', 'Salad', 'Cheese', 'Seafood', 'Tomato'] - # NOTE: The order of this list hardcoded here, and needs to be changed when re-training the model! - # When exporting the model in tflite format, the model_spec is lost, so we cannot do it like that: - # classes = ['???'] * model.model_spec.config.num_classes - # label_map = model.model_spec.config.label_map - # for label_id, label_name in label_map.as_dict().items(): - # classes[label_id-1] = label_name - self.classes = ['Baked Goods', 'Salad', 'Cheese', 'Seafood', 'Tomato'] def _get_model(self, model_path='model.tflite'): """ Load the model from a local path """ @@ -59,10 +55,10 @@ def detect_objects(self, image): self.model.invoke() # Get all outputs from the model - boxes = self.get_output_tensor(0) - classes = self.get_output_tensor(1) - scores = self.get_output_tensor(2) - count = int(self.get_output_tensor(3)) + boxes = self.get_output_tensor(1) + classes = self.get_output_tensor(3) + scores = self.get_output_tensor(0) + count = int(self.get_output_tensor(2)) results = [] for i in range(count): @@ -86,7 +82,7 @@ def run_detection(self, image_path, threshold=0.5): results = self.detect_objects(preprocessed_image) return results - def predict(self, image_path, threshold=0.5): + def predict(self, image_path, threshold=0.1): detection_result_image = self.run_detection(image_path, threshold) print(f"Received file for prediction: {image_path}") @@ -96,7 +92,7 @@ def predict(self, image_path, threshold=0.5): for result in detection_result_image: if result['score'] > threshold: - res_class = self.classes[result['class_id']] + res_class = self.classes_map[result['class_id']] # do not return redundant ingredients if res_class not in ingredients: ingredients.append(res_class) diff --git a/cookit/trainer.py b/cookit/trainer.py index 196d3e4..1ca995d 100644 --- a/cookit/trainer.py +++ b/cookit/trainer.py @@ -1,57 +1,108 @@ -import joblib +import os +import pickle +import argparse +import tensorflow as tf +from tflite_model_maker.config import ExportFormat +from tflite_model_maker import model_spec +from tflite_model_maker import object_detector +import pycocotools # this is needed for model.evaluate, even if not explicitley used in code! from termcolor import colored -from sklearn.compose import ColumnTransformer -from sklearn.linear_model import LinearRegression -from sklearn.model_selection import train_test_split -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import OneHotEncoder, StandardScaler +from cookit.data import upload_file_to_bucket +from cookit.params import BUCKET_NAME + +# make tensorflow less verbose +tf.get_logger().setLevel('ERROR') +from absl import logging + +logging.set_verbosity(logging.ERROR) +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + -from cookit.data import get_data class Trainer(object): - def __init__(self, X, y): + def __init__(self, spec='efficientdet_lite0'): """ - X: pandas DataFrame - y: pandas Series """ - self.pipeline = None - self.X = X - self.y = y + self.model = None + self._spec = model_spec.get(spec) + self._cache_prefix = 'cookit_trainer' + + def load_data(self, csv_path=f'gs://{BUCKET_NAME}/oi_food_converted_sample.csv', + force_download=False): + cache_dir = csv_path.split('/')[-1].rstrip('.csv') + if not os.path.isdir(cache_dir) or force_download == True: + print(f"Downloading images from {csv_path}") - def set_pipeline(self): - """defines the pipeline as a class attribute""" - pass + data = object_detector.DataLoader.from_csv(csv_path, + cache_dir=cache_dir, + cache_prefix_filename=self._cache_prefix) + self.train_data = data[0] + self.val_data = data[1] + self.test_data = data[2] + self.label_map = self.train_data.label_map + else: + self.load_data_from_cache(cache_dir) - def run(self): + def load_data_from_cache(self, cache_dir): + print(f"Load images from {cache_dir}") + cache_prefix = f"{cache_dir}/train_{self._cache_prefix}" + self.train_data = object_detector.DataLoader.from_cache(cache_prefix) + cache_prefix = f"{cache_dir}/test_{self._cache_prefix}" + self.test_data = object_detector.DataLoader.from_cache(cache_prefix) + cache_prefix = f"{cache_dir}/val_{self._cache_prefix}" + self.val_data = object_detector.DataLoader.from_cache(cache_prefix) + self.label_map = self.train_data.label_map + + def run(self, epochs=50, batch_size=32, train_whole_model=True): """fits model""" - pass + self.model = object_detector.create(self.train_data, + epochs=epochs, + model_spec=self._spec, + batch_size=batch_size, + train_whole_model=train_whole_model, + validation_data=self.val_data) - def evaluate(self, X_test, y_test): + def evaluate(self): """evaluates the pipeline on df_test""" + eval_dict = self.model.evaluate(self.test_data) + print(f"Model evaluation: {eval_dict}") + return eval_dict - def save_model_locally(self): + def save_model_locally(self, model_name='model.tflite', label_filename='class_labels'): """Save the model into a .joblib format""" + # see https://www.tensorflow.org/lite/tutorials/model_maker_object_detection#export_to_different_formats + # model.export(export_dir='.', export_format=[ExportFormat.SAVED_MODEL, ExportFormat.LABEL]) + self.model.export(export_dir='.', + tflite_filename=model_name, + label_filename=label_filename, + #saved_model_filename=model_name, + #export_format=None, + ) + pickle_name = label_filename.split('.')[0] + '.pkl' + with open(f"{pickle_name}", 'wb') as f: + pickle.dump(self.label_map, f, 0) - joblib.dump(self.pipeline, 'model.joblib') - print(colored("model.joblib saved locally", "green")) - + print(colored(f"Saved trained model to {model_name}", "green")) + print(colored(f"Saved class labels to {pickle_name}", "green")) + return model_name, pickle_name if __name__ == "__main__": - # Get and clean data - df = get_data(nrows=1000) - - y = df["classes"] - X = df.drop("classes", axis=1) - - X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) + parser = argparse.ArgumentParser() + parser.add_argument("csv", help="Path to dataset CSV file") + parser.add_argument("model_name", help="Filename for trained model") + parser.add_argument("-s", "--spec", default='efficientdet_lite0', help="TF Lite Specification") + parser.add_argument("-e", "--epochs", type=int, default=50, help="Nr of epochs") + parser.add_argument("-b", "--batch_size", type=int, default=64, help="Batch size") + parser.add_argument("-w", "--train_whole_model", type=bool, default=True, help="Train whole model or only last layers") - # Train and save model, locally and - trainer = Trainer(X_train, y_train) - trainer.set_experiment_name('xp2') - trainer.run() - score = trainer.evaluate(X_test, y_test) - print(f"Score of model : {score}") + args = parser.parse_args() - trainer.save_model_locally() + trainer = Trainer(args.spec) + trainer.load_data(args.csv) + trainer.run(args.epochs, args.batch_size, args.train_whole_model) + eval_dict = trainer.evaluate() + model_name, pickle_name = trainer.save_model_locally(args.model_name + '.tflite') + upload_file_to_bucket(model_name) + upload_file_to_bucket(pickle_name) diff --git a/cookit/utils.py b/cookit/utils.py index 12232d8..dd1004a 100644 --- a/cookit/utils.py +++ b/cookit/utils.py @@ -1,17 +1,66 @@ OIv4_FOOD_CLASSES = [ 'Apple', 'Artichoke', 'Asparagus', 'Bagel', 'Baked goods', 'Banana', 'Beer', 'Bell pepper', 'Bread', 'Broccoli', 'Burrito', 'Cabbage', 'Cake', - 'Candy', 'Cantaloupe', 'Carrot', 'Cheese', 'Chicken', 'Chopsticks', + 'Candy', 'Cantaloupe', 'Carrot', 'Cheese', 'Chicken', 'Coconut', 'Coffee', 'Cookie', 'Crab', 'Cream', 'Croissant', 'Cucumber', 'Dairy', 'Dessert', 'Doughnut', 'Duck', 'Egg', 'Fast food', 'Fish', 'Food', - 'Food processor', 'French fries', 'Fruit', 'Goose', 'Grape', 'Grapefruit', + 'French fries', 'Fruit', 'Goose', 'Grape', 'Grapefruit', 'Guacamole', 'Hamburger', 'Honeycomb', 'Hot dog', 'Ice cream', 'Jellyfish', 'Juice', 'Lemon', 'Lobster', 'Mammal', 'Mango', 'Maple', 'Milk', 'Muffin', 'Mushroom', 'Orange', 'Oyster', 'Pancake', 'Pasta', 'Pastry', 'Peach', 'Pear', 'Pineapple', 'Pizza', 'Plant', 'Pomegranate', 'Popcorn', 'Potato', 'Pretzel', 'Pumpkin', 'Radish', 'Salad', 'Salt and pepper shakers', 'Sandwich', 'Seafood', 'Shellfish', 'Shrimp', 'Snail', 'Strawberry', - 'Submarine sandwich', 'Sushi', 'Taco', 'Tart', 'Tea', 'Tomato', 'Towel', + 'Submarine sandwich', 'Sushi', 'Taco', 'Tart', 'Tea', 'Tomato', 'Turkey', 'Vegetable', 'Waffle', 'Watermelon', 'Whisk', 'Wine', 'Winter melon', 'Zucchini' ] + +OIv4_INGREDIENTS_ONLY = [ + 'Apple', 'Artichoke', 'Asparagus', 'Banana', + 'Beer', 'Bell pepper', 'Bread', 'Broccoli', 'Burrito', 'Cabbage', + 'Cantaloupe', 'Carrot', 'Cheese', 'Chicken', 'Coconut', + 'Crab', 'Cream', 'Croissant', 'Cucumber', 'Dairy', + 'Duck', 'Egg', 'Fish', 'French fries', + 'Grape', 'Grapefruit', 'Guacamole', 'Hamburger', + 'Honeycomb', 'Ice cream', 'Jellyfish', 'Juice', 'Lemon', + 'Lobster', 'Mango', 'Maple', 'Milk', 'Muffin', 'Mushroom', + 'Orange', 'Oyster', 'Pancake', 'Pasta', 'Pastry', 'Peach', 'Pear', + 'Pineapple', 'Pizza', 'Pomegranate', 'Popcorn', 'Potato', + 'Pretzel', 'Pumpkin', 'Radish', 'Salad', + 'Sandwich', 'Seafood', 'Shellfish', 'Shrimp', 'Strawberry', + 'Submarine sandwich', 'Sushi', 'Taco', 'Tea', 'Tomato', 'Turkey', + 'Waffle', 'Watermelon','Wine', 'Winter melon', 'Zucchini' +] + +OIv4_MIN_SET = [ + 'Apple', 'Artichoke', 'Asparagus', 'Banana', 'Bell pepper', + 'Bread', 'Broccoli', 'Cabbage', 'Cantaloupe', 'Carrot', + 'Cheese', 'Coconut', 'Cream', 'Cucumber', 'Egg', 'Fish', 'Grape', + 'Grapefruit', 'Honeycomb', 'Ice cream', 'Lemon', 'Lobster', 'Mango', + 'Milk', 'Mushroom', 'Orange', 'Oyster', 'Pasta', 'Peach', 'Pear', + 'Pineapple', 'Pomegranate', 'Potato', + 'Pumpkin', 'Radish', 'Salad', 'Shellfish', 'Shrimp', + 'Strawberry', 'Tomato', 'Watermelon', 'Winter melon', 'Zucchini' +] + +TEST_FOOD_CLASSES = [ + 'Apple', 'Apricot', 'Artichoke', 'Asparagus', 'Aubergine', 'Avocado', + 'Banana', 'Basil', 'Basilicum', 'Bean', 'Beans', 'Beet', 'Beetroot', + 'Bell pepper', 'Bread', 'Broccoli', 'Brussel sprout', 'Brussels sprout', + 'Brussels sprouts', 'Butternut', 'Cabbage', 'Capsicum', 'Carrot', + 'Cauliflower', 'Celery', 'Cheese', 'Chickpea', 'Chili', 'Chilli', 'Chive', + 'Coconut milk', 'Coriander', 'Corn', 'Cream', 'Cucumber', 'Date', + 'Dragon fruit', 'Dried bean', 'Dried fruit', 'Egg', 'Emmental', 'Endive', + 'Fenel', 'Fennel', 'Fetta', 'Fig', 'Fish', 'Flour', 'Garlic', 'Ginger', + 'Grained cheese', 'Grape', 'Great beans', 'Green bean', 'Green beans', + 'Juice', 'Kaki', 'Kiwi', 'Leek', 'Lemon', 'Lentil', 'Lettuce', 'Litchi', + 'Mango', 'Meat', 'Melon', 'Milk', 'Mozzarella', 'Mushroom', 'Nut', 'Oat', + 'Oil', 'Olive', 'Olive oil', 'Onion', 'Orange', 'Papaya', 'Paprika', + 'Parmesan', 'Parsley', 'Pasta', 'Peanut', 'Peanut butter', 'Pear', 'Peas', + 'Pepper', 'Persley', 'Pineapple', 'Plum', 'Potato', 'Potatoe', 'Radish', + 'Rice', 'Romanesco', 'Runner bean', 'Salad', 'Salmon', 'Salt', 'Soya', + 'Spinach', 'Spring onion', 'Strawberry', 'Sunflower seeds', + 'Sweet potatoe', 'Tangerine', 'Tomato', 'Tomatoe', 'Tuna', 'Turmeric', + 'Turnip', 'Water', 'Water melon', 'Zucchini' +] diff --git a/metrics/metric.py b/metrics/metric.py index 97dbefa..d8666e5 100644 --- a/metrics/metric.py +++ b/metrics/metric.py @@ -5,7 +5,7 @@ import shutil import argparse from cookit.utils import OIv4_FOOD_CLASSES -from cookit.predict import Predictor +from cookit.predict_lite import Predictor oi_classes = [ingr.lower() for ingr in OIv4_FOOD_CLASSES] diff --git a/oi_food_balanced_400_lite4_ll.tflite b/oi_food_balanced_400_lite4_ll.tflite new file mode 100644 index 0000000..11ffc1c Binary files /dev/null and b/oi_food_balanced_400_lite4_ll.tflite differ diff --git a/requirements.txt b/requirements.txt index 4f0301d..4c21f74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,10 @@ wheel>=0.29 numpy pandas scikit-learn -tensorflow +tensorflow>=2.5.0 tensorflow_hub +tflite-model-maker +pycocotools # tests/linter black diff --git a/scripts/cookit-run b/scripts/cookit-run deleted file mode 100644 index faa18be..0000000 --- a/scripts/cookit-run +++ /dev/null @@ -1,2 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- diff --git a/scripts/train-cookit b/scripts/train-cookit new file mode 100644 index 0000000..681c4a2 --- /dev/null +++ b/scripts/train-cookit @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse +from cookit.data import get_random_slice +from cookit.utils import OIv4_MIN_SET +from cookit.trainer import Trainer + + + +# if __name__ == "__main__": +# parser = argparse.ArgumentParser(description='Train a TFLite Model for cookit') + +# parser.add_argument('csv_file', help='CSV File ') +# parser.add_argument('-s','--spec', help='TFLite model configuration String, see \ +# https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/object_detector', +# default='efficientdet_lite0') +# parser.add_argument('-e','--epochs', help='Description for bar argument', required=True) +# parser.add_argument('-b','--batch_size', help='Description for bar argument', required=True) diff --git a/setup.py b/setup.py index 59b2c0d..50dee06 100644 --- a/setup.py +++ b/setup.py @@ -13,5 +13,5 @@ test_suite='tests', # include_package_data: to install data from MANIFEST.in include_package_data=True, - scripts=['scripts/cookit-run'], + scripts=['scripts/train-cookit'], zip_safe=False)