Skip to content

Conversation

kshyatt
Copy link
Member

@kshyatt kshyatt commented Oct 8, 2025

This ports the (duplicated) sparse broadcasting support from CUDA.jl and AMDGPU.jl to GPUArrays.jl. It should allow all the "child" GPU libraries to use one unified set of on-device sparse types and broadcasting kernels. I implemented the appropriate types in JLArrays and tests there are passing. If this merges, we should be able to strip out most of the sparse broadcasting code from downstream packages.

@kshyatt kshyatt requested review from amontoison and maleadt October 8, 2025 13:56
@kshyatt kshyatt force-pushed the ksh/sparse branch 2 times, most recently from 844f20c to 9a74b4d Compare October 8, 2025 14:07
@kshyatt
Copy link
Member Author

kshyatt commented Oct 8, 2025

I had to @allowscalar around the accumulate! and sort! calls since KA doesn't have device-agnostic implementations of these IIRC. Could be added as part of this PR if people prefer.

@maleadt
Copy link
Member

maleadt commented Oct 8, 2025

I had to @allowscalar around the accumulate! and sort! calls since KA doesn't have device-agnostic implementations of these IIRC

You mean GPUArrays.jl itself? I wouldn't expect those to be defined in KA.jl (maybe AK.jl, but with different signatures).

@kshyatt kshyatt marked this pull request as ready for review October 8, 2025 18:03
@kshyatt
Copy link
Member Author

kshyatt commented Oct 8, 2025

Now with all tests uncommented and testing SparseVector and SparseMatrixCSC also

Copy link
Contributor

github-actions bot commented Oct 8, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl
index 5ada976..3d772fe 100644
--- a/lib/JLArrays/src/JLArrays.jl
+++ b/lib/JLArrays/src/JLArrays.jl
@@ -124,29 +124,33 @@ mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, T
     len::Int
     nnz::Ti
 
-    function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
-                                    len::Integer) where {Tv, Ti <: Integer}
-        new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
+    function JLSparseVector{Tv, Ti}(
+            iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
+            len::Integer
+        ) where {Tv, Ti <: Integer}
+        return new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
     end
 end
 SparseArrays.SparseVector(x::JLSparseVector) = SparseVector(length(x), Array(x.iPtr), Array(x.nzVal))
-SparseArrays.nnz(x::JLSparseVector)          = x.nnz 
-SparseArrays.nonzeroinds(x::JLSparseVector)  = x.iPtr
-SparseArrays.nonzeros(x::JLSparseVector)     = x.nzVal
+SparseArrays.nnz(x::JLSparseVector) = x.nnz
+SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
+SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal
 
 mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti}
     colPtr::JLArray{Ti, 1}
     rowVal::JLArray{Ti, 1}
     nzVal::JLArray{Tv, 1}
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     nnz::Ti
 
-    function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
-                                       nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
-        new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
+    function JLSparseMatrixCSC{Tv, Ti}(
+            colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
+            nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}
+        ) where {Tv, Ti <: Integer}
+        return new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
     end
 end
-function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
+function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}) where {Tv, Ti <: Integer}
     return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims)
 end
 SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal))
@@ -155,28 +159,30 @@ JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A
 
 function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti}
     r1 = Int(@inbounds A.colPtr[j])
-    r2 = Int(@inbounds A.colPtr[j+1]-1)
+    r2 = Int(@inbounds A.colPtr[j + 1] - 1)
     (r1 > r2) && return zero(Tv)
     r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1
-    ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
+    return ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
 end
 
 mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti}
     rowPtr::JLArray{Ti, 1}
     colVal::JLArray{Ti, 1}
     nzVal::JLArray{Tv, 1}
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     nnz::Ti
 
-    function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
-                                       nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer}
-        new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
+    function JLSparseMatrixCSR{Tv, Ti}(
+            rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
+            nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}
+        ) where {Tv, Ti <: Integer}
+        return new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
     end
 end
-function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
+function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}) where {Tv, Ti <: Integer}
     return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
 end
-function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR) 
+function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
     x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal))
     return SparseMatrixCSC(transpose(x_transpose))
 end
@@ -196,12 +202,12 @@ GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
 GPUArrays._sparse_array_type(sa::JLSparseVector) = JLSparseVector
 GPUArrays._sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector
 
-GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray 
-GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray 
-GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray 
-GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray 
-GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray 
-GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray 
+GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray
+GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray
+GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray
+GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
+GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray
+GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
 
 # conversion of untyped data to a typed Array
 function typed_data(x::JLArray{T}) where {T}
@@ -304,11 +310,11 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
 JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)
 
 function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv}
-    iPtr  = JLVector{Ti}(undef, length(xs.nzind))
+    iPtr = JLVector{Ti}(undef, length(xs.nzind))
     nzVal = JLVector{Tv}(undef, length(xs.nzval))
     copyto!(iPtr, convert(Vector{Ti}, xs.nzind))
     copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
-    return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),)
+    return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs))
 end
 Base.length(x::JLSparseVector) = x.len
 Base.size(x::JLSparseVector) = (x.len,)
@@ -316,10 +322,10 @@ Base.size(x::JLSparseVector) = (x.len,)
 function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
     colPtr = JLVector{Ti}(undef, length(xs.colptr))
     rowVal = JLVector{Ti}(undef, length(xs.rowval))
-    nzVal  = JLVector{Tv}(undef, length(xs.nzval))
+    nzVal = JLVector{Tv}(undef, length(xs.nzval))
     copyto!(colPtr, convert(Vector{Ti}, xs.colptr))
     copyto!(rowVal, convert(Vector{Ti}, xs.rowval))
-    copyto!(nzVal,  convert(Vector{Tv}, xs.nzval))
+    copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
     return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n))
 end
 Base.length(x::JLSparseMatrixCSC) = prod(x.dims)
@@ -329,10 +335,10 @@ function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
     csr_xs = SparseMatrixCSC(transpose(xs))
     rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr))
     colVal = JLVector{Ti}(undef, length(csr_xs.rowval))
-    nzVal  = JLVector{Tv}(undef, length(csr_xs.nzval))
+    nzVal = JLVector{Tv}(undef, length(csr_xs.nzval))
     copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr))
     copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval))
-    copyto!(nzVal,  convert(Vector{Tv}, csr_xs.nzval))
+    copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
     return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n))
 end
 Base.length(x::JLSparseMatrixCSR) = prod(x.dims)
@@ -479,17 +485,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
     R
 end
 
-Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
-Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
-Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv, Ti}) where {Tv, Ti} =
+    GPUSparseDeviceMatrixCSC{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv, Ti}) where {Tv, Ti} =
+    GPUSparseDeviceMatrixCSR{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv, Ti}) where {Tv, Ti} =
+    GPUSparseDeviceVector{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
 
 ## KernelAbstractions interface
 
 KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
-KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend()
+KernelAbstractions.get_backend(a::JLA) where {JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector}} = JLBackend()
 
 function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
     return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
diff --git a/src/device/sparse.jl b/src/device/sparse.jl
index 77abe0d..c9290cf 100644
--- a/src/device/sparse.jl
+++ b/src/device/sparse.jl
@@ -10,12 +10,12 @@ using SparseArrays
 # core types
 
 export GPUSparseDeviceVector, GPUSparseDeviceMatrixCSC, GPUSparseDeviceMatrixCSR,
-       GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO
+    GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO
 
 abstract type AbstractGPUSparseDeviceMatrix{Tv, Ti} <: AbstractSparseMatrix{Tv, Ti} end
 
 
-struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv} <: AbstractSparseVector{Tv,Ti}
+struct GPUSparseDeviceVector{Tv, Ti, Vi, Vv} <: AbstractSparseVector{Tv, Ti}
     iPtr::Vi
     nzVal::Vv
     len::Int
@@ -28,19 +28,19 @@ SparseArrays.nnz(g::GPUSparseDeviceVector) = g.nnz
 SparseArrays.nonzeroinds(g::GPUSparseDeviceVector) = g.iPtr
 SparseArrays.nonzeros(g::GPUSparseDeviceVector) = g.nzVal
 
-struct GPUSparseDeviceMatrixCSC{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
+struct GPUSparseDeviceMatrixCSC{Tv, Ti, Vi, Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
     colPtr::Vi
     rowVal::Vi
     nzVal::Vv
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     nnz::Ti
 end
 
 SparseArrays.rowvals(g::GPUSparseDeviceMatrixCSC) = g.rowVal
 SparseArrays.getcolptr(g::GPUSparseDeviceMatrixCSC) = g.colPtr
-SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1)
+SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col + 1] - 1)
 
-struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixCSR{Tv, Ti, Vi, Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
     rowPtr::Vi
     colVal::Vi
     nzVal::Vv
@@ -48,21 +48,21 @@ struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv
     nnz::Ti
 end
 
-struct GPUSparseDeviceMatrixBSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixBSR{Tv, Ti, Vi, Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
     rowPtr::Vi
     colVal::Vi
     nzVal::Vv
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     blockDim::Ti
     dir::Char
     nnz::Ti
 end
 
-struct GPUSparseDeviceMatrixCOO{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixCOO{Tv, Ti, Vi, Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
     rowInd::Vi
     colInd::Vi
     nzVal::Vv
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     nnz::Ti
 end
 
@@ -79,9 +79,9 @@ struct GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M} <: AbstractSparseArray{Tv,
     nnz::Ti
 end
 
-function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, Vi<:AbstractDeviceArray{<:Integer,M}, Vv<:AbstractDeviceArray{Tv, M}, N}
+function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N, <:Integer}) where {Tv, Ti <: Integer, M, Vi <: AbstractDeviceArray{<:Integer, M}, Vv <: AbstractDeviceArray{Tv, M}, N}
     @assert M == N - 1 "GPUSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
-    GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
+    return GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
 end
 
 Base.length(g::GPUSparseDeviceArrayCSR) = prod(g.dims)
@@ -94,42 +94,42 @@ SparseArrays.getnzval(g::GPUSparseDeviceArrayCSR) = g.nzVal
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceVector)
     println(io, "$(length(A))-element device sparse vector at:")
     println(io, "  iPtr: $(A.iPtr)")
-    print(io,   "  nzVal: $(A.nzVal)")
+    return print(io, "  nzVal: $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSR)
     println(io, "$(length(A))-element device sparse matrix CSR at:")
     println(io, "  rowPtr: $(A.rowPtr)")
     println(io, "  colVal: $(A.colVal)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSC)
     println(io, "$(length(A))-element device sparse matrix CSC at:")
     println(io, "  colPtr: $(A.colPtr)")
     println(io, "  rowVal: $(A.rowVal)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixBSR)
     println(io, "$(length(A))-element device sparse matrix BSR at:")
     println(io, "  rowPtr: $(A.rowPtr)")
     println(io, "  colVal: $(A.colVal)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCOO)
     println(io, "$(length(A))-element device sparse matrix COO at:")
     println(io, "  rowPtr: $(A.rowPtr)")
     println(io, "  colInd: $(A.colInd)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceArrayCSR)
     println(io, "$(length(A))-element device sparse array CSR at:")
     println(io, "  rowPtr: $(A.rowPtr)")
     println(io, "  colVal: $(A.colVal)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 # COV_EXCL_STOP
diff --git a/src/host/sparse.jl b/src/host/sparse.jl
index ff82030..6f1e77f 100644
--- a/src/host/sparse.jl
+++ b/src/host/sparse.jl
@@ -7,20 +7,20 @@ abstract type AbstractGPUSparseMatrixCSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, T
 abstract type AbstractGPUSparseMatrixCOO{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end
 abstract type AbstractGPUSparseMatrixBSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end
 
-const AbstractGPUSparseVecOrMat = Union{AbstractGPUSparseVector,AbstractGPUSparseMatrix}
+const AbstractGPUSparseVecOrMat = Union{AbstractGPUSparseVector, AbstractGPUSparseMatrix}
 
 Base.convert(T::Type{<:AbstractGPUSparseArray}, m::AbstractArray) = m isa T ? m : T(m)
 
-_dense_array_type(sa::SparseVector)     = SparseVector
+_dense_array_type(sa::SparseVector) = SparseVector
 _dense_array_type(::Type{SparseVector}) = SparseVector
 _sparse_array_type(sa::SparseVector) = SparseVector
 _dense_vector_type(sa::AbstractSparseArray) = Vector
-_dense_vector_type(sa::AbstractArray)       = Vector
+_dense_vector_type(sa::AbstractArray) = Vector
 _dense_vector_type(::Type{<:AbstractSparseArray}) = Vector
-_dense_vector_type(::Type{<:AbstractArray})       = Vector
-_dense_array_type(sa::SparseMatrixCSC)     = SparseMatrixCSC
+_dense_vector_type(::Type{<:AbstractArray}) = Vector
+_dense_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
 _dense_array_type(::Type{SparseMatrixCSC}) = SparseMatrixCSC
-_sparse_array_type(sa::SparseMatrixCSC)    = SparseMatrixCSC
+_sparse_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
 
 function _sparse_array_type(sa::AbstractGPUSparseArray) end
 function _dense_array_type(sa::AbstractGPUSparseArray) end
@@ -32,7 +32,7 @@ struct GPUSparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
 struct GPUSparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
 Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseVector}) = GPUSparseVecStyle()
 Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseMatrix}) = GPUSparseMatStyle()
-const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle}
+const SPVM = Union{GPUSparseVecStyle, GPUSparseMatStyle}
 
 # GPUSparseVecStyle handles 0-1 dimensions, GPUSparseMatStyle 0-2 dimensions.
 # GPUSparseVecStyle promotes to GPUSparseMatStyle for 2 dimensions.
@@ -40,11 +40,11 @@ const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle}
 GPUSparseVecStyle(::Val{0}) = GPUSparseVecStyle()
 GPUSparseVecStyle(::Val{1}) = GPUSparseVecStyle()
 GPUSparseVecStyle(::Val{2}) = GPUSparseMatStyle()
-GPUSparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
+GPUSparseVecStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
 GPUSparseMatStyle(::Val{0}) = GPUSparseMatStyle()
 GPUSparseMatStyle(::Val{1}) = GPUSparseMatStyle()
 GPUSparseMatStyle(::Val{2}) = GPUSparseMatStyle()
-GPUSparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
+GPUSparseMatStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
 
 Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{1}) = GPUSparseVecStyle()
 Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{2}) = GPUSparseMatStyle()
@@ -71,37 +71,37 @@ end
 
 # Work around losing Type{T}s as DataTypes within the tuple that makeargs creates
 @inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} =
-    capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
+    capturescalars((args...) -> f(T, args...), Base.tail(mixedargs))
 @inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Ref{Type{S}}, Vararg{Any}}) where {T, S} =
     # This definition is identical to the one above and necessary only for
     # avoiding method ambiguity.
-    capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
+    capturescalars((args...) -> f(T, args...), Base.tail(mixedargs))
 @inline capturescalars(f, mixedargs::Tuple{AbstractGPUSparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} =
-    capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
-@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{<:Any,0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
-    capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
+    capturescalars((a1, args...) -> f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
+@inline capturescalars(f, mixedargs::Tuple{Union{Ref, AbstractArray{<:Any, 0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
+    capturescalars((args...) -> f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
 
 scalararg(::Number) = true
 scalararg(::Any) = false
-scalarwrappedarg(::Union{AbstractArray{<:Any,0},Ref}) = true
+scalarwrappedarg(::Union{AbstractArray{<:Any, 0}, Ref}) = true
 scalarwrappedarg(::Any) = false
 
 @inline function _capturescalars()
     return (), () -> ()
 end
 @inline function _capturescalars(arg, mixedargs...)
-    let (rest, f) = _capturescalars(mixedargs...)
+    return let (rest, f) = _capturescalars(mixedargs...)
         if scalararg(arg)
-            return rest, @inline function(tail...)
-                (arg, f(tail...)...)
+            return rest, @inline function (tail...)
+                    return (arg, f(tail...)...)
             end # add back scalararg after (in makeargs)
         elseif scalarwrappedarg(arg)
-            return rest, @inline function(tail...)
-                (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
+            return rest, @inline function (tail...)
+                    return (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
             end # unwrap and add back scalararg after (in makeargs)
         else
-            return (arg, rest...), @inline function(head, tail...)
-                (head, f(tail...)...)
+            return (arg, rest...), @inline function (head, tail...)
+                    return (head, f(tail...)...)
             end # pass-through to broadcast
         end
     end
@@ -137,13 +137,13 @@ access the elements themselves.
 For convenience, this iterator can be passed non-sparse arguments as well, which will be
 ignored (with the returned `col`/`ptr` values set to 0).
 """
-struct CSRIterator{Ti,N,ATs}
+struct CSRIterator{Ti, N, ATs}
     row::Ti
     col_ends::NTuple{N, Ti}
     args::ATs
 end
 
-function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N}
+function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti, N}
     # check that `row` is valid for all arguments
     @boundscheck begin
         ntuple(Val(N)) do i
@@ -155,16 +155,16 @@ function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N}
     col_ends = ntuple(Val(N)) do i
         arg = @inbounds args[i]
         if arg isa GPUSparseDeviceMatrixCSR
-            @inbounds(arg.rowPtr[row+1])
+            @inbounds(arg.rowPtr[row + 1])
         else
             zero(Ti)
         end
     end
 
-    CSRIterator{Ti, N, typeof(args)}(row, col_ends, args)
+    return CSRIterator{Ti, N, typeof(args)}(row, col_ends, args)
 end
 
-@inline function Base.iterate(iter::CSRIterator{Ti,N}, state=nothing) where {Ti,N}
+@inline function Base.iterate(iter::CSRIterator{Ti, N}, state = nothing) where {Ti, N}
     # helper function to get the column of a sparse array at a specific pointer
     @inline function get_col(i, ptr)
         arg = @inbounds iter.args[i]
@@ -174,13 +174,14 @@ end
                 return @inbounds arg.colVal[ptr] % Ti
             end
         end
-        typemax(Ti)
+        return typemax(Ti)
     end
 
     # initialize the state
     # - ptr: the current index into the colVal/nzVal arrays
     # - col: the current column index (cached so that we don't have to re-read each time)
-    state = something(state,
+    state = something(
+        state,
         ntuple(Val(N)) do i
             arg = @inbounds iter.args[i]
             if arg isa GPUSparseDeviceMatrixCSR
@@ -222,13 +223,13 @@ end
     return (cur_col, ptrs), new_state
 end
 
-struct CSCIterator{Ti,N,ATs}
+struct CSCIterator{Ti, N, ATs}
     col::Ti
     row_ends::NTuple{N, Ti}
     args::ATs
 end
 
-function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N}
+function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti, N}
     # check that `col` is valid for all arguments
     @boundscheck begin
         ntuple(Val(N)) do i
@@ -240,17 +241,17 @@ function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N}
     row_ends = ntuple(Val(N)) do i
         arg = @inbounds args[i]
         x = if arg isa GPUSparseDeviceMatrixCSC
-            @inbounds(arg.colPtr[col+1])
+            @inbounds(arg.colPtr[col + 1])
         else
             zero(Ti)
         end
         x
     end
 
-    CSCIterator{Ti, N, typeof(args)}(col, row_ends, args)
+    return CSCIterator{Ti, N, typeof(args)}(col, row_ends, args)
 end
 
-@inline function Base.iterate(iter::CSCIterator{Ti,N}, state=nothing) where {Ti,N}
+@inline function Base.iterate(iter::CSCIterator{Ti, N}, state = nothing) where {Ti, N}
     # helper function to get the column of a sparse array at a specific pointer
     @inline function get_col(i, ptr)
         arg = @inbounds iter.args[i]
@@ -260,13 +261,14 @@ end
                 return @inbounds arg.rowVal[ptr] % Ti
             end
         end
-        typemax(Ti)
+        return typemax(Ti)
     end
 
     # initialize the state
     # - ptr: the current index into the rowVal/nzVal arrays
     # - row: the current row index (cached so that we don't have to re-read each time)
-    state = something(state,
+    state = something(
+        state,
         ntuple(Val(N)) do i
             arg = @inbounds iter.args[i]
             if arg isa GPUSparseDeviceMatrixCSC
@@ -309,8 +311,8 @@ end
 end
 
 # helpers to index a sparse or dense array
-function _getindex(arg::Union{<:GPUSparseDeviceMatrixCSR,GPUSparseDeviceMatrixCSC}, I, ptr)
-    if ptr == 0
+function _getindex(arg::Union{<:GPUSparseDeviceMatrixCSR, GPUSparseDeviceMatrixCSC}, I, ptr)
+    return if ptr == 0
         zero(eltype(arg))
     else
         @inbounds arg.nzVal[ptr]
@@ -338,9 +340,11 @@ function _has_row(A::GPUSparseDeviceVector, offsets, row, ::Bool)
     return 0
 end
 
-@kernel function compute_offsets_kernel(::Type{<:AbstractGPUSparseVector}, first_row::Ti, last_row::Ti,
-                                        fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
-                                        args...) where {Ti, N}
+@kernel function compute_offsets_kernel(
+        ::Type{<:AbstractGPUSparseVector}, first_row::Ti, last_row::Ti,
+        fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+        args...
+    ) where {Ti, N}
     my_ix = @index(Global, Linear)
     row = my_ix + first_row - one(eltype(my_ix))
     if row ≤ last_row
@@ -359,11 +363,13 @@ end
 end
 
 # kernel to count the number of non-zeros in a row, to determine the row offsets
-@kernel function compute_offsets_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}},
-                                        offsets::AbstractVector{Ti}, args...) where Ti
+@kernel function compute_offsets_kernel(
+        T::Type{<:Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}},
+        offsets::AbstractVector{Ti}, args...
+    ) where {Ti}
     # every thread processes an entire row
     leading_dim = @index(Global, Linear)
-    if leading_dim ≤ length(offsets)-1 
+    if leading_dim ≤ length(offsets) - 1
         iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
 
         # count the nonzero leading_dims of all inputs
@@ -378,19 +384,21 @@ end
             if leading_dim == 1
                 offsets[1] = 1
             end
-            offsets[leading_dim+1] = accum
+            offsets[leading_dim + 1] = accum
         end
     end
 end
 
-@kernel function sparse_to_sparse_broadcast_kernel(f::F, output::GPUSparseDeviceVector{Tv,Ti},
-                                                   offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
-                                                   args...) where {Tv, Ti, N, F}
+@kernel function sparse_to_sparse_broadcast_kernel(
+        f::F, output::GPUSparseDeviceVector{Tv, Ti},
+        offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+        args...
+    ) where {Tv, Ti, N, F}
     row_ix = @index(Global, Linear)
     if row_ix ≤ output.nnz
         row_and_ptrs = @inbounds offsets[row_ix]
-        row          = @inbounds row_and_ptrs[1]
-        arg_ptrs     = @inbounds row_and_ptrs[2]
+        row = @inbounds row_and_ptrs[1]
+        arg_ptrs = @inbounds row_and_ptrs[2]
         vals = ntuple(Val(N)) do i
             @inline
             arg = @inbounds args[i]
@@ -400,14 +408,20 @@ end
             _getindex(arg, row, ptr)
         end
         output_val = f(vals...)
-        @inbounds output.iPtr[row_ix]  = row
+        @inbounds output.iPtr[row_ix] = row
         @inbounds output.nzVal[row_ix] = output_val
     end
 end
 
-@kernel function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{<:AbstractArray,Nothing},
-                                                   args...) where {Ti, T<:Union{GPUSparseDeviceMatrixCSR{<:Any,Ti},
-                                                                                GPUSparseDeviceMatrixCSC{<:Any,Ti}}}
+@kernel function sparse_to_sparse_broadcast_kernel(
+        f, output::T, offsets::Union{<:AbstractArray, Nothing},
+        args...
+    ) where {
+        Ti, T <: Union{
+            GPUSparseDeviceMatrixCSR{<:Any, Ti},
+            GPUSparseDeviceMatrixCSC{<:Any, Ti},
+        },
+    }
     # every thread processes an entire row
     leading_dim = @index(Global, Linear)
     leading_dim_size = output isa GPUSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2)
@@ -415,19 +429,19 @@ end
         iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
 
 
-        output_ptrs  = output isa GPUSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr
+        output_ptrs = output isa GPUSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr
         output_ivals = output isa GPUSparseDeviceMatrixCSR ? output.colVal : output.rowVal
         # fetch the row offset, and write it to the output
         @inbounds begin
             output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
             if leading_dim == leading_dim_size
-                output_ptrs[leading_dim+one(eltype(leading_dim))] = offsets[leading_dim+one(eltype(leading_dim))]
+                output_ptrs[leading_dim + one(eltype(leading_dim))] = offsets[leading_dim + one(eltype(leading_dim))]
             end
         end
 
         # set the values for this row
         for (sub_leading_dim, ptrs) in iter
-            index_first  = output isa GPUSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
+            index_first = output isa GPUSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
             index_second = output isa GPUSparseDeviceMatrixCSR ? sub_leading_dim : leading_dim
             I = CartesianIndex(index_first, index_second)
             vals = ntuple(Val(length(args))) do i
@@ -442,9 +456,15 @@ end
         end
     end
 end
-@kernel function sparse_to_dense_broadcast_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR{Tv, Ti},
-                                                                  AbstractGPUSparseMatrixCSC{Tv, Ti}}},
-                                                  f, output::AbstractDeviceArray, args...) where {Tv, Ti}
+@kernel function sparse_to_dense_broadcast_kernel(
+        T::Type{
+            <:Union{
+                AbstractGPUSparseMatrixCSR{Tv, Ti},
+                AbstractGPUSparseMatrixCSC{Tv, Ti},
+            },
+        },
+        f, output::AbstractDeviceArray, args...
+    ) where {Tv, Ti}
     # every thread processes an entire row
     leading_dim = @index(Global, Linear)
     leading_dim_size = T <: AbstractGPUSparseMatrixCSR ? size(output, 1) : size(output, 2)
@@ -453,7 +473,7 @@ end
 
         # set the values for this row
         for (sub_leading_dim, ptrs) in iter
-            index_first  = T <: AbstractGPUSparseMatrixCSR ? leading_dim : sub_leading_dim
+            index_first = T <: AbstractGPUSparseMatrixCSR ? leading_dim : sub_leading_dim
             index_second = T <: AbstractGPUSparseMatrixCSR ? sub_leading_dim : leading_dim
             I = CartesianIndex(index_first, index_second)
             vals = ntuple(Val(length(args))) do i
@@ -467,16 +487,18 @@ end
     end
 end
 
-@kernel function sparse_to_dense_broadcast_kernel(::Type{<:AbstractGPUSparseVector}, f::F,
-                                                  output::AbstractDeviceArray{Tv},
-                                                  offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
-                                                  args...) where {Tv, F, N, Ti}
+@kernel function sparse_to_dense_broadcast_kernel(
+        ::Type{<:AbstractGPUSparseVector}, f::F,
+        output::AbstractDeviceArray{Tv},
+        offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+        args...
+    ) where {Tv, F, N, Ti}
     # every thread processes an entire row
     row_ix = @index(Global, Linear)
     if row_ix ≤ length(output)
         row_and_ptrs = @inbounds offsets[row_ix]
-        row          = @inbounds row_and_ptrs[1]
-        arg_ptrs     = @inbounds row_and_ptrs[2]
+        row = @inbounds row_and_ptrs[1]
+        arg_ptrs = @inbounds row_and_ptrs[2]
         vals = ntuple(Val(length(args))) do i
             @inline
             arg = @inbounds args[i]
@@ -491,39 +513,41 @@ end
 end
 ## COV_EXCL_STOP
 
-function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatStyle}})
+function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle, GPUSparseMatStyle}})
     # find the sparse inputs
     bc = Broadcast.flatten(bc)
     sparse_args = findall(bc.args) do arg
         arg isa AbstractGPUSparseArray
     end
-    sparse_types = unique(map(i->nameof(typeof(bc.args[i])), sparse_args))
+    sparse_types = unique(map(i -> nameof(typeof(bc.args[i])), sparse_args))
     if length(sparse_types) > 1
         error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported")
     end
     sparse_typ = typeof(bc.args[first(sparse_args)])
-    sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC,AbstractGPUSparseVector} ||
+    sparse_typ <: Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC, AbstractGPUSparseVector} ||
         error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices")
     Ti = if sparse_typ <: AbstractGPUSparseMatrixCSR
-        reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args))
+        reduce(promote_type, map(i -> eltype(bc.args[i].rowPtr), sparse_args))
     elseif sparse_typ <: AbstractGPUSparseMatrixCSC
-        reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args))
+        reduce(promote_type, map(i -> eltype(bc.args[i].colPtr), sparse_args))
     elseif sparse_typ <: AbstractGPUSparseVector
-        reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args))
+        reduce(promote_type, map(i -> eltype(bc.args[i].iPtr), sparse_args))
     end
 
     # determine the output type
     Tv = Broadcast.combine_eltypes(bc.f, eltype.(bc.args))
     if !Base.isconcretetype(Tv)
-        error("""GPU sparse broadcast resulted in non-concrete element type $Tv.
-                 This probably means that the function you are broadcasting contains an error or type instability.""")
+        error(
+            """GPU sparse broadcast resulted in non-concrete element type $Tv.
+            This probably means that the function you are broadcasting contains an error or type instability."""
+        )
     end
 
     # partially-evaluate the function, removing scalars.
     parevalf, passedsrcargstup = capturescalars(bc.f, bc.args)
     # check if the partially-evaluated function preserves zeros. if so, we'll only need to
     # apply it to the sparse input arguments, preserving the sparse structure.
-    if all(arg->isa(arg, AbstractSparseArray), passedsrcargstup)
+    if all(arg -> isa(arg, AbstractSparseArray), passedsrcargstup)
         fofzeros = parevalf(_zeros_eltypes(passedsrcargstup...)...)
         fpreszeros = _iszero(fofzeros)
     else
@@ -532,7 +556,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
 
     # the kernels below parallelize across rows or cols, not elements, so it's unlikely
     # we'll launch many threads. to maximize utilization, parallelize across blocks first.
-    rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1) 
+    rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1)
     # `size(bc, ::Int)` is missing
     # for AbstractGPUSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
     # but the only rows present in any sparse vector input are between 2 and 128, no need to
@@ -547,7 +571,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
         # either we have dense inputs, or the function isn't preserving zeros,
         # so use a dense output to broadcast into.
         val_array = sparse_arg.nzVal
-        output    = similar(val_array, Tv, size(bc))
+        output = similar(val_array, Tv, size(bc))
         # since we'll be iterating the sparse inputs, we need to pre-fill the dense output
         # with appropriate values (while setting the sparse inputs to zero). we do this by
         # re-using the dense broadcast implementation.
@@ -565,24 +589,24 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
         # this avoids a kernel launch and costly synchronization.
         if sparse_typ <: AbstractGPUSparseMatrixCSR
             offsets = rowPtr = sparse_arg.rowPtr
-            colVal  = similar(sparse_arg.colVal)
-            nzVal   = similar(sparse_arg.nzVal, Tv)
-            output  = _sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
+            colVal = similar(sparse_arg.colVal)
+            nzVal = similar(sparse_arg.nzVal, Tv)
+            output = _sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
         elseif sparse_typ <: AbstractGPUSparseMatrixCSC
             offsets = colPtr = sparse_arg.colPtr
-            rowVal  = similar(sparse_arg.rowVal)
-            nzVal   = similar(sparse_arg.nzVal, Tv)
-            output  = _sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
+            rowVal = similar(sparse_arg.rowVal)
+            nzVal = similar(sparse_arg.nzVal, Tv)
+            output = _sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
         end
     else
         # determine the number of non-zero elements per row so that we can create an
         # appropriately-structured output container
         offsets = if sparse_typ <: AbstractGPUSparseMatrixCSR
             ptr_array = sparse_arg.rowPtr
-            similar(ptr_array, Ti, rows+1)
+            similar(ptr_array, Ti, rows + 1)
         elseif sparse_typ <: AbstractGPUSparseMatrixCSC
             ptr_array = sparse_arg.colPtr
-            similar(ptr_array, Ti, cols+1)
+            similar(ptr_array, Ti, cols + 1)
         elseif sparse_typ <: AbstractGPUSparseVector
             ptr_array = sparse_arg.iPtr
             @allowscalar begin
@@ -596,7 +620,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
                 end
             end
             overall_first_row = min(arg_first_rows...)
-            overall_last_row  = max(arg_last_rows...)
+            overall_last_row = max(arg_last_rows...)
             similar(ptr_array, Pair{Ti, NTuple{length(bc.args), Ti}}, overall_last_row - overall_first_row + 1)
         end
         let
@@ -606,7 +630,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
                 (sparse_typ, offsets, bc.args...)
             end
             kernel = compute_offsets_kernel(get_backend(bc.args[first(sparse_args)]))
-            kernel(args...; ndrange=length(offsets))
+            kernel(args...; ndrange = length(offsets))
         end
         # accumulate these values so that we can use them directly as row pointer offsets,
         # as well as to get the total nnz count to allocate the sparse output array.
@@ -615,10 +639,10 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
             @allowscalar accumulate!(Base.add_sum, offsets, offsets)
             total_nnz = @allowscalar last(offsets[end]) - 1
         else
-            @allowscalar sort!(offsets; by=first)
-            total_nnz = mapreduce(x->first(x) != typemax(first(x)), +, offsets)
+            @allowscalar sort!(offsets; by = first)
+            total_nnz = mapreduce(x -> first(x) != typemax(first(x)), +, offsets)
         end
-        output = if sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC}
+        output = if sparse_typ <: Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}
             ixVal = similar(offsets, Ti, total_nnz)
             nzVal = similar(offsets, Tv, total_nnz)
             sparse_typ(offsets, ixVal, nzVal, size(bc))
@@ -626,8 +650,8 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
             val_array = bc.args[first(sparse_args)].nzVal
             similar(val_array, Tv, size(bc))
         elseif sparse_typ <: AbstractGPUSparseVector && fpreszeros
-            iPtr   = similar(offsets, Ti, total_nnz)
-            nzVal  = similar(offsets, Tv, total_nnz)
+            iPtr = similar(offsets, Ti, total_nnz)
+            nzVal = similar(offsets, Tv, total_nnz)
             _sparse_array_type(sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
         end
         if sparse_typ <: AbstractGPUSparseVector && !fpreszeros
@@ -644,12 +668,12 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
     end
     # perform the actual broadcast
     if output isa AbstractGPUSparseArray
-        args   = (bc.f, output, offsets, bc.args...)
+        args = (bc.f, output, offsets, bc.args...)
         kernel = sparse_to_sparse_broadcast_kernel(get_backend(bc.args[first(sparse_args)]))
         ndrange = output.nnz
     else
-        args   = sparse_typ <: AbstractGPUSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) :
-                                                         (sparse_typ, bc.f, output, bc.args...)
+        args = sparse_typ <: AbstractGPUSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) :
+            (sparse_typ, bc.f, output, bc.args...)
         kernel = sparse_to_dense_broadcast_kernel(get_backend(bc.args[first(sparse_args)]))
         ndrange = sparse_typ <: AbstractGPUSparseMatrixCSC ? size(output, 2) : size(output, 1)
     end
diff --git a/test/runtests.jl b/test/runtests.jl
index f59f07a..0611af0 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -48,12 +48,12 @@ include("setup.jl")     # make sure everything is precompiled
 # choose tests
 const tests = []
 const test_runners = Dict()
-for AT in (JLArray, Array), name in filter(n->n != "sparse", keys(TestSuite.tests))
+for AT in (JLArray, Array), name in filter(n -> n != "sparse", keys(TestSuite.tests))
     push!(tests, "$(AT)/$name")
-    test_runners["$(AT)/$name"] = ()->TestSuite.tests[name](AT)
+    test_runners["$(AT)/$name"] = () -> TestSuite.tests[name](AT)
 end
 
-for AT in ( JLSparseMatrixCSR, JLSparseMatrixCSC, JLSparseVector, SparseMatrixCSC, SparseVector), name in ["sparse"]
+for AT in (JLSparseMatrixCSR, JLSparseMatrixCSC, JLSparseVector, SparseMatrixCSC, SparseVector), name in ["sparse"]
     push!(tests, "$(AT)/$name")
     test_runners["$(AT)/$name"] = ()->TestSuite.tests[name](AT)
 end
diff --git a/test/testsuite/sparse.jl b/test/testsuite/sparse.jl
index d9783e1..f7b417b 100644
--- a/test/testsuite/sparse.jl
+++ b/test/testsuite/sparse.jl
@@ -1,4 +1,4 @@
-@testsuite "sparse" (AT, eltypes)->begin
+@testsuite "sparse" (AT, eltypes) -> begin
     if AT <: AbstractSparseVector
         broadcasting_vector(AT, eltypes)
     elseif AT <: AbstractSparseMatrix
@@ -13,68 +13,68 @@ function broadcasting_vector(AT, eltypes)
     dense_VT = GPUArrays._dense_vector_type(AT)
     for ET in eltypes
         @testset "SparseVector($ET)" begin
-            m  = 64
-            p  = 0.5
-            x  = sprand(ET, m, p)
+            m = 64
+            p = 0.5
+            x = sprand(ET, m, p)
             dx = AT(x)
 
             # zero-preserving
-            y  = x  .* ET(1)
+            y = x .* ET(1)
             dy = dx .* ET(1)
             @test dy isa AT{ET}
-            @test collect(SparseArrays.nonzeroinds(dy))  == collect(SparseArrays.nonzeroinds(dx))
-            @test collect(SparseArrays.nonzeroinds(dy))  == SparseArrays.nonzeroinds(y)
+            @test collect(SparseArrays.nonzeroinds(dy)) == collect(SparseArrays.nonzeroinds(dx))
+            @test collect(SparseArrays.nonzeroinds(dy)) == SparseArrays.nonzeroinds(y)
             @test collect(SparseArrays.nonzeros(dy)) == SparseArrays.nonzeros(y)
             @test y == SparseVector(dy)
 
             # not zero-preserving
-            y  = x  .+ ET(1)
+            y = x .+ ET(1)
             dy = dx .+ ET(1)
             @test dy isa dense_AT{ET}
             hy = Array(dy)
             @test Array(y) == hy
 
             # involving something dense
-            y  = x  .+ ones(ET, m)
+            y = x .+ ones(ET, m)
             dy = dx .+ dense_AT(ones(ET, m))
             @test dy isa dense_AT{ET}
             @test Array(y) == Array(dy)
 
             # sparse to sparse
             dx = AT(x)
-            y  = sprand(ET, m, p)
+            y = sprand(ET, m, p)
             dy = AT(y)
-            z  = x  .* y
+            z = x .* y
             dz = dx .* dy
             @test dz isa AT{ET}
             @test z == SparseVector(dz)
 
             # multiple inputs
-            y  = sprand(ET, m, p)
-            w  = sprand(ET, m, p)
+            y = sprand(ET, m, p)
+            w = sprand(ET, m, p)
             dy = AT(y)
             dx = AT(x)
             dw = AT(w)
-            z  = @. x  * y  * w
+            z = @. x * y * w
             dz = @. dx * dy * dw
             @test dz isa AT{ET}
             @test z == SparseVector(dz)
 
             y = sprand(ET, m, p)
             w = sprand(ET, m, p)
-            dense_arr   = rand(ET, m)
+            dense_arr = rand(ET, m)
             d_dense_arr = dense_AT(dense_arr)
             dy = AT(y)
             dw = AT(w)
-            z  = @. x  * y  * w  * dense_arr
+            z = @. x * y * w * dense_arr
             dz = @. dx * dy * dw * d_dense_arr
             @test dz isa dense_AT{ET}
             @test Array(z) == Array(dz)
-            
-            y  = sprand(ET, m, p)
+
+            y = sprand(ET, m, p)
             dy = AT(y)
             dx = AT(x)
-            z  = x  .* y  .* ET(2)
+            z = x .* y .* ET(2)
             dz = dx .* dy .* ET(2)
             @test dz isa AT{ET}
             @test z == SparseVector(dz)
@@ -83,38 +83,39 @@ function broadcasting_vector(AT, eltypes)
             ## non-zero-preserving
             dx = AT(x)
             dy = dx .+ 1
-            y  = x .+ 1
+            y = x .+ 1
             @test dy isa dense_AT{promote_type(ET, Int)}
             @test Array(y) == Array(dy)
             ## zero-preserving
             dy = dx .* 1
-            y  = x  .* 1
+            y = x .* 1
             @test dy isa AT{promote_type(ET, Int)}
-            @test collect(SparseArrays.nonzeroinds(dy))  == collect(SparseArrays.nonzeroinds(dx))
-            @test collect(SparseArrays.nonzeroinds(dy))  == SparseArrays.nonzeroinds(y)
+            @test collect(SparseArrays.nonzeroinds(dy)) == collect(SparseArrays.nonzeroinds(dx))
+            @test collect(SparseArrays.nonzeroinds(dy)) == SparseArrays.nonzeroinds(y)
             @test collect(SparseArrays.nonzeros(dy)) == SparseArrays.nonzeros(y)
             @test y == SparseVector(dy)
         end
     end
+    return
 end
 
 function broadcasting_matrix(AT, eltypes)
     dense_AT = GPUArrays._dense_array_type(AT)
     dense_VT = GPUArrays._dense_vector_type(AT)
     for ET in eltypes
-       @testset "SparseMatrix($ET)" begin
+        @testset "SparseMatrix($ET)" begin
             m, n = 5, 6
-            p   = 0.5
-            x   = sprand(ET, m, n, p)
-            dx  = AT(x)
+            p = 0.5
+            x = sprand(ET, m, n, p)
+            dx = AT(x)
             # zero-preserving
-            y  = x  .* ET(1)
+            y = x .* ET(1)
             dy = dx .* ET(1)
             @test dy isa AT{ET}
             @test y == SparseMatrixCSC(dy)
 
             # not zero-preserving
-            y  = x  .+ ET(1)
+            y = x .+ ET(1)
             dy = dx .+ ET(1)
             @test dy isa dense_AT{ET}
             hy = Array(dy)
@@ -122,26 +123,27 @@ function broadcasting_matrix(AT, eltypes)
             @test Array(y) == Array(dy)
 
             # involving something dense
-            y  = x  .* ones(ET, m, n)
+            y = x .* ones(ET, m, n)
             dy = dx .* dense_AT(ones(ET, m, n))
             @test dy isa dense_AT{ET}
             @test Array(y) == Array(dy)
-            
+
             # multiple inputs
-            y  = sprand(ET, m, n, p)
+            y = sprand(ET, m, n, p)
             dy = AT(y)
-            z  = x  .* y  .* ET(2)
+            z = x .* y .* ET(2)
             dz = dx .* dy .* ET(2)
             @test dz isa AT{ET}
             @test z == SparseMatrixCSC(dz)
 
             # multiple inputs
-            w  = sprand(ET, m, n, p)
+            w = sprand(ET, m, n, p)
             dw = AT(w)
-            z  = x  .* y  .* w
+            z = x .* y .* w
             dz = dx .* dy .* dw
             @test dz isa AT{ET}
             @test z == SparseMatrixCSC(dz)
         end
     end
+    return
 end

@kshyatt kshyatt changed the title [WIP] Sparse GPU array and broadcasting support Sparse GPU array and broadcasting support Oct 9, 2025
Comment on lines +18 to +23
struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv} <: AbstractSparseVector{Tv,Ti}
iPtr::Vi
nzVal::Vv
len::Int
nnz::Ti
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit inconsistent that we keep the host sparse object layout to the back-end, but define the device one concretely. I'm not sure if it's better to entirely move the definitions away from (or rather into) GPUArrays.jl though. I guess back-ends may want additional control over the object layout in order to facilitate vendor library interactions, but maybe we should then also leave the device-side version up to the back-end and only implement things here in terms of SparseArrays interfaces (rowvals, getcolptr, etc). Thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was torn on this too. The advantage here is that libraries get a working device-side implementation "for free" -- they are able to implement their own (better) one and just give Adapt.jl information about how to move their host-side structs to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants