1212
1313import tensorflow as tf
1414from tensorflow .keras .models import Model
15+ from tensorflow .keras .applications import EfficientNetV2S
1516from tensorflow .keras .layers import Input , Dense , Activation , Flatten , BatchNormalization
16- from tensorflow .keras .layers import Conv2D , AveragePooling2D , MaxPooling2D
17+ from tensorflow .keras .layers import Conv2D , AveragePooling2D , MaxPooling2D , Resizing
1718from tensorflow .keras .regularizers import l2
1819
1920#get model
@@ -29,12 +30,14 @@ def get_quant_model_name():
2930 else :
3031 return "pretrainedResnet"
3132
32- #define model
33- def resnet_v1_eembc ():
33+ #define models
34+
35+ # 200k params
36+ def resnet_v1_eembc (conv_filters = 26 ):
3437 # Resnet parameters
3538 input_shape = [32 ,32 ,3 ] # default size for cifar10
3639 num_classes = 10 # default class number for cifar10
37- num_filters = 16 # this should be 64 for an official resnet model
40+ num_filters = conv_filters # this should be 64 for an official resnet model
3841
3942 # Input layer, change kernel size to 7x7 and strides to 2 for an official resnet
4043 inputs = Input (shape = input_shape )
@@ -76,7 +79,7 @@ def resnet_v1_eembc():
7679 # Second stack
7780
7881 # Weight layers
79- num_filters = 32 # Filters need to be double for each stack
82+ num_filters = conv_filters * 2 # Filters need to be double for each stack
8083 y = Conv2D (num_filters ,
8184 kernel_size = 3 ,
8285 strides = 2 ,
@@ -109,7 +112,7 @@ def resnet_v1_eembc():
109112 # Third stack
110113
111114 # Weight layers
112- num_filters = 64
115+ num_filters = conv_filters * 4
113116 y = Conv2D (num_filters ,
114117 kernel_size = 3 ,
115118 strides = 2 ,
@@ -144,7 +147,7 @@ def resnet_v1_eembc():
144147 # Uncomments to use it
145148
146149# # Weight layers
147- # num_filters = 128
150+ # num_filters = conv_filters * 8
148151# y = Conv2D(num_filters,
149152# kernel_size=3,
150153# strides=2,
@@ -185,3 +188,158 @@ def resnet_v1_eembc():
185188 # Instantiate model.
186189 model = Model (inputs = inputs , outputs = outputs )
187190 return model
191+
192+ # EffectiveNet V2S
193+ def effnet_v2s (transfer = True ):
194+ # EffNet parameters
195+ input_shape = [224 ,224 ,3 ] # default size for cifar10
196+ num_classes = 10 # default class number for cifar10
197+
198+ # Input layer
199+ inputs = Input (shape = input_shape )
200+ # x = Resizing(224, 224)(inputs)
201+ effnet = tf .keras .applications .EfficientNetV2S (include_top = False , weights = "imagenet" , include_preprocessing = True )
202+ if transfer :
203+ for layer in effnet .layers :
204+ layer .trainable = False
205+ x = effnet (inputs )
206+
207+ # Final classification layer.
208+ pool_size = int (np .amin (x .shape [1 :3 ]))
209+ x = AveragePooling2D (pool_size = pool_size )(x )
210+ y = Flatten ()(x )
211+ outputs = Dense (num_classes ,
212+ activation = 'softmax' ,
213+ kernel_initializer = 'he_normal' )(y )
214+
215+ # Instantiate model.
216+ model = Model (inputs = inputs , outputs = outputs )
217+ return model
218+
219+ # class Distiller(tf.keras.Model):
220+ # def __init__(self, student, teacher):
221+ # super().__init__()
222+ # self.teacher = teacher
223+ # self.student = student
224+
225+ # def compile(
226+ # self,
227+ # optimizer,
228+ # metrics,
229+ # student_loss_fn,
230+ # distillation_loss_fn,
231+ # alpha=0.1,
232+ # temperature=3,
233+ # ):
234+ # """Configure the distiller.
235+
236+ # Args:
237+ # optimizer: Keras optimizer for the student weights
238+ # metrics: Keras metrics for evaluation
239+ # student_loss_fn: Loss function of difference between student
240+ # predictions and ground-truth
241+ # distillation_loss_fn: Loss function of difference between soft
242+ # student predictions and soft teacher predictions
243+ # alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
244+ # temperature: Temperature for softening probability distributions.
245+ # Larger temperature gives softer distributions.
246+ # """
247+ # super().compile(optimizer=optimizer, metrics=metrics)
248+ # self.student_loss_fn = student_loss_fn
249+ # self.distillation_loss_fn = distillation_loss_fn
250+ # self.alpha = alpha
251+ # self.temperature = temperature
252+
253+ # def compute_loss(
254+ # self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
255+ # ):
256+ # teacher_pred = self.teacher(x, training=False)
257+ # student_loss = self.student_loss_fn(y, y_pred)
258+
259+ # distillation_loss = self.distillation_loss_fn(
260+ # tf.nn.softmax(teacher_pred / self.temperature, axis=1),
261+ # tf.nn.softmax(y_pred / self.temperature, axis=1),
262+ # ) * (self.temperature**2)
263+
264+ # loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
265+ # return loss
266+
267+ # def call(self, x):
268+ # return self.student(x)
269+
270+ # Reference:
271+ # https://keras.io/examples/vision/knowledge_distillation/
272+ class Distiller (tf .keras .Model ):
273+ def __init__ (self , student , teacher , batch_size ):
274+ super (Distiller , self ).__init__ ()
275+ self .student = student
276+ self .teacher = teacher
277+ self .batch_size = batch_size
278+
279+ def compile (
280+ self ,
281+ optimizer ,
282+ metrics ,
283+ distillation_loss_fn ,
284+ temperature = 2 ,
285+ ):
286+ super (Distiller , self ).compile (optimizer = optimizer , metrics = metrics )
287+ self .distillation_loss_fn = distillation_loss_fn
288+ self .temperature = temperature
289+
290+ def train_step (self , data ):
291+ # Unpack data
292+ x , _ = data
293+
294+ # Forward pass of teacher
295+ teacher_predictions = self .teacher (x , training = False )
296+
297+ with tf .GradientTape () as tape :
298+ # Forward pass of student
299+ student_predictions = self .student (x , training = True )
300+
301+ # Compute loss
302+ distillation_loss = self .distillation_loss_fn (
303+ teacher_predictions / self .temperature ,
304+ student_predictions / self .temperature
305+ )
306+ distillation_loss = tf .nn .compute_average_loss (distillation_loss ,
307+ global_batch_size = self .batch_size )
308+
309+ # Compute gradients
310+ trainable_vars = self .student .trainable_variables
311+ gradients = tape .gradient (distillation_loss , trainable_vars )
312+
313+ # Update weights
314+ self .optimizer .apply_gradients (zip (gradients , trainable_vars ))
315+
316+ # Report progress
317+ results = {m .name : m .result () for m in self .metrics }
318+ results .update (
319+ {"distillation_loss" : distillation_loss }
320+ )
321+ return results
322+
323+ def test_step (self , data ):
324+ # Unpack data
325+ x , y = data
326+
327+ # Forward pass of teacher
328+ teacher_predictions = self .teacher (x , training = False )
329+ student_predictions = self .student (x , training = False )
330+
331+ # Calculate the loss
332+ distillation_loss = self .distillation_loss_fn (
333+ teacher_predictions / self .temperature ,
334+ student_predictions / self .temperature
335+ )
336+ distillation_loss = tf .nn .compute_average_loss (distillation_loss ,
337+ global_batch_size = self .batch_size )
338+
339+ # Report progress
340+ self .compiled_metrics .update_state (y , student_predictions )
341+ results = {m .name : m .result () for m in self .metrics }
342+ results .update (
343+ {"distillation_loss" : distillation_loss }
344+ )
345+ return results
0 commit comments