Skip to content

Very slow loss aggregation #172

Closed
@MilesCranmer

Description

@MilesCranmer

The mean and sum implementations in this package are extremely slow as they rely on generators. Could this be changed to a faster implementation? For example, here is how the current implementation would compute sum-squared-error:

square(x) = x ^ 2

# example of generator-based approach:
function sse(outputs, targets)
    sum(square(ŷ - y) for (ŷ, y) in zip(outputs, targets))
end

(i.e., like this code)

which gives us the following time:

julia> @btime sse(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000))
  92.833 μs (0 allocations: 0 bytes)

but if we change this to an approach using sum(<function>, <indices>), it's much faster:

function sse2(outputs, targets)
    sum(i -> square(outputs[i] - targets[i]), eachindex(outputs, targets))
end
julia> @btime sse2(outputs, targets) setup=(outputs=randn(100_000); targets=randn(100_000))
  26.708 μs (0 allocations: 0 bytes)

which is a 3.5x speedup.

Could this be implemented as the default loss calculation? I thought this was the method that used to be used. Perhaps it got changed in the recent refactoring?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions