Skip to content

实际预测的时候图注意力层报错。 #2

@leiyinghanguang

Description

@leiyinghanguang

for train_index, test_index in kfold.split(B, y1):
print('train_index',train_index)
print('test_index',test_index)
train_index, val_index = train_test_split(
train_index, test_size=0.1, random_state=seed_value)

mask_train = np.zeros(N, dtype=bool)
mask_val = np.zeros(N, dtype=bool)
mask_test = np.zeros(N, dtype=bool)
mask_train[train_index] = True
mask_val[val_index] = True
mask_test[test_index] = True

checkpoint_path = './model/checkpoint-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

if os.path.exists(checkpoint_dir):
    shutil.rmtree(checkpoint_dir)

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                  patience=5,
                                                  mode='min')

best_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     monitor='val_loss',
                                                     verbose=1,
                                                     save_best_only=True,
                                                     save_weights_only=True,
                                                     mode='auto')

model = perCLTV(timestep=timestep, behavior_maxlen=maxlen)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
              loss={'output_1': tf.keras.losses.BinaryCrossentropy(),
                    'output_2': tf.keras.losses.MeanSquaredError()},
              loss_weights={'output_1': beta1, 'output_2': beta2},
              metrics={'output_1': tf.keras.metrics.AUC(),
                       'output_2': 'mae'})

model.fit([B, C, P, A], [y1, y2],
          validation_data=([B, C, P, A], [y1, y2], mask_val),
          sample_weight=mask_train,
          batch_size=N,
          epochs=1,
          shuffle=False,
          callbacks=[early_stopping, best_checkpoint],
          verbose=1)


# print('A:',A)
# print('B[0,:]:', B[0,:])



predictions = model.predict([B, C, P,A])
# predictions = model.predict([B[mask_val], C[mask_val], P[mask_val], A[mask_val]])
print('predictions:',predictions)                 Traceback (most recent call last):

File "J:\MetajoyAlogrithm\perCLTV-master\main.py", line 133, in
predictions = model.predict([B, C, P,A])
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\tensorflow\python\eager\execute.py", line 52, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'per_cltv/social_behavior_net/gat_conv/GatherV2' defined at (most recent call last):
File "J:\MetajoyAlogrithm\perCLTV-master\main.py", line 133, in
predictions = model.predict([B, C, P,A])
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2382, in predict
tmp_batch_outputs = self.predict_function(iterator)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2169, in predict_function
return step_function(self, iterator)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2155, in step_function
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2143, in run_step
outputs = model.predict_step(data)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 2111, in predict_step
return self(x, training=False)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 558, in call
return super().call(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\base_layer.py", line 1145, in call
outputs = call_fn(inputs, *args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm\perCLTV-master\src\model.py", line 75, in call
O = self.social_behavior_net([X, A])
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\training.py", line 558, in call
return super().call(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\base_layer.py", line 1145, in call
outputs = call_fn(inputs, *args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\sequential.py", line 427, in call
outputs = layer(inputs, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\engine\base_layer.py", line 1145, in call
outputs = call_fn(inputs, *args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
return fn(*args, **kwargs)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\conv.py", line 167, in _inner_check_dtypes
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\gat_conv.py", line 168, in call
if mode == modes.SINGLE and K.is_sparse(a):
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\gat_conv.py", line 169, in call
output, attn_coef = self._call_single(x, a)
File "J:\MetajoyAlogrithm.venv2\Lib\site-packages\spektral\layers\convolutional\gat_conv.py", line 213, in _call_single
attn_for_self = tf.gather(attn_for_self, targets)
Node: 'per_cltv/social_behavior_net/gat_conv/GatherV2'
indices[3] = 33 is not in [0, 32)
[[{{node per_cltv/social_behavior_net/gat_conv/GatherV2}}]] [Op:__inference_predict_function_29622]
报错,这一行 predictions = model.predict([B, C, P,A])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions