Skip to content

Commit 5a0d26b

Browse files
committed
prefix scan example
1 parent cb60a48 commit 5a0d26b

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

examples/scan_kernel.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using Revise
2+
using WGPUCompute
3+
using Test
4+
5+
empty!(task_local_storage())
6+
7+
function naive_prefix_scan_kernel(x::WgpuArray{T, N}, out::WgpuArray{T, N}, partials::WgpuArray{T, N}) where {T, N}
8+
gId = xDims.x * globalId.y + globalId.x
9+
W = Float32(xDims.x * xDims.y)
10+
steps = UInt32(ceil(log2(W)))
11+
out[gId] = x[gId]
12+
base = 2.0f0
13+
for itr in 0:steps
14+
v = 0.0f0
15+
exponent = Float32(itr)
16+
baseexp = pow(base, exponent)
17+
stride = UInt32(baseexp)
18+
if localId.x >= stride
19+
v = out[gId - stride]
20+
end
21+
synchronize()
22+
if localId.x >= stride
23+
out[gId] += v
24+
end
25+
synchronize()
26+
end
27+
28+
if localId.x == workgroupDims.x - 1
29+
partials[workgroupId.x] = out[gId]
30+
end
31+
end
32+
33+
function naive_prefix_partials_scatter_kernel(y::WgpuArray{T, N}, p::WgpuArray{T, N}) where {T, N}
34+
gId = yDims.x * globalId.y + globalId.x
35+
y[gId] += p[workgroupId.x - 1]
36+
end
37+
38+
function prefix_scan_heuristics(x::WgpuArray{T, N}) where {T, N}
39+
div(reduce(*, size(x)), 256)
40+
end
41+
42+
function naive_prefix_scan(x::WgpuArray{T, N}) where {T, N}
43+
y = similar(x)
44+
wgsize = div(reduce(*, size(x)), 256)
45+
p = WgpuArray{T, N}(zeros(wgsize))
46+
@wgpukernel(
47+
launch=true,
48+
workgroupSizes = (256,),
49+
workgroupCount = (wgsize,),
50+
shmem = (),
51+
naive_prefix_scan_kernel(x, y, p)
52+
)
53+
pscan = cumsum(p |> collect)
54+
partials = WgpuArray{T, N}(pscan)
55+
@wgpukernel(
56+
launch=true,
57+
workgroupSizes = (256,),
58+
workgroupCount = (wgsize,),
59+
shmem = (),
60+
naive_prefix_partials_scatter_kernel(y, partials)
61+
)
62+
return y
63+
end
64+
65+
x = WgpuArray{Float32}(rand(Float32, 2^16))
66+
z = naive_prefix_scan(x,)
67+
68+
x_cpu = (x |> collect)
69+
cumcpu = cumsum(x_cpu, dims=1)
70+
cumgpu = (z |> collect)
71+
72+
@test all(x-> x < 10-6, cumcpu .- cumgpu)
73+

0 commit comments

Comments
 (0)