Skip to content

Commit 5872274

Browse files
committed
Image classification model development: added several new training scripts including one which iterates multiple times per model and/or over multiple filter counts, transfer learning, transfer fine tuning and knowledge distillation, modularized existing ResNet-style model to include argument for first-stage filter count, updated pip requirements and readme for TF2.14
1 parent e4eac49 commit 5872274

File tree

9 files changed

+1585
-55
lines changed

9 files changed

+1585
-55
lines changed

benchmark/training/image_classification/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Dataset: Cifar10
1313
Run the following commands to go through the whole training and validation process
1414

1515
``` Bash
16-
# Prepare Python venv (Python 3.7+ and pip>20 required)
16+
# Prepare Python venv (TF2.14, Python 3.11 and pip 26 required)
1717
./prepare_training_env.sh
1818

1919
# Download training, train model, test the model

benchmark/training/image_classification/download_cifar10_train_resnet.sh

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
44
tar -xvf cifar-10-python.tar.gz
55

66
# load performance subset
7-
. venv/bin/activate
7+
# . venv/bin/activate
88
python3 perf_samples_loader.py
99

1010
# train ans test the model

benchmark/training/image_classification/keras_model.py

Lines changed: 165 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212

1313
import tensorflow as tf
1414
from tensorflow.keras.models import Model
15+
from tensorflow.keras.applications import EfficientNetV2S
1516
from tensorflow.keras.layers import Input, Dense, Activation, Flatten, BatchNormalization
16-
from tensorflow.keras.layers import Conv2D, AveragePooling2D, MaxPooling2D
17+
from tensorflow.keras.layers import Conv2D, AveragePooling2D, MaxPooling2D, Resizing
1718
from tensorflow.keras.regularizers import l2
1819

1920
#get model
@@ -29,12 +30,14 @@ def get_quant_model_name():
2930
else:
3031
return "pretrainedResnet"
3132

32-
#define model
33-
def resnet_v1_eembc():
33+
#define models
34+
35+
# 200k params
36+
def resnet_v1_eembc(conv_filters=26):
3437
# Resnet parameters
3538
input_shape=[32,32,3] # default size for cifar10
3639
num_classes=10 # default class number for cifar10
37-
num_filters = 16 # this should be 64 for an official resnet model
40+
num_filters = conv_filters # this should be 64 for an official resnet model
3841

3942
# Input layer, change kernel size to 7x7 and strides to 2 for an official resnet
4043
inputs = Input(shape=input_shape)
@@ -76,7 +79,7 @@ def resnet_v1_eembc():
7679
# Second stack
7780

7881
# Weight layers
79-
num_filters = 32 # Filters need to be double for each stack
82+
num_filters = conv_filters * 2 # Filters need to be double for each stack
8083
y = Conv2D(num_filters,
8184
kernel_size=3,
8285
strides=2,
@@ -109,7 +112,7 @@ def resnet_v1_eembc():
109112
# Third stack
110113

111114
# Weight layers
112-
num_filters = 64
115+
num_filters = conv_filters * 4
113116
y = Conv2D(num_filters,
114117
kernel_size=3,
115118
strides=2,
@@ -144,7 +147,7 @@ def resnet_v1_eembc():
144147
# Uncomments to use it
145148

146149
# # Weight layers
147-
# num_filters = 128
150+
# num_filters = conv_filters * 8
148151
# y = Conv2D(num_filters,
149152
# kernel_size=3,
150153
# strides=2,
@@ -185,3 +188,158 @@ def resnet_v1_eembc():
185188
# Instantiate model.
186189
model = Model(inputs=inputs, outputs=outputs)
187190
return model
191+
192+
# EffectiveNet V2S
193+
def effnet_v2s(transfer=True):
194+
# EffNet parameters
195+
input_shape=[224,224,3] # default size for cifar10
196+
num_classes=10 # default class number for cifar10
197+
198+
# Input layer
199+
inputs = Input(shape=input_shape)
200+
# x = Resizing(224, 224)(inputs)
201+
effnet = tf.keras.applications.EfficientNetV2S(include_top=False, weights="imagenet", include_preprocessing=True)
202+
if transfer:
203+
for layer in effnet.layers:
204+
layer.trainable = False
205+
x = effnet(inputs)
206+
207+
# Final classification layer.
208+
pool_size = int(np.amin(x.shape[1:3]))
209+
x = AveragePooling2D(pool_size=pool_size)(x)
210+
y = Flatten()(x)
211+
outputs = Dense(num_classes,
212+
activation='softmax',
213+
kernel_initializer='he_normal')(y)
214+
215+
# Instantiate model.
216+
model = Model(inputs=inputs, outputs=outputs)
217+
return model
218+
219+
# class Distiller(tf.keras.Model):
220+
# def __init__(self, student, teacher):
221+
# super().__init__()
222+
# self.teacher = teacher
223+
# self.student = student
224+
225+
# def compile(
226+
# self,
227+
# optimizer,
228+
# metrics,
229+
# student_loss_fn,
230+
# distillation_loss_fn,
231+
# alpha=0.1,
232+
# temperature=3,
233+
# ):
234+
# """Configure the distiller.
235+
236+
# Args:
237+
# optimizer: Keras optimizer for the student weights
238+
# metrics: Keras metrics for evaluation
239+
# student_loss_fn: Loss function of difference between student
240+
# predictions and ground-truth
241+
# distillation_loss_fn: Loss function of difference between soft
242+
# student predictions and soft teacher predictions
243+
# alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
244+
# temperature: Temperature for softening probability distributions.
245+
# Larger temperature gives softer distributions.
246+
# """
247+
# super().compile(optimizer=optimizer, metrics=metrics)
248+
# self.student_loss_fn = student_loss_fn
249+
# self.distillation_loss_fn = distillation_loss_fn
250+
# self.alpha = alpha
251+
# self.temperature = temperature
252+
253+
# def compute_loss(
254+
# self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
255+
# ):
256+
# teacher_pred = self.teacher(x, training=False)
257+
# student_loss = self.student_loss_fn(y, y_pred)
258+
259+
# distillation_loss = self.distillation_loss_fn(
260+
# tf.nn.softmax(teacher_pred / self.temperature, axis=1),
261+
# tf.nn.softmax(y_pred / self.temperature, axis=1),
262+
# ) * (self.temperature**2)
263+
264+
# loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
265+
# return loss
266+
267+
# def call(self, x):
268+
# return self.student(x)
269+
270+
# Reference:
271+
# https://keras.io/examples/vision/knowledge_distillation/
272+
class Distiller(tf.keras.Model):
273+
def __init__(self, student, teacher, batch_size):
274+
super(Distiller, self).__init__()
275+
self.student = student
276+
self.teacher = teacher
277+
self.batch_size = batch_size
278+
279+
def compile(
280+
self,
281+
optimizer,
282+
metrics,
283+
distillation_loss_fn,
284+
temperature=2,
285+
):
286+
super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
287+
self.distillation_loss_fn = distillation_loss_fn
288+
self.temperature = temperature
289+
290+
def train_step(self, data):
291+
# Unpack data
292+
x, _ = data
293+
294+
# Forward pass of teacher
295+
teacher_predictions = self.teacher(x, training=False)
296+
297+
with tf.GradientTape() as tape:
298+
# Forward pass of student
299+
student_predictions = self.student(x, training=True)
300+
301+
# Compute loss
302+
distillation_loss = self.distillation_loss_fn(
303+
teacher_predictions / self.temperature,
304+
student_predictions / self.temperature
305+
)
306+
distillation_loss = tf.nn.compute_average_loss(distillation_loss,
307+
global_batch_size=self.batch_size)
308+
309+
# Compute gradients
310+
trainable_vars = self.student.trainable_variables
311+
gradients = tape.gradient(distillation_loss, trainable_vars)
312+
313+
# Update weights
314+
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
315+
316+
# Report progress
317+
results = {m.name: m.result() for m in self.metrics}
318+
results.update(
319+
{"distillation_loss": distillation_loss}
320+
)
321+
return results
322+
323+
def test_step(self, data):
324+
# Unpack data
325+
x, y = data
326+
327+
# Forward pass of teacher
328+
teacher_predictions = self.teacher(x, training=False)
329+
student_predictions = self.student(x, training=False)
330+
331+
# Calculate the loss
332+
distillation_loss = self.distillation_loss_fn(
333+
teacher_predictions / self.temperature,
334+
student_predictions / self.temperature
335+
)
336+
distillation_loss = tf.nn.compute_average_loss(distillation_loss,
337+
global_batch_size=self.batch_size)
338+
339+
# Report progress
340+
self.compiled_metrics.update_state(y, student_predictions)
341+
results = {m.name: m.result() for m in self.metrics}
342+
results.update(
343+
{"distillation_loss": distillation_loss}
344+
)
345+
return results

0 commit comments

Comments
 (0)