Skip to content

How to use in 3d array? #223

@wuqingle

Description

@wuqingle

Hello, I tried to input three-dimensional data into the model, and replaced the dense layer with t3f.nn.KerasDense, but a None dimension was missing. How to use t3f.nn.KerasDense in three-dimensional data?

_import numpy as np
from keras.models import Sequential
from keras.layers import Conv1D, MaxPooling1D, Dropout, Dense, Flatten, Reshape

from keras.layers import Reshape

This generates some test sample for me to check your code

X_train = np.random.rand(100, 4, 400)
Y_train = np.random.rand(100, 2)

model = Sequential()

model.add(Conv1D(32, 3, activation='relu', input_shape=(4, 400)))

model.add(MaxPooling1D(2))

model.add(Dropout(0.5))

model.add(Flatten()) # <- You need a flatten here

tt_layer = t3f.nn.KerasDense(input_dims=[4, 4, 2, 1], output_dims=[4, 4, 2, 1],
tt_rank=16, activation='relu')
model.add(tt_layer)#

model.add(Dense(32, activation='relu'))

model.add(Reshape((1,32)))

model.add(Flatten())
model.add(Dense(2, activation='sigmoid')) # <- the last dense must have output 2

model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.summary()
model.fit(X_train, Y_train, batch_size=16, epochs=10)

_
model summary is

Model: "sequential_11"


Layer (type) Output Shape Param #

conv1d_10 (Conv1D) (None, 2, 32) 38432

tt_dense_10 (KerasDense) (2, 32) 5424

flatten_7 (Flatten) (2, 32) 0

dense_9 (Dense) (2, 2) 66

=================================================================
Total params: 43922 (171.57 KB)
Trainable params: 43922 (171.57 KB)
Non-trainable params: 0 (0.00 Byte)

An error occurred in model.fit
File "/usr/local/lib/python3.10/dist-packages/t3f/ops.py", line 231, in tt_dense_matmul
Input to reshape is a tensor with 1024 values, but the requested shape has 64
[[{{node sequential_11/tt_dense_10/t3f_matmul/Reshape_4}}]] [Op:__inference_train_function_9879]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions