Skip to content

8321010: RISC-V: C2 RoundVF #17745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/hotspot/cpu/riscv/assembler_riscv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,7 @@ enum VectorMask {
INSN(viota_m, 0b1010111, 0b010, 0b10000, 0b010100);

// Vector Single-Width Floating-Point/Integer Type-Convert Instructions
INSN(vfcvt_x_f_v, 0b1010111, 0b001, 0b00001, 0b010010);
INSN(vfcvt_f_x_v, 0b1010111, 0b001, 0b00011, 0b010010);
INSN(vfcvt_rtz_x_f_v, 0b1010111, 0b001, 0b00111, 0b010010);

Expand Down
68 changes: 68 additions & 0 deletions src/hotspot/cpu/riscv/c2_MacroAssembler_riscv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2325,6 +2325,74 @@ void C2_MacroAssembler::expand_bits_l_v(Register dst, Register src, Register mas
expand_bits_v(dst, src, mask, /* is_long */ true);
}

// j.l.Math.round(float)
// Returns the closest int to the argument, with ties rounding to positive infinity.
// We need to handle 3 special cases defined by java api spec:
// NaN,
// float >= Integer.MAX_VALUE,
// float <= Integer.MIN_VALUE.
void C2_MacroAssembler::java_round_float_v(VectorRegister dst, VectorRegister src, FloatRegister ftmp,
BasicType bt, uint vector_length) {
// In riscv, there is no straight corresponding rounding mode to satisfy the behaviour defined,
// in java api spec, i.e. any rounding mode can not handle some corner cases, e.g.
// RNE is the closest one, but it ties to "even", which means 1.5/2.5 both will be converted
// to 2, instead of 2 and 3 respectively.
// RUP does not work either, although java api requires "rounding to positive infinity",
// but both 1.3/1.8 will be converted to 2, instead of 1 and 2 respectively.
//
// The optimal solution for non-NaN cases is:
// src+0.5 => dst, with rdn rounding mode,
// convert dst from float to int, with rnd rounding mode.
// and, this solution works as expected for float >= Integer.MAX_VALUE and float <= Integer.MIN_VALUE.
//
// But, we still need to handle NaN explicilty with vector mask instructions.
//
// Check MacroAssembler::java_round_float and C2_MacroAssembler::vector_round_sve in aarch64 for more details.

csrwi(CSR_FRM, C2_MacroAssembler::rdn);
vsetvli_helper(bt, vector_length);

// don't rearrage the instructions sequence order without performance testing.
// check MacroAssembler::java_round_float in riscv64 for more details.
mv(t0, jint_cast(0.5f));
fmv_w_x(ftmp, t0);

// replacing vfclass with feq as performance optimization
vmfeq_vv(v0, src, src);
// set dst = 0 in cases of NaN
vmv_v_x(dst, zr);

// dst = (src + 0.5) rounded down towards negative infinity
vfadd_vf(dst, src, ftmp, Assembler::v0_t);
vfcvt_x_f_v(dst, dst, Assembler::v0_t); // in RoundingMode::rdn

csrwi(CSR_FRM, C2_MacroAssembler::rne);
}

// java.lang.Math.round(double a)
// Returns the closest long to the argument, with ties rounding to positive infinity.
void C2_MacroAssembler::java_round_double_v(VectorRegister dst, VectorRegister src, FloatRegister ftmp,
BasicType bt, uint vector_length) {
// check C2_MacroAssembler::java_round_float_v above for more details.

csrwi(CSR_FRM, C2_MacroAssembler::rdn);
vsetvli_helper(bt, vector_length);

mv(t0, julong_cast(0.5));
fmv_d_x(ftmp, t0);

// replacing vfclass with feq as performance optimization
vmfeq_vv(v0, src, src);
// set dst = 0 in cases of NaN
vmv_v_x(dst, zr);

// dst = (src + 0.5) rounded down towards negative infinity
vfadd_vf(dst, src, ftmp, Assembler::v0_t);
vfcvt_x_f_v(dst, dst, Assembler::v0_t); // in RoundingMode::rdn

csrwi(CSR_FRM, C2_MacroAssembler::rne);
}

void C2_MacroAssembler::element_compare(Register a1, Register a2, Register result, Register cnt, Register tmp1, Register tmp2,
VectorRegister vr1, VectorRegister vr2, VectorRegister vrs, bool islatin, Label &DONE,
Assembler::LMUL lmul) {
Expand Down
3 changes: 3 additions & 0 deletions src/hotspot/cpu/riscv/c2_MacroAssembler_riscv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@
void expand_bits_i_v(Register dst, Register src, Register mask);
void expand_bits_l_v(Register dst, Register src, Register mask);

void java_round_float_v(VectorRegister dst, VectorRegister src, FloatRegister ftmp, BasicType bt, uint vector_length);
void java_round_double_v(VectorRegister dst, VectorRegister src, FloatRegister ftmp, BasicType bt, uint vector_length);

void float16_to_float_v(VectorRegister dst, VectorRegister src, uint vector_length);
void float_to_float16_v(VectorRegister dst, VectorRegister src, VectorRegister vtmp, Register tmp, uint vector_length);

Expand Down
12 changes: 12 additions & 0 deletions src/hotspot/cpu/riscv/riscv.ad
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,18 @@ bool Matcher::match_rule_supported(int opcode) {
case Op_EncodeISOArray:
return UseRVV;

// Current test shows that, it brings performance gain when MaxVectorSize >= 32, but brings
// regression when MaxVectorSize == 16. So only enable the intrinsic when MaxVectorSize >= 32.
case Op_RoundVF:
return UseRVV && MaxVectorSize >= 32;

// For double, current test shows that even with MaxVectorSize == 32, there is still some regression.
// Although there is no hardware to verify it for now, from the trend of performance data on hardwares
// (with vlenb == 16 and 32 respectively), it's promising to bring better performance rather than
// regression for double when MaxVectorSize == 64+. So only enable the intrinsic when MaxVectorSize >= 64.
case Op_RoundVD:
return UseRVV && MaxVectorSize >= 64;

case Op_PopCountI:
case Op_PopCountL:
return UsePopCountInstruction;
Expand Down
28 changes: 28 additions & 0 deletions src/hotspot/cpu/riscv/riscv_v.ad
Original file line number Diff line number Diff line change
Expand Up @@ -4715,6 +4715,34 @@ instruct vsignum_reg(vReg dst, vReg zero, vReg one, vRegMask_V0 v0) %{
ins_pipe(pipe_slow);
%}

// ---------------- Round float/double Vector Operations ----------------

instruct vround_f(vReg dst, vReg src, fRegF tmp, vRegMask_V0 v0) %{
match(Set dst (RoundVF src));
effect(TEMP_DEF dst, TEMP tmp, TEMP v0);
format %{ "java_round_float_v $dst, $src\t" %}
ins_encode %{
BasicType bt = Matcher::vector_element_basic_type(this);
uint vector_length = Matcher::vector_length(this);
__ java_round_float_v(as_VectorRegister($dst$$reg), as_VectorRegister($src$$reg),
as_FloatRegister($tmp$$reg), bt, vector_length);
%}
ins_pipe(pipe_slow);
%}

instruct vround_d(vReg dst, vReg src, fRegD tmp, vRegMask_V0 v0) %{
match(Set dst (RoundVD src));
effect(TEMP_DEF dst, TEMP tmp, TEMP v0);
format %{ "java_round_double_v $dst, $src\t" %}
ins_encode %{
BasicType bt = Matcher::vector_element_basic_type(this);
uint vector_length = Matcher::vector_length(this);
__ java_round_double_v(as_VectorRegister($dst$$reg), as_VectorRegister($src$$reg),
as_FloatRegister($tmp$$reg), bt, vector_length);
%}
ins_pipe(pipe_slow);
%}

// -------------------------------- Reverse Bytes Vector Operations ------------------------

instruct vreverse_bytes_masked(vReg dst_src, vRegMask_V0 v0) %{
Expand Down
84 changes: 84 additions & 0 deletions test/hotspot/jtreg/compiler/floatingpoint/TestRoundFloatAll.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, Rivos Inc. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

/**
* @test
* @bug 8321010
* @summary Test intrinsic for Math.round(float) in full 32 bits range
*
* @library /test/lib /
* @modules java.base/jdk.internal.math
* @requires os.arch == "riscv64"
* @run main/othervm -XX:-TieredCompilation -Xbatch -XX:CompileThresholdScaling=0.3 -XX:-UseSuperWord
* -XX:CompileCommand=compileonly,compiler.floatingpoint.TestRoundFloatAll::test*
* compiler.floatingpoint.TestRoundFloatAll
*/

package compiler.floatingpoint;

import static compiler.lib.golden.GoldenRound.golden_round;

public class TestRoundFloatAll {

public static void main(String args[]) {
test();
}

// return true when test fails
static boolean test(int n, float f) {
int actual = Math.round(f);
int expected = golden_round(f);
if (actual != expected) {
System.err.println("round error, input: " + f + ", res: " + actual + "expected: " + expected + ", input hex: " + n);
return true;
}
return false;
}

static void test() {
final int ITERS = 11000;
boolean fail = false;

// Warmup
System.out.println("Warmup");
for (int i=0; i<ITERS; i++) {
float f = Float.intBitsToFloat(i);
fail |= test(i, f);
}
if (fail) {
throw new RuntimeException("Warmup failed");
}

// Test and verify results
System.out.println("Verification");
int testInt = 0;
do {
float testFloat = Float.intBitsToFloat(testInt);
fail |= test(testInt, testFloat);
} while (++testInt != 0);
if (fail) {
throw new RuntimeException("Test failed");
}
}
}
95 changes: 95 additions & 0 deletions test/hotspot/jtreg/compiler/lib/golden/GoldenRound.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, Rivos Inc. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

package compiler.lib.golden;

import jdk.internal.math.DoubleConsts;
import jdk.internal.math.FloatConsts;

public class GoldenRound {
public static int golden_round(float a) {
// below code is copied from java.base/share/classes/java/lang/Math.java
// public static int round(float a) { ... }

int intBits = Float.floatToRawIntBits(a);
int biasedExp = (intBits & FloatConsts.EXP_BIT_MASK)
>> (FloatConsts.SIGNIFICAND_WIDTH - 1);
int shift = (FloatConsts.SIGNIFICAND_WIDTH - 2
+ FloatConsts.EXP_BIAS) - biasedExp;
if ((shift & -32) == 0) { // shift >= 0 && shift < 32
// a is a finite number such that pow(2,-32) <= ulp(a) < 1
int r = ((intBits & FloatConsts.SIGNIF_BIT_MASK)
| (FloatConsts.SIGNIF_BIT_MASK + 1));
if (intBits < 0) {
r = -r;
}
// In the comments below each Java expression evaluates to the value
// the corresponding mathematical expression:
// (r) evaluates to a / ulp(a)
// (r >> shift) evaluates to floor(a * 2)
// ((r >> shift) + 1) evaluates to floor((a + 1/2) * 2)
// (((r >> shift) + 1) >> 1) evaluates to floor(a + 1/2)
return ((r >> shift) + 1) >> 1;
} else {
// a is either
// - a finite number with abs(a) < exp(2,FloatConsts.SIGNIFICAND_WIDTH-32) < 1/2
// - a finite number with ulp(a) >= 1 and hence a is a mathematical integer
// - an infinity or NaN
return (int) a;
}
}


public static long golden_round(double a) {
// below code is copied from java.base/share/classes/java/lang/Math.java
// public static int round(double a) { ... }

long longBits = Double.doubleToRawLongBits(a);
long biasedExp = (longBits & DoubleConsts.EXP_BIT_MASK)
>> (DoubleConsts.SIGNIFICAND_WIDTH - 1);
long shift = (DoubleConsts.SIGNIFICAND_WIDTH - 2
+ DoubleConsts.EXP_BIAS) - biasedExp;
if ((shift & -64) == 0) { // shift >= 0 && shift < 64
// a is a finite number such that pow(2,-64) <= ulp(a) < 1
long r = ((longBits & DoubleConsts.SIGNIF_BIT_MASK)
| (DoubleConsts.SIGNIF_BIT_MASK + 1));
if (longBits < 0) {
r = -r;
}
// In the comments below each Java expression evaluates to the value
// the corresponding mathematical expression:
// (r) evaluates to a / ulp(a)
// (r >> shift) evaluates to floor(a * 2)
// ((r >> shift) + 1) evaluates to floor((a + 1/2) * 2)
// (((r >> shift) + 1) >> 1) evaluates to floor(a + 1/2)
return ((r >> shift) + 1) >> 1;
} else {
// a is either
// - a finite number with abs(a) < exp(2,DoubleConsts.SIGNIFICAND_WIDTH-64) < 1/2
// - a finite number with ulp(a) >= 1 and hence a is a mathematical integer
// - an infinity or NaN
return (long) a;
}
}
}
Loading