Skip to content

Commit e130cff

Browse files
authored
Merge pull request #43 from arhik/main
Update reduce_kernel.jl
2 parents 3c40f58 + 8ab7778 commit e130cff

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

examples/reduce_kernel.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,33 @@ using Revise
22
using WGPUCompute
33
using Test
44

5-
function naive_reduce_kernel(x::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
6-
gId = xDims.x*globalId.y + globalId.x
7-
W = Float32(xDims.x*xDims.y)
8-
steps = UInt32(ceil(log2(W)))
9-
out[gId] = x[gId]
10-
base=2.0
11-
for itr in 0:steps
12-
exponent = Float32(itr)
13-
stride = UInt32(pow(base, exponent))
14-
if gId%(2*stride) == 0
5+
empty!(task_local_storage())
6+
7+
function naive_reduce_kernel(x::WgpuArray{T,N}, out::WgpuArray{T,N}) where {T,N}
8+
gId = xDims.x * globalId.y + globalId.x
9+
W = Float32(xDims.x * xDims.y)
10+
steps = UInt32(ceil(log2(W)))
11+
out[gId] = x[gId]
12+
base = 2.0f0
13+
for itr in 0:steps
14+
if gId%2 == 0
15+
exponent = Float32(itr)
16+
stride = UInt32(pow(base, exponent))
1517
out[gId] += out[gId + stride]
16-
end
17-
synchronize()
18+
end
1819
end
1920
end
2021

21-
function naive_reduce(x::WgpuArray{T, N}) where {T, N}
22-
y = WgpuArray{T}(undef, size(x))
23-
@wgpukernel(
24-
launch=true,
25-
workgroupSizes=(4, 4),
26-
workgroupCount=(2, 2),
27-
shmem=(:shmem=>(Float32, (4, 4)),),
28-
naive_reduce_kernel(x, y)
29-
)
30-
return (y |> collect)[1]
22+
function naive_reduce(x::WgpuArray{T,N}) where {T,N}
23+
y = WgpuArray{T}(undef, size(x))
24+
@wgpukernel(
25+
launch = true,
26+
workgroupSizes = (8, 8),
27+
workgroupCount = (1, 1),
28+
shmem = (),
29+
naive_reduce_kernel(x, y)
30+
)
31+
return (y |> collect)
3132
end
3233

3334
x = WgpuArray{Float32}(rand(Float32, 8, 8))
@@ -36,7 +37,6 @@ z = naive_reduce(x)
3637
x_cpu = (x |> collect)
3738

3839
sum_cpu = sum(x_cpu)
39-
sum_gpu = (z |> collect)[1]
40+
sum_gpu = (z|>collect)[1]
4041

4142
@test sum_cpu sum_gpu
42-

0 commit comments

Comments
 (0)