@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
66import { ShapeUtil } from '../../util' ;
77import { ComputeContext , ProgramInfo , ProgramUniform } from '../types' ;
88
9- import { castToF32 , fillVector , getMaxComponents , inputVariable , outputVariable , ShaderHelper , sumVector , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
9+ import { castToF32 , getMaxComponents , inputVariable , outputVariable , ShaderHelper , sumVector , tensorTypeToWsglStorageType , UniformsArrayType } from './common' ;
1010
1111export interface SkipLayerNormAttributes {
1212 simplified : boolean ;
@@ -58,7 +58,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
5858 throw new Error ( 'Beta must have the same hidden size as input' ) ;
5959 }
6060 }
61-
6261 if ( inputs . length > 4 ) {
6362 const bias : TensorView = inputs [ 4 ] ;
6463 if ( bias . dims . length !== 1 ) {
@@ -86,6 +85,7 @@ const createSkipLayerNormProgramInfo =
8685 const hasMeanOutput = isTraining && outputCount > 1 ;
8786 const hasInvStdDevOutput = isTraining && outputCount > 2 ;
8887 const hasInputSkipBiasSumOutput = outputCount > 3 ;
88+ const workgroupSize = 64 ;
8989
9090 const components = getMaxComponents ( hiddenSize ) ;
9191
@@ -124,35 +124,61 @@ const createSkipLayerNormProgramInfo =
124124 variables . push ( outputVariable ( 'input_skip_bias_sum' , inputs [ 0 ] . dataType , outputShape , components ) ) ;
125125 }
126126 const dataType = tensorTypeToWsglStorageType ( inputs [ 0 ] . dataType ) ;
127+ const vecDataType = tensorTypeToWsglStorageType ( DataType . float , components ) ;
127128 return `
128129
129130 ${ shaderHelper . registerUniforms ( uniformsArray ) . declareVariables ( ...variables ) }
131+ var<workgroup> sum_shared : array<${ vecDataType } , ${ workgroupSize } >;
132+ var<workgroup> sum_squared_shared : array<${ vecDataType } , ${ workgroupSize } >;
133+
134+ ${ shaderHelper . mainStart ( [
135+ workgroupSize , 1 , 1
136+ ] ) }
137+ let ix = local_id.x;
138+ let iy = global_id.x / ${ workgroupSize } ;
130139
131- ${ shaderHelper . mainStart ( ) }
132- ${ shaderHelper . guardAgainstOutOfBoundsWorkgroupSizes ( 'uniforms.output_size / uniforms.hidden_size' ) }
133140 let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
134- let offset = global_idx * hidden_size_vectorized;
135- var sum = ${ fillVector ( 'f32' , components ) } ;
136- var squareSum = ${ fillVector ( 'f32' , components ) } ;
137- for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
141+ var stride = hidden_size_vectorized / ${ workgroupSize } ;
142+ let offset = ix * stride + iy * hidden_size_vectorized;
143+ let offset1d = stride * ix;
144+ if (ix == ${ workgroupSize - 1 } ) {
145+ stride = hidden_size_vectorized - stride * ix;
146+ }
147+ for (var i: u32 = 0; i < stride; i++) {
138148 let skip_value = skip[offset + i];
139- let bias_value = ${ hasBiasInput ? 'bias[i]' : dataType + '(0.0)' } ;
149+ let bias_value = ${ hasBiasInput ? 'bias[offset1d + i]' : dataType + '(0.0)' } ;
140150 let input_value = x[offset + i];
141151 let value = input_value + skip_value + bias_value;
142152 ${ hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : '' }
143153 output[offset + i] = value;
144154 let f32_value = ${ castToF32 ( dataType , components , 'value' ) } ;
145- sum += f32_value;
146- squareSum += f32_value * f32_value;
155+ sum_shared[ix] += f32_value;
156+ sum_squared_shared[ix] += f32_value * f32_value;
147157 }
158+ workgroupBarrier();
159+
160+ var reduce_size : u32 = ${ workgroupSize } ;
161+ for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {
162+ reduce_size = curr_size + (reduce_size & 1);
163+ if (ix < curr_size) {
164+ sum_shared[ix] += sum_shared[ix + reduce_size];
165+ sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];
166+ }
167+ workgroupBarrier();
168+ }
169+
170+ let sum = sum_shared[0];
171+ let square_sum = sum_squared_shared[0];
148172 let mean = ${ sumVector ( 'sum' , components ) } / f32(uniforms.hidden_size);
149- let inv_std_dev = inverseSqrt(${ sumVector ( 'squareSum ' , components ) } / f32(uniforms.hidden_size) ${
173+ let inv_std_dev = inverseSqrt(${ sumVector ( 'square_sum ' , components ) } / f32(uniforms.hidden_size) ${
150174 simplified ? '' : '- mean * mean' } + uniforms.epsilon);
151175 ${ hasMeanOutput ? 'mean_output[global_idx] = mean;' : '' }
152176 ${ hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : '' }
153- for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
154- output[offset + i] = (output[offset + i] ${ simplified ? '' : `- ${ dataType } (mean)` } ) * ${
155- dataType } (inv_std_dev) * gamma[i] ${ hasBetaInput ? '+ beta[i]' : '' } ;
177+
178+ for (var i: u32 = 0; i < stride; i++) {
179+ output[offset + i] = (output[offset + i] ${ simplified ? '' : `- ${ dataType } (mean)` } ) *
180+ ${ dataType } (inv_std_dev) * gamma[offset1d + i]
181+ ${ hasBetaInput ? '+ beta[offset1d + i]' : '' } ;
156182 }
157183 }` ;
158184 } ;
@@ -173,7 +199,13 @@ const createSkipLayerNormProgramInfo =
173199 inputDependencies : inputs . map ( ( _input , _index ) => 'type' )
174200 } ,
175201 getShaderSource,
176- getRunData : ( ) => ( { outputs, dispatchGroup : { x : Math . ceil ( outputSize / hiddenSize / 64 ) } , programUniforms} ) ,
202+ getRunData : ( ) => ( {
203+ outputs,
204+ dispatchGroup : {
205+ x : Math . ceil ( outputSize / hiddenSize ) ,
206+ } ,
207+ programUniforms
208+ } ) ,
177209 } ;
178210 } ;
179211
0 commit comments