Skip to content

Fail when metric given as function #18

@mpekalski

Description

@mpekalski

I have a model that I compile with

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=[matthews_correlation])

and it fails after the first epoch

17/18 [===========================>..] - ETA: 0s - loss: 14.9240 - <function matthews_correlation at 0x7f61611d0c80>: 0.0000e+00Traceback (most recent call last):
  File "<stdin>", line 19, in <module>
  File "/opt/conda/lib/python3.6/site-packages/importance_sampling/training.py", line 137, in fit
    on_scores=on_scores
  File "/opt/conda/lib/python3.6/site-packages/importance_sampling/training.py", line 289, in fit_dataset
    batch_size=batch_size
  File "/opt/conda/lib/python3.6/site-packages/importance_sampling/model_wrappers.py", line 75, in evaluate
    for xi, yi in self._iterate_batches(x, y, batch_size)
  File "/opt/conda/lib/python3.6/site-packages/importance_sampling/model_wrappers.py", line 75, in <listcomp>
    for xi, yi in self._iterate_batches(x, y, batch_size)
  File "/opt/conda/lib/python3.6/site-packages/importance_sampling/model_wrappers.py", line 298, in evaluate_batch
    print(len(outputs))
  File "/opt/conda/lib/python3.6/site-packages/numpy/core/shape_base.py", line 288, in hstack
    return _nx.concatenate(arrs, 1)
ValueError: all the input arrays must have same number of dimensions

when I run with 'accuracy' as a metric e.g.

model.compile(loss='binary_crossentropy', optimizer='adam',metrics=['accuracy'])

everything is fine.

I've tried it even with a dummy metric

def test_metric(x,y):
    return tf.constant(1.0, dtype=tf.float32)

and it also fails.

I am using CUDA 10.0, TF '1.13.0-rc1', Keras 2.2.4 and latest keras-imporatance-sampling installed via pip. Without ImportanceTraning everything runs fine.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions