Skip to content

Commit d0bac82

Browse files
authored
[js/webgpu] fix bcast in where (microsoft#19009)
1 parent 5349770 commit d0bac82

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

js/web/lib/wasm/jsep/webgpu/ops/where.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
7676
const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC));
7777
let outputShape = dimsA;
7878
let outputSize = ShapeUtil.size(dimsA);
79-
const vecSize = Math.ceil(outputSize / 4);
8079
// TODO: deal with zero-sized tensors (eg. dims=[1,0])
8180

8281
if (isBroadcast) {
@@ -88,6 +87,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
8887
outputSize = ShapeUtil.size(outputShape);
8988
}
9089

90+
const vecSize = Math.ceil(outputSize / 4);
91+
9192
return {
9293
name: 'Where',
9394
shaderCache: {inputDependencies: ['rank', 'rank', 'rank']},

0 commit comments

Comments
 (0)