-
-
Notifications
You must be signed in to change notification settings - Fork 26.4k
[MRG+1] Make csr row norms support fused types #6785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] Make csr row norms support fused types #6785
Conversation
|
Tests are failing, you have to reduce the precision for the sqrt also. |
sklearn/utils/sparsefuncs_fast.pyx
Outdated
| np.ndarray[int, ndim=1, mode="c"] indptr = X.indptr | ||
| unsigned int n_samples = shape[0] | ||
| unsigned int n_features = shape[1] | ||
| np.ndarray[double, ndim=1, mode="c"] norms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this no longer DOUBLE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean your comment or the dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bah.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, it's my mistake.
BTW, since type double comes from import numpy as np and typenp.float64_t comes from cimport numpy as np,
is there a big performance difference between them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see how double comes from import numpy as np. double is a C type whose size is officially unspecified, but to be honest I'm not sure if that's the only reason the more precise numpy type is preferred over the C type.
8930eb6 to
ad211a1
Compare
|
@jnothman Sorry, my execution time test before is not correct. After debugging and running it again, result shows that running time increases from 9.9s to 13s if we explicitly cast every entry as we multiply it. It seems that it indeed cause a big runtime hit ... |
|
However, test can be passed if I change test's precision from 1e-6 to 1e-4 as @TomDLT suggested. |
|
Oh. I thought I'd replied to this. But perhaps for the same reason than I'm still not sure what to say, I didn't. I did briefly try to look for a BLAS or similar function (I'm not very familiar with what's available) that might support fast sum-of-squares, potentially with a result in higher precision than the input. I agree that 30% seems a substantial runtime hit. I suppose we can accept the loss in precision. :/ |
sklearn/utils/tests/test_extmath.py
Outdated
|
|
||
| assert_array_almost_equal(sq_norm, row_norms(X, squared=True), 5) | ||
| assert_array_almost_equal(np.sqrt(sq_norm), row_norms(X)) | ||
| assert_array_almost_equal(sq_norm, row_norms(X, squared=True), 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be a good idea to test for the float64 and float32 dtype separately, with a higher precision for the float64 dtype. WDYT?
|
Just a minor comment, +1 otherwise |
ad211a1 to
c7d6f9f
Compare
|
@jnothman merge? |
|
Let's do it. |
Since
csr_row_normsis called byrow_normsfunction defined insklearn/utils/tests/test_extmath.py, androw_normsis used widely ink_means_.py,it will be useful if
csr_row_normsfunction also supports cython fused types.However, making this change would degrade the precision of the function.
In order to pass the local test, I have to alleviate the strictness by changing
assert_array_almost_equal's last n decimal digit from 5 to 4.May @MechCoder and @jnothman give me some advice on this trade-off?