Closed
Description
The mean
and sum
implementations in this package are extremely slow as they rely on generators. Could this be changed to a faster implementation? For example, here is how the current implementation would compute sum-squared-error:
square(x) = x ^ 2
# example of generator-based approach:
function sse(outputs, targets)
sum(square(ŷ - y) for (ŷ, y) in zip(outputs, targets))
end
(i.e., like this code)
which gives us the following time:
julia> @btime sse(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000))
92.833 μs (0 allocations: 0 bytes)
but if we change this to an approach using sum(<function>, <indices>)
, it's much faster:
function sse2(outputs, targets)
sum(i -> square(outputs[i] - targets[i]), eachindex(outputs, targets))
end
julia> @btime sse2(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000))
26.708 μs (0 allocations: 0 bytes)
which is a 3.5x speedup.
Could this be implemented as the default loss calculation? I thought this was the method that used to be used. Perhaps it got changed in the recent refactoring?
Metadata
Metadata
Assignees
Labels
No labels