-
Notifications
You must be signed in to change notification settings - Fork 8
Description
for train_index, test_index in kfold.split(B, y1):
print('train_index',train_index)
print('test_index',test_index)
train_index, val_index = train_test_split(
train_index, test_size=0.1, random_state=seed_value)
mask_train = np.zeros(N, dtype=bool)
mask_val = np.zeros(N, dtype=bool)
mask_test = np.zeros(N, dtype=bool)
mask_train[train_index] = True
mask_val[val_index] = True
mask_test[test_index] = True
checkpoint_path = './model/checkpoint-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
patience=5,
mode='min')
best_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
monitor='val_loss',
verbose=1,
save_best_only=True,
save_weights_only=True,
mode='auto')
model = perCLTV(timestep=timestep, behavior_maxlen=maxlen)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss={'output_1': tf.keras.losses.BinaryCrossentropy(),
'output_2': tf.keras.losses.MeanSquaredError()},
loss_weights={'output_1': beta1, 'output_2': beta2},
metrics={'output_1': tf.keras.metrics.AUC(),
'output_2': 'mae'})
model.fit([B, C, P, A], [y1, y2],
validation_data=([B, C, P, A], [y1, y2], mask_val),
sample_weight=mask_train,
batch_size=N,
epochs=1,
shuffle=False,
callbacks=[early_stopping, best_checkpoint],
verbose=1)
# print('A:',A)
# print('B[0,:]:', B[0,:])
predictions = model.predict([B, C, P,A])
# predictions = model.predict([B[mask_val], C[mask_val], P[mask_val], A[mask_val]])
print('predictions:',predictions) Traceback (most recent call last):
File "J:\MetajoyAlogrithm\perCLTV-master\main.py", line 133, in
predictions = model.predict([B, C, P,A])
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\tensorflow\python\eager\execute.py", line 52, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
Detected at node 'per_cltv/social_behavior_net/gat_conv/GatherV2' defined at (most recent call last):
File "J:\MetajoyAlogrithm\perCLTV-master\main.py", line 133, in
predictions = model.predict([B, C, P,A])
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2382, in predict
tmp_batch_outputs = self.predict_function(iterator)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2169, in predict_function
return step_function(self, iterator)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2155, in step_function
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2143, in run_step
outputs = model.predict_step(data)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2111, in predict_step
return self(x, training=False)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 558, in call
return super().call(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\base_layer.py", line 1145, in call
outputs = call_fn(inputs, *args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm\perCLTV-master\src\model.py", line 75, in call
O = self.social_behavior_net([X, A])
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 558, in call
return super().call(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\base_layer.py", line 1145, in call
outputs = call_fn(inputs, *args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\sequential.py", line 427, in call
outputs = layer(inputs, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\base_layer.py", line 1145, in call
outputs = call_fn(inputs, *args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\conv.py", line 167, in _inner_check_dtypes
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\gat_conv.py", line 168, in call
if mode == modes.SINGLE and K.is_sparse(a):
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\gat_conv.py", line 169, in call
output, attn_coef = self._call_single(x, a)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\gat_conv.py", line 213, in _call_single
attn_for_self = tf.gather(attn_for_self, targets)
Node: 'per_cltv/social_behavior_net/gat_conv/GatherV2'
indices[3] = 33 is not in [0, 32)
[[{{node per_cltv/social_behavior_net/gat_conv/GatherV2}}]] [Op:__inference_predict_function_29622]
报错,这一行 predictions = model.predict([B, C, P,A])