-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
88 lines (74 loc) · 2.22 KB
/
model.py
File metadata and controls
88 lines (74 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
train_dir = "dataset/full/train"
val_dir = "dataset/full/val"
img_size = (69, 69)
batch_size = 64
num_classes = 13
epochs = 9
model_path = "square_classifier.h5"
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
brightness_range=(0.8,1.2),
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=False,
vertical_flip=False
)
val_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=img_size,
batch_size=batch_size,
class_mode='categorical'
)
with open("class_indices.json", "w") as f:
json.dump(train_generator.class_indices, f)
print("Saved class indices:", train_generator.class_indices)
val_generator = val_datagen.flow_from_directory(
val_dir,
target_size=img_size,
batch_size=batch_size,
class_mode='categorical'
)
model = Sequential([
Input(shape=(69,69,3)),
# Conv Layer 1
Conv2D(16, (3,3), activation='relu', padding='same'),
BatchNormalization(),
MaxPooling2D(2,2),
Dropout(0.2),
# Conv Layer 2
Conv2D(32, (3,3), activation='relu', padding='same'),
BatchNormalization(),
MaxPooling2D(2,2),
Dropout(0.2),
Flatten(),
Dense(64, activation='relu'),
Dropout(0.3),
Dense(13, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
callbacks = [
ModelCheckpoint(model_path, monitor='val_accuracy', save_best_only=True, verbose=1),
EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
]
history = model.fit(
train_generator,
validation_data=val_generator,
epochs=epochs,
callbacks=callbacks
)
model.save(model_path)
print(f"Model saved to {model_path}")