diff --git a/inception_score/model.py b/inception_score/model.py index bf69254..cb9a642 100644 --- a/inception_score/model.py +++ b/inception_score/model.py @@ -41,6 +41,8 @@ def get_inception_score(images, splits=10): inp = inps[(i * bs):min((i + 1) * bs, len(inps))] inp = np.concatenate(inp, 0) pred = sess.run(softmax, {'ExpandDims:0': inp}) + # remove extra 8 classes from inception model + pred = pred[:, :1000] preds.append(pred) preds = np.concatenate(preds, 0) scores = []