-
Notifications
You must be signed in to change notification settings - Fork 87
Sparse GPU array and broadcasting support #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
844f20c
to
9a74b4d
Compare
I had to |
You mean GPUArrays.jl itself? I wouldn't expect those to be defined in KA.jl (maybe AK.jl, but with different signatures). |
Now with all tests uncommented and testing |
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl
index 5ada976..3d772fe 100644
--- a/lib/JLArrays/src/JLArrays.jl
+++ b/lib/JLArrays/src/JLArrays.jl
@@ -124,29 +124,33 @@ mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, T
len::Int
nnz::Ti
- function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
- len::Integer) where {Tv, Ti <: Integer}
- new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
+ function JLSparseVector{Tv, Ti}(
+ iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
+ len::Integer
+ ) where {Tv, Ti <: Integer}
+ return new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
end
end
SparseArrays.SparseVector(x::JLSparseVector) = SparseVector(length(x), Array(x.iPtr), Array(x.nzVal))
-SparseArrays.nnz(x::JLSparseVector) = x.nnz
-SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
-SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal
+SparseArrays.nnz(x::JLSparseVector) = x.nnz
+SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
+SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal
mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti}
colPtr::JLArray{Ti, 1}
rowVal::JLArray{Ti, 1}
nzVal::JLArray{Tv, 1}
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
nnz::Ti
- function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
- nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
- new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
+ function JLSparseMatrixCSC{Tv, Ti}(
+ colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
+ nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}
+ ) where {Tv, Ti <: Integer}
+ return new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
end
end
-function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
+function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}) where {Tv, Ti <: Integer}
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims)
end
SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal))
@@ -155,28 +159,30 @@ JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A
function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti}
r1 = Int(@inbounds A.colPtr[j])
- r2 = Int(@inbounds A.colPtr[j+1]-1)
+ r2 = Int(@inbounds A.colPtr[j + 1] - 1)
(r1 > r2) && return zero(Tv)
r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1
- ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
+ return ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
end
mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti}
rowPtr::JLArray{Ti, 1}
colVal::JLArray{Ti, 1}
nzVal::JLArray{Tv, 1}
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
nnz::Ti
- function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
- nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer}
- new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
+ function JLSparseMatrixCSR{Tv, Ti}(
+ rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
+ nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}
+ ) where {Tv, Ti <: Integer}
+ return new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
end
end
-function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
+function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}) where {Tv, Ti <: Integer}
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
end
-function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
+function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal))
return SparseMatrixCSC(transpose(x_transpose))
end
@@ -196,12 +202,12 @@ GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
GPUArrays._sparse_array_type(sa::JLSparseVector) = JLSparseVector
GPUArrays._sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector
-GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray
-GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray
-GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray
-GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
-GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray
-GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
+GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray
+GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray
+GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray
+GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
+GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray
+GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
# conversion of untyped data to a typed Array
function typed_data(x::JLArray{T}) where {T}
@@ -304,11 +310,11 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)
function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv}
- iPtr = JLVector{Ti}(undef, length(xs.nzind))
+ iPtr = JLVector{Ti}(undef, length(xs.nzind))
nzVal = JLVector{Tv}(undef, length(xs.nzval))
copyto!(iPtr, convert(Vector{Ti}, xs.nzind))
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
- return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),)
+ return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs))
end
Base.length(x::JLSparseVector) = x.len
Base.size(x::JLSparseVector) = (x.len,)
@@ -316,10 +322,10 @@ Base.size(x::JLSparseVector) = (x.len,)
function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
colPtr = JLVector{Ti}(undef, length(xs.colptr))
rowVal = JLVector{Ti}(undef, length(xs.rowval))
- nzVal = JLVector{Tv}(undef, length(xs.nzval))
+ nzVal = JLVector{Tv}(undef, length(xs.nzval))
copyto!(colPtr, convert(Vector{Ti}, xs.colptr))
copyto!(rowVal, convert(Vector{Ti}, xs.rowval))
- copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
+ copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n))
end
Base.length(x::JLSparseMatrixCSC) = prod(x.dims)
@@ -329,10 +335,10 @@ function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
csr_xs = SparseMatrixCSC(transpose(xs))
rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr))
colVal = JLVector{Ti}(undef, length(csr_xs.rowval))
- nzVal = JLVector{Tv}(undef, length(csr_xs.nzval))
+ nzVal = JLVector{Tv}(undef, length(csr_xs.nzval))
copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr))
copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval))
- copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
+ copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n))
end
Base.length(x::JLSparseMatrixCSR) = prod(x.dims)
@@ -479,17 +485,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
R
end
-Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
-Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
-Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv, Ti}) where {Tv, Ti} =
+ GPUSparseDeviceMatrixCSC{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv, Ti}) where {Tv, Ti} =
+ GPUSparseDeviceMatrixCSR{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv, Ti}) where {Tv, Ti} =
+ GPUSparseDeviceVector{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
## KernelAbstractions interface
KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
-KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend()
+KernelAbstractions.get_backend(a::JLA) where {JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector}} = JLBackend()
function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
diff --git a/src/device/sparse.jl b/src/device/sparse.jl
index 77abe0d..c9290cf 100644
--- a/src/device/sparse.jl
+++ b/src/device/sparse.jl
@@ -10,12 +10,12 @@ using SparseArrays
# core types
export GPUSparseDeviceVector, GPUSparseDeviceMatrixCSC, GPUSparseDeviceMatrixCSR,
- GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO
+ GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO
abstract type AbstractGPUSparseDeviceMatrix{Tv, Ti} <: AbstractSparseMatrix{Tv, Ti} end
-struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv} <: AbstractSparseVector{Tv,Ti}
+struct GPUSparseDeviceVector{Tv, Ti, Vi, Vv} <: AbstractSparseVector{Tv, Ti}
iPtr::Vi
nzVal::Vv
len::Int
@@ -28,19 +28,19 @@ SparseArrays.nnz(g::GPUSparseDeviceVector) = g.nnz
SparseArrays.nonzeroinds(g::GPUSparseDeviceVector) = g.iPtr
SparseArrays.nonzeros(g::GPUSparseDeviceVector) = g.nzVal
-struct GPUSparseDeviceMatrixCSC{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
+struct GPUSparseDeviceMatrixCSC{Tv, Ti, Vi, Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
colPtr::Vi
rowVal::Vi
nzVal::Vv
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
nnz::Ti
end
SparseArrays.rowvals(g::GPUSparseDeviceMatrixCSC) = g.rowVal
SparseArrays.getcolptr(g::GPUSparseDeviceMatrixCSC) = g.colPtr
-SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1)
+SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col + 1] - 1)
-struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixCSR{Tv, Ti, Vi, Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
rowPtr::Vi
colVal::Vi
nzVal::Vv
@@ -48,21 +48,21 @@ struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv
nnz::Ti
end
-struct GPUSparseDeviceMatrixBSR{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixBSR{Tv, Ti, Vi, Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
rowPtr::Vi
colVal::Vi
nzVal::Vv
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
blockDim::Ti
dir::Char
nnz::Ti
end
-struct GPUSparseDeviceMatrixCOO{Tv,Ti,Vi,Vv} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixCOO{Tv, Ti, Vi, Vv} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
rowInd::Vi
colInd::Vi
nzVal::Vv
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
nnz::Ti
end
@@ -79,9 +79,9 @@ struct GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M} <: AbstractSparseArray{Tv,
nnz::Ti
end
-function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, Vi<:AbstractDeviceArray{<:Integer,M}, Vv<:AbstractDeviceArray{Tv, M}, N}
+function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N, <:Integer}) where {Tv, Ti <: Integer, M, Vi <: AbstractDeviceArray{<:Integer, M}, Vv <: AbstractDeviceArray{Tv, M}, N}
@assert M == N - 1 "GPUSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
- GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
+ return GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
end
Base.length(g::GPUSparseDeviceArrayCSR) = prod(g.dims)
@@ -94,42 +94,42 @@ SparseArrays.getnzval(g::GPUSparseDeviceArrayCSR) = g.nzVal
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceVector)
println(io, "$(length(A))-element device sparse vector at:")
println(io, " iPtr: $(A.iPtr)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSR)
println(io, "$(length(A))-element device sparse matrix CSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSC)
println(io, "$(length(A))-element device sparse matrix CSC at:")
println(io, " colPtr: $(A.colPtr)")
println(io, " rowVal: $(A.rowVal)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixBSR)
println(io, "$(length(A))-element device sparse matrix BSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCOO)
println(io, "$(length(A))-element device sparse matrix COO at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colInd: $(A.colInd)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceArrayCSR)
println(io, "$(length(A))-element device sparse array CSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
# COV_EXCL_STOP
diff --git a/src/host/sparse.jl b/src/host/sparse.jl
index ff82030..6f1e77f 100644
--- a/src/host/sparse.jl
+++ b/src/host/sparse.jl
@@ -7,20 +7,20 @@ abstract type AbstractGPUSparseMatrixCSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, T
abstract type AbstractGPUSparseMatrixCOO{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end
abstract type AbstractGPUSparseMatrixBSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end
-const AbstractGPUSparseVecOrMat = Union{AbstractGPUSparseVector,AbstractGPUSparseMatrix}
+const AbstractGPUSparseVecOrMat = Union{AbstractGPUSparseVector, AbstractGPUSparseMatrix}
Base.convert(T::Type{<:AbstractGPUSparseArray}, m::AbstractArray) = m isa T ? m : T(m)
-_dense_array_type(sa::SparseVector) = SparseVector
+_dense_array_type(sa::SparseVector) = SparseVector
_dense_array_type(::Type{SparseVector}) = SparseVector
_sparse_array_type(sa::SparseVector) = SparseVector
_dense_vector_type(sa::AbstractSparseArray) = Vector
-_dense_vector_type(sa::AbstractArray) = Vector
+_dense_vector_type(sa::AbstractArray) = Vector
_dense_vector_type(::Type{<:AbstractSparseArray}) = Vector
-_dense_vector_type(::Type{<:AbstractArray}) = Vector
-_dense_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
+_dense_vector_type(::Type{<:AbstractArray}) = Vector
+_dense_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
_dense_array_type(::Type{SparseMatrixCSC}) = SparseMatrixCSC
-_sparse_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
+_sparse_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
function _sparse_array_type(sa::AbstractGPUSparseArray) end
function _dense_array_type(sa::AbstractGPUSparseArray) end
@@ -32,7 +32,7 @@ struct GPUSparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
struct GPUSparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseVector}) = GPUSparseVecStyle()
Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseMatrix}) = GPUSparseMatStyle()
-const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle}
+const SPVM = Union{GPUSparseVecStyle, GPUSparseMatStyle}
# GPUSparseVecStyle handles 0-1 dimensions, GPUSparseMatStyle 0-2 dimensions.
# GPUSparseVecStyle promotes to GPUSparseMatStyle for 2 dimensions.
@@ -40,11 +40,11 @@ const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle}
GPUSparseVecStyle(::Val{0}) = GPUSparseVecStyle()
GPUSparseVecStyle(::Val{1}) = GPUSparseVecStyle()
GPUSparseVecStyle(::Val{2}) = GPUSparseMatStyle()
-GPUSparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
+GPUSparseVecStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
GPUSparseMatStyle(::Val{0}) = GPUSparseMatStyle()
GPUSparseMatStyle(::Val{1}) = GPUSparseMatStyle()
GPUSparseMatStyle(::Val{2}) = GPUSparseMatStyle()
-GPUSparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
+GPUSparseMatStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{1}) = GPUSparseVecStyle()
Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{2}) = GPUSparseMatStyle()
@@ -71,37 +71,37 @@ end
# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} =
- capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
+ capturescalars((args...) -> f(T, args...), Base.tail(mixedargs))
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Ref{Type{S}}, Vararg{Any}}) where {T, S} =
# This definition is identical to the one above and necessary only for
# avoiding method ambiguity.
- capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
+ capturescalars((args...) -> f(T, args...), Base.tail(mixedargs))
@inline capturescalars(f, mixedargs::Tuple{AbstractGPUSparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} =
- capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
-@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{<:Any,0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
- capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
+ capturescalars((a1, args...) -> f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
+@inline capturescalars(f, mixedargs::Tuple{Union{Ref, AbstractArray{<:Any, 0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
+ capturescalars((args...) -> f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
scalararg(::Number) = true
scalararg(::Any) = false
-scalarwrappedarg(::Union{AbstractArray{<:Any,0},Ref}) = true
+scalarwrappedarg(::Union{AbstractArray{<:Any, 0}, Ref}) = true
scalarwrappedarg(::Any) = false
@inline function _capturescalars()
return (), () -> ()
end
@inline function _capturescalars(arg, mixedargs...)
- let (rest, f) = _capturescalars(mixedargs...)
+ return let (rest, f) = _capturescalars(mixedargs...)
if scalararg(arg)
- return rest, @inline function(tail...)
- (arg, f(tail...)...)
+ return rest, @inline function (tail...)
+ return (arg, f(tail...)...)
end # add back scalararg after (in makeargs)
elseif scalarwrappedarg(arg)
- return rest, @inline function(tail...)
- (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
+ return rest, @inline function (tail...)
+ return (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
end # unwrap and add back scalararg after (in makeargs)
else
- return (arg, rest...), @inline function(head, tail...)
- (head, f(tail...)...)
+ return (arg, rest...), @inline function (head, tail...)
+ return (head, f(tail...)...)
end # pass-through to broadcast
end
end
@@ -137,13 +137,13 @@ access the elements themselves.
For convenience, this iterator can be passed non-sparse arguments as well, which will be
ignored (with the returned `col`/`ptr` values set to 0).
"""
-struct CSRIterator{Ti,N,ATs}
+struct CSRIterator{Ti, N, ATs}
row::Ti
col_ends::NTuple{N, Ti}
args::ATs
end
-function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N}
+function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti, N}
# check that `row` is valid for all arguments
@boundscheck begin
ntuple(Val(N)) do i
@@ -155,16 +155,16 @@ function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N}
col_ends = ntuple(Val(N)) do i
arg = @inbounds args[i]
if arg isa GPUSparseDeviceMatrixCSR
- @inbounds(arg.rowPtr[row+1])
+ @inbounds(arg.rowPtr[row + 1])
else
zero(Ti)
end
end
- CSRIterator{Ti, N, typeof(args)}(row, col_ends, args)
+ return CSRIterator{Ti, N, typeof(args)}(row, col_ends, args)
end
-@inline function Base.iterate(iter::CSRIterator{Ti,N}, state=nothing) where {Ti,N}
+@inline function Base.iterate(iter::CSRIterator{Ti, N}, state = nothing) where {Ti, N}
# helper function to get the column of a sparse array at a specific pointer
@inline function get_col(i, ptr)
arg = @inbounds iter.args[i]
@@ -174,13 +174,14 @@ end
return @inbounds arg.colVal[ptr] % Ti
end
end
- typemax(Ti)
+ return typemax(Ti)
end
# initialize the state
# - ptr: the current index into the colVal/nzVal arrays
# - col: the current column index (cached so that we don't have to re-read each time)
- state = something(state,
+ state = something(
+ state,
ntuple(Val(N)) do i
arg = @inbounds iter.args[i]
if arg isa GPUSparseDeviceMatrixCSR
@@ -222,13 +223,13 @@ end
return (cur_col, ptrs), new_state
end
-struct CSCIterator{Ti,N,ATs}
+struct CSCIterator{Ti, N, ATs}
col::Ti
row_ends::NTuple{N, Ti}
args::ATs
end
-function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N}
+function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti, N}
# check that `col` is valid for all arguments
@boundscheck begin
ntuple(Val(N)) do i
@@ -240,17 +241,17 @@ function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N}
row_ends = ntuple(Val(N)) do i
arg = @inbounds args[i]
x = if arg isa GPUSparseDeviceMatrixCSC
- @inbounds(arg.colPtr[col+1])
+ @inbounds(arg.colPtr[col + 1])
else
zero(Ti)
end
x
end
- CSCIterator{Ti, N, typeof(args)}(col, row_ends, args)
+ return CSCIterator{Ti, N, typeof(args)}(col, row_ends, args)
end
-@inline function Base.iterate(iter::CSCIterator{Ti,N}, state=nothing) where {Ti,N}
+@inline function Base.iterate(iter::CSCIterator{Ti, N}, state = nothing) where {Ti, N}
# helper function to get the column of a sparse array at a specific pointer
@inline function get_col(i, ptr)
arg = @inbounds iter.args[i]
@@ -260,13 +261,14 @@ end
return @inbounds arg.rowVal[ptr] % Ti
end
end
- typemax(Ti)
+ return typemax(Ti)
end
# initialize the state
# - ptr: the current index into the rowVal/nzVal arrays
# - row: the current row index (cached so that we don't have to re-read each time)
- state = something(state,
+ state = something(
+ state,
ntuple(Val(N)) do i
arg = @inbounds iter.args[i]
if arg isa GPUSparseDeviceMatrixCSC
@@ -309,8 +311,8 @@ end
end
# helpers to index a sparse or dense array
-function _getindex(arg::Union{<:GPUSparseDeviceMatrixCSR,GPUSparseDeviceMatrixCSC}, I, ptr)
- if ptr == 0
+function _getindex(arg::Union{<:GPUSparseDeviceMatrixCSR, GPUSparseDeviceMatrixCSC}, I, ptr)
+ return if ptr == 0
zero(eltype(arg))
else
@inbounds arg.nzVal[ptr]
@@ -338,9 +340,11 @@ function _has_row(A::GPUSparseDeviceVector, offsets, row, ::Bool)
return 0
end
-@kernel function compute_offsets_kernel(::Type{<:AbstractGPUSparseVector}, first_row::Ti, last_row::Ti,
- fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
- args...) where {Ti, N}
+@kernel function compute_offsets_kernel(
+ ::Type{<:AbstractGPUSparseVector}, first_row::Ti, last_row::Ti,
+ fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+ args...
+ ) where {Ti, N}
my_ix = @index(Global, Linear)
row = my_ix + first_row - one(eltype(my_ix))
if row ≤ last_row
@@ -359,11 +363,13 @@ end
end
# kernel to count the number of non-zeros in a row, to determine the row offsets
-@kernel function compute_offsets_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}},
- offsets::AbstractVector{Ti}, args...) where Ti
+@kernel function compute_offsets_kernel(
+ T::Type{<:Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}},
+ offsets::AbstractVector{Ti}, args...
+ ) where {Ti}
# every thread processes an entire row
leading_dim = @index(Global, Linear)
- if leading_dim ≤ length(offsets)-1
+ if leading_dim ≤ length(offsets) - 1
iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
# count the nonzero leading_dims of all inputs
@@ -378,19 +384,21 @@ end
if leading_dim == 1
offsets[1] = 1
end
- offsets[leading_dim+1] = accum
+ offsets[leading_dim + 1] = accum
end
end
end
-@kernel function sparse_to_sparse_broadcast_kernel(f::F, output::GPUSparseDeviceVector{Tv,Ti},
- offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
- args...) where {Tv, Ti, N, F}
+@kernel function sparse_to_sparse_broadcast_kernel(
+ f::F, output::GPUSparseDeviceVector{Tv, Ti},
+ offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+ args...
+ ) where {Tv, Ti, N, F}
row_ix = @index(Global, Linear)
if row_ix ≤ output.nnz
row_and_ptrs = @inbounds offsets[row_ix]
- row = @inbounds row_and_ptrs[1]
- arg_ptrs = @inbounds row_and_ptrs[2]
+ row = @inbounds row_and_ptrs[1]
+ arg_ptrs = @inbounds row_and_ptrs[2]
vals = ntuple(Val(N)) do i
@inline
arg = @inbounds args[i]
@@ -400,14 +408,20 @@ end
_getindex(arg, row, ptr)
end
output_val = f(vals...)
- @inbounds output.iPtr[row_ix] = row
+ @inbounds output.iPtr[row_ix] = row
@inbounds output.nzVal[row_ix] = output_val
end
end
-@kernel function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{<:AbstractArray,Nothing},
- args...) where {Ti, T<:Union{GPUSparseDeviceMatrixCSR{<:Any,Ti},
- GPUSparseDeviceMatrixCSC{<:Any,Ti}}}
+@kernel function sparse_to_sparse_broadcast_kernel(
+ f, output::T, offsets::Union{<:AbstractArray, Nothing},
+ args...
+ ) where {
+ Ti, T <: Union{
+ GPUSparseDeviceMatrixCSR{<:Any, Ti},
+ GPUSparseDeviceMatrixCSC{<:Any, Ti},
+ },
+ }
# every thread processes an entire row
leading_dim = @index(Global, Linear)
leading_dim_size = output isa GPUSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2)
@@ -415,19 +429,19 @@ end
iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
- output_ptrs = output isa GPUSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr
+ output_ptrs = output isa GPUSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr
output_ivals = output isa GPUSparseDeviceMatrixCSR ? output.colVal : output.rowVal
# fetch the row offset, and write it to the output
@inbounds begin
output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
if leading_dim == leading_dim_size
- output_ptrs[leading_dim+one(eltype(leading_dim))] = offsets[leading_dim+one(eltype(leading_dim))]
+ output_ptrs[leading_dim + one(eltype(leading_dim))] = offsets[leading_dim + one(eltype(leading_dim))]
end
end
# set the values for this row
for (sub_leading_dim, ptrs) in iter
- index_first = output isa GPUSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
+ index_first = output isa GPUSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
index_second = output isa GPUSparseDeviceMatrixCSR ? sub_leading_dim : leading_dim
I = CartesianIndex(index_first, index_second)
vals = ntuple(Val(length(args))) do i
@@ -442,9 +456,15 @@ end
end
end
end
-@kernel function sparse_to_dense_broadcast_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR{Tv, Ti},
- AbstractGPUSparseMatrixCSC{Tv, Ti}}},
- f, output::AbstractDeviceArray, args...) where {Tv, Ti}
+@kernel function sparse_to_dense_broadcast_kernel(
+ T::Type{
+ <:Union{
+ AbstractGPUSparseMatrixCSR{Tv, Ti},
+ AbstractGPUSparseMatrixCSC{Tv, Ti},
+ },
+ },
+ f, output::AbstractDeviceArray, args...
+ ) where {Tv, Ti}
# every thread processes an entire row
leading_dim = @index(Global, Linear)
leading_dim_size = T <: AbstractGPUSparseMatrixCSR ? size(output, 1) : size(output, 2)
@@ -453,7 +473,7 @@ end
# set the values for this row
for (sub_leading_dim, ptrs) in iter
- index_first = T <: AbstractGPUSparseMatrixCSR ? leading_dim : sub_leading_dim
+ index_first = T <: AbstractGPUSparseMatrixCSR ? leading_dim : sub_leading_dim
index_second = T <: AbstractGPUSparseMatrixCSR ? sub_leading_dim : leading_dim
I = CartesianIndex(index_first, index_second)
vals = ntuple(Val(length(args))) do i
@@ -467,16 +487,18 @@ end
end
end
-@kernel function sparse_to_dense_broadcast_kernel(::Type{<:AbstractGPUSparseVector}, f::F,
- output::AbstractDeviceArray{Tv},
- offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
- args...) where {Tv, F, N, Ti}
+@kernel function sparse_to_dense_broadcast_kernel(
+ ::Type{<:AbstractGPUSparseVector}, f::F,
+ output::AbstractDeviceArray{Tv},
+ offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+ args...
+ ) where {Tv, F, N, Ti}
# every thread processes an entire row
row_ix = @index(Global, Linear)
if row_ix ≤ length(output)
row_and_ptrs = @inbounds offsets[row_ix]
- row = @inbounds row_and_ptrs[1]
- arg_ptrs = @inbounds row_and_ptrs[2]
+ row = @inbounds row_and_ptrs[1]
+ arg_ptrs = @inbounds row_and_ptrs[2]
vals = ntuple(Val(length(args))) do i
@inline
arg = @inbounds args[i]
@@ -491,39 +513,41 @@ end
end
## COV_EXCL_STOP
-function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatStyle}})
+function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle, GPUSparseMatStyle}})
# find the sparse inputs
bc = Broadcast.flatten(bc)
sparse_args = findall(bc.args) do arg
arg isa AbstractGPUSparseArray
end
- sparse_types = unique(map(i->nameof(typeof(bc.args[i])), sparse_args))
+ sparse_types = unique(map(i -> nameof(typeof(bc.args[i])), sparse_args))
if length(sparse_types) > 1
error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported")
end
sparse_typ = typeof(bc.args[first(sparse_args)])
- sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC,AbstractGPUSparseVector} ||
+ sparse_typ <: Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC, AbstractGPUSparseVector} ||
error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices")
Ti = if sparse_typ <: AbstractGPUSparseMatrixCSR
- reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args))
+ reduce(promote_type, map(i -> eltype(bc.args[i].rowPtr), sparse_args))
elseif sparse_typ <: AbstractGPUSparseMatrixCSC
- reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args))
+ reduce(promote_type, map(i -> eltype(bc.args[i].colPtr), sparse_args))
elseif sparse_typ <: AbstractGPUSparseVector
- reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args))
+ reduce(promote_type, map(i -> eltype(bc.args[i].iPtr), sparse_args))
end
# determine the output type
Tv = Broadcast.combine_eltypes(bc.f, eltype.(bc.args))
if !Base.isconcretetype(Tv)
- error("""GPU sparse broadcast resulted in non-concrete element type $Tv.
- This probably means that the function you are broadcasting contains an error or type instability.""")
+ error(
+ """GPU sparse broadcast resulted in non-concrete element type $Tv.
+ This probably means that the function you are broadcasting contains an error or type instability."""
+ )
end
# partially-evaluate the function, removing scalars.
parevalf, passedsrcargstup = capturescalars(bc.f, bc.args)
# check if the partially-evaluated function preserves zeros. if so, we'll only need to
# apply it to the sparse input arguments, preserving the sparse structure.
- if all(arg->isa(arg, AbstractSparseArray), passedsrcargstup)
+ if all(arg -> isa(arg, AbstractSparseArray), passedsrcargstup)
fofzeros = parevalf(_zeros_eltypes(passedsrcargstup...)...)
fpreszeros = _iszero(fofzeros)
else
@@ -532,7 +556,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
# the kernels below parallelize across rows or cols, not elements, so it's unlikely
# we'll launch many threads. to maximize utilization, parallelize across blocks first.
- rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1)
+ rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1)
# `size(bc, ::Int)` is missing
# for AbstractGPUSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
# but the only rows present in any sparse vector input are between 2 and 128, no need to
@@ -547,7 +571,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
# either we have dense inputs, or the function isn't preserving zeros,
# so use a dense output to broadcast into.
val_array = sparse_arg.nzVal
- output = similar(val_array, Tv, size(bc))
+ output = similar(val_array, Tv, size(bc))
# since we'll be iterating the sparse inputs, we need to pre-fill the dense output
# with appropriate values (while setting the sparse inputs to zero). we do this by
# re-using the dense broadcast implementation.
@@ -565,24 +589,24 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
# this avoids a kernel launch and costly synchronization.
if sparse_typ <: AbstractGPUSparseMatrixCSR
offsets = rowPtr = sparse_arg.rowPtr
- colVal = similar(sparse_arg.colVal)
- nzVal = similar(sparse_arg.nzVal, Tv)
- output = _sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
+ colVal = similar(sparse_arg.colVal)
+ nzVal = similar(sparse_arg.nzVal, Tv)
+ output = _sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
elseif sparse_typ <: AbstractGPUSparseMatrixCSC
offsets = colPtr = sparse_arg.colPtr
- rowVal = similar(sparse_arg.rowVal)
- nzVal = similar(sparse_arg.nzVal, Tv)
- output = _sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
+ rowVal = similar(sparse_arg.rowVal)
+ nzVal = similar(sparse_arg.nzVal, Tv)
+ output = _sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
end
else
# determine the number of non-zero elements per row so that we can create an
# appropriately-structured output container
offsets = if sparse_typ <: AbstractGPUSparseMatrixCSR
ptr_array = sparse_arg.rowPtr
- similar(ptr_array, Ti, rows+1)
+ similar(ptr_array, Ti, rows + 1)
elseif sparse_typ <: AbstractGPUSparseMatrixCSC
ptr_array = sparse_arg.colPtr
- similar(ptr_array, Ti, cols+1)
+ similar(ptr_array, Ti, cols + 1)
elseif sparse_typ <: AbstractGPUSparseVector
ptr_array = sparse_arg.iPtr
@allowscalar begin
@@ -596,7 +620,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
end
end
overall_first_row = min(arg_first_rows...)
- overall_last_row = max(arg_last_rows...)
+ overall_last_row = max(arg_last_rows...)
similar(ptr_array, Pair{Ti, NTuple{length(bc.args), Ti}}, overall_last_row - overall_first_row + 1)
end
let
@@ -606,7 +630,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
(sparse_typ, offsets, bc.args...)
end
kernel = compute_offsets_kernel(get_backend(bc.args[first(sparse_args)]))
- kernel(args...; ndrange=length(offsets))
+ kernel(args...; ndrange = length(offsets))
end
# accumulate these values so that we can use them directly as row pointer offsets,
# as well as to get the total nnz count to allocate the sparse output array.
@@ -615,10 +639,10 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
@allowscalar accumulate!(Base.add_sum, offsets, offsets)
total_nnz = @allowscalar last(offsets[end]) - 1
else
- @allowscalar sort!(offsets; by=first)
- total_nnz = mapreduce(x->first(x) != typemax(first(x)), +, offsets)
+ @allowscalar sort!(offsets; by = first)
+ total_nnz = mapreduce(x -> first(x) != typemax(first(x)), +, offsets)
end
- output = if sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC}
+ output = if sparse_typ <: Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}
ixVal = similar(offsets, Ti, total_nnz)
nzVal = similar(offsets, Tv, total_nnz)
sparse_typ(offsets, ixVal, nzVal, size(bc))
@@ -626,8 +650,8 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
val_array = bc.args[first(sparse_args)].nzVal
similar(val_array, Tv, size(bc))
elseif sparse_typ <: AbstractGPUSparseVector && fpreszeros
- iPtr = similar(offsets, Ti, total_nnz)
- nzVal = similar(offsets, Tv, total_nnz)
+ iPtr = similar(offsets, Ti, total_nnz)
+ nzVal = similar(offsets, Tv, total_nnz)
_sparse_array_type(sparse_arg){Tv, Ti}(iPtr, nzVal, rows)
end
if sparse_typ <: AbstractGPUSparseVector && !fpreszeros
@@ -644,12 +668,12 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
end
# perform the actual broadcast
if output isa AbstractGPUSparseArray
- args = (bc.f, output, offsets, bc.args...)
+ args = (bc.f, output, offsets, bc.args...)
kernel = sparse_to_sparse_broadcast_kernel(get_backend(bc.args[first(sparse_args)]))
ndrange = output.nnz
else
- args = sparse_typ <: AbstractGPUSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) :
- (sparse_typ, bc.f, output, bc.args...)
+ args = sparse_typ <: AbstractGPUSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) :
+ (sparse_typ, bc.f, output, bc.args...)
kernel = sparse_to_dense_broadcast_kernel(get_backend(bc.args[first(sparse_args)]))
ndrange = sparse_typ <: AbstractGPUSparseMatrixCSC ? size(output, 2) : size(output, 1)
end
diff --git a/test/runtests.jl b/test/runtests.jl
index f59f07a..0611af0 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -48,12 +48,12 @@ include("setup.jl") # make sure everything is precompiled
# choose tests
const tests = []
const test_runners = Dict()
-for AT in (JLArray, Array), name in filter(n->n != "sparse", keys(TestSuite.tests))
+for AT in (JLArray, Array), name in filter(n -> n != "sparse", keys(TestSuite.tests))
push!(tests, "$(AT)/$name")
- test_runners["$(AT)/$name"] = ()->TestSuite.tests[name](AT)
+ test_runners["$(AT)/$name"] = () -> TestSuite.tests[name](AT)
end
-for AT in ( JLSparseMatrixCSR, JLSparseMatrixCSC, JLSparseVector, SparseMatrixCSC, SparseVector), name in ["sparse"]
+for AT in (JLSparseMatrixCSR, JLSparseMatrixCSC, JLSparseVector, SparseMatrixCSC, SparseVector), name in ["sparse"]
push!(tests, "$(AT)/$name")
test_runners["$(AT)/$name"] = ()->TestSuite.tests[name](AT)
end
diff --git a/test/testsuite/sparse.jl b/test/testsuite/sparse.jl
index d9783e1..f7b417b 100644
--- a/test/testsuite/sparse.jl
+++ b/test/testsuite/sparse.jl
@@ -1,4 +1,4 @@
-@testsuite "sparse" (AT, eltypes)->begin
+@testsuite "sparse" (AT, eltypes) -> begin
if AT <: AbstractSparseVector
broadcasting_vector(AT, eltypes)
elseif AT <: AbstractSparseMatrix
@@ -13,68 +13,68 @@ function broadcasting_vector(AT, eltypes)
dense_VT = GPUArrays._dense_vector_type(AT)
for ET in eltypes
@testset "SparseVector($ET)" begin
- m = 64
- p = 0.5
- x = sprand(ET, m, p)
+ m = 64
+ p = 0.5
+ x = sprand(ET, m, p)
dx = AT(x)
# zero-preserving
- y = x .* ET(1)
+ y = x .* ET(1)
dy = dx .* ET(1)
@test dy isa AT{ET}
- @test collect(SparseArrays.nonzeroinds(dy)) == collect(SparseArrays.nonzeroinds(dx))
- @test collect(SparseArrays.nonzeroinds(dy)) == SparseArrays.nonzeroinds(y)
+ @test collect(SparseArrays.nonzeroinds(dy)) == collect(SparseArrays.nonzeroinds(dx))
+ @test collect(SparseArrays.nonzeroinds(dy)) == SparseArrays.nonzeroinds(y)
@test collect(SparseArrays.nonzeros(dy)) == SparseArrays.nonzeros(y)
@test y == SparseVector(dy)
# not zero-preserving
- y = x .+ ET(1)
+ y = x .+ ET(1)
dy = dx .+ ET(1)
@test dy isa dense_AT{ET}
hy = Array(dy)
@test Array(y) == hy
# involving something dense
- y = x .+ ones(ET, m)
+ y = x .+ ones(ET, m)
dy = dx .+ dense_AT(ones(ET, m))
@test dy isa dense_AT{ET}
@test Array(y) == Array(dy)
# sparse to sparse
dx = AT(x)
- y = sprand(ET, m, p)
+ y = sprand(ET, m, p)
dy = AT(y)
- z = x .* y
+ z = x .* y
dz = dx .* dy
@test dz isa AT{ET}
@test z == SparseVector(dz)
# multiple inputs
- y = sprand(ET, m, p)
- w = sprand(ET, m, p)
+ y = sprand(ET, m, p)
+ w = sprand(ET, m, p)
dy = AT(y)
dx = AT(x)
dw = AT(w)
- z = @. x * y * w
+ z = @. x * y * w
dz = @. dx * dy * dw
@test dz isa AT{ET}
@test z == SparseVector(dz)
y = sprand(ET, m, p)
w = sprand(ET, m, p)
- dense_arr = rand(ET, m)
+ dense_arr = rand(ET, m)
d_dense_arr = dense_AT(dense_arr)
dy = AT(y)
dw = AT(w)
- z = @. x * y * w * dense_arr
+ z = @. x * y * w * dense_arr
dz = @. dx * dy * dw * d_dense_arr
@test dz isa dense_AT{ET}
@test Array(z) == Array(dz)
-
- y = sprand(ET, m, p)
+
+ y = sprand(ET, m, p)
dy = AT(y)
dx = AT(x)
- z = x .* y .* ET(2)
+ z = x .* y .* ET(2)
dz = dx .* dy .* ET(2)
@test dz isa AT{ET}
@test z == SparseVector(dz)
@@ -83,38 +83,39 @@ function broadcasting_vector(AT, eltypes)
## non-zero-preserving
dx = AT(x)
dy = dx .+ 1
- y = x .+ 1
+ y = x .+ 1
@test dy isa dense_AT{promote_type(ET, Int)}
@test Array(y) == Array(dy)
## zero-preserving
dy = dx .* 1
- y = x .* 1
+ y = x .* 1
@test dy isa AT{promote_type(ET, Int)}
- @test collect(SparseArrays.nonzeroinds(dy)) == collect(SparseArrays.nonzeroinds(dx))
- @test collect(SparseArrays.nonzeroinds(dy)) == SparseArrays.nonzeroinds(y)
+ @test collect(SparseArrays.nonzeroinds(dy)) == collect(SparseArrays.nonzeroinds(dx))
+ @test collect(SparseArrays.nonzeroinds(dy)) == SparseArrays.nonzeroinds(y)
@test collect(SparseArrays.nonzeros(dy)) == SparseArrays.nonzeros(y)
@test y == SparseVector(dy)
end
end
+ return
end
function broadcasting_matrix(AT, eltypes)
dense_AT = GPUArrays._dense_array_type(AT)
dense_VT = GPUArrays._dense_vector_type(AT)
for ET in eltypes
- @testset "SparseMatrix($ET)" begin
+ @testset "SparseMatrix($ET)" begin
m, n = 5, 6
- p = 0.5
- x = sprand(ET, m, n, p)
- dx = AT(x)
+ p = 0.5
+ x = sprand(ET, m, n, p)
+ dx = AT(x)
# zero-preserving
- y = x .* ET(1)
+ y = x .* ET(1)
dy = dx .* ET(1)
@test dy isa AT{ET}
@test y == SparseMatrixCSC(dy)
# not zero-preserving
- y = x .+ ET(1)
+ y = x .+ ET(1)
dy = dx .+ ET(1)
@test dy isa dense_AT{ET}
hy = Array(dy)
@@ -122,26 +123,27 @@ function broadcasting_matrix(AT, eltypes)
@test Array(y) == Array(dy)
# involving something dense
- y = x .* ones(ET, m, n)
+ y = x .* ones(ET, m, n)
dy = dx .* dense_AT(ones(ET, m, n))
@test dy isa dense_AT{ET}
@test Array(y) == Array(dy)
-
+
# multiple inputs
- y = sprand(ET, m, n, p)
+ y = sprand(ET, m, n, p)
dy = AT(y)
- z = x .* y .* ET(2)
+ z = x .* y .* ET(2)
dz = dx .* dy .* ET(2)
@test dz isa AT{ET}
@test z == SparseMatrixCSC(dz)
# multiple inputs
- w = sprand(ET, m, n, p)
+ w = sprand(ET, m, n, p)
dw = AT(w)
- z = x .* y .* w
+ z = x .* y .* w
dz = dx .* dy .* dw
@test dz isa AT{ET}
@test z == SparseMatrixCSC(dz)
end
end
+ return
end |
struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv} <: AbstractSparseVector{Tv,Ti} | ||
iPtr::Vi | ||
nzVal::Vv | ||
len::Int | ||
nnz::Ti | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit inconsistent that we keep the host sparse object layout to the back-end, but define the device one concretely. I'm not sure if it's better to entirely move the definitions away from (or rather into) GPUArrays.jl though. I guess back-ends may want additional control over the object layout in order to facilitate vendor library interactions, but maybe we should then also leave the device-side version up to the back-end and only implement things here in terms of SparseArrays interfaces (rowvals
, getcolptr
, etc). Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was torn on this too. The advantage here is that libraries get a working device-side implementation "for free" -- they are able to implement their own (better) one and just give Adapt.jl
information about how to move their host-side structs to it.
This ports the (duplicated) sparse broadcasting support from CUDA.jl and AMDGPU.jl to GPUArrays.jl. It should allow all the "child" GPU libraries to use one unified set of on-device sparse types and broadcasting kernels. I implemented the appropriate types in
JLArrays
and tests there are passing. If this merges, we should be able to strip out most of the sparse broadcasting code from downstream packages.