Skip to content

Commit fdbe53f

Browse files
committed
clean examples
1 parent bb88c9d commit fdbe53f

File tree

4 files changed

+26
-7
lines changed

4 files changed

+26
-7
lines changed

examples/cast_kernel.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using WGPUCompute
2+
using Test
23

34
function cast_kernel(x::WgpuArray{T, N}, out::WgpuArray{S, N}) where {T, S, N}
45
xdim = workgroupDims.x
@@ -15,8 +16,15 @@ function cast(S::DataType, x::WgpuArray{T, N}) where {T, N}
1516
return y
1617
end
1718

18-
x = WgpuArray{Float32}(rand(Float32, 8, 8) .- 0.5f0)
19-
z = cast(UInt32, x)
19+
x = rand(Float32, 8, 8) .- 0.5f0
20+
21+
x_gpu = WgpuArray{Float32}(x)
22+
z_gpu = cast(UInt32, x_gpu)
23+
z_cpu = z_gpu |> collect
24+
25+
z = UInt32.(x .> 0.0)
26+
27+
@test z z_cpu
2028

2129
# TODO Bool cast is not working yet
2230
# y = cast(Bool, x)

examples/clamp_kernel.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Revise
22
using WGPUCompute
3+
using Test
34

45
function clamp_kernel(x::WgpuArray{T, N}, out::WgpuArray{T, N}, minval::T, maxval::T) where {T, N}
56
gId = xDims.x * globalId.y + globalId.x
@@ -17,3 +18,9 @@ end
1718
x = WgpuArray{Float32, 2}(rand(16, 16))
1819

1920
y = Base.clamp(x, 0.2f0, 0.5f0)
21+
y_cpu = y |> collect
22+
23+
@testset "Clamp minimum and maximum" begin
24+
@test minimum(y_cpu) == 0.2f0
25+
@test maximum(y_cpu) == 0.5f0
26+
end

examples/scan_kernel.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ end
4141

4242
function naive_prefix_scan(x::WgpuArray{T, N}) where {T, N}
4343
y = similar(x)
44-
wgsize = div(reduce(*, size(x)), 256)
44+
maxthreads = 256
45+
wgsize = div(reduce(*, size(x)), maxthreads)
4546
p = WgpuArray{T, N}(zeros(wgsize))
4647
@wgpukernel(
4748
launch=true,
48-
workgroupSizes = (256,),
49+
workgroupSizes = (maxthreads,),
4950
workgroupCount = (wgsize,),
5051
shmem = (),
5152
naive_prefix_scan_kernel(x, y, p)
@@ -54,15 +55,15 @@ function naive_prefix_scan(x::WgpuArray{T, N}) where {T, N}
5455
partials = WgpuArray{T, N}(pscan)
5556
@wgpukernel(
5657
launch=true,
57-
workgroupSizes = (256,),
58+
workgroupSizes = (maxthreads,),
5859
workgroupCount = (wgsize,),
5960
shmem = (),
6061
naive_prefix_partials_scatter_kernel(y, partials)
6162
)
6263
return y
6364
end
6465

65-
x = WgpuArray{Float32}(rand(Float32, 2^16))
66+
x = WgpuArray{Float32}(rand(Float32, 2^22))
6667
z = naive_prefix_scan(x,)
6768

6869
x_cpu = (x |> collect)

examples/tiled_matmul_kernel.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ Base.:*(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N} = tiled_matmul(x,
8484

8585
z = x*y
8686

87-
z_cpu = (x |> collect)*(y |> collect)
87+
x_cpu = (x |> collect);
88+
y_cpu = (y |> collect);
89+
90+
z_cpu = x_cpu*y_cpu
8891

8992
@test z_cpu (z |> collect)
9093

0 commit comments

Comments
 (0)