Skip to content

Commit

Permalink
Pipeline FP Adder (#126)
Browse files Browse the repository at this point in the history
* pipeline the fp adder

* shorten fp tests, length time limit

* fix fpv construction issue for ofInts(), ofBigInts()
  • Loading branch information
desmonddak authored Nov 6, 2024
1 parent 7499025 commit 062ea0d
Show file tree
Hide file tree
Showing 13 changed files with 246 additions and 223 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/general.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
run-checks:
name: Run Checks
permissions: {}
timeout-minutes: 30
timeout-minutes: 60
runs-on: ${{ github.repository_owner == 'intel' && 'intel-ubuntu-latest' || 'ubuntu-latest' }}
steps:
- name: Checkout
Expand Down
2 changes: 1 addition & 1 deletion doc/components/floating_point.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ A very basic `FloatingPointAdderSimple` component is available which does not pe

Currently, the `FloatingPointAdderSimple` is close in accuracy (as it has no rounding) and is not optimized for circuit performance, but only provides the key functionalities of alignment, addition, and normalization. Still, this component is a starting point for more realistic floating-point components that leverage the logical `FloatingPoint` and literal `FloatingPointValue` type abstractions.

A second `FloatingPointAdderRound` component is available which does perform rounding. It is based on "Delay-Optimized Implementation of IEEE Floating-Point Addition", by Peter-Michael Seidel and Guy Even, using an R-path and an N-path to process far-apart exponents and use rounding and an N-path for exponents within 2 and subtraction, which is exact.
A second `FloatingPointAdderRound` component is available which does perform rounding. It is based on "Delay-Optimized Implementation of IEEE Floating-Point Addition", by Peter-Michael Seidel and Guy Even, using an R-path and an N-path to process far-apart exponents and use rounding and an N-path for exponents within 2 and subtraction, which is exact. If you pass in an optional clock, a pipestage will be added to help optimize frequency; an optional reset and enable are can control the pipestage.
144 changes: 116 additions & 28 deletions lib/src/arithmetic/floating_point/floating_point_adder_round.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,57 @@ import 'package:meta/meta.dart';
import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/rohd_hcl.dart';

/// Conditionally constructs a positive edge triggered flip condFlop on [clk].
///
/// It returns either [FlipFlop.q] if [clk] is valid or [d] if not.
///
/// When the optional [en] is provided, an additional input will be created for
/// condFlop. If optional [en] is high or not provided, output will vary as per
/// input[d]. For low [en], output remains frozen irrespective of input [d].
///
/// When the optional [reset] is provided, the condFlop will be reset
/// (active-high).
/// If no [resetValue] is provided, the reset value is always `0`. Otherwise,
/// it will reset to the provided [resetValue].
Logic condFlop(
Logic? clk,
Logic d, {
Logic? en,
Logic? reset,
dynamic resetValue,
}) =>
(clk == null)
? d
: flop(
clk,
d,
en: en,
reset: reset,
resetValue: resetValue,
);

/// An adder module for variable FloatingPoint type with rounding.
// This is the Seidel/Even adder, dual-path
// This is a Seidel/Even adder, dual-path implementation.
class FloatingPointAdderRound extends Module {
/// Must be greater than 0.
final int exponentWidth;

/// Must be greater than 0.
final int mantissaWidth;

/// Output [FloatingPoint] computed
/// The [clk]: if a valid clock signal is passed in, a pipestage is added to
/// the adder to help optimize frequency.
Logic? clk;

/// Optional [reset], used only if a [clk] is not null to reset the pipeline
/// flops.
Logic? reset;

/// Optional [enable], used only if a [clk] is not null to enable the pipeline
/// flops.
Logic? enable;

/// Output [FloatingPoint] representing the sum of two input [FloatingPoint]s
late final FloatingPoint sum =
FloatingPoint(exponentWidth: exponentWidth, mantissaWidth: mantissaWidth)
..gets(output('sum'));
Expand All @@ -48,6 +89,9 @@ class FloatingPointAdderRound extends Module {
/// functions.
FloatingPointAdderRound(FloatingPoint a, FloatingPoint b,
{Logic? subtract,
this.clk,
this.reset,
this.enable,
Adder Function(Logic, Logic) adderGen = ParallelPrefixAdder.new,
ParallelPrefix Function(List<Logic>, Logic Function(Logic, Logic))
ppTree = KoggeStone.new,
Expand All @@ -58,6 +102,15 @@ class FloatingPointAdderRound extends Module {
b.mantissa.width != mantissaWidth) {
throw RohdHclException('FloatingPoint widths must match');
}
if (clk != null) {
clk = addInput('clk', clk!);
}
if (reset != null) {
reset = addInput('reset', reset!);
}
if (enable != null) {
enable = addInput('enable', enable!);
}
a = a.clone()..gets(addInput('a', a, width: a.width));
b = b.clone()..gets(addInput('b', b, width: b.width));
addOutput('sum', width: _sum.width) <= _sum;
Expand Down Expand Up @@ -108,17 +161,35 @@ class FloatingPointAdderRound extends Module {
smallerAlignRPath.width - 1,
smallerAlignRPath.width - largeOperand.width);

/// R Pipestage here:
final aIsNormalLatched =
condFlop(clk, a.isNormal(), en: enable, reset: reset);
final bIsNormalLatched =
condFlop(clk, b.isNormal(), en: enable, reset: reset);
final effectiveSubtractionLatched =
condFlop(clk, effectiveSubtraction, en: enable, reset: reset);
final largeOperandLatched =
condFlop(clk, largeOperand, en: enable, reset: reset);
final smallerOperandRPathLatched =
condFlop(clk, smallerOperandRPath, en: enable, reset: reset);
final smallerAlignRPathLatched =
condFlop(clk, smallerAlignRPath, en: enable, reset: reset);
final largerExpLatched =
condFlop(clk, larger.exponent, en: enable, reset: reset);
final deltaLatched = condFlop(clk, delta, en: enable, reset: reset);

final carryRPath = Logic();
final significandAdderRPath = OnesComplementAdder(
largeOperand, smallerOperandRPath,
subtractIn: effectiveSubtraction,
largeOperandLatched, smallerOperandRPathLatched,
subtractIn: effectiveSubtractionLatched,
carryOut: carryRPath,
adderGen: adderGen);

final lowBitsRPath = smallerAlignRPath.slice(extendWidthRPath - 1, 0);
final lowBitsRPath =
smallerAlignRPathLatched.slice(extendWidthRPath - 1, 0);
final lowAdderRPath = OnesComplementAdder(
carryRPath.zeroExtend(extendWidthRPath),
mux(effectiveSubtraction, ~lowBitsRPath, lowBitsRPath),
mux(effectiveSubtractionLatched, ~lowBitsRPath, lowBitsRPath),
adderGen: adderGen);

final preStickyRPath =
Expand All @@ -135,8 +206,10 @@ class FloatingPointAdderRound extends Module {
final sumP1RPath =
(significandAdderRPath.sum + 1).slice(mantissaWidth + 1, 0);

final sumLeadZeroRPath = ~sumRPath[-1] & (a.isNormal() | b.isNormal());
final sumP1LeadZeroRPath = ~sumP1RPath[-1] & (a.isNormal() | b.isNormal());
final sumLeadZeroRPath =
~sumRPath[-1] & (aIsNormalLatched | bIsNormalLatched);
final sumP1LeadZeroRPath =
~sumP1RPath[-1] & (aIsNormalLatched | bIsNormalLatched);

final selectRPath = lowAdderRPath.sum[-1];
final shiftGRSRPath = [earlyGRSRPath[2], stickyBitRPath].swizzle();
Expand All @@ -162,25 +235,26 @@ class FloatingPointAdderRound extends Module {

final firstZeroRPath = mux(selectRPath, ~sumP1RPath[-1], ~sumRPath[-1]);

final exponentRPath = Logic(width: larger.exponent.width);
final exponentRPath = Logic(width: exponentWidth);
Combinational([
If.block([
// Subtract 1 from exponent
Iff(~incExpRPath & effectiveSubtraction & firstZeroRPath, [
exponentRPath < ParallelPrefixDecr(larger.exponent, ppGen: ppTree).out
Iff(~incExpRPath & effectiveSubtractionLatched & firstZeroRPath, [
exponentRPath <
ParallelPrefixDecr(largerExpLatched, ppGen: ppTree).out
]),
// Add 1 to exponent
ElseIf(
~effectiveSubtraction &
~effectiveSubtractionLatched &
(incExpRPath & firstZeroRPath | ~incExpRPath & ~firstZeroRPath),
[
exponentRPath <
ParallelPrefixIncr(larger.exponent, ppGen: ppTree).out
ParallelPrefixIncr(largerExpLatched, ppGen: ppTree).out
]),
// Add 2 to exponent
ElseIf(incExpRPath & effectiveSubtraction & ~firstZeroRPath,
[exponentRPath < larger.exponent << 1]),
Else([exponentRPath < larger.exponent])
ElseIf(incExpRPath & effectiveSubtractionLatched & ~firstZeroRPath,
[exponentRPath < largerExpLatched << 1]),
Else([exponentRPath < largerExpLatched])
])
]);

Expand Down Expand Up @@ -208,39 +282,53 @@ class FloatingPointAdderRound extends Module {
.zeroExtend(exponentWidth),
Const(15, width: exponentWidth));

// N pipestage here:
final significandNPathLatched =
condFlop(clk, significandNPath, en: enable, reset: reset);
final significandSubtractorNPathSignLatched = condFlop(
clk, significandSubtractorNPath.sign,
en: enable, reset: reset);
final leadOneNPathLatched =
condFlop(clk, leadOneNPath, en: enable, reset: reset);
final largerSignLatched =
condFlop(clk, larger.sign, en: enable, reset: reset);
final smallerSignLatched =
condFlop(clk, smaller.sign, en: enable, reset: reset);

final expCalcNPath = OnesComplementAdder(
larger.exponent, leadOneNPath.zeroExtend(larger.exponent.width),
subtractIn: effectiveSubtraction, adderGen: adderGen);
largerExpLatched, leadOneNPathLatched.zeroExtend(exponentWidth),
subtractIn: effectiveSubtractionLatched, adderGen: adderGen);

final preExpNPath = expCalcNPath.sum.slice(exponentWidth - 1, 0);

final posExpNPath = preExpNPath.or() & ~expCalcNPath.sign;

final exponentNPath = mux(posExpNPath, preExpNPath, zeroExp);

final preMinShiftNPath = ~leadOneNPath.or() | ~larger.exponent.or();
final preMinShiftNPath = ~leadOneNPathLatched.or() | ~largerExpLatched.or();

final minShiftNPath =
mux(posExpNPath | preMinShiftNPath, leadOneNPath, larger.exponent - 1);
final notSubnormalNPath = a.isNormal() | b.isNormal();
final minShiftNPath = mux(posExpNPath | preMinShiftNPath,
leadOneNPathLatched, largerExpLatched - 1);
final notSubnormalNPath = aIsNormalLatched | bIsNormalLatched;

final shiftedSignificandNPath =
(significandNPath << minShiftNPath).slice(mantissaWidth, 1);
(significandNPathLatched << minShiftNPath).slice(mantissaWidth, 1);

final finalSignificandNPath = mux(
notSubnormalNPath,
shiftedSignificandNPath,
significandNPath.slice(significandNPath.width - 1, 2));
significandNPathLatched.slice(significandNPathLatched.width - 1, 2));

final signNPath =
mux(significandSubtractorNPath.sign, smaller.sign, larger.sign);
final signNPath = mux(significandSubtractorNPathSignLatched,
smallerSignLatched, largerSignLatched);

final isR = delta.gte(Const(2, width: delta.width)) | ~effectiveSubtraction;
final isR = deltaLatched.gte(Const(2, width: delta.width)) |
~effectiveSubtractionLatched;
_sum <=
mux(
isR,
[
larger.sign,
largerSignLatched,
exponentRPath,
mantissaRPath.slice(mantissaRPath.width - 2, 1)
].swizzle(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ class FloatingPoint16Value extends FloatingPointValue {
/// [FloatingPoint16Value] constructor from a set of [BigInt]s of the binary
/// representation
FloatingPoint16Value.ofBigInts(super.exponent, super.mantissa, {super.sign})
: super.ofBigInts();
: super.ofBigInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// [FloatingPoint16Value] constructor from a set of [int]s of the binary
/// representation
FloatingPoint16Value.ofInts(super.exponent, super.mantissa, {super.sign})
: super.ofInts();
: super.ofInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// Numeric conversion of a [FloatingPoint16Value] from a host double
factory FloatingPoint16Value.ofDouble(double inDouble) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ class FloatingPoint32Value extends FloatingPointValue {
/// [FloatingPoint32Value] constructor from a set of [BigInt]s of the binary
/// representation
FloatingPoint32Value.ofBigInts(super.exponent, super.mantissa, {super.sign})
: super.ofBigInts();
: super.ofBigInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// [FloatingPoint32Value] constructor from a set of [int]s of the binary
/// representation
FloatingPoint32Value.ofInts(super.exponent, super.mantissa, {super.sign})
: super.ofInts();
: super.ofInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// Numeric conversion of a [FloatingPoint32Value] from a host double
factory FloatingPoint32Value.ofDouble(double inDouble) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ class FloatingPoint64Value extends FloatingPointValue {
/// [FloatingPoint64Value] constructor from a set of [BigInt]s of the binary
/// representation
FloatingPoint64Value.ofBigInts(super.exponent, super.mantissa, {super.sign})
: super.ofBigInts();
: super.ofBigInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// [FloatingPoint64Value] constructor from a set of [int]s of the binary
/// representation
FloatingPoint64Value.ofInts(super.exponent, super.mantissa, {super.sign})
: super.ofInts();
: super.ofInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// Numeric conversion of a [FloatingPoint64Value] from a host double
factory FloatingPoint64Value.ofDouble(double inDouble) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,14 @@ class FloatingPoint8E5M2Value extends FloatingPointValue {
/// binary representation
FloatingPoint8E5M2Value.ofBigInts(super.exponent, super.mantissa,
{super.sign})
: super.ofBigInts();
: super.ofBigInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// [FloatingPoint8E5M2Value] constructor from a set of [int]s of the binary
/// representation
FloatingPoint8E5M2Value.ofInts(super.exponent, super.mantissa, {super.sign})
: super.ofInts();
: super.ofInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// Numeric conversion of a [FloatingPoint8E5M2Value] from a host double
factory FloatingPoint8E5M2Value.ofDouble(double inDouble) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ class FloatingPointBF16Value extends FloatingPointValue {
/// [FloatingPointBF16Value] constructor from a set of [BigInt]s of the binary
/// representation
FloatingPointBF16Value.ofBigInts(super.exponent, super.mantissa, {super.sign})
: super.ofBigInts();
: super.ofBigInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// [FloatingPointBF16Value] constructor from a set of [int]s of the binary
/// representation
FloatingPointBF16Value.ofInts(super.exponent, super.mantissa, {super.sign})
: super.ofInts();
: super.ofInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// Numeric conversion of a [FloatingPointBF16Value] from a host double
factory FloatingPointBF16Value.ofDouble(double inDouble) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ class FloatingPointTF32Value extends FloatingPointValue {
/// [FloatingPointTF32Value] constructor from a set of [BigInt]s of the binary
/// representation
FloatingPointTF32Value.ofBigInts(super.exponent, super.mantissa, {super.sign})
: super.ofBigInts();
: super.ofBigInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// [FloatingPointTF32Value] constructor from a set of [int]s of the binary
/// representation
FloatingPointTF32Value.ofInts(super.exponent, super.mantissa, {super.sign})
: super.ofInts();
: super.ofInts(
exponentWidth: exponentWidth, mantissaWidth: mantissaWidth);

/// Numeric conversion of a [FloatingPointTF32Value] from a host double
factory FloatingPointTF32Value.ofDouble(double inDouble) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,16 +552,19 @@ class FloatingPointValue implements Comparable<FloatingPointValue> {
(mantissa.width != other.mantissa.width)) {
throw Exception('FloatingPointValue widths must match for comparison');
}
final signCompare = sign.compareTo(other.sign);
if (signCompare != 0) {
return signCompare;
// } else {
final expCompare = exponent.compareTo(other.exponent);
final mantCompare = mantissa.compareTo(other.mantissa);
if (expCompare != 0) {
return expCompare;
} else if (mantCompare != 0) {
return mantCompare;
} else {
final expCompare = exponent.compareTo(other.exponent);
if (expCompare != 0) {
return expCompare;
} else {
return mantissa.compareTo(other.mantissa);
final signCompare = sign.compareTo(other.sign);
if ((signCompare != 0) && !(exponent.isZero && mantissa.isZero)) {
return signCompare;
}
return 0;
}
}

Expand All @@ -578,6 +581,11 @@ class FloatingPointValue implements Comparable<FloatingPointValue> {
(mantissa.width != other.mantissa.width)) {
return false;
}
// IEEE 754: -0 an +0 are considered equal
if ((exponent.isZero && mantissa.isZero) &&
(other.exponent.isZero && other.mantissa.isZero)) {
return true;
}

return (sign == other.sign) &
(exponent == other.exponent) &
Expand Down
Loading

0 comments on commit 062ea0d

Please sign in to comment.