-
-
Notifications
You must be signed in to change notification settings - Fork 656
Precision/Recall Fixes #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Precision/Recall Fixes #140
Conversation
tests/ignite/metrics/test_recall.py
Outdated
| y = torch.ones(2).type(torch.LongTensor) | ||
| recall.update((y_pred, y)) | ||
| result = list(recall.compute()) | ||
| # assert isnan(result[0]) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@jasonkriss what do you think about to add an option to compute the mean in the self.agg_fn = torch.mean if option else lambda x: x
def compute(self):
if self._all_positives is None:
raise NotComputableError('Precision must have at least one example before it can be computed')
return self.agg_fn(self._true_positives / (self._all_positives + eps)) |
|
@vfdev-5 I'd be ok with that. Would it be better to have it take an optional |
|
For precision and recall what other type of reducer functions would you want except for mean? If there are more than that, then I'm in favor of a constructor arg |
|
We choose the way of more flexibility and leave the user a choice of a custom function :) |
|
@vfdev-5 yep I'll add a commit to take care of the nans and warn. For the reduce argument, I think we should make the API as clear as possible. So unless there are clear use cases for other reductions, I'd prefer the flag. Similar to loss functions in pytorch. |
|
@vfdev-5 @alykhantejani made a couple updates if you don't mind taking a peek. |
|
@jasonkriss yes, I saw your commits, looks good for me! Maybe just a little question on performance of vs Sure, that if we have 10-100 classes, it wont be remarkable... |
|
Yea I don't think that will make much of a difference performance-wise (even with 1000s of classes). |
ignite/metrics/precision.py
Outdated
| raise NotComputableError('Precision must have at least one example before it can be computed') | ||
| return self._true_positives / self._all_positives | ||
| elif self._all_positives.eq(0.0).any(): | ||
| warnings.warn('Labels with no predicted examples are set to have precision of 0.0.', UndefinedMetricWarning) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ignite/metrics/precision.py
Outdated
| from .metric import Metric | ||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics.metric import Metric | ||
| from ignite.exceptions import NotComputableError, UndefinedMetricWarning |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ignite/metrics/recall.py
Outdated
| actual = actual_onehot.sum(dim=0) | ||
| true_positives = correct_onehot.sum(dim=0) | ||
| if correct.sum() == 0: | ||
| true_positives = torch.zeros(num_classes) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ignite/metrics/precision.py
Outdated
| all_positives = pred_onehot.sum(dim=0) | ||
| true_positives = correct_onehot.sum(dim=0) | ||
| if correct.sum() == 0: | ||
| true_positives = torch.zeros(num_classes) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
What would I do with you @vfdev-5 😄. Great catch. This should be fixed. This also brings up a bigger question. Not specific to this PR, but in general does anyone have ideas on how we can catch these kind of bugs earlier. i.e. is there any way we can we run CI/tests on a GPU? |
|
Probably, if you run the tests locally on your machine with GPU it can be catched :) Otherwise, pytorch does this with jenkins here and there are probably some machines with GPU. |
|
Thanks @jasonkriss! |
No description provided.