Skip to content

Commit 55a6986

Browse files
authored
optimize skiplayernorm (microsoft#20551)
SkipSimplifiedLayerNormalization used in phi3 comes down from 222usec to 14usec
1 parent 737eb48 commit 55a6986

File tree

1 file changed

+48
-16
lines changed

1 file changed

+48
-16
lines changed

js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
66
import {ShapeUtil} from '../../util';
77
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
88

9-
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
9+
import {castToF32, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
1010

1111
export interface SkipLayerNormAttributes {
1212
simplified: boolean;
@@ -58,7 +58,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
5858
throw new Error('Beta must have the same hidden size as input');
5959
}
6060
}
61-
6261
if (inputs.length > 4) {
6362
const bias: TensorView = inputs[4];
6463
if (bias.dims.length !== 1) {
@@ -86,6 +85,7 @@ const createSkipLayerNormProgramInfo =
8685
const hasMeanOutput = isTraining && outputCount > 1;
8786
const hasInvStdDevOutput = isTraining && outputCount > 2;
8887
const hasInputSkipBiasSumOutput = outputCount > 3;
88+
const workgroupSize = 64;
8989

9090
const components = getMaxComponents(hiddenSize);
9191

@@ -124,35 +124,61 @@ const createSkipLayerNormProgramInfo =
124124
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components));
125125
}
126126
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
127+
const vecDataType = tensorTypeToWsglStorageType(DataType.float, components);
127128
return `
128129
129130
${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)}
131+
var<workgroup> sum_shared : array<${vecDataType}, ${workgroupSize}>;
132+
var<workgroup> sum_squared_shared : array<${vecDataType}, ${workgroupSize}>;
133+
134+
${shaderHelper.mainStart([
135+
workgroupSize, 1, 1
136+
])}
137+
let ix = local_id.x;
138+
let iy = global_id.x / ${workgroupSize};
130139
131-
${shaderHelper.mainStart()}
132-
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')}
133140
let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
134-
let offset = global_idx * hidden_size_vectorized;
135-
var sum = ${fillVector('f32', components)};
136-
var squareSum = ${fillVector('f32', components)};
137-
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
141+
var stride = hidden_size_vectorized / ${workgroupSize};
142+
let offset = ix * stride + iy * hidden_size_vectorized;
143+
let offset1d = stride * ix;
144+
if (ix == ${workgroupSize - 1}) {
145+
stride = hidden_size_vectorized - stride * ix;
146+
}
147+
for (var i: u32 = 0; i < stride; i++) {
138148
let skip_value = skip[offset + i];
139-
let bias_value = ${hasBiasInput ? 'bias[i]' : dataType + '(0.0)'};
149+
let bias_value = ${hasBiasInput ? 'bias[offset1d + i]' : dataType + '(0.0)'};
140150
let input_value = x[offset + i];
141151
let value = input_value + skip_value + bias_value;
142152
${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''}
143153
output[offset + i] = value;
144154
let f32_value = ${castToF32(dataType, components, 'value')};
145-
sum += f32_value;
146-
squareSum += f32_value * f32_value;
155+
sum_shared[ix] += f32_value;
156+
sum_squared_shared[ix] += f32_value * f32_value;
147157
}
158+
workgroupBarrier();
159+
160+
var reduce_size : u32 = ${workgroupSize};
161+
for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {
162+
reduce_size = curr_size + (reduce_size & 1);
163+
if (ix < curr_size) {
164+
sum_shared[ix] += sum_shared[ix + reduce_size];
165+
sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];
166+
}
167+
workgroupBarrier();
168+
}
169+
170+
let sum = sum_shared[0];
171+
let square_sum = sum_squared_shared[0];
148172
let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size);
149-
let inv_std_dev = inverseSqrt(${sumVector('squareSum', components)} / f32(uniforms.hidden_size) ${
173+
let inv_std_dev = inverseSqrt(${sumVector('square_sum', components)} / f32(uniforms.hidden_size) ${
150174
simplified ? '' : '- mean * mean'} + uniforms.epsilon);
151175
${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''}
152176
${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''}
153-
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
154-
output[offset + i] = (output[offset + i] ${simplified ? '' : `- ${dataType}(mean)`}) * ${
155-
dataType}(inv_std_dev) * gamma[i] ${hasBetaInput ? '+ beta[i]' : ''};
177+
178+
for (var i: u32 = 0; i < stride; i++) {
179+
output[offset + i] = (output[offset + i] ${simplified ? '' : `- ${dataType}(mean)`}) *
180+
${dataType}(inv_std_dev) * gamma[offset1d + i]
181+
${hasBetaInput ? '+ beta[offset1d + i]' : ''};
156182
}
157183
}`;
158184
};
@@ -173,7 +199,13 @@ const createSkipLayerNormProgramInfo =
173199
inputDependencies: inputs.map((_input, _index) => 'type')
174200
},
175201
getShaderSource,
176-
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}),
202+
getRunData: () => ({
203+
outputs,
204+
dispatchGroup: {
205+
x: Math.ceil(outputSize / hiddenSize),
206+
},
207+
programUniforms
208+
}),
177209
};
178210
};
179211

0 commit comments

Comments
 (0)