Skip to content

Commit bb88c9d

Browse files
committed
Update array.jl
1 parent 5a0d26b commit bb88c9d

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

src/array.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ using LinearAlgebra
77
using GPUArrays
88
using Adapt
99

10+
export WAtomic
11+
12+
struct WAtomic{T}
13+
el::T
14+
end
15+
1016
export WgpuArray
1117

1218
# TODO MTL tracks cmdEncoder with task local storage. Thats neat.
@@ -53,6 +59,34 @@ function Base.unsafe_copyto!(gpuDevice, dst::WgpuArrayPtr{T}, src::WgpuArrayPtr{
5359
WGPUCore.submit(gpuDevice.queue, [WGPUCore.finish(cmdEncoder),])
5460
end
5561

62+
# GPU -> GPU (atomic)
63+
function Base.unsafe_copyto!(gpuDevice, dst::WgpuArrayPtr{WAtomic{T}}, src::WgpuArrayPtr{T}, N::Integer) where T
64+
cmdEncoder = WGPUCore.createCommandEncoder(gpuDevice, "COMMAND ENCODER")
65+
WGPUCore.copyBufferToBuffer(
66+
cmdEncoder,
67+
src.buffer,
68+
src.offset |> Int,
69+
dst.buffer,
70+
dst.offset |> Int,
71+
N*sizeof(T)
72+
)
73+
WGPUCore.submit(gpuDevice.queue, [WGPUCore.finish(cmdEncoder),])
74+
end
75+
76+
# GPU (atomic) -> GPU
77+
function Base.unsafe_copyto!(gpuDevice, dst::WgpuArrayPtr{T}, src::WgpuArrayPtr{WAtomic{T}}, N::Integer) where T
78+
cmdEncoder = WGPUCore.createCommandEncoder(gpuDevice, "COMMAND ENCODER")
79+
WGPUCore.copyBufferToBuffer(
80+
cmdEncoder,
81+
src.buffer,
82+
src.offset |> Int,
83+
dst.buffer,
84+
dst.offset |> Int,
85+
N*sizeof(T)
86+
)
87+
WGPUCore.submit(gpuDevice.queue, [WGPUCore.finish(cmdEncoder),])
88+
end
89+
5690
# GPU -> CPU
5791
function Base.unsafe_copyto!(gpuDevice, dst::Ptr{T}, src::WgpuArrayPtr{T}, N::Integer) where T
5892
cmdEncoder = WGPUCore.createCommandEncoder(gpuDevice, "COMMAND ENCODER")
@@ -186,6 +220,10 @@ WgpuArray{T}(::UndefInitializer, dims::NTuple{N, Integer}) where {T, N} =
186220
WgpuArray{T}(::UndefInitializer, dims::Vararg{Integer, N}) where {T, N} =
187221
WgpuArray{T, N}(undef, convert(Tuple{Vararg{Int}}, dims))
188222

223+
# atomic array support
224+
# WgpuArray{T}(::UndefInitializer, dims::NTuple{N, Integer}) where {T, N} =
225+
# WgpuArray{T, N}(undef, convert(Tuple{Vararg{Int}}, dims))
226+
189227
# empty vector constructors
190228
WgpuArray{T, 1}() where {T} = WgpuArray{T, 1}(undef, 0)
191229

@@ -322,6 +360,38 @@ function Base.copyto!(dest::WgpuArray{T}, doffs::Integer, src::WgpuArray{T}, sof
322360
return dest
323361
end
324362

363+
function Base.copyto!(dest::WgpuArray{WAtomic{T}}, doffs::Integer, src::WgpuArray{T}, soffs::Integer,
364+
n::Integer) where T
365+
(n==0 || sizeof(T) == 0) && return dest
366+
@boundscheck checkbounds(dest, doffs)
367+
@boundscheck checkbounds(dest, doffs+n-1)
368+
@boundscheck checkbounds(src, soffs)
369+
@boundscheck checkbounds(src, soffs+n-1)
370+
# TODO: which device to use here?
371+
if device(dest) == device(src)
372+
unsafe_copyto!(device(dest), dest, doffs, src, soffs, n)
373+
else
374+
error("Copy between different devices not implemented")
375+
end
376+
return dest
377+
end
378+
379+
function Base.copyto!(dest::WgpuArray{T}, doffs::Integer, src::WgpuArray{WAtomic{T}}, soffs::Integer,
380+
n::Integer) where T
381+
(n==0 || sizeof(T) == 0) && return dest
382+
@boundscheck checkbounds(dest, doffs)
383+
@boundscheck checkbounds(dest, doffs+n-1)
384+
@boundscheck checkbounds(src, soffs)
385+
@boundscheck checkbounds(src, soffs+n-1)
386+
# TODO: which device to use here?
387+
if device(dest) == device(src)
388+
unsafe_copyto!(device(dest), dest, doffs, src, soffs, n)
389+
else
390+
error("Copy between different devices not implemented")
391+
end
392+
return dest
393+
end
394+
325395
Base.copyto!(dest::WgpuArray{T}, src::WgpuArray{T}) where {T} =
326396
copyto!(dest, 1, src, 1, length(src))
327397

@@ -337,6 +407,32 @@ function Base.unsafe_copyto!(dev, dest::WgpuArray{T}, doffs, src::Array{T}, soff
337407
return dest
338408
end
339409

410+
# copy WgpuArray{T} -> WgpuArray{Atomic{T}}
411+
function Base.unsafe_copyto!(dev, dest::WgpuArray{WAtomic{T}}, doffs, src::WgpuArray{T}, soffs, n) where T
412+
# these copies are implemented using pure memcpys, not API calls, so arent ordered.
413+
# synchronize()
414+
415+
GC.@preserve src dest unsafe_copyto!(dev, pointer(dest, doffs), pointer(src, soffs), n)
416+
if Base.isbitsunion(T)
417+
# copy selector bytes
418+
error("Not implemented")
419+
end
420+
return dest
421+
end
422+
423+
# copy WgpuArray{WAtomic{T}} -> WgpuArray{T}
424+
function Base.unsafe_copyto!(dev, dest::WgpuArray{T}, doffs, src::WgpuArray{WAtomic{T}}, soffs, n) where T
425+
# these copies are implemented using pure memcpys, not API calls, so arent ordered.
426+
# synchronize()
427+
428+
GC.@preserve src dest unsafe_copyto!(dev, pointer(dest, doffs), pointer(src, soffs), n)
429+
if Base.isbitsunion(T)
430+
# copy selector bytes
431+
error("Not implemented")
432+
end
433+
return dest
434+
end
435+
340436
function Base.unsafe_copyto!(dev, dest::Array{T}, doffs, src::WgpuArray{T}, soffs, n) where T
341437
# these copies are implemented using pure memcpys, not API calls, so arent ordered.
342438
# synchronize()

0 commit comments

Comments
 (0)