Skip to content

Commit 50958b0

Browse files
committed
Refactor. Reduce fraction size
1 parent 73f3f11 commit 50958b0

File tree

4 files changed

+103
-82
lines changed

4 files changed

+103
-82
lines changed

src/main/scala/PositExtractor.scala

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,28 @@ class PositExtractor(val totalBits: Int, val es: Int) extends Module with HasHar
99
val out = Output(new unpackedPosit(totalBits, es))
1010
})
1111

12-
io.out.sign := io.in(totalBits - 1)
13-
io.out.isZero := ~io.in.orR()
14-
io.out.isNaR := io.in(totalBits - 1) & (~io.in(totalBits - 2, 0).orR())
15-
io.out.stickyBit := false.B
12+
val sign = io.in(totalBits - 1)
13+
val absIn = Mux(sign, ~io.in + 1.U, io.in).asUInt()
1614

17-
val num = Mux(io.out.sign, ~io.in + 1.U, io.in).asUInt()
18-
19-
val regExpFrac = num(totalBits - 2, 0)
15+
val regExpFrac = absIn(totalBits - 2, 0)
2016
val regMaps =
2117
Array.range(0, totalBits - 2).reverse.map(index => {
2218
(!(regExpFrac(index + 1) === regExpFrac(index))) -> (totalBits - (index + 2)).U
2319
})
24-
val regimeLength =
20+
val regimeCount =
2521
Cat(0.U(1.W), MuxCase((totalBits - 1).U, regMaps))
2622
val regime =
27-
Mux(num(totalBits - 2), regimeLength - 1.U, ~regimeLength + 1.U)
23+
Mux(absIn(totalBits - 2), regimeCount - 1.U, ~regimeCount + 1.U)
2824

29-
val expFrac = num << (regimeLength + 2.U)
25+
val expFrac = absIn << (regimeCount + 2.U)
3026
val extractedExponent =
3127
if (es > 0) expFrac(totalBits - 1, totalBits - es)
3228
else 0.U
3329
val frac = expFrac << es
3430

35-
io.out.exponent := ((regime << es) | extractedExponent).asSInt
36-
io.out.fraction := Cat(1.U, frac(totalBits - 1, totalBits - maxFractionBits))
31+
io.out.sign := sign
32+
io.out.isZero := isZero(io.in)
33+
io.out.isNaR := isNaR(io.in)
34+
io.out.exponent := ((regime << es) | extractedExponent).asSInt
35+
io.out.fraction := Cat(1.U, frac(totalBits - 1, totalBits - maxFractionBits))
3736
}

src/main/scala/PositFMA.scala

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,77 +16,83 @@ class PositFMA(val totalBits: Int, val es: Int) extends Module with HasHardPosit
1616
val out = Output(UInt(totalBits.W))
1717
})
1818

19-
private val num1Extractor = Module(new PositExtractor(totalBits, es))
20-
private val num2Extractor = Module(new PositExtractor(totalBits, es))
21-
private val num3Extractor = Module(new PositExtractor(totalBits, es))
19+
val num1Extractor = Module(new PositExtractor(totalBits, es))
20+
val num2Extractor = Module(new PositExtractor(totalBits, es))
21+
val num3Extractor = Module(new PositExtractor(totalBits, es))
2222

2323
num2Extractor.io.in := io.num2
2424
num1Extractor.io.in := io.num1
2525
num3Extractor.io.in := io.num3
2626

27-
private val num1 = num1Extractor.io.out
28-
private val num2 = num2Extractor.io.out
29-
private val num3 = num3Extractor.io.out
27+
val num1 = num1Extractor.io.out
28+
val num2 = num2Extractor.io.out
29+
val num3 = num3Extractor.io.out
3030

31-
io.isNaR := num1.isNaR || num2.isNaR || num3.isNaR
32-
io.isZero := (num1.isZero || num2.isZero) && num3.isZero
31+
val productSign = num1.sign ^ num2.sign ^ io.negate
32+
val addendSign = num3.sign ^ io.negate ^ io.sub
3333

34-
private val productSign = num1.sign ^ num2.sign ^ io.negate
35-
private val addendSign = num3.sign ^ io.negate ^ io.sub
34+
val productExponent = num1.exponent + num2.exponent
35+
val productFraction =
36+
WireInit(UInt(maxProductFractionBits.W), num1.fraction * num2.fraction)
3637

37-
private val productExponent = num1.exponent + num2.exponent
38-
private val productFraction = WireInit(UInt(maxProductFractionBits.W), num1.fraction * num2.fraction)
38+
val prodOverflow = productFraction(maxProductFractionBits - 1)
39+
val normProductFraction = (productFraction >> prodOverflow.asUInt()).asUInt()
40+
val normProductExponent = productExponent + Mux(prodOverflow, 1.S, 0.S)
41+
val prodStickyBit = Mux(prodOverflow, productFraction(0), false.B)
3942

40-
private val prodOverflow = productFraction(maxProductFractionBits - 1)
41-
private val normProductFraction = (productFraction >> prodOverflow.asUInt()).asUInt()
42-
private val normProductExponent = productExponent + Mux(prodOverflow, 1.S, 0.S)
43-
private val prodStickyBit = Mux(prodOverflow, productFraction(0), false.B)
43+
val addendFraction = (num3.fraction << maxFractionBits).asUInt
44+
val addendExponent = num3.exponent
4445

45-
private val addendFraction = (num3.fraction << maxFractionBits).asUInt
46-
private val addendExponent = num3.exponent
46+
val isAddendLargerThanProduct =
47+
(addendExponent > normProductExponent) |
48+
(addendExponent === normProductExponent &&
49+
(addendFraction > normProductFraction))
4750

48-
private val isAddendLargerThanProduct = (addendExponent > normProductExponent) | (addendExponent === normProductExponent && (addendFraction > normProductFraction))
51+
val largeExp = Mux(isAddendLargerThanProduct, addendExponent, normProductExponent)
52+
val largeFrac = Mux(isAddendLargerThanProduct, addendFraction, normProductFraction)
53+
val largeSign = Mux(isAddendLargerThanProduct, addendSign, productSign)
4954

50-
private val largerExponent = Mux(isAddendLargerThanProduct, addendExponent, normProductExponent)
51-
private val largerFraction = Mux(isAddendLargerThanProduct, addendFraction, normProductFraction)
52-
private val largerSign = Mux(isAddendLargerThanProduct, addendSign, productSign)
55+
val smallExp = Mux(isAddendLargerThanProduct, normProductExponent, addendExponent)
56+
val smallFrac = Mux(isAddendLargerThanProduct, normProductFraction, addendFraction)
57+
val smallSign = Mux(isAddendLargerThanProduct, productSign, addendSign)
5358

54-
private val smallerExponent = Mux(isAddendLargerThanProduct, normProductExponent, addendExponent)
55-
private val smallerFraction = Mux(isAddendLargerThanProduct, normProductFraction, addendFraction)
56-
private val smallerSign = Mux(isAddendLargerThanProduct, productSign, addendSign)
59+
val expDiff = (largeExp - smallExp).asUInt()
60+
val shiftedSmallFrac =
61+
Mux(expDiff < maxProductFractionBits.U, smallFrac >> expDiff, 0.U)
62+
val smallFracStickyBit = (smallFrac & ((1.U << expDiff) - 1.U)).orR()
5763

58-
private val exponentDifference = (largerExponent - smallerExponent).asUInt()
59-
private val shiftedSmallerFraction = (smallerFraction >> exponentDifference).asUInt()
60-
private val smallFractionStickyBit = (smallerFraction & ((1.U << exponentDifference) - 1.U)).orR()
64+
val isAddition = ~(largeSign ^ smallSign)
65+
val signedSmallerFraction =
66+
Mux(isAddition, shiftedSmallFrac, ~shiftedSmallFrac + 1.U)
67+
val fmaFraction =
68+
WireInit(UInt(maxProductFractionBits.W), largeFrac +& signedSmallerFraction)
6169

62-
private val isAddition = ~(largerSign ^ smallerSign)
63-
private val signedSmallerFraction = Mux(isAddition, shiftedSmallerFraction, ~shiftedSmallerFraction + 1.U)
64-
private val fmaFraction = WireInit(UInt(maxProductFractionBits.W), largerFraction +& signedSmallerFraction)
70+
val sumOverflow = fmaFraction(maxProductFractionBits - 1)
71+
val adjFmaFraction =
72+
Mux(isAddition, fmaFraction >> sumOverflow.asUInt(), fmaFraction(maxProductFractionBits - 2, 0))
73+
val adjFmaExponent = largeExp + Mux(isAddition & sumOverflow, 1.S, 0.S)
74+
val sumStickyBit = Mux(isAddition & sumOverflow, fmaFraction(0), false.B)
6575

66-
private val sumOverflow = fmaFraction(maxProductFractionBits - 1)
67-
private val adjFmaFraction = Mux(isAddition, fmaFraction >> sumOverflow.asUInt(), fmaFraction(maxProductFractionBits - 2, 0))
68-
private val adjFmaExponent = largerExponent + Mux(isAddition & sumOverflow, 1.S, 0.S)
69-
private val sumStickyBit = Mux(isAddition & sumOverflow, fmaFraction(0), false.B)
70-
71-
private val normalizationFactor = MuxCase(0.S, Array.range(0, maxProductFractionBits - 2).map(index => {
76+
val normalizationFactor = MuxCase(0.S, Array.range(0, maxProductFractionBits - 2).map(index => {
7277
(adjFmaFraction(maxProductFractionBits - 2, maxProductFractionBits - index - 2) === 1.U) -> index.S
7378
}))
7479

75-
private val normFmaExponent = adjFmaExponent - normalizationFactor
76-
private val normFmaFraction = adjFmaFraction << normalizationFactor.asUInt()
80+
val normFmaExponent = adjFmaExponent - normalizationFactor
81+
val normFmaFraction = adjFmaFraction << normalizationFactor.asUInt()
7782

78-
private val result = Wire(new unpackedPosit(totalBits, es))
79-
result.isNaR := num1.isNaR || num2.isNaR || num3.isNaR
80-
result.isZero := (num1.isZero || num2.isZero) && num3.isZero
81-
result.sign := largerSign
83+
val result = Wire(new unpackedPosit(totalBits, es))
84+
result.isNaR := num1.isNaR || num2.isNaR || num3.isNaR
85+
result.isZero := (num1.isZero || num2.isZero) && num3.isZero
86+
result.sign := largeSign
8287
result.exponent := normFmaExponent
8388
result.fraction := (normFmaFraction >> maxFractionBits).asUInt()
84-
result.stickyBit := prodStickyBit | sumStickyBit | smallFractionStickyBit | normFmaFraction(maxFractionBits - 1, 0).orR()
8589

86-
private val positGenerator = Module(new PositGenerator(totalBits, es))
90+
val positGenerator = Module(new PositGenerator(totalBits, es))
8791
positGenerator.io.in <> result
92+
positGenerator.io.trailingBits := normFmaFraction(maxFractionBits - 1, maxFractionBits - trailingBitCount)
93+
positGenerator.io.stickyBit := prodStickyBit | sumStickyBit | smallFracStickyBit | normFmaFraction(maxFractionBits - trailingBitCount - 1, 0).orR()
8894

89-
io.isNaR := result.isNaR || (positGenerator.io.out === NaR)
95+
io.isNaR := result.isNaR || (positGenerator.io.out === NaR)
9096
io.isZero := result.isZero || (positGenerator.io.out === 0.U)
91-
io.out := Mux(result.isNaR, NaR, positGenerator.io.out)
97+
io.out := Mux(result.isNaR, NaR, positGenerator.io.out)
9298
}

src/main/scala/PositGenerator.scala

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,53 @@ class PositGenerator(val totalBits: Int, val es: Int) extends Module with HasHar
77

88
val io = IO(new Bundle {
99
val in = Input(new unpackedPosit(totalBits, es))
10+
val trailingBits = Input(UInt(trailingBitCount.W))
11+
val stickyBit = Input(Bool())
1012
val out = Output(UInt(totalBits.W))
1113
})
1214

13-
private val exponentOffset = PriorityMux(Array.range(0, maxFractionBits).map(index => {
15+
val exponentOffset = PriorityMux(Array.range(0, maxFractionBits).map(index => { //TODO Remove normalization check
1416
(io.in.fraction(maxFractionBits, maxFractionBits - index) === 1.U) -> index.S
1517
}))
16-
private val normalisedExponent = io.in.exponent - exponentOffset
17-
private val normalisedFraction = (io.in.fraction << exponentOffset.asUInt()) (maxFractionBits - 1, 0)
18-
private val negExponent = normalisedExponent < 0.S
1918

20-
private val positRegime = Mux(negExponent, -(normalisedExponent >> es), normalisedExponent >> es).asUInt()
21-
private val positExponent = normalisedExponent(if (es > 0) es - 1 else 0, 0)
22-
private val positOffset = positRegime + es.U + Mux(negExponent, 2.U, 3.U)
19+
val normalisedExponent = io.in.exponent - exponentOffset
20+
val normalisedFraction =
21+
(io.in.fraction << exponentOffset.asUInt()) (maxFractionBits - 1, 0)
22+
val negExp = normalisedExponent < 0.S
2323

24-
private val regimeBits = Mux(negExponent, 1.U << (positRegime + 1.U) >> (positRegime + 1.U), (1.U << positRegime + 2.U).asUInt() - 2.U)
25-
private val regimeWithExponentBits = if (es > 0) Cat(regimeBits, positExponent) else regimeBits
24+
val positRegime =
25+
Mux(negExp, -(normalisedExponent >> es), normalisedExponent >> es).asUInt()
26+
val positExponent = normalisedExponent(if (es > 0) es - 1 else 0, 0)
27+
val positOffset =
28+
positRegime - (negExp & positRegime =/= (totalBits - 1).U) + trailingBitCount.U
29+
30+
val regimeBits =
31+
Mux(negExp, 1.U << (positRegime + 1.U) >> (positRegime + 1.U),
32+
(1.U << positRegime + 2.U).asUInt() - 2.U)
33+
val regimeWithExponentBits =
34+
if (es > 0) Cat(regimeBits, positExponent)
35+
else regimeBits
2636

2737
//u => un ; T => Trimmed ; R => Rounded ; S => Signed
28-
private val uT_uS_posit = Cat(regimeWithExponentBits, normalisedFraction)
29-
private val uR_uS_posit = (uT_uS_posit >> positOffset) (totalBits - 2, 0)
38+
val uT_uS_posit = Cat(regimeWithExponentBits, normalisedFraction, io.trailingBits)
39+
val uR_uS_posit = (uT_uS_posit >> positOffset) (totalBits - 2, 0)
40+
41+
val trailingBits = (uT_uS_posit & ((1.U << positOffset) - 1.U)).asUInt()
42+
val gr = (trailingBits >> (positOffset - 2.U)) (1, 0)
43+
val stickyBit =
44+
io.stickyBit | (trailingBits & ((1.U << (positOffset - 2.U)) - 1.U)).orR()
45+
val roundingBit =
46+
Mux(uR_uS_posit.andR(), false.B,
47+
gr(1) & ~(~uR_uS_posit(0) & gr(1) & ~gr(0) & ~stickyBit))
48+
val R_uS_posit = uR_uS_posit + roundingBit
3049

31-
private val trailingBits = (uT_uS_posit & ((1.U << positOffset) - 1.U)).asUInt()
32-
private val lastBit = uR_uS_posit(0)
33-
private val afterBit = (trailingBits >> (positOffset - 1.U)) (0)
34-
private val stickyBit = io.in.stickyBit | (trailingBits & ((1.U << (positOffset - 1.U)) - 1.U)).orR()
35-
private val roundingBit = Mux(uR_uS_posit.andR(), false.B, (lastBit & afterBit) | (afterBit & stickyBit))
50+
//Underflow Correction
51+
val uFC_R_uS_posit =
52+
Cat(0.U(1.W), R_uS_posit | (R_uS_posit === 0.U))
3653

37-
private val R_uS_posit = uR_uS_posit + roundingBit
38-
private val uFC_R_uS_posit = Cat(0.U(1.W), R_uS_posit | (R_uS_posit === 0.U))
39-
private val R_S_posit = Mux(io.in.sign, ~uFC_R_uS_posit + 1.U, uFC_R_uS_posit)
54+
val R_S_posit =
55+
Mux(io.in.sign, ~uFC_R_uS_posit + 1.U, uFC_R_uS_posit)
4056

4157
io.out := Mux(io.in.isNaR, NaR,
42-
Mux((io.in.fraction === 0.U) | io.in.isZero, 0.U, R_S_posit))
58+
Mux((io.in.fraction === 0.U) | io.in.isZero, zero, R_S_posit))
4359
}

src/main/scala/common.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ class unpackedPosit(val totalBits: Int, val es: Int) extends Bundle with HasHard
1818

1919
val sign = Bool()
2020
val exponent = SInt(maxExponentBits.W)
21-
val fraction = UInt(maxFractionBitsWithHiddenBit.W)
21+
val fraction = UInt(maxFractionBitsWithHiddenBit.W) //TODO Transfer only fraction bits without hidden bit
2222
val isZero = Bool()
2323
val isNaR = Bool()
24-
val stickyBit = Bool()
2524

2625
override def cloneType =
2726
new unpackedPosit(totalBits, es).asInstanceOf[this.type]
@@ -33,7 +32,7 @@ trait HasHardPositParams {
3332

3433
def maxExponentBits: Int = log2Ceil(totalBits) + es + 2
3534

36-
def maxFractionBits: Int = totalBits//if (es + 2 >= totalBits) 0 else totalBits - 3 - es
35+
def maxFractionBits: Int = if (es + 2 >= totalBits) 0 else totalBits - 3 - es
3736

3837
def maxFractionBitsWithHiddenBit: Int = maxFractionBits + 1
3938

@@ -47,4 +46,5 @@ trait HasHardPositParams {
4746

4847
def isZero(num: UInt): Bool = ~num.orR()
4948

49+
def trailingBitCount = 2
5050
}

0 commit comments

Comments
 (0)