-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_and_export.py
More file actions
27 lines (23 loc) · 972 Bytes
/
train_and_export.py
File metadata and controls
27 lines (23 loc) · 972 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# train_and_export.py
import numpy as np
from keras.datasets import imdb
from keras.models import Sequential
from keras.layers import Embedding, Flatten, Dense
from keras.preprocessing import sequence
import keras2onnx
import onnx
top_words = 5000
max_words = 500
(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=top_words)
X_train = sequence.pad_sequences(X_train, maxlen=max_words)
X_test = sequence.pad_sequences(X_test, maxlen=max_words)
model = Sequential()
model.add(Embedding(top_words, 32, input_length=max_words))
model.add(Flatten())
model.add(Dense(250, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=2, batch_size=128, validation_data=(X_test, y_test))
onnx_model = keras2onnx.convert_keras(model, model.name)
onnx.save_model(onnx_model, "imdb_sentiment.onnx")
print("✅ Model saved to imdb_sentiment.onnx")