@@ -7,6 +7,12 @@ using LinearAlgebra
77using GPUArrays
88using Adapt
99
10+ export WAtomic
11+
12+ struct WAtomic{T}
13+ el:: T
14+ end
15+
1016export WgpuArray
1117
1218# TODO MTL tracks cmdEncoder with task local storage. Thats neat.
@@ -53,6 +59,34 @@ function Base.unsafe_copyto!(gpuDevice, dst::WgpuArrayPtr{T}, src::WgpuArrayPtr{
5359 WGPUCore. submit (gpuDevice. queue, [WGPUCore. finish (cmdEncoder),])
5460end
5561
62+ # GPU -> GPU (atomic)
63+ function Base. unsafe_copyto! (gpuDevice, dst:: WgpuArrayPtr{WAtomic{T}} , src:: WgpuArrayPtr{T} , N:: Integer ) where T
64+ cmdEncoder = WGPUCore. createCommandEncoder (gpuDevice, " COMMAND ENCODER" )
65+ WGPUCore. copyBufferToBuffer (
66+ cmdEncoder,
67+ src. buffer,
68+ src. offset |> Int,
69+ dst. buffer,
70+ dst. offset |> Int,
71+ N* sizeof (T)
72+ )
73+ WGPUCore. submit (gpuDevice. queue, [WGPUCore. finish (cmdEncoder),])
74+ end
75+
76+ # GPU (atomic) -> GPU
77+ function Base. unsafe_copyto! (gpuDevice, dst:: WgpuArrayPtr{T} , src:: WgpuArrayPtr{WAtomic{T}} , N:: Integer ) where T
78+ cmdEncoder = WGPUCore. createCommandEncoder (gpuDevice, " COMMAND ENCODER" )
79+ WGPUCore. copyBufferToBuffer (
80+ cmdEncoder,
81+ src. buffer,
82+ src. offset |> Int,
83+ dst. buffer,
84+ dst. offset |> Int,
85+ N* sizeof (T)
86+ )
87+ WGPUCore. submit (gpuDevice. queue, [WGPUCore. finish (cmdEncoder),])
88+ end
89+
5690# GPU -> CPU
5791function Base. unsafe_copyto! (gpuDevice, dst:: Ptr{T} , src:: WgpuArrayPtr{T} , N:: Integer ) where T
5892 cmdEncoder = WGPUCore. createCommandEncoder (gpuDevice, " COMMAND ENCODER" )
@@ -186,6 +220,10 @@ WgpuArray{T}(::UndefInitializer, dims::NTuple{N, Integer}) where {T, N} =
186220WgpuArray {T} (:: UndefInitializer , dims:: Vararg{Integer, N} ) where {T, N} =
187221 WgpuArray {T, N} (undef, convert (Tuple{Vararg{Int}}, dims))
188222
223+ # atomic array support
224+ # WgpuArray{T}(::UndefInitializer, dims::NTuple{N, Integer}) where {T, N} =
225+ # WgpuArray{T, N}(undef, convert(Tuple{Vararg{Int}}, dims))
226+
189227# empty vector constructors
190228WgpuArray {T, 1} () where {T} = WgpuArray {T, 1} (undef, 0 )
191229
@@ -322,6 +360,38 @@ function Base.copyto!(dest::WgpuArray{T}, doffs::Integer, src::WgpuArray{T}, sof
322360 return dest
323361end
324362
363+ function Base. copyto! (dest:: WgpuArray{WAtomic{T}} , doffs:: Integer , src:: WgpuArray{T} , soffs:: Integer ,
364+ n:: Integer ) where T
365+ (n== 0 || sizeof (T) == 0 ) && return dest
366+ @boundscheck checkbounds (dest, doffs)
367+ @boundscheck checkbounds (dest, doffs+ n- 1 )
368+ @boundscheck checkbounds (src, soffs)
369+ @boundscheck checkbounds (src, soffs+ n- 1 )
370+ # TODO : which device to use here?
371+ if device (dest) == device (src)
372+ unsafe_copyto! (device (dest), dest, doffs, src, soffs, n)
373+ else
374+ error (" Copy between different devices not implemented" )
375+ end
376+ return dest
377+ end
378+
379+ function Base. copyto! (dest:: WgpuArray{T} , doffs:: Integer , src:: WgpuArray{WAtomic{T}} , soffs:: Integer ,
380+ n:: Integer ) where T
381+ (n== 0 || sizeof (T) == 0 ) && return dest
382+ @boundscheck checkbounds (dest, doffs)
383+ @boundscheck checkbounds (dest, doffs+ n- 1 )
384+ @boundscheck checkbounds (src, soffs)
385+ @boundscheck checkbounds (src, soffs+ n- 1 )
386+ # TODO : which device to use here?
387+ if device (dest) == device (src)
388+ unsafe_copyto! (device (dest), dest, doffs, src, soffs, n)
389+ else
390+ error (" Copy between different devices not implemented" )
391+ end
392+ return dest
393+ end
394+
325395Base. copyto! (dest:: WgpuArray{T} , src:: WgpuArray{T} ) where {T} =
326396 copyto! (dest, 1 , src, 1 , length (src))
327397
@@ -337,6 +407,32 @@ function Base.unsafe_copyto!(dev, dest::WgpuArray{T}, doffs, src::Array{T}, soff
337407 return dest
338408end
339409
410+ # copy WgpuArray{T} -> WgpuArray{Atomic{T}}
411+ function Base. unsafe_copyto! (dev, dest:: WgpuArray{WAtomic{T}} , doffs, src:: WgpuArray{T} , soffs, n) where T
412+ # these copies are implemented using pure memcpys, not API calls, so arent ordered.
413+ # synchronize()
414+
415+ GC. @preserve src dest unsafe_copyto! (dev, pointer (dest, doffs), pointer (src, soffs), n)
416+ if Base. isbitsunion (T)
417+ # copy selector bytes
418+ error (" Not implemented" )
419+ end
420+ return dest
421+ end
422+
423+ # copy WgpuArray{WAtomic{T}} -> WgpuArray{T}
424+ function Base. unsafe_copyto! (dev, dest:: WgpuArray{T} , doffs, src:: WgpuArray{WAtomic{T}} , soffs, n) where T
425+ # these copies are implemented using pure memcpys, not API calls, so arent ordered.
426+ # synchronize()
427+
428+ GC. @preserve src dest unsafe_copyto! (dev, pointer (dest, doffs), pointer (src, soffs), n)
429+ if Base. isbitsunion (T)
430+ # copy selector bytes
431+ error (" Not implemented" )
432+ end
433+ return dest
434+ end
435+
340436function Base. unsafe_copyto! (dev, dest:: Array{T} , doffs, src:: WgpuArray{T} , soffs, n) where T
341437 # these copies are implemented using pure memcpys, not API calls, so arent ordered.
342438 # synchronize()
0 commit comments