Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 0 additions & 66 deletions .github/workflows/runner_requirements.txt

This file was deleted.

4 changes: 2 additions & 2 deletions RealtimeClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tensorflow as tf
import time

from tensorflow.keras.models import load_model
from tensorflow.keras.models import load_model # type: ignore
from include.Logger import Logger

class LiveCameraClassifier:
Expand Down Expand Up @@ -56,6 +56,6 @@ def run(self):
# Example Usage
if __name__ == "__main__":
class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
model = load_model("models/batch_norm_model.h5")
model = load_model("models/batch_norm_model_rmsprop.keras")
classifier = LiveCameraClassifier(model, class_names)
classifier.run()
26 changes: 15 additions & 11 deletions Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
from include.TensorModel import TensorModel
from include.ModelProfiler import ModelProfiler

NUM_EPOCHS = 5
NUM_EPOCHS = 10
BATCH_FITTING = 128
BATCH_PROFILING = [32, 64, 128]

MODELS = ["base_model", "batch_norm_model", "batch_norm_model_sgd", "batch_norm_model_rmsprop"]
USE_EXISTING_MODELS = True

SAVE_MODELS = True
SAVE_MODELS_AS_H5 = True
SAVE_MODELS_AS_H5 = False
SAVE_MODELS_AS_KERAS = True
SAVE_MODELS_AS_SavedModel = True
SAVE_MODELS_AS_SavedModel = False

PROFILE_MODELS = True

def create_predicition_matrix(model_handler: TensorModel, visualiser: Visualiser, model, x_test, y_test, str_model):
conf_matrix = model_handler.compute_confusion_matrix(model, x_test, y_test)
Expand All @@ -28,9 +30,9 @@ def create_predicition_matrix(model_handler: TensorModel, visualiser: Visualiser

def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualiser, logger: Logger, x_train, y_train, x_test, y_test, batch_size) -> tuple:
# Check if model exists
if USE_EXISTING_MODELS and os.path.exists(f"models/{model_name}.h5"):
model = model = tf.keras.models.load_model(f"models/{model_name}.h5")
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
if USE_EXISTING_MODELS and os.path.exists(f"models/{model_name}.keras"):
model = model = tf.keras.models.load_model(f"models/{model_name}.keras")
model.compile(optimizer="sgd", loss="categorical_crossentropy", metrics=["accuracy"])
model.build()
model.summary()
else:
Expand All @@ -39,8 +41,8 @@ def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualis
else:
model = model_handler.create_cnn(batch_normalisation=True)

history = model.fit(x_train, y_train, epochs=NUM_EPOCHS, batch_size=batch_size, validation_data=(x_test, y_test))
test_loss, test_acc = model.evaluate(x_test, y_test)
history = model.fit(x_train, y_train, steps_per_epoch=x_train.shape[0], epochs=NUM_EPOCHS, batch_size=batch_size, validation_data=(x_test, y_test))
test_loss, test_acc = model.evaluate(x_test, y_test, batch_size=batch_size)

if SAVE_MODELS:
if SAVE_MODELS_AS_H5:
Expand All @@ -52,7 +54,7 @@ def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualis
if SAVE_MODELS_AS_SavedModel:
model.export(f"models/{model_name}_saved_model")

logger.info(f"Model accuracy: {test_acc * 100:.2f}%")
logger.info(f"{model_name} Model accuracy: {test_acc * 100:.2f}%")
visualiser.plot_training_history(history, model_name)
return (model, model_name), (test_acc)

Expand All @@ -72,7 +74,7 @@ def profile_models(model_acc_results:dict, visualiser: Visualiser, logger: Logge
# Load models
for model_name, accuracy in model_acc_results.items():
# Load the model and data
model = model_handler.load_model(f"models/{model_name}.h5")
model = model_handler.load_model(f"models/{model_name}.keras")
(_, _), (x_test, _) = model_handler.load_data()
(batch_time, throughput_time), (single_image_time) = profiler.measure_average_inference_time(batch_size, model, x_test, show_single_image_inference=True)

Expand Down Expand Up @@ -108,6 +110,8 @@ def profile_models(model_acc_results:dict, visualiser: Visualiser, logger: Logge

# Train and profile models
model_acc_results = train_models(model_handler, visualiser, logger, x_train, y_train, x_test, y_test, model_acc_results)
profile_models(model_acc_results, visualiser, logger)

if PROFILE_MODELS:
profile_models(model_acc_results, visualiser, logger)

logger.info("Done!")
Binary file modified images/Inference/base_model/base_model_128_inference_timings.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/Inference/base_model/base_model_32_inference_timings.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/Inference/base_model/base_model_64_inference_timings.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/Training/base_model/base_model_training_history.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/augmented_sample_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 40 additions & 7 deletions include/CNNBuilder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
from tensorflow.keras import layers, models, optimizers, callbacks # type: ignore
from include.Logger import Logger
from tensorflow.keras import layers, models, optimizers, callbacks, regularizers # type: ignore

class CustomEarlyStopping(callbacks.Callback):
def __init__(self, target_accuracy=0.95, patience=5, restore_best_weights=True):
super(CustomEarlyStopping, self).__init__()
self.target_accuracy = target_accuracy
self.patience = patience
self.restore_best_weights = restore_best_weights
self.best_accuracy = 0
self.wait = 0
self.logger = Logger(__name__)

def on_epoch_end(self, epoch, logs=None):
current_accuracy = logs["val_accuracy"]

# Check the current accuracy against the best accuracy
if current_accuracy > self.best_accuracy:
self.best_accuracy = current_accuracy
self.wait = 0
else:
self.wait += 1
self.logger.debug(f"Accuracy: {current_accuracy}, Best Accuracy: {self.best_accuracy}, Wait: {self.wait}")

if self.wait >= self.patience:
self.model.stop_training = True
self.logger.info(f"Early stopping triggered at epoch {epoch}")

if self.restore_best_weights:
self.model.set_weights(self.best_weights)
self.logger.info("Restored best weights")

class CNNBuilder:
def __init__(self, input_shape):
self.model = models.Sequential()
self.model.add(layers.InputLayer(input_shape=input_shape))
self.logger = Logger(__name__)

def add_conv_layer(self, filters, kernel_size, activation="relu", padding="same", kernel_reguliser=None):
self.model.add(layers.Conv2D(filters, kernel_size, padding=padding, kernel_regularizer=kernel_reguliser, kernel_initializer="he_normal"))
def add_conv_layer(self, filters, kernel_size, activation="relu", padding="same", kernel_reguliser=None, weight_decay=1e-4):
kernel_reguliser = kernel_reguliser if kernel_reguliser else None
self.model.add(layers.Conv2D(filters, kernel_size, strides=1, padding=padding, kernel_regularizer=regularizers.l2(weight_decay)))
self.model.add(layers.BatchNormalization()) # BatchNorm before activation
self.model.add(layers.Activation(activation)) # Separate activation
return self
Expand Down Expand Up @@ -68,15 +98,18 @@ def compile_model(self, optimiser="adam", learning_rate=0.001, decay_factor=0.95
self.lr_warmup = None
if use_lr_warmup:
def lr_schedule(epoch, lr):
if epoch < 5: # Warm-up for first 5 epochs
return lr * (epoch + 1) / 5
return lr
if epoch < 10:
return 0.001
elif epoch < 30:
return 0.0005
else:
return 0.0002
self.lr_warmup = callbacks.LearningRateScheduler(lr_schedule)

# Early Stopping
self.early_stopping = None
if use_early_stopping:
self.early_stopping = callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True)
self.early_stopping = CustomEarlyStopping(target_accuracy=0.95, patience=5, restore_best_weights=True)

# Get all callbacks
self.callbacks_list = self.get_callbacks()
Expand Down
21 changes: 12 additions & 9 deletions include/TensorModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def get_augmentation(self) -> tf.keras.Sequential:
tf.keras.layers.RandomFlip("horizontal"), # Randomly flip images horizontally
tf.keras.layers.RandomRotation(0.1), # Randomly rotate image up to 10%
tf.keras.layers.RandomZoom(0.1), # Randomly zoom image up to 10%
tf.keras.layers.RandomContrast(0.1), # Randomly adjust contrast up to 10%
tf.keras.layers.Lambda(lambda x: tf.image.random_brightness(x, max_delta=0.2)), # Randomly adjust brightness
tf.keras.layers.Lambda(lambda x: tf.image.random_crop(x, size=[tf.shape(x)[0], tf.shape(x)[1] - 8, tf.shape(x)[2] - 8, tf.shape(x)[3]])), # Cutout augmentation
tf.keras.layers.Lambda(lambda x: 0.5 * x + 0.5 * tf.roll(x, shift=1, axis=0)) # MixUp augmentation (approximate)
tf.keras.layers.RandomContrast(0.1) # Randomly adjust contrast up to 10%
])

return data_augmentation
Expand All @@ -68,13 +65,19 @@ def create_cnn(self, optimiser="adam", batch_normalisation=False, learning_rate=
if batch_normalisation:
model = (builder
.add_data_augmentation(data_augmentation)
.add_conv_layer(32, (3,3))
.add_pooling_layer()
.add_conv_layer(64, (3,3))
.add_conv_layer(64, (3, 3))
.add_pooling_layer()
.add_pooling_layer(strides=(2, 2))
.add_conv_layer(64, (3, 3))
.add_conv_layer(128, (3, 3))
.add_pooling_layer(strides=(2, 2))
.add_conv_layer(64, (3, 3))
.add_conv_layer(64, (3, 3))
.add_pooling_layer(strides=(2, 2))
.add_dropout(0.3)
.add_flaten_layer()
.add_dense_layer(128, activation="relu")
.add_dropout(0.5)
.add_dense_layer(256, activation="relu")
.add_dropout(0.3)
.add_dense_layer(10, activation="softmax")
.compile_model(optimiser=optimiser, learning_rate=learning_rate, decay_factor=decay_factor, use_lr_warmup=True, use_early_stopping=True)
.build()
Expand Down
Binary file removed models/base_model.h5
Binary file not shown.
Binary file modified models/base_model.keras
Binary file not shown.
1 change: 0 additions & 1 deletion models/base_model_saved_model/fingerprint.pb

This file was deleted.

Binary file removed models/base_model_saved_model/saved_model.pb
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed models/batch_norm_model.h5
Binary file not shown.
Binary file modified models/batch_norm_model.keras
Binary file not shown.
Binary file removed models/batch_norm_model_rmsprop.h5
Binary file not shown.
Binary file modified models/batch_norm_model_rmsprop.keras
Binary file not shown.
1 change: 0 additions & 1 deletion models/batch_norm_model_rmsprop_saved_model/fingerprint.pb

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 0 additions & 1 deletion models/batch_norm_model_saved_model/fingerprint.pb

This file was deleted.

Binary file removed models/batch_norm_model_saved_model/saved_model.pb
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed models/batch_norm_model_sgd.h5
Binary file not shown.
Binary file modified models/batch_norm_model_sgd.keras
Binary file not shown.
1 change: 0 additions & 1 deletion models/batch_norm_model_sgd_saved_model/fingerprint.pb

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.