11export naive_matmul_kernel, matmul
22
3+ """
4+ matmul_heuristics(x, y)
5+ This function computes workgroup size and workgroup count heuristics for a given input.
6+ This is used by `naive_matmul_kernel`.
7+ """
38function matmul_heuristics (x, y)
49 aSize = size (x)
510 bSize = size (y)
@@ -9,6 +14,12 @@ function matmul_heuristics(x, y)
914 return (outSize, outSize, (1 , 1 ))
1015end
1116
17+ """
18+ naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
19+ This is naive matrix multiplication implementation kernel. This is not supposed to be used as a regular
20+ julia function. This needs to be passed to @wgpukernel to under transformations to `WGSL` compatible
21+ shader code.
22+ """
1223function naive_matmul_kernel (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} , out:: WgpuArray{T, N} ) where {T, N}
1324 gIdx = globalId. x
1425 gIdy = globalId. y
@@ -23,14 +34,24 @@ function naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuAr
2334 out[gId] = sum
2435end
2536
37+ """
38+ matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
39+ This is wrapper function for end users which uses naive implementation of matrix multiplication
40+ `naive_matmul_kernel` kernel for matrix computation.
41+ """
2642function matmul (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} ) where {T, N}
2743 (outSize, wgSize, wgCount) = matmul_heuristics (x, y)
2844 out = WgpuArray {eltype(x), ndims(x)} (undef, outSize)
2945 @wgpukernel launch= true workgroupSizes= wgSize workgroupCount= wgCount shmem= () naive_matmul_kernel (x, y, out)
3046 return out
3147end
3248
33-
49+ """
50+ tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
51+ This is compute kernel which carries out tiled matrix multiplication of input `WgpuArrays`. This is
52+ not supposed to be used as a regular julia function. This instead needs to be passed to `@wgpukernel` macro
53+ inside a wrapper function.
54+ """
3455function tiled_matmul_kernel (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} , out:: WgpuArray{T, N} ) where {T, N}
3556 # set out matrix to zero
3657 gId = xDims. x* globalId. y + globalId. x
@@ -61,18 +82,29 @@ function tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuAr
6182
6283 out[gId] = sum
6384end
64- # For now valid only for square matrices of size powers of 2 and base size 16.
85+
86+ """
87+ tiled_matmul_heuristics(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
88+ This function computes workgroup size and workgroup count for a given input for
89+ `tiled_matmul_heuristics` kernel function.
90+ """
6591function tiled_matmul_heuristics (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} ) where {T, N}
6692 aSize = size (x)
6793 bSize = size (y)
6894 @assert last (aSize) == first (bSize)
6995 outSize = (first (aSize), last (bSize))
7096 @assert eltype (x) == eltype (y)
97+ # For now valid only for square matrices of size powers of 2 and base size 16.
7198 wgSize = (16 , 16 ) # This can be fixed for now
7299 wgCount = div .((outSize[1 ], outSize[2 ]), 16 , RoundUp)
73100 return (outSize, wgSize, wgCount)
74101end
75102
103+ """
104+ tiled_matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
105+ This is user end matrix multiplication function which carries out tiled matrix multiplication of
106+ input `WgpuArray` arguments.
107+ """
76108function tiled_matmul (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} ) where {T, N}
77109 (outSize, wgSize, wgCount) = tiled_matmul_heuristics (x, y)
78110 out = WgpuArray {eltype(x), ndims(x)} (undef, outSize)
0 commit comments