diff --git a/ignite/metrics/binary_accuracy.py b/ignite/metrics/binary_accuracy.py index e64a86d53eca..94176584681c 100644 --- a/ignite/metrics/binary_accuracy.py +++ b/ignite/metrics/binary_accuracy.py @@ -11,7 +11,8 @@ class BinaryAccuracy(Metric): Calculates the binary accuracy. `update` must receive output of the form (y_pred, y). - `y_pred` must be in the following shape (batch_size, ...) + `y_pred` must be in the following shape (batch_size, ...) and it's + elements must be between 0 and 1. `y` must be in the following shape (batch_size, ...) """ def reset(self): @@ -20,7 +21,7 @@ def reset(self): def update(self, output): y_pred, y = output - correct = torch.eq(torch.round(y_pred).type(torch.LongTensor), y).view(-1) + correct = torch.eq(torch.round(y_pred).type(y.type()), y).view(-1) self._num_correct += torch.sum(correct) self._num_examples += correct.shape[0]