Skip to content

[GR-64794] Vector API: Implement Vector::compress/expand #11434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,7 @@ private enum EVEXFeatureAssertion {
AVX512F_DQ_512(null, null, EnumSet.of(AVX512F, AVX512DQ)),
AVX512F_512(null, null, EnumSet.of(AVX512F)),
AVX512_VBMI_VL(EnumSet.of(CPUFeature.AVX512_VBMI, CPUFeature.AVX512VL), EnumSet.of(CPUFeature.AVX512_VBMI, CPUFeature.AVX512VL), EnumSet.of(CPUFeature.AVX512_VBMI)),
AVX512_VBMI2_VL(EnumSet.of(CPUFeature.AVX512_VBMI2, CPUFeature.AVX512VL), EnumSet.of(CPUFeature.AVX512_VBMI2, CPUFeature.AVX512VL), EnumSet.of(CPUFeature.AVX512_VBMI2)),
CLMUL_AVX512F_VL(EnumSet.of(CPUFeature.AVX512VL, CPUFeature.CLMUL), EnumSet.of(CPUFeature.AVX512VL, CPUFeature.CLMUL), EnumSet.of(CPUFeature.AVX512F, CPUFeature.CLMUL)),
GFNI_AVX512F_VL(EnumSet.of(CPUFeature.AVX512VL, CPUFeature.GFNI), EnumSet.of(CPUFeature.AVX512VL, CPUFeature.GFNI), EnumSet.of(CPUFeature.AVX512F, CPUFeature.GFNI));

Expand Down Expand Up @@ -1338,6 +1339,7 @@ private enum VEXOpAssertion {
AVX512F_512ONLY(null, EVEXFeatureAssertion.AVX512F_512, XMM, XMM, XMM),
AVX512DQ_512ONLY(null, EVEXFeatureAssertion.AVX512F_DQ_512, XMM, XMM, XMM),
AVX512_VBMI_VL(null, EVEXFeatureAssertion.AVX512_VBMI_VL, XMM, XMM, XMM),
AVX512_VBMI2_VL(null, EVEXFeatureAssertion.AVX512_VBMI2_VL, XMM, XMM, XMM),

AVX512F_CPU_OR_MASK(VEXFeatureAssertion.AVX512F_L0, null, CPU_OR_MASK, null, CPU_OR_MASK),
AVX512BW_CPU_OR_MASK(VEXFeatureAssertion.AVX512BW_L0, null, CPU_OR_MASK, null, CPU_OR_MASK),
Expand Down Expand Up @@ -1398,6 +1400,7 @@ public enum VectorFeatureAssertion {
AVX512F_VL(VEXOpAssertion.AVX512F_VL),
AVX512BW_VL(VEXOpAssertion.AVX512BW_VL),
AVX512DQ_VL(VEXOpAssertion.AVX512DQ_VL),
AVX512_VBMI2_VL(VEXOpAssertion.AVX512_VBMI2_VL),
FMA(VEXOpAssertion.FMA);

private final VEXOpAssertion opAssertion;
Expand Down Expand Up @@ -1722,6 +1725,10 @@ public static class VexRMOp extends VexRROp {
public static final VexRMOp EVPABSQ = new VexRMOp("EVPABSQ", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.WIG, 0x1F, VEXOpAssertion.AVX512F_VL, EVEXTuple.FVM, VEXPrefixConfig.W1, true);
public static final VexRMOp EVCVTPH2PS = new VexRMOp("EVCVTPH2PS", VCVTPH2PS);

public static final VexRMOp EVPEXPANDB = new VexRMOp("EVPEXPANDB", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x62, VEXOpAssertion.AVX512_VBMI2_VL, EVEXTuple.T1S_8BIT, VEXPrefixConfig.W0, true);
public static final VexRMOp EVPEXPANDW = new VexRMOp("EVPEXPANDW", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W1, 0x62, VEXOpAssertion.AVX512_VBMI2_VL, EVEXTuple.T1S_16BIT, VEXPrefixConfig.W1, true);
public static final VexRMOp EVPEXPANDD = new VexRMOp("EVPEXPANDD", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x89, VEXOpAssertion.AVX512F_VL, EVEXTuple.T1S_32BIT, VEXPrefixConfig.W0, true);
public static final VexRMOp EVPEXPANDQ = new VexRMOp("EVPEXPANDQ", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W1, 0x89, VEXOpAssertion.AVX512F_VL, EVEXTuple.T1S_64BIT, VEXPrefixConfig.W1, true);
// @formatter:on

protected VexRMOp(String opcode, int pp, int mmmmm, int w, int op, VEXOpAssertion assertion) {
Expand Down Expand Up @@ -2063,7 +2070,10 @@ public void emit(AMD64Assembler asm, AVXSize size, Register dst, AMD64Address sr
*/
public static final class VexMROp extends VexRROp {
// @formatter:off
public static final VexMROp EVPCOMPRESSB = new VexMROp("EVPCOMPRESSB", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x63, VEXOpAssertion.AVX512_VBMI2_VL, EVEXTuple.T1S_8BIT, VEXPrefixConfig.W0, true);
public static final VexMROp EVPCOMPRESSW = new VexMROp("EVPCOMPRESSW", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W1, 0x63, VEXOpAssertion.AVX512_VBMI2_VL, EVEXTuple.T1S_16BIT, VEXPrefixConfig.W1, true);
public static final VexMROp EVPCOMPRESSD = new VexMROp("EVPCOMPRESSD", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x8B, VEXOpAssertion.AVX512F_VL, EVEXTuple.T1S_32BIT, VEXPrefixConfig.W0, true);
public static final VexMROp EVPCOMPRESSQ = new VexMROp("EVPCOMPRESSQ", VEXPrefixConfig.P_66, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W1, 0x8B, VEXOpAssertion.AVX512F_VL, EVEXTuple.T1S_64BIT, VEXPrefixConfig.W1, true);

public static final VexMROp EVPMOVWB = new VexMROp("EVPMOVWB", VEXPrefixConfig.P_F3, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x30, VEXOpAssertion.AVX512BW_VL, EVEXTuple.HVM, VEXPrefixConfig.W0, true);
public static final VexMROp EVPMOVDB = new VexMROp("EVPMOVDB", VEXPrefixConfig.P_F3, VEXPrefixConfig.M_0F38, VEXPrefixConfig.W0, 0x31, VEXOpAssertion.AVX512F_VL, EVEXTuple.QVM, VEXPrefixConfig.W0, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ public UnimplementedGraalIntrinsics(Architecture arch) {
// scalar operations
"jdk/internal/vm/vector/Float16Math.fma(Ljava/lang/Class;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljdk/internal/vm/vector/Float16Math$TernaryOperator;)Ljava/lang/Object;",
"jdk/internal/vm/vector/Float16Math.sqrt(Ljava/lang/Class;Ljava/lang/Object;Ljava/util/function/UnaryOperator;)Ljava/lang/Object;",
"jdk/internal/vm/vector/VectorSupport.compressExpandOp(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$CompressExpandOperation;)Ljdk/internal/vm/vector/VectorSupport$VectorPayload;",
"jdk/internal/vm/vector/VectorSupport.indexVector(Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;ILjdk/internal/vm/vector/VectorSupport$VectorSpecies;Ljdk/internal/vm/vector/VectorSupport$IndexOperation;)Ljdk/internal/vm/vector/VectorSupport$Vector;",
// JDK-8353786: Migrate Vector API math library support to FFM API
"jdk/internal/vm/vector/VectorSupport.libraryBinaryOp(JLjava/lang/Class;Ljava/lang/Class;ILjava/lang/String;Ljdk/internal/vm/vector/VectorSupport$VectorPayload;Ljdk/internal/vm/vector/VectorSupport$VectorPayload;Ljdk/internal/vm/vector/VectorSupport$BinaryOperation;)Ljdk/internal/vm/vector/VectorSupport$VectorPayload;",
Expand Down Expand Up @@ -187,6 +186,7 @@ public UnimplementedGraalIntrinsics(Architecture arch) {
"jdk/internal/vm/vector/VectorSupport.blend(Ljava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorBlendOp;)Ljdk/internal/vm/vector/VectorSupport$Vector;",
"jdk/internal/vm/vector/VectorSupport.broadcastInt(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;ILjdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorBroadcastIntOp;)Ljdk/internal/vm/vector/VectorSupport$Vector;",
"jdk/internal/vm/vector/VectorSupport.compare(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$VectorCompareOp;)Ljdk/internal/vm/vector/VectorSupport$VectorMask;",
"jdk/internal/vm/vector/VectorSupport.compressExpandOp(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljdk/internal/vm/vector/VectorSupport$CompressExpandOperation;)Ljdk/internal/vm/vector/VectorSupport$VectorPayload;",
"jdk/internal/vm/vector/VectorSupport.convert(ILjava/lang/Class;Ljava/lang/Class;ILjava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$VectorPayload;Ljdk/internal/vm/vector/VectorSupport$VectorSpecies;Ljdk/internal/vm/vector/VectorSupport$VectorConvertOp;)Ljdk/internal/vm/vector/VectorSupport$VectorPayload;",
"jdk/internal/vm/vector/VectorSupport.extract(Ljava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$VectorPayload;ILjdk/internal/vm/vector/VectorSupport$VecExtractOp;)J",
"jdk/internal/vm/vector/VectorSupport.fromBitsCoerced(Ljava/lang/Class;Ljava/lang/Class;IJILjdk/internal/vm/vector/VectorSupport$VectorSpecies;Ljdk/internal/vm/vector/VectorSupport$FromBitsCoercedOperation;)Ljdk/internal/vm/vector/VectorSupport$VectorPayload;",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.graal.compiler.lir.amd64.vector;

import static jdk.graal.compiler.asm.amd64.AMD64BaseAssembler.EVEXPrefixConfig.B0;
import static jdk.graal.compiler.asm.amd64.AMD64BaseAssembler.EVEXPrefixConfig.Z1;
import static jdk.vm.ci.code.ValueUtil.asRegister;

import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexMROp;
import jdk.graal.compiler.asm.amd64.AMD64Assembler.VexRMOp;
import jdk.graal.compiler.asm.amd64.AMD64MacroAssembler;
import jdk.graal.compiler.asm.amd64.AVXKind;
import jdk.graal.compiler.debug.GraalError;
import jdk.graal.compiler.lir.LIRInstructionClass;
import jdk.graal.compiler.lir.amd64.AMD64LIRInstruction;
import jdk.graal.compiler.lir.asm.CompilationResultBuilder;
import jdk.vm.ci.amd64.AMD64Kind;
import jdk.vm.ci.meta.AllocatableValue;

/**
* This class implements the LIR nodes for AVX512 instructions {@code vpcompress} and
* {@code vpexpand}.
*/
public class AVX512CompressExpand {
/**
* The LIR node for the instruction {@code vpcompress}.
*/
public static final class CompressOp extends AMD64LIRInstruction {
public static final LIRInstructionClass<CompressOp> TYPE = LIRInstructionClass.create(CompressOp.class);

@Def protected AllocatableValue result;
@Use protected AllocatableValue source;
@Use protected AllocatableValue mask;

public CompressOp(AllocatableValue result, AllocatableValue source, AllocatableValue mask) {
super(TYPE);
this.result = result;
this.source = source;
this.mask = mask;
}

@Override
public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
AMD64Kind eKind = ((AMD64Kind) result.getPlatformKind()).getScalar();
AVXKind.AVXSize avxSize = AVXKind.getRegisterSize(result);
VexMROp op = switch (eKind) {
case BYTE -> VexMROp.EVPCOMPRESSB;
case WORD -> VexMROp.EVPCOMPRESSW;
case DWORD, SINGLE -> VexMROp.EVPCOMPRESSD;
case QWORD, DOUBLE -> VexMROp.EVPCOMPRESSQ;
default -> throw GraalError.shouldNotReachHereUnexpectedValue(eKind);
};
op.emit(masm, avxSize, asRegister(result), asRegister(source), asRegister(mask), Z1, B0);
}
}

/**
* The LIR node for the instruction {@code vpexpand}.
*/
public static final class ExpandOp extends AMD64LIRInstruction {
public static final LIRInstructionClass<ExpandOp> TYPE = LIRInstructionClass.create(ExpandOp.class);

@Def protected AllocatableValue result;
@Use protected AllocatableValue source;
@Use protected AllocatableValue mask;

public ExpandOp(AllocatableValue result, AllocatableValue source, AllocatableValue mask) {
super(TYPE);
this.result = result;
this.source = source;
this.mask = mask;
}

@Override
public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
AMD64Kind eKind = ((AMD64Kind) result.getPlatformKind()).getScalar();
AVXKind.AVXSize avxSize = AVXKind.getRegisterSize(result);
VexRMOp op = switch (eKind) {
case BYTE -> VexRMOp.EVPEXPANDB;
case WORD -> VexRMOp.EVPEXPANDW;
case DWORD, SINGLE -> VexRMOp.EVPEXPANDD;
case QWORD, DOUBLE -> VexRMOp.EVPEXPANDQ;
default -> throw GraalError.shouldNotReachHereUnexpectedValue(eKind);
};
op.emit(masm, avxSize, asRegister(result), asRegister(source), asRegister(mask), Z1, B0);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ protected int getSupportedVectorLogicLengthHelper(LogicNode logicNode, int maxLe
*/
public abstract int getSupportedVectorBlendLength(Stamp elementStamp, int maxLength);

/**
* Get the maximum supported vector length for a vector compress/expand based on a mask.
*
* @param elementStamp the stamp of the elements to be blended
* @param maxLength the maximum length to return
* @return the number of elements that can be compressed/expanded by a single instruction
*/
public abstract int getSupportedVectorCompressExpandLength(Stamp elementStamp, int maxLength);

/**
* Determine the minimum alignment in bytes that is guaranteed for objects.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ public int getSupportedVectorBlendLength(Stamp elementStamp, int maxLength) {
return getSupportedVectorLength(elementStamp, maxLength);
}

@Override
public int getSupportedVectorCompressExpandLength(Stamp elementStamp, int maxLength) {
return 1;
}

@Override
public int getObjectAlignment() {
return objectAlignment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,17 @@ private boolean isImpossibleLongToDoubleConversion(Stamp result, Stamp input) {
return !supportUnsignedLongToDouble((IntegerStamp) input) && !supportSignedLongToDouble((IntegerStamp) input);
}

@Override
public int getSupportedVectorCompressExpandLength(Stamp elementStamp, int maxLength) {
if (!hasMinimumVectorizationRequirements(maxLength)) {
return 1;
}

AVXSize avxSize = compressExpandOps.getSupportedAVXSize(elementStamp, maxLength);
int supportedLength = getSupportedVectorLength(elementStamp, maxLength, avxSize);
return Math.min(supportedLength, maxLength);
}

@Override
public int getObjectAlignment() {
return objectAlignment;
Expand Down Expand Up @@ -1187,6 +1198,41 @@ public AVXSize getSupportedAVXSize(Stamp stamp, int maxLength) {
}
}

private final AMD64SupportedCompressExpandVectorInstructionsTable compressExpandOps = new AMD64SupportedCompressExpandVectorInstructionsTable(this);

private static final class AMD64VectorCompressExpandInstructionsMap extends AMD64SimpleVectorInstructionsTable.AMD64SimpleVectorInstructionsMap {
@SuppressWarnings("unchecked")
AMD64VectorCompressExpandInstructionsMap() {
super(
entry(IntegerStamp.class,
op(BYTE_BITS, VectorFeatureAssertion.AVX512_VBMI2_VL),
op(WORD_BITS, VectorFeatureAssertion.AVX512_VBMI2_VL),
op(DWORD_BITS, VectorFeatureAssertion.AVX512F_VL),
op(QWORD_BITS, VectorFeatureAssertion.AVX512F_VL)),

entry(FloatStamp.class,
op(SINGLE_BITS, VectorFeatureAssertion.AVX512F_VL),
op(DOUBLE_BITS, VectorFeatureAssertion.AVX512F_VL)));
}
}

private static final class AMD64SupportedCompressExpandVectorInstructionsTable extends AMD64SimpleVectorInstructionsTable {

private static final AMD64VectorCompressExpandInstructionsMap COMPRESS_EXPAND_INSTRUCTIONS_MAP = new AMD64VectorCompressExpandInstructionsMap();

private AMD64SupportedCompressExpandVectorInstructionsTable(VectorAMD64 vectorAMD64) {
super(vectorAMD64, COMPRESS_EXPAND_INSTRUCTIONS_MAP);
}

public AVXSize getSupportedAVXSize(Stamp stamp, int maxLength) {
if (stamp instanceof AbstractObjectStamp) {
// For compress/expand, treat pointers like integers of the appropriate size.
return getEntry(IntegerStamp.class, oopBits((AbstractObjectStamp) stamp), maxLength);
}
return getEntry(stamp.getClass(), PrimitiveStamp.getBits(stamp), maxLength);
}
}

private static class VectorSimpleOperation {
private final int bits;
protected final VectorFeatureAssertion requiredFeatures;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,8 @@ public interface VectorLIRGeneratorTool extends ArithmeticLIRGeneratorTool {
Value emitMoveOpMaskToInteger(LIRKind resultKind, Value mask, int maskLen);

Value emitMoveIntegerToOpMask(LIRKind resultKind, Value mask);

Value emitVectorCompress(LIRKind resultKind, Value source, Value mask);

Value emitVectorExpand(LIRKind resultKind, Value source, Value mask);
}
Original file line number Diff line number Diff line change
Expand Up @@ -872,4 +872,14 @@ public Value emitMoveOpMaskToInteger(LIRKind resultKind, Value mask, int maskLen
public Value emitMoveIntegerToOpMask(LIRKind resultKind, Value mask) {
throw new UnsupportedOperationException();
}

@Override
public Variable emitVectorCompress(LIRKind resultKind, Value source, Value mask) {
throw new UnsupportedOperationException();
}

@Override
public Variable emitVectorExpand(LIRKind resultKind, Value source, Value mask) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorMove;
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorShuffle;
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorUnary;
import jdk.graal.compiler.lir.amd64.vector.AVX512CompressExpand;
import jdk.graal.compiler.lir.amd64.vector.AVX512MaskedOp;
import jdk.graal.compiler.nodes.ValueNode;
import jdk.graal.compiler.nodes.calc.AbsNode;
Expand Down Expand Up @@ -2019,4 +2020,18 @@ public static AMD64Assembler.VexOp getMaskedOpcode(AMD64 arch, MaskedOpMetaData
return null;
}
}

@Override
public Variable emitVectorCompress(LIRKind resultKind, Value source, Value mask) {
Variable result = getLIRGen().newVariable(resultKind);
getLIRGen().append(new AVX512CompressExpand.CompressOp(result, asAllocatable(source), asAllocatable(mask)));
return result;
}

@Override
public Variable emitVectorExpand(LIRKind resultKind, Value source, Value mask) {
Variable result = getLIRGen().newVariable(resultKind);
getLIRGen().append(new AVX512CompressExpand.ExpandOp(result, asAllocatable(source), asAllocatable(mask)));
return result;
}
}
Loading