diff --git a/demo/train.py b/demo/train.py index 5e7ac15f84..504a285665 100644 --- a/demo/train.py +++ b/demo/train.py @@ -90,7 +90,7 @@ def get_confusion_image(predictions, dataset): for n, pred in enumerate(predictions): actual = pred["actual"] predicted = pred["predicted"] - image = np.array(dataset[n]) / 255 + image = np.array(dataset[n][0]) / 255 confusion[(actual, predicted)] = image max_i, max_j = 0, 0