From 2eba218896a16b360c01a686b0fca0c53235f3c3 Mon Sep 17 00:00:00 2001 From: Desmond Kirkpatrick Date: Mon, 14 Oct 2024 14:17:51 -0700 Subject: [PATCH 1/5] Fpadder (#106) * floating point adder with rounding * added floating_point_adder documentation, passed ParallelPrefix/Adder options into fpadder * added floating point adder configurator --- doc/components/floating_point.md | 6 +- lib/src/arithmetic/arithmetic.dart | 1 + .../arithmetic/evaluate_partial_product.dart | 1 - .../floating_point/floating_point.dart | 3 +- .../floating_point_adder_round.dart | 249 +++++++++++ ....dart => floating_point_adder_simple.dart} | 8 +- .../floating_point/floating_point_value.dart | 17 +- .../components/component_registry.dart | 1 + .../components/components.dart | 1 + .../config_floating_point_adder_round.dart | 71 ++++ .../floating_point_adder_round_test.dart | 393 ++++++++++++++++++ ... => floating_point_adder_simple_test.dart} | 34 +- 12 files changed, 756 insertions(+), 29 deletions(-) create mode 100644 lib/src/arithmetic/floating_point/floating_point_adder_round.dart rename lib/src/arithmetic/floating_point/{floating_point_adder.dart => floating_point_adder_simple.dart} (93%) create mode 100644 lib/src/component_config/components/config_floating_point_adder_round.dart create mode 100644 test/arithmetic/floating_point/floating_point_adder_round_test.dart rename test/arithmetic/floating_point/{floating_point_adder_test.dart => floating_point_adder_simple_test.dart} (92%) diff --git a/doc/components/floating_point.md b/doc/components/floating_point.md index 8f525226..4e8360a6 100644 --- a/doc/components/floating_point.md +++ b/doc/components/floating_point.md @@ -32,6 +32,8 @@ Again, like `FloatingPointValue`, `FloatingPoint64` and `FloatingPoint32` subcla ## FloatingPointAdder -A very basic `FloatingPointAdder` component is available which does not perform any rounding. It takes two `FloatingPoint` `LogicStructure`s and adds them, returning a normalized `FloatingPoint` on the output. An option on input is the type of `ParallelPrefixTree` used in the internal addition of the mantissas. +A very basic `FloatingPointAdderSimple` component is available which does not perform any rounding. It takes two `FloatingPoint` `LogicStructure`s and adds them, returning a normalized `FloatingPoint` on the output. An option on input is the type of `ParallelPrefixTree` used in the internal addition of the mantissas. -Currently, the `FloatingPointAdder` is close in accuracy (as it has no rounding) and is not optimized for circuit performance, but only provides the key functionalities of alignment, addition, and normalization. Still, this component is a starting point for more realistic floating-point components that leverage the logical `FloatingPoint` and literal `FloatingPointValue` type abstractions. +Currently, the `FloatingPointAdderSimple` is close in accuracy (as it has no rounding) and is not optimized for circuit performance, but only provides the key functionalities of alignment, addition, and normalization. Still, this component is a starting point for more realistic floating-point components that leverage the logical `FloatingPoint` and literal `FloatingPointValue` type abstractions. + +A second `FloatingPointAdderRound` component is available which does perform rounding. It is based on "Delay-Optimized Implementation of IEEE Floating-Point Addition", by Peter-Michael Seidel and Guy Even, using an R-path and an N-path to process far-apart exponents and use rounding and an N-path for exponents within 2 and subtraction, which is exact. diff --git a/lib/src/arithmetic/arithmetic.dart b/lib/src/arithmetic/arithmetic.dart index af28b3a3..6cfff87e 100644 --- a/lib/src/arithmetic/arithmetic.dart +++ b/lib/src/arithmetic/arithmetic.dart @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause export 'adder.dart'; +export 'arithmetic_utils.dart'; export 'carry_save_mutiplier.dart'; export 'compound_adder.dart'; export 'divider.dart'; diff --git a/lib/src/arithmetic/evaluate_partial_product.dart b/lib/src/arithmetic/evaluate_partial_product.dart index 970b35df..71dcc9cb 100644 --- a/lib/src/arithmetic/evaluate_partial_product.dart +++ b/lib/src/arithmetic/evaluate_partial_product.dart @@ -9,7 +9,6 @@ import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; -import 'package:rohd_hcl/src/arithmetic/arithmetic_utils.dart'; /// Debug routines for printing out partial product matrix during /// simulation with live logic values diff --git a/lib/src/arithmetic/floating_point/floating_point.dart b/lib/src/arithmetic/floating_point/floating_point.dart index 23156957..de0b3b40 100644 --- a/lib/src/arithmetic/floating_point/floating_point.dart +++ b/lib/src/arithmetic/floating_point/floating_point.dart @@ -1,6 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -export 'floating_point_adder.dart'; +export 'floating_point_adder_round.dart'; +export 'floating_point_adder_simple.dart'; export 'floating_point_logic.dart'; export 'floating_point_value.dart'; diff --git a/lib/src/arithmetic/floating_point/floating_point_adder_round.dart b/lib/src/arithmetic/floating_point/floating_point_adder_round.dart new file mode 100644 index 00000000..0b928f50 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_adder_round.dart @@ -0,0 +1,249 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_adder_round.dart +// A variable-width floating point adder with rounding +// +// 2024 August 30 +// Author: Desmond A Kirkpatrick + ( + toSwap.$1.clone()..gets(mux(swap, toSwap.$2, toSwap.$1)), + toSwap.$2.clone()..gets(mux(swap, toSwap.$1, toSwap.$2)) + ); + + /// Add two floating point numbers [a] and [b], returning result in [sum]. + /// [subtract] is an optional Logic input to do subtraction + /// [adderGen] is an adder generator to be used in the primary adder + /// functions. + /// [ppTree] is an ParallelPrefix generator for use in increment /decrement + /// functions. + FloatingPointAdderRound(FloatingPoint a, FloatingPoint b, + {Logic? subtract, + Adder Function(Logic, Logic) adderGen = ParallelPrefixAdder.new, + ParallelPrefix Function(List, Logic Function(Logic, Logic)) + ppTree = KoggeStone.new, + super.name = 'floating_point_adder'}) + : exponentWidth = a.exponent.width, + mantissaWidth = a.mantissa.width { + if (b.exponent.width != exponentWidth || + b.mantissa.width != mantissaWidth) { + throw RohdHclException('FloatingPoint widths must match'); + } + a = a.clone()..gets(addInput('a', a, width: a.width)); + b = b.clone()..gets(addInput('b', b, width: b.width)); + addOutput('sum', width: _sum.width) <= _sum; + + final exponentSubtractor = OnesComplementAdder(a.exponent, b.exponent, + subtract: true, adderGen: adderGen); + final signDelta = exponentSubtractor.sign; + + final delta = exponentSubtractor.sum; + + // Seidel: (sl, el, fl) = larger; (ss, es, fs) = smaller + final (larger, smaller) = _swap(signDelta, (a, b)); + + final fl = mux( + larger.isNormal(), + [larger.isNormal(), larger.mantissa].swizzle(), + [larger.mantissa, Const(0)].swizzle()); + final fs = mux( + smaller.isNormal(), + [smaller.isNormal(), smaller.mantissa].swizzle(), + [smaller.mantissa, Const(0)].swizzle()); + + // Seidel: S.EFF = effectiveSubtraction + final effectiveSubtraction = a.sign ^ b.sign ^ (subtract ?? Const(0)); + + // Seidel: flp larger preshift, normally in [2,4) + final sigWidth = fl.width + 1; + final largeShift = mux(effectiveSubtraction, fl.zeroExtend(sigWidth) << 1, + fl.zeroExtend(sigWidth)); + final smallShift = mux(effectiveSubtraction, fs.zeroExtend(sigWidth) << 1, + fs.zeroExtend(sigWidth)); + + final zeroExp = Const(0, width: exponentWidth); + + final largeOperand = largeShift; + // + // R Datapath: Far exponents or addition + // + final extendWidthRPath = + min(mantissaWidth + 3, pow(2, exponentWidth).toInt() - 3); + + final smallerFullRPath = + [smallShift, Const(0, width: extendWidthRPath)].swizzle(); + smallerFullRPath <= smallerFullRPath.withSet(extendWidthRPath, smallShift); + + final smallerAlignRPath = smallerFullRPath >>> exponentSubtractor.sum; + final smallerOperandRPath = smallerAlignRPath.slice( + smallerAlignRPath.width - 1, + smallerAlignRPath.width - largeOperand.width); + + final carryRPath = Logic(); + final significandAdderRPath = OnesComplementAdder( + largeOperand, smallerOperandRPath, + subtractIn: effectiveSubtraction, + carryOut: carryRPath, + adderGen: adderGen); + + final lowBitsRPath = smallerAlignRPath.slice(extendWidthRPath - 1, 0); + final lowAdderRPath = OnesComplementAdder( + carryRPath.zeroExtend(extendWidthRPath), + mux(effectiveSubtraction, ~lowBitsRPath, lowBitsRPath), + adderGen: adderGen); + + final preStickyRPath = + lowAdderRPath.sum.slice(lowAdderRPath.sum.width - 4, 0).or(); + final stickyBitRPath = lowAdderRPath.sum[-3] | preStickyRPath; + + final earlyGRSRPath = [ + lowAdderRPath.sum + .slice(lowAdderRPath.sum.width - 2, lowAdderRPath.sum.width - 3), + preStickyRPath + ].swizzle(); + + final sumRPath = significandAdderRPath.sum.slice(mantissaWidth + 1, 0); + final sumP1RPath = + (significandAdderRPath.sum + 1).slice(mantissaWidth + 1, 0); + + final sumLeadZeroRPath = ~sumRPath[-1] & (a.isNormal() | b.isNormal()); + final sumP1LeadZeroRPath = ~sumP1RPath[-1] & (a.isNormal() | b.isNormal()); + + final selectRPath = lowAdderRPath.sum[-1]; + final shiftGRSRPath = [earlyGRSRPath[2], stickyBitRPath].swizzle(); + final mergedSumRPath = mux( + sumLeadZeroRPath, + [sumRPath, earlyGRSRPath].swizzle().slice(sumRPath.width + 1, 0), + [sumRPath, shiftGRSRPath].swizzle()); + + final mergedSumP1RPath = mux( + sumP1LeadZeroRPath, + [sumP1RPath, earlyGRSRPath].swizzle().slice(sumRPath.width + 1, 0), + [sumP1RPath, shiftGRSRPath].swizzle()); + + final finalSumLGRSRPath = + mux(selectRPath, mergedSumP1RPath, mergedSumRPath); + // RNE: guard & (lsb | round | sticky) + final rndRPath = finalSumLGRSRPath[2] & + (finalSumLGRSRPath[3] | finalSumLGRSRPath[1] | finalSumLGRSRPath[0]); + + // Rounding from 1111 to 0000. + final incExpRPath = + rndRPath & sumLeadZeroRPath.eq(Const(1)) & sumP1LeadZeroRPath.eq(0); + + final firstZeroRPath = mux(selectRPath, ~sumP1RPath[-1], ~sumRPath[-1]); + + final exponentRPath = Logic(width: larger.exponent.width); + Combinational([ + If.block([ + // Subtract 1 from exponent + Iff(~incExpRPath & effectiveSubtraction & firstZeroRPath, [ + exponentRPath < ParallelPrefixDecr(larger.exponent, ppGen: ppTree).out + ]), + // Add 1 to exponent + ElseIf( + ~effectiveSubtraction & + (incExpRPath & firstZeroRPath | ~incExpRPath & ~firstZeroRPath), + [ + exponentRPath < + ParallelPrefixIncr(larger.exponent, ppGen: ppTree).out + ]), + // Add 2 to exponent + ElseIf(incExpRPath & effectiveSubtraction & ~firstZeroRPath, + [exponentRPath < larger.exponent << 1]), + Else([exponentRPath < larger.exponent]) + ]) + ]); + + final sumMantissaRPath = mux(selectRPath, sumP1RPath, sumRPath) + + rndRPath.zeroExtend(sumRPath.width); + final mantissaRPath = sumMantissaRPath << + mux(selectRPath, sumP1LeadZeroRPath, sumLeadZeroRPath); + + // + // N Datapath here: close exponents, subtraction + // + final smallOperandNPath = smallShift >>> (a.exponent[0] ^ b.exponent[0]); + + final significandSubtractorNPath = OnesComplementAdder( + largeOperand, smallOperandNPath, + subtractIn: effectiveSubtraction, adderGen: adderGen); + + final significandNPath = + significandSubtractorNPath.sum.slice(smallOperandNPath.width - 1, 0); + + final leadOneNPath = mux( + significandNPath.or(), + ParallelPrefixPriorityEncoder(significandNPath.reversed, ppGen: ppTree) + .out + .zeroExtend(exponentWidth), + Const(15, width: exponentWidth)); + + final expCalcNPath = OnesComplementAdder( + larger.exponent, leadOneNPath.zeroExtend(larger.exponent.width), + subtractIn: effectiveSubtraction, adderGen: adderGen); + + final preExpNPath = expCalcNPath.sum.slice(exponentWidth - 1, 0); + + final posExpNPath = preExpNPath.or() & ~expCalcNPath.sign; + + final exponentNPath = mux(posExpNPath, preExpNPath, zeroExp); + + final preMinShiftNPath = ~leadOneNPath.or() | ~larger.exponent.or(); + + final minShiftNPath = + mux(posExpNPath | preMinShiftNPath, leadOneNPath, larger.exponent - 1); + final notSubnormalNPath = a.isNormal() | b.isNormal(); + + final shiftedSignificandNPath = + (significandNPath << minShiftNPath).slice(mantissaWidth, 1); + + final finalSignificandNPath = mux( + notSubnormalNPath, + shiftedSignificandNPath, + significandNPath.slice(significandNPath.width - 1, 2)); + + final signNPath = + mux(significandSubtractorNPath.sign, smaller.sign, larger.sign); + + final isR = delta.gte(Const(2, width: delta.width)) | ~effectiveSubtraction; + _sum <= + mux( + isR, + [ + larger.sign, + exponentRPath, + mantissaRPath.slice(mantissaRPath.width - 2, 1) + ].swizzle(), + [signNPath, exponentNPath, finalSignificandNPath].swizzle()); + } +} diff --git a/lib/src/arithmetic/floating_point/floating_point_adder.dart b/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart similarity index 93% rename from lib/src/arithmetic/floating_point/floating_point_adder.dart rename to lib/src/arithmetic/floating_point/floating_point_adder_simple.dart index e6e3f1d2..ddb0338d 100644 --- a/lib/src/arithmetic/floating_point/floating_point_adder.dart +++ b/lib/src/arithmetic/floating_point/floating_point_adder_simple.dart @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // -// floating_point_adder.dart +// floating_point_adder_simple.dart // A very basic Floating-point adder component. // // 2024 August 30 @@ -12,7 +12,7 @@ import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; /// An adder module for FloatingPoint values -class FloatingPointAdder extends Module { +class FloatingPointAdderSimple extends Module { /// Must be greater than 0. final int exponentWidth; @@ -38,7 +38,7 @@ class FloatingPointAdder extends Module { ); /// Add two floating point numbers [a] and [b], returning result in [sum] - FloatingPointAdder(FloatingPoint a, FloatingPoint b, + FloatingPointAdderSimple(FloatingPoint a, FloatingPoint b, {ParallelPrefix Function(List, Logic Function(Logic, Logic)) ppGen = KoggeStone.new, super.name}) @@ -66,7 +66,6 @@ class FloatingPointAdder extends Module { // Align and add mantissas final expDiff = aExp - bExp; - // print('${expDiff.value.toInt()} exponent diff'); final adder = SignMagnitudeAdder( a.sign, [a.isNormal(), a.mantissa].swizzle(), @@ -102,6 +101,5 @@ class FloatingPointAdder extends Module { ]) ]) ]); - // print('final sum: ${_sum.value.bitString}'); } } diff --git a/lib/src/arithmetic/floating_point/floating_point_value.dart b/lib/src/arithmetic/floating_point/floating_point_value.dart index 59f8d74b..a3ca64d6 100644 --- a/lib/src/arithmetic/floating_point/floating_point_value.dart +++ b/lib/src/arithmetic/floating_point/floating_point_value.dart @@ -569,10 +569,21 @@ class FloatingPointValue implements Comparable { /// defined as having mantissa in the range [1,2) bool isNormal() => exponent != LogicValue.ofInt(0, exponent.width); + /// Return a string representation of FloatingPointValue. + /// if [integer] is true, return sign, exponent, mantissa as integers. + /// if [integer] is false, return sign, exponent, mantissa as ibinary strings. @override - String toString() => '${sign.toString(includeWidth: false)}' - ' ${exponent.toString(includeWidth: false)}' - ' ${mantissa.toString(includeWidth: false)}'; + String toString({bool integer = false}) { + if (integer) { + return '(${sign.toInt()}' + ' ${exponent.toInt()}' + ' ${mantissa.toInt()})'; + } else { + return '${sign.toString(includeWidth: false)}' + ' ${exponent.toString(includeWidth: false)}' + ' ${mantissa.toString(includeWidth: false)}'; + } + } // TODO(desmonddak): what about floating point representations >> 64 bits? FloatingPointValue _performOp( diff --git a/lib/src/component_config/components/component_registry.dart b/lib/src/component_config/components/component_registry.dart index 0f624073..93b846a4 100644 --- a/lib/src/component_config/components/component_registry.dart +++ b/lib/src/component_config/components/component_registry.dart @@ -25,6 +25,7 @@ List get componentRegistry => [ RegisterFileConfigurator(), EdgeDetectorConfigurator(), FindConfigurator(), + FloatingPointAdderRoundConfigurator(), ParallelPrefixAdderConfigurator(), CompressionTreeMultiplierConfigurator(), ExtremaConfigurator(), diff --git a/lib/src/component_config/components/components.dart b/lib/src/component_config/components/components.dart index 8c465296..f45ac6ef 100644 --- a/lib/src/component_config/components/components.dart +++ b/lib/src/component_config/components/components.dart @@ -9,6 +9,7 @@ export 'config_edge_detector.dart'; export 'config_extrema.dart'; export 'config_fifo.dart'; export 'config_find.dart'; +export 'config_floating_point_adder_round.dart'; export 'config_one_hot.dart'; export 'config_parallel_prefix_adder.dart'; export 'config_priority_arbiter.dart'; diff --git a/lib/src/component_config/components/config_floating_point_adder_round.dart b/lib/src/component_config/components/config_floating_point_adder_round.dart new file mode 100644 index 00000000..e3e18d5d --- /dev/null +++ b/lib/src/component_config/components/config_floating_point_adder_round.dart @@ -0,0 +1,71 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// config_floating_point_adder.dart +// Configurator for a Floating-Point Adder. +// +// 2024 October 11 +// Author: Desmond Kirkpatrick + +import 'dart:collection'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A [Configurator] for [FloatingPointAdderRound]s. +class FloatingPointAdderRoundConfigurator extends Configurator { + /// Map from Type to Function for Adder generator + static Map adderGeneratorMap = { + Ripple: (a, b) => ParallelPrefixAdder(a, b, ppGen: Ripple.new), + Sklansky: (a, b) => ParallelPrefixAdder(a, b, ppGen: Sklansky.new), + KoggeStone: ParallelPrefixAdder.new, + BrentKung: (a, b) => ParallelPrefixAdder(a, b, ppGen: BrentKung.new) + }; + + /// Map from Type to Function for Parallel Prefix generator + static Map, Logic Function(Logic, Logic))> + treeGeneratorMap = { + Ripple: Ripple.new, + Sklansky: Sklansky.new, + KoggeStone: KoggeStone.new, + BrentKung: BrentKung.new + }; + + /// Controls the type of [ParallelPrefix] tree used in internal adders. + final adderTreeKnob = + ChoiceConfigKnob(adderGeneratorMap.keys.toList(), value: KoggeStone); + + /// Controls the type of [ParallelPrefix] tree used in the other functions. + final prefixTreeKnob = + ChoiceConfigKnob(treeGeneratorMap.keys.toList(), value: KoggeStone); + + /// Controls the width of the exponent. + final IntConfigKnob exponentWidthKnob = IntConfigKnob(value: 4); + + /// Controls the width of the mantissa. + final IntConfigKnob mantissaWidthKnob = IntConfigKnob(value: 5); + + @override + Module createModule() => FloatingPointAdderRound( + FloatingPoint( + exponentWidth: exponentWidthKnob.value, + mantissaWidth: mantissaWidthKnob.value, + ), + FloatingPoint( + exponentWidth: exponentWidthKnob.value, + mantissaWidth: mantissaWidthKnob.value), + adderGen: adderGeneratorMap[adderTreeKnob.value]!, + ppTree: treeGeneratorMap[prefixTreeKnob.value]!); + + @override + late final Map> knobs = UnmodifiableMapView({ + 'Prefix tree type': prefixTreeKnob, + 'Adder tree type': adderTreeKnob, + 'Exponent width': exponentWidthKnob, + 'Mantissa width': mantissaWidthKnob, + }); + + @override + final String name = 'Floating-Point Rounding Adder'; +} diff --git a/test/arithmetic/floating_point/floating_point_adder_round_test.dart b/test/arithmetic/floating_point/floating_point_adder_round_test.dart new file mode 100644 index 00000000..fe9eeb9c --- /dev/null +++ b/test/arithmetic/floating_point/floating_point_adder_round_test.dart @@ -0,0 +1,393 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_rnd_test.dart +// Tests of FloatingPointAdderRnd -- a rounding FP Adder. +// +// 2024 August 30 +// Author: Desmond A Kirkpatrick = 2) { + for (var ii = 0; ii <= largestMantissa; ii++) { + for (var jj = 0; jj <= largestMantissa; jj++) { + final fva = FloatingPointValue.ofInts(i, ii, + exponentWidth: eWidth, mantissaWidth: mWidth); + final fvb = FloatingPointValue.ofInts(j, jj, + exponentWidth: eWidth, mantissaWidth: mWidth, sign: sign); + + fa.put(fva); + fb.put(fvb); + final expected = fva + fvb; + + final computed = adder.sum.floatingPointValue; + expect(computed.isNaN(), equals(expected.isNaN())); + expect(computed, equals(expected)); + } + } + } + } + } + } + }); + + test('FP: R path, full normal', () { + const eWidth = 3; + const mWidth = 5; + + final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(0); + fb.put(0); + final adder = FloatingPointAdderRound(fa, fb); + + final largestExponent = FloatingPointValue.computeBias(eWidth) + + FloatingPointValue.computeMaxExponent(eWidth); + final largestMantissa = pow(2, mWidth).toInt() - 1; + for (final sign in [false, true]) { + for (var i = 1; i <= largestExponent; i++) { + for (var j = 1; j <= largestExponent; j++) { + if ((i - j).abs() >= 2) { + for (var ii = 0; ii <= largestMantissa; ii++) { + for (var jj = 0; jj <= largestMantissa; jj++) { + final fva = FloatingPointValue.ofInts(i, ii, + exponentWidth: eWidth, mantissaWidth: mWidth); + final fvb = FloatingPointValue.ofInts(j, jj, + exponentWidth: eWidth, mantissaWidth: mWidth, sign: sign); + + fa.put(fva); + fb.put(fvb); + final expected = fva + fvb; + final computed = adder.sum.floatingPointValue; + expect(computed.isNaN(), equals(expected.isNaN())); + expect(computed, equals(expected)); + } + } + } + } + } + } + }); + + test('FP: R path, full all', () { + const eWidth = 3; + const mWidth = 5; + + final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(0); + fb.put(0); + final adder = FloatingPointAdderRound(fa, fb); + + final largestExponent = FloatingPointValue.computeBias(eWidth) + + FloatingPointValue.computeMaxExponent(eWidth); + final largestMantissa = pow(2, mWidth).toInt() - 1; + for (final sign in [false, true]) { + for (var i = 0; i <= largestExponent; i++) { + for (var j = 0; j <= largestExponent; j++) { + if (!sign || (i - j).abs() >= 2) { + for (var ii = 0; ii <= largestMantissa; ii++) { + for (var jj = 0; jj <= largestMantissa; jj++) { + final fva = FloatingPointValue.ofInts(i, ii, + exponentWidth: eWidth, mantissaWidth: mWidth); + final fvb = FloatingPointValue.ofInts(j, jj, + exponentWidth: eWidth, mantissaWidth: mWidth, sign: sign); + fa.put(fva); + fb.put(fvb); + final expected = fva + fvb; + final computed = adder.sum.floatingPointValue; + expect(computed.isNaN(), equals(expected.isNaN())); + expect(computed, equals(expected)); + } + } + } + } + } + } + }); + + test('FP: R path, full random', () { + const eWidth = 3; + const mWidth = 5; + + final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(0); + fb.put(0); + final adder = FloatingPointAdderRound(fa, fb); + final value = Random(47); + + var cnt = 200; + while (cnt > 0) { + final fva = FloatingPointValue.random(value, + exponentWidth: eWidth, mantissaWidth: mWidth); + final fvb = FloatingPointValue.random(value, + exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(fva); + fb.put(fvb); + if ((fva.exponent.toInt() - fvb.exponent.toInt()).abs() >= 2) { + cnt--; + final expected = fva + fvb; + final computed = adder.sum.floatingPointValue; + expect(computed.isNaN(), equals(expected.isNaN())); + expect(computed, equals(expected)); + } + } + }); + + test('FP: singleton merged path', () { + const eWidth = 3; + const mWidth = 5; + final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(0); + fb.put(0); + final fva = FloatingPointValue.ofInts(14, 31, + exponentWidth: eWidth, mantissaWidth: mWidth); + final fvb = FloatingPointValue.ofInts(13, 7, + exponentWidth: eWidth, mantissaWidth: mWidth, sign: true); + fa.put(fva); + fb.put(fvb); + + final expectedNoRound = FloatingPointValue.fromDoubleIter( + fva.toDouble() + fvb.toDouble(), + exponentWidth: eWidth, + mantissaWidth: mWidth); + + final FloatingPointValue expected; + final expectedRound = fva + fvb; + if (((fva.exponent.toInt() - fvb.exponent.toInt()).abs() < 2) & + (fva.sign.toInt() != fvb.sign.toInt())) { + expected = expectedNoRound; + } else { + expected = expectedRound; + } + final adder = FloatingPointAdderRound(fa, fb); + + final computed = adder.sum.floatingPointValue; + expect(computed.isNaN(), equals(expected.isNaN())); + expect(computed, equals(expected)); + }); + + test('FP: exhaustive', () { + const eWidth = 3; + const mWidth = 5; + + final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(0); + fb.put(0); + final adder = FloatingPointAdderRound(fa, fb); + + final largestExponent = FloatingPointValue.computeBias(eWidth) + + FloatingPointValue.computeMaxExponent(eWidth); + final largestMantissa = pow(2, mWidth).toInt() - 1; + for (final sign in [false, true]) { + for (var i = 0; i <= largestExponent; i++) { + for (var j = 0; j <= largestExponent; j++) { + for (var ii = 0; ii <= largestMantissa; ii++) { + for (var jj = 0; jj <= largestMantissa; jj++) { + final fva = FloatingPointValue.ofInts(i, ii, + exponentWidth: eWidth, mantissaWidth: mWidth); + final fvb = FloatingPointValue.ofInts(j, jj, + exponentWidth: eWidth, mantissaWidth: mWidth, sign: sign); + + fa.put(fva); + fb.put(fvb); + final expectedNoRound = FloatingPointValue.fromDoubleIter( + fva.toDouble() + fvb.toDouble(), + exponentWidth: eWidth, + mantissaWidth: mWidth); + + final FloatingPointValue expected; + final expectedRound = fva + fvb; + if (((fva.exponent.toInt() - fvb.exponent.toInt()).abs() < 2) & + (fva.sign.toInt() != fvb.sign.toInt())) { + expected = expectedNoRound; + } else { + expected = expectedRound; + } + final computed = adder.sum.floatingPointValue; + expect(computed.isNaN(), equals(expected.isNaN())); + expect(computed, equals(expected)); + } + } + } + } + } + }); + test('FP: full random', () { + const eWidth = 3; + const mWidth = 5; + + final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(0); + fb.put(0); + final adder = FloatingPointAdderRound(fa, fb); + final value = Random(47); + + var cnt = 500; + while (cnt > 0) { + final fva = FloatingPointValue.random(value, + exponentWidth: eWidth, mantissaWidth: mWidth); + final fvb = FloatingPointValue.random(value, + exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(fva); + fb.put(fvb); + final expectedNoRound = FloatingPointValue.fromDoubleIter( + fva.toDouble() + fvb.toDouble(), + exponentWidth: eWidth, + mantissaWidth: mWidth); + + final FloatingPointValue expected; + final expectedRound = fva + fvb; + if (((fva.exponent.toInt() - fvb.exponent.toInt()).abs() < 2) & + (fva.sign.toInt() != fvb.sign.toInt())) { + expected = expectedNoRound; + } else { + expected = expectedRound; + } + cnt--; + final computed = adder.sum.floatingPointValue; + expect(computed.isNaN(), equals(expected.isNaN())); + expect(computed, equals(expected)); + } + }); + test('FP: full random wide', () { + const eWidth = 11; + const mWidth = 52; + + final fa = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + final fb = FloatingPoint(exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(0); + fb.put(0); + final adder = FloatingPointAdderRound(fa, fb); + final value = Random(51); + + var cnt = 100; + while (cnt > 0) { + final fva = FloatingPointValue.random(value, + exponentWidth: eWidth, mantissaWidth: mWidth); + final fvb = FloatingPointValue.random(value, + exponentWidth: eWidth, mantissaWidth: mWidth); + fa.put(fva); + fb.put(fvb); + final expected = fva + fvb; + final computed = adder.sum.floatingPointValue; + expect(computed.isNaN(), equals(expected.isNaN())); + expect(computed, equals(expected)); + cnt--; + } + }); +} diff --git a/test/arithmetic/floating_point/floating_point_adder_test.dart b/test/arithmetic/floating_point/floating_point_adder_simple_test.dart similarity index 92% rename from test/arithmetic/floating_point/floating_point_adder_test.dart rename to test/arithmetic/floating_point/floating_point_adder_simple_test.dart index 3a9f6907..4147274e 100644 --- a/test/arithmetic/floating_point/floating_point_adder_test.dart +++ b/test/arithmetic/floating_point/floating_point_adder_simple_test.dart @@ -1,8 +1,8 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause // -// floating_point_test.dart -// Tests of Floating Point stuff +// floating_point_smple test.dart +// Tests of FloatingPointAdderSimple -- non-rounding FP adder // // 2024 April 1 // Authors: @@ -21,7 +21,7 @@ void main() { final fp2 = FloatingPoint32() ..put(FloatingPoint32Value.fromDouble(1.5).value); final out = FloatingPoint32Value.fromDouble(3.25 + 1.5); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -37,7 +37,7 @@ void main() { ..put(FloatingPoint32Value.fromDouble(pow(2.0, -23).toDouble()).value); final out = FloatingPoint32Value.fromDouble(val + val); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -55,7 +55,7 @@ void main() { ..put(FloatingPoint32Value.fromDouble(pair.$2).value); final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -71,7 +71,7 @@ void main() { ..put(FloatingPoint32Value.fromDouble(1.5).value); final out = FloatingPoint32Value.fromDouble(3.25 + 1.5); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -94,7 +94,7 @@ void main() { .value); final out = FloatingPoint32Value.fromDouble(val - val); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().abs().toStringAsPrecision(7); @@ -110,7 +110,7 @@ void main() { ..put(FloatingPoint32Value.fromDouble(pow(2.5, -12).toDouble()).value); final out = FloatingPoint32Value.fromDouble(val + val); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -128,7 +128,7 @@ void main() { ..put(FloatingPoint32Value.fromDouble(pair.$2).value); final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -149,7 +149,7 @@ void main() { ..put(FloatingPoint32Value.fromDouble(pair.$2).value); final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -171,7 +171,7 @@ void main() { final out = FloatingPoint32Value.fromDouble( fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble()); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -196,7 +196,7 @@ void main() { fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble(); final out = FloatingPointValue.fromDoubleIter(outDouble, exponentWidth: ew, mantissaWidth: mw); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); expect(adder.sum.floatingPointValue.compareTo(out), 0); }); @@ -217,7 +217,7 @@ void main() { final out = FloatingPointValue.fromDouble(pair.$1 + pair.$2, exponentWidth: ew, mantissaWidth: mw); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); expect(adder.sum.floatingPointValue.compareTo(out), 0); }); @@ -237,7 +237,7 @@ void main() { final out = fp2.floatingPointValue + fp1.floatingPointValue; - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); // TODO(desmonddak): figure out how to handle -0.0, as this would fail expect(adder.sum.floatingPointValue.abs().compareTo(out), 0); }); @@ -252,7 +252,7 @@ void main() { ..put(FloatingPoint32Value.fromDouble(pair.$2).value); final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -269,7 +269,7 @@ void main() { ..put(FloatingPoint32Value.fromDouble(pair.$2).value); final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); - final adder = FloatingPointAdder(fp1, fp2); + final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; final fpStr = fpSuper.toDouble().toStringAsPrecision(7); @@ -289,7 +289,7 @@ void main() { FloatingPointConstants.smallestPositiveNormal, eWidth, mWidth); fa.put(0); fb.put(0); - final adder = FloatingPointAdder(fa, fb); + final adder = FloatingPointAdderSimple(fa, fb); final value = Random(513); for (var i = 0; i < 50; i++) { final fva = FloatingPointValue.random(value, From 26674d004d38db3cf7f4fcf9f933a5c0de64c3f3 Mon Sep 17 00:00:00 2001 From: Desmond Kirkpatrick Date: Mon, 14 Oct 2024 14:18:39 -0700 Subject: [PATCH 2/5] Pp array (#107) * refactor PartialProductGenerator to expose PartialProductArray * updated evaluate_compressor to use listString utility * test for evaluate_compressor function * update compressor documentation --- doc/components/multiplier_components.md | 24 +-- lib/src/arithmetic/addend_compressor.dart | 4 +- lib/src/arithmetic/evaluate_compressor.dart | 61 ++++---- .../arithmetic/partial_product_generator.dart | 144 +++++++++--------- test/arithmetic/multiplier_test.dart | 21 +++ 5 files changed, 147 insertions(+), 107 deletions(-) diff --git a/doc/components/multiplier_components.md b/doc/components/multiplier_components.md index 9f4e144d..0db47ddf 100644 --- a/doc/components/multiplier_components.md +++ b/doc/components/multiplier_components.md @@ -124,7 +124,7 @@ Our `RadixEncoder` module is general, creating selection tables for arbitrary Bo The `PartialProductGenerator` class also provides for sign extension with multiple options including `SignExtension.none` which is no sign extension for help in debugging, as well as `SignExtension.compactRect` which is a compact form which works for rectangular products where the multiplicand and multiplier can be of different widths. -If customization is needed beyond sign extension options, routines are provided that allow for fixed customization of bit positions, or conditional (mux based on a Logic) form. +The `PartialProductGenerator` creates a set of addends in its base class `PartialProductArray` which is simply a `List>` to represent addends and a `rowShift[row]` to represent the shifts in the partial product matrix. If customization is needed beyond sign extension options, routines are provided that allow for fixed customization of bit positions or conditional (mux based on a Logic) form in the `PartialProductArray`. ```dart final ppg = PartialProductGenerator(a,b); @@ -167,7 +167,7 @@ You can also generate a Markdown form of the same matrix: Once you have a partial product matrix, you would like to add up the addends. Traditionally this is done using compression trees which instantiate 2:1 and 3:2 column compressors (or carry-save adders) to reduce the matrix to two addends. The final two addends are often added with an efficient final adder. -Our `ColumnCompressor` class uses a delay-driven approach to efficiently compress the rows of the partial product matrix. Its only argument is a `PartialProductGenerator`, and it creates a list of `ColumnQueue`s containing the final two addends stored by column after compression. An `extractRow`routine can be used to extract the columns. `ColumnCompressor` currently has an extension `EvaluateColumnCompressor` which can be used to print out the compression progress. Here is the legend for these printouts. +Our `ColumnCompressor` class uses a delay-driven approach to efficiently compress the rows of the partial product matrix. Its only argument is a `PartialProductArray` (base class of `PartialProductGenerator`), and it creates a list of `ColumnQueue`s containing the final two addends stored by column after compression. An `extractRow`routine can be used to extract the columns. `ColumnCompressor` currently has an extension `EvaluateColumnCompressor` which can be used to print out the compression progress. Here is the legend for these printouts. - `ppR,C` = partial product entry at row R, column C - `sR,C` = sum term coming last from row R, column C @@ -183,12 +183,14 @@ Compression Tree before: pp2,6 pp2,5 pp2,4 pp2,3 pp1,6 pp1,5 pp1,4 - 1 1 0 0 0 0 0 0 0 1 1 0 110000000110 (3078) - 1 1 0 0 0 1 1 1 0 0 001100011100 (796) - 0 0 0 0 0 1 0 0 000000001000 (8) - 0 1 0 0 0 0 000001000000 (64) - 1 1 1 1 000001111000 (120) - 0 1 1 000000110000 (48) Total=18 + 11 10 9 8 7 6 5 4 3 2 1 0 + 1 1 0 0 0 0 0 s s s S S = 3075 (-1021) + 1 1 0 0 0 0 0 0 0 1 = 769 (769) + 0 0 0 0 0 1 1 1 = 14 (14) + 1 i S 1 1 0 = 184 (184) + 0 0 1 1 = 24 (24) + 0 1 1 = 48 (48) +p 0 0 0 0 0 0 0 1 0 0 1 0 = 18 (18) ``` Compression Tree after compression: @@ -197,8 +199,10 @@ Compression Tree after compression: pp5,11 pp5,10 s0,9 s0,8 s0,7 c0,5 c0,4 c0,3 s0,3 s0,2 pp0,1 pp1,0 c0,9 c0,8 c0,7 c0,6 s0,6 s0,5 s0,4 s0,3 s0,2 s0,1 pp0,0 - 1 1 1 1 1 0 1 0 0 1 0 0 111110100100 (4004) - 0 0 0 0 1 1 0 1 1 1 0 000001101110 (110) Total=18 + 11 10 9 8 7 6 5 4 3 2 1 0 + 1 1 1 1 1 1 0 0 1 1 0 S = 4045 (-51) + 0 0 0 0 1 0 0 0 1 0 1 = 69 (69) +p 0 0 0 0 0 0 0 1 0 0 1 0 = 18 (18) ``` ## Final Adder diff --git a/lib/src/arithmetic/addend_compressor.dart b/lib/src/arithmetic/addend_compressor.dart index 2513a857..2c8d65a8 100644 --- a/lib/src/arithmetic/addend_compressor.dart +++ b/lib/src/arithmetic/addend_compressor.dart @@ -167,8 +167,8 @@ class ColumnCompressor { late final List columns; - /// The partial product generator to be compressed - final PartialProductGenerator pp; + /// The partial product array to be compressed + final PartialProductArray pp; /// Initialize a ColumnCompressor for a set of partial products ColumnCompressor(this.pp) { diff --git a/lib/src/arithmetic/evaluate_compressor.dart b/lib/src/arithmetic/evaluate_compressor.dart index b69c588d..71685e44 100644 --- a/lib/src/arithmetic/evaluate_compressor.dart +++ b/lib/src/arithmetic/evaluate_compressor.dart @@ -7,56 +7,63 @@ // 2024 June 04 // Author: Desmond Kirkpatrick -import 'dart:io'; import 'package:rohd/rohd.dart'; -import 'package:rohd_hcl/src/arithmetic/multiplier_lib.dart'; -import 'package:rohd_hcl/src/utils.dart'; +import 'package:rohd_hcl/src/arithmetic/arithmetic.dart'; /// Debug routines for printing out ColumnCompressor during /// simulation with live logic values extension EvaluateLiveColumnCompressor on ColumnCompressor { /// Evaluate the (un)compressed partial product array - /// logic=true will read the logic gate outputs at each level - /// printOut=true will print out the array in the StringBuffer - (BigInt, StringBuffer) evaluate({bool printOut = false, bool logic = false}) { + /// [logic] =true will read the logic gate outputs at each level + /// [printOut]=true will print out the array in the StringBuffer + /// [extraSpace] add spacing for readability + /// [header] add a header for the column position + /// [prefix] add a prefix count of spaces + (BigInt, StringBuffer) evaluate( + {bool printOut = false, + bool logic = false, + bool header = true, + int prefix = 1, + int extraSpace = 5}) { final ts = StringBuffer(); final rows = longestColumn(); final width = pp.maxWidth(); - var accum = BigInt.zero; + for (var row = 0; row < rows; row++) { - final rowBits = []; + final rowLogic = []; for (var col = columns.length - 1; col >= 0; col--) { final colList = columns[col].toList(); if (row < colList.length) { - final value = - logic ? colList[row].logic.value : (colList[row].evaluate()); - rowBits.add(value); - if (printOut) { - ts.write('\t${value.bitString}'); - } - } else if (printOut) { - ts.write('\t'); + rowLogic.insert(0, colList[row].logic); } } + final rowBits = [for (final c in rowLogic) c.value].reversed.toList(); + // ignore: cascade_invocations rowBits.addAll(List.filled(pp.rowShift[row], LogicValue.zero)); final rowBitsExtend = rowBits.length < width ? rowBits.swizzle().zeroExtend(width) : rowBits.swizzle(); final val = rowBitsExtend.toBigInt(); - accum += val; - if (printOut) { - ts.write('\t${rowBitsExtend.bitString} ($val)'); - if (row == rows - 1) { - ts.write(' Total=${accum.toSigned(width)}\n'); - stdout.write(ts); - } else { - ts.write('\n'); - } - } + ts + ..write(rowLogic.listString('', + header: header & (row == 0), + alignHigh: width, + prefix: prefix, + extraSpace: extraSpace, + intValue: true, + shift: pp.rowShift[row])) + ..write('\n'); } - return (accum.toSigned(width), ts); + + final sum = Logic(width: width); + // ignore: cascade_invocations + sum.put(accum.toSigned(width)); + ts.write(sum.elements + .listString('p', prefix: 1, extraSpace: extraSpace, intValue: true)); + + return (sum.value.toBigInt().toSigned(width), ts); } /// Return a string representing the compression tree in its current state diff --git a/lib/src/arithmetic/partial_product_generator.dart b/lib/src/arithmetic/partial_product_generator.dart index 62931086..719e397a 100644 --- a/lib/src/arithmetic/partial_product_generator.dart +++ b/lib/src/arithmetic/partial_product_generator.dart @@ -24,79 +24,35 @@ class SignBit extends Logic { } } -/// A [PartialProductGenerator] class that generates a set of partial products. -/// Essentially a set of -/// shifted rows of [Logic] addends generated by Booth recoding and -/// manipulated by sign extension, before being compressed -abstract class PartialProductGenerator { - /// Get the shift increment between neighboring product rows - int get shift => selector.shift; +/// A [PartialProductArray] is a class that holds a set of partial products +/// for manipulation by [PartialProductGenerator] and [ColumnCompressor]. +abstract class PartialProductArray { + /// Construct a basic List to hold an array of partial products + /// as well as a rowShift array to hold the row shifts. + PartialProductArray(); /// The actual shift in each row. This value will be modified by the /// sign extension routine used when folding in a sign bit from another /// row final rowShift = []; - /// rows of partial products - int get rows => partialProducts.length; - - /// The multiplicand term - Logic get multiplicand => selector.multiplicand; - - /// The multiplier term - Logic get multiplier => encoder.multiplier; - /// Partial Products output. Generated by selector and extended by sign /// extension routines late final List> partialProducts; - /// Encoder for the full multiply operand - late final MultiplierEncoder encoder; - - /// Selector for the multiplicand which uses the encoder to index into - /// multiples of the multiplicand and generate partial products - late final MultiplicandSelector selector; - - /// Operands are signed - final bool signed; - - /// Used to avoid sign extending more than once - bool isSignExtended = false; - - /// Construct a [PartialProductGenerator] -- the partial product matrix - PartialProductGenerator( - Logic multiplicand, Logic multiplier, RadixEncoder radixEncoder, - {required this.signed}) { - encoder = MultiplierEncoder(multiplier, radixEncoder, signed: signed); - selector = - MultiplicandSelector(radixEncoder.radix, multiplicand, signed: signed); - - if (multiplicand.width < selector.shift) { - throw RohdHclException('multiplicand width must be greater than ' - '${selector.shift}'); - } - if (multiplier.width < (selector.shift + (signed ? 1 : 0))) { - throw RohdHclException('multiplier width must be greater than ' - '${selector.shift + (signed ? 1 : 0)}'); - } - _build(); - signExtend(); - } - - /// Perform sign extension (defined in child classes) - @protected - void signExtend(); + /// rows of partial products + int get rows => partialProducts.length; - /// Setup the partial products array (partialProducts and rowShift) - void _build() { - partialProducts = >[]; - for (var row = 0; row < encoder.rows; row++) { - partialProducts.add(List.generate( - selector.width, (i) => selector.select(i, encoder.getEncoding(row)))); - } + /// Return the actual largest width of all rows + int maxWidth() { + var maxW = 0; for (var row = 0; row < rows; row++) { - rowShift.add(row * shift); + final entry = partialProducts[row]; + if (entry.length + rowShift[row] > maxW) { + maxW = entry.length + rowShift[row]; + } } + return maxW; } /// Return the Logic at the absolute position ([row], [col]). @@ -220,17 +176,69 @@ abstract class PartialProductGenerator { /// to the [list] of values void insertAbsoluteAll(int row, int col, List list) => partialProducts[row].insertAll(col - rowShift[row], list); +} - /// Return the actual largest width of all rows - int maxWidth() { - var maxW = 0; +/// A [PartialProductGenerator] class that generates a set of partial products. +/// Essentially a set of +/// shifted rows of [Logic] addends generated by Booth recoding and +/// manipulated by sign extension, before being compressed +abstract class PartialProductGenerator extends PartialProductArray { + /// Get the shift increment between neighboring product rows + int get shift => selector.shift; + + /// The multiplicand term + Logic get multiplicand => selector.multiplicand; + + /// The multiplier term + Logic get multiplier => encoder.multiplier; + + /// Encoder for the full multiply operand + late final MultiplierEncoder encoder; + + /// Selector for the multiplicand which uses the encoder to index into + /// multiples of the multiplicand and generate partial products + late final MultiplicandSelector selector; + + /// Operands are signed + final bool signed; + + /// Used to avoid sign extending more than once + bool isSignExtended = false; + + /// Construct a [PartialProductGenerator] -- the partial product matrix + PartialProductGenerator( + Logic multiplicand, Logic multiplier, RadixEncoder radixEncoder, + {required this.signed}) { + encoder = MultiplierEncoder(multiplier, radixEncoder, signed: signed); + selector = + MultiplicandSelector(radixEncoder.radix, multiplicand, signed: signed); + + if (multiplicand.width < selector.shift) { + throw RohdHclException('multiplicand width must be greater than ' + '${selector.shift}'); + } + if (multiplier.width < (selector.shift + (signed ? 1 : 0))) { + throw RohdHclException('multiplier width must be greater than ' + '${selector.shift + (signed ? 1 : 0)}'); + } + _build(); + signExtend(); + } + + /// Perform sign extension (defined in child classes) + @protected + void signExtend(); + + /// Setup the partial products array (partialProducts and rowShift) + void _build() { + partialProducts = >[]; + for (var row = 0; row < encoder.rows; row++) { + partialProducts.add(List.generate( + selector.width, (i) => selector.select(i, encoder.getEncoding(row)))); + } for (var row = 0; row < rows; row++) { - final entry = partialProducts[row]; - if (entry.length + rowShift[row] > maxW) { - maxW = entry.length + rowShift[row]; - } + rowShift.add(row * shift); } - return maxW; } } diff --git a/test/arithmetic/multiplier_test.dart b/test/arithmetic/multiplier_test.dart index 4485f397..07afce8d 100644 --- a/test/arithmetic/multiplier_test.dart +++ b/test/arithmetic/multiplier_test.dart @@ -10,6 +10,7 @@ import 'dart:math'; import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; +import 'package:rohd_hcl/src/arithmetic/evaluate_compressor.dart'; import 'package:test/test.dart'; // Inner test of a multipy accumulate unit @@ -279,5 +280,25 @@ void main() { expect(ppG0.getAbsolute(0, 4).value, equals(Const(0).value)); expect(ppG0.getAbsolute(1, 9).value, equals(Const(1).value)); expect(ppG0.getAbsolute(1, 10).value, equals(Const(0).value)); + + final cc = ColumnCompressor(ppG0); + const expectedRep = ''' + pp3,15 pp3,14 pp3,13 pp3,12 pp3,11 pp3,10 pp3,9 pp3,8 pp3,7 pp3,6 pp3,5 pp2,4 pp2,3 pp1,2 pp1,1 pp0,0 + pp2,13 pp2,12 pp1,11 pp2,10 pp2,9 pp2,8 pp2,7 pp2,6 pp2,5 pp0,4 pp0,3 pp0,2 pp0,1 + pp2,11 pp1,10 pp1,9 pp1,8 pp1,7 pp1,6 pp1,5 pp1,4 pp1,3 + pp0,10 pp0,9 pp0,8 pp0,7 pp0,6 pp0,5 +'''; + expect(cc.representation(), equals(expectedRep)); + + final (v, ts) = cc.evaluate(printOut: true); + + const expectedEval = ''' + 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 + 1 1 0 0 0 0 0 0 1 1 0 0 0 1 1 0 = 49350 (-16186) + 1 1 1 0 0 0 0 0 0 0 1 0 0 = 14344 (14344) + 0 0 1 0 0 0 0 1 1 = 536 (536) + i S S 1 1 0 = 960 (960) +p 1 1 1 1 1 1 1 0 1 0 1 0 0 1 1 0 = 65190 (-346)'''; + expect(ts.toString(), equals(expectedEval)); }); } From b04003be3c96279d00e19d60342f2fa6232f3825 Mon Sep 17 00:00:00 2001 From: Max Korbel Date: Fri, 18 Oct 2024 14:49:00 -0700 Subject: [PATCH 3/5] Clock gating (#96) --- doc/README.md | 2 + doc/components/clock_gate.png | Bin 0 -> 13536 bytes doc/components/clock_gating.md | 84 +++++++++ example/clock_gating_example.dart | 176 +++++++++++++++++++ example/example.dart | 14 +- lib/rohd_hcl.dart | 1 + lib/src/clock_gating.dart | 270 +++++++++++++++++++++++++++++ test/clock_gating_test.dart | 275 ++++++++++++++++++++++++++++++ test/example_test.dart | 21 +++ 9 files changed, 838 insertions(+), 5 deletions(-) create mode 100644 doc/components/clock_gate.png create mode 100644 doc/components/clock_gating.md create mode 100644 example/clock_gating_example.dart create mode 100644 lib/src/clock_gating.dart create mode 100644 test/clock_gating_test.dart create mode 100644 test/example_test.dart diff --git a/doc/README.md b/doc/README.md index 06534248..b3a342e3 100644 --- a/doc/README.md +++ b/doc/README.md @@ -73,6 +73,8 @@ Some in-development items will have opened issues, as well. Feel free to create - CRC - [Parity](./components/parity.md) - Interleaving +- Clocking + - [Clock gating](./components/clock_gating.md) - Data flow - Ready/Valid - Connect/Disconnect diff --git a/doc/components/clock_gate.png b/doc/components/clock_gate.png new file mode 100644 index 0000000000000000000000000000000000000000..16fe71f4a98ceb7f28708f366116af09066de9c6 GIT binary patch literal 13536 zcmd6OWmuH$`{v9r;tZ05v`7so0#Xu!3`2KF2?7FwbV^DN(v5V7bc(bnEhR18B}htl zxDUSXyZ_(2$FaK~c8@*w10R@q?&rSZzR&Zj2~t*+!o#7&0f9hxGScEIAP~3-_yfVP zflrXVxg8)7El5UOM9l@frOVkl(>dZ%7FHAN1?#Snc$X(yXJddpO!D~$S{N%OMaKGW9VI_Zax;O zA50Z=?$eps|7D7#zRgEp27!EdQLl2rfIYT|wLRCUw^00&SdNIFRlVsjyKzU_Jt` z;~8gaQ`_J33${B~3;s*Yd$mM!v&y4+UWH#gJ}_ANgn69heG`7Omu-kj<9&U8wDsar zb(vFf<>+Wj+i_#GByEJW z;9wCM2@jYrHZ+;=?BHVxOqv0-_u2kuc;KY67nxb8EGZotq64L237{pUB;2t`r6L-K zfD!Yg`8IP*41Gt2Wp1~mF`sD`HwC33J)STVM+}tz7OzIqTM*n9=XuP`e}5Qmd(Jst ze*@VxLtuAAjRYGD`&I~`At^VgSQJEnMK{>KS$hMEZVlJiL_=#Y0e?ae^S=W4&+(%` z6*Mhc)U@Efw_&llw1l!4mF|JO@zHy3sKwdVECCDeZY}Ct9>Om(74xzlD3YmnX_eRY zW%pMIEP8er4!x{7TFwracsettac5*rGaof*E#Y}f1axe<*w%~C&To{fl%VbgZ1J!) zq6&*}oRwBbu~4%Kkta$#uSFek&yn-Y6-B}bUVpEl5#MvNSw;4v$jW;*M1UmV5v*)% zh2nBfCbsv&radShEuqU_-JFAS64ebk!x@zlgVNuF}*mO8KCK*I<2cw(+LQ4yO% zWQ-^^-nob40DI(v_1Fokpeln3f}ib*sT4HNQ_jv%qlz?56tQwAZjv*-lWGY4(bo8Y z-4HBDU?zWooPu>{1Abe<`o;2hBV$Hy3hx(YqMQch4h2!6iX_aGRv6^L?fbA-TwvEp z6pe^!{Db5naVukc$JD5daFZ-Op{?UXiGtVW)m@$zpAi1u0{r33Gm9?fUeczUXSStQ zqc0VxG*)yE>tolg@<_UzY(6%cOzmIb-UF}j8Bb_^;iu)NAS|y42K5%>O(+KByEjin zSFKuh1y2-qt*0E#(3u(?(kW3BuI)bJV7ui;;P)WdP|Pz&KFxxD-<$u{fy zGqhITO_P%T_Hm=nTouK6_7<_TPPl@jv&#<#s)&@bgb@0aOJx=;qHx<)tqfyql^0mW z+-T_c6@jsjb-%{iXXQRUFMs?(k9JcBOKem#TA>Cnvb{OMyrgl%!caPq@M2>l%8JPt z>KbWoKi!G>Mv^K?8lav^E7aNI!XRgH_Pt=}#nAp?&V{#fwKiKyLFrZ%i@b;9!-j@O zVXfJyluPm~0uK+HZVTK{r)mKrX^b{qqt$0*lhtlx@F*i}09x47cff<*2OS4B$pfb2 z_U`^d3Fp`c_cXgxK#IAuxhgozi|>Fk!nzpH5P8qQ_vF%$PzXzG60DOjqu|}SIGPlU z$Y4$TWXb?u;ZmWPl7aNN?38|WK+MPDR*onjGH!ILd+_{n44JS&=*v44XamrY0S!J~ z10a^qn3%QnZsN$fd)HUkPP7s#aeyxv0e%eKVI7EP*9*Rnd)a*NJ9#XtC?=Sh95yTZfQnEde-R3IIVbC#=4(tvekH=UUc|BI=r-HE6~Ue%rj4T zrQY!zQh9iGFwQl?eP?%m4c|cNYbe-&UjOzS)ce_o7eSSBUS!^OB#cP3SwgBMKAxJ9 zD&tGig;=dbk-($l?+pU2ELxqzIbp}K>K&AjgS9HF!>-j(doY$iwr@Y`T;4e_U0h!s zT8rbLUK?!|3cW9Mo3|k5>@)P0vBQcd^Yf_4-IPvaO7%uiRZ$Vw+%pZD=H%YyFR0I( z<%88Msj9yN5c2~S=)7u~H}o7g06@Q8dBh+GiMcD6tR{O&D#DYf5sSAj=a6G+U);|j z7q8U#)bS0r>HXwSa*!1GouU$d0IGTNqs;@M_fdC9L=ul@W06L>QeqUd=F0qIThkl1 zy2K_3eIdLl9Sh5Hq5ZlA%_&_u(MA~}+S#$AA6>+|!qdO~&YNm=*#Qr3+;Kzb$Jl+4 zQC8A>{^J?1mT~YFi+P>#uz2=He0oy;3PS%B@py@ZN6PV`jPk|^DQ;ECt6UFRqWS%w z{Ubos@D;vub$Px%TB41^w%fM|GdadD;V@H#~ajm zc{Fu@oUYaVXfv_pWlC~#aulv!qbGF^+XIDLV1)kjrH^qu$fl~uhDE1r7$G6b^Fq^O zzn5=t5Zd(?cYEgZGh_c%teQ*UkWFuo;~$1X`F-&tg088Fu117bPWwDD({z8KHQ<#7 zOhw1AwW{xkZU0KWS**imuAyYM(W}wF;Io;S`MTinoHLeAJ2f5H_=(G>goo6GXbf#m z6`zq`tz_U)&f8adH0WBnaz%L(hN6XP6y*C=gC{|0o&;k4F!az^C>9^)%DAdQ4LPOP zzFZ7pfB#Ag*;FFw{N^+8;55DYsu%`39)R23Mi_|?>ueCsm1`3i$buEA*oARO=1M>| z!9qSS&Aj?Ex+jtV{&rg3<9s5CmR_bDo2s$Nk`IakjUC!F5ZLN|Q8f#L*4iS=20+K} z4D5Ulu0qJzzimP$6hl2&cROefIwk{VoVRDy)4i_VOx&#W3xzKqT`75;soH!}x;mfx z-Eh_+eD%@xz^E{J>0#E>*bwWTR)b&d7MP((p9nB_5L-I|3jR62RRRNDeaZC+`Xh5& zA$oeO=!9dL%#fzhb@%4w$0LK|XL=s$*bbW;Z43<_#pVaGhUm>b;oqnI>l+^$)1GL2 zd;k(=jcm4}L{M>~;kR2IBSEbnC!>*YTcxZ-Q$0~HBBeLvyrqo>M6%{ozn*>#1;OmP zBq2I)XAkO-O=}5Tnb>e!f~*7=nwNz0q{mqPEukhMD3=|5B6t)?Y%uN}1Nhr~LNX~r z-VEAY#_>z!zUe+VCV<=C_L@y5s zE_D-)1dX)`#*Rv33PQhzZ`Eo*zD6zzFqStLP{C~R7={Wi8rKI2FTJ{G!c@)-~2PzXaK&oc-fR_T3L;>)Ws6REKl`z_x#9Sg#L#c>_8v zm2Xo7<2A+-VGVBnHo!ck0=Z}=@KiQVYckm+;M=#47Nd=8?ADT2pacL|@@@JcVIT?N zYdSROydwS5^7|Au>SP|vQ6SYgqU-nN-0#>3Gb~9%FJ^Y8B)len62sGTf89x|As~>$ zw-TTZ`Lk5M#GuaUTfMg&cN6n1-`r`h#PDZ17BZ|}N;Ms)&m|!}&YR^QpJ<>v7bt0^ z37iz`8^j?xWA&Vh_M|MQ#i7gF4b_^}xBv`sU1$n|yazBGhgzYkI#!#nYNy=$?Aenp1VTTmwVXW*pxAXaopoqn z+dX@lkhlln#msHPvD=2!W5P^j`O>Zo;3Hyv8HWHdH@wfrj}aM_WGZtHi2_A11;TJ_ zFo{ighZUiHwHlnzX@TyCA&Z>hoiS)24T)y0&y?}JHbT1(I?nh$(bR>zsY2lGNpU6q z7D-;ndUZFhEz@3cwKEiOd)%+3tGe4YpT;hG8nWdvD%pld&g&?Sc9tp(u`wiWQXJyI z-4rz+99?&Siw`fmNlr=(CiW(W2F~*xw!az5)jWq~YaWz4B!xrCQKC(ETf@#6)_nNvUk3(Di0 zGzO*S-~H(7#O{&fR2+z<+^q9GGUDNBqWCc5-B$_{?|<@~Qp7&Y6QxeF>_fd!a!d-* z7vMb^+?;&6#+uyYdojR*$ndiCk+TcCH!AVFezi?)xwp6Xz9MbY0LtrNP--2>$=r7M z+ep76T#YlL?lo&`E=%4CL=P$c(D^*1O^$o3BUqp8BpZ2=)&z< z54?eUh7*g7O{%e7l+zX?vKuULj*{K+9G*69WBNU`K z*VOT7@`z&u2;2%8Di}WJ-t=AQGc3F{=2kt6cA-TT2gl4xW4eh_&zZt=*p!Z)g>Fnl zsf#_~Ipvv?Lp{plK8M+2eBCEf#V~j4NjXpu-_%$Bg%t`#ispJUn1ZlQ{H?^hg#ID} z-3tLZi+9@`ndG-*g@nFV$3aaQxGq0`HJ@H9DSR!l9y1ZjbqI&fr3Zs{bh(dzl!)e&EMmrR@$r#45&Tv2gB zNaITCuaB;6xiT(z-gqStKWm@)9)bJpjZ_`UgT-V$k64tac+a?+3e9!KfE3j@jz_E_+k5g%)qx`Bx8|Z6M~Ut2NaC)bUFuuJ`P~rvW|owTdBgz zlmb4ucRoya$9Y9w^D#DaN9f72nGy{E_n#Sfd94~4 z!MtOeJ!X>jb>&kep2Xt%PZnXklIl?~9U3q_){tuCng2;ny5`rn>M5{Dmj)zh^%Edg zsQ0{l{<@550|HL!pY}~%J#V}^WdCUZ_BtC%zkMaoik^k7e@~7`$e9YvBp6Dr<#}%d z+_UrImv1Zq8J`lBN*C%X>-X(jK)B=4Q*!a~BhDYd_-i3I#aCknG_eCg1n;9RT23yy93H`K;XV)I`^LZ`F`?wSM|`*tK=a_s;U!t)iGa16q!8zRq+J$OhL#R7gWEH?BL+0;NZ=lmGXF?tf$ z9jyJjC`qI`uls>07rf7dQf#a?Y91R_?mGSST;&@+IDB%*N^UP%4g0>CYH8i3g=Z(u zSW0PcB8bz^@`W&u2L@{T-cINo*`#SW^Qphx*-HroK67BDZz`pBHLn(HA=&{>gm*9~ zd?N;k0KC4n3BcAZtREB4+jyO8AC5I%RoR!Q>G8G2>=SsMh;4JS5uZ_s;RYR&+;|o(|S8!P0IZ z-m71?Y_7d$!!80?8v7SdjkT&Oe83Z`Qqq{HGSWY9K{n|C8`uwyDHYmn@I&Z7xpkCB zA%Z7fdlbhFF7?wYylT++J$BT^xb+ScpTYP1S#-?f!+f~8FX-@n;>s-TIgxa}UjER8 zyULe_22F}lMsy6`1ovd(1f&4lcY<0rV^O1-)dUjiXp_B#UvPG=fihf*{_c#>m(BX9 zs;z=&iA z&-gc!mM5-M%!qm3T0QlmOlT+7LhH=SBK`; z^XIzwv);wyt9%Cv7?{}RM9h13@Fu2NiC|1sg?U~@q<5O`fHtenkGl1)R4(_^f6KSm zwBt5c3F(|~xIscol_0~?I%D_XJ9X`D(X*ur7?^QwS}pJ6;)z+BtDT#E_ri-Om?m4I zOmZHtGz_@7gmp4X^v?0-N^twbC z;2Z%@zfDUb^jEM<6Vh*i7Z0Uqg$cDgvEMvtk-tW-pz7S!@%6gh+f-w5Ky(;(LQ+je zNyyR0^ik$2DZb1^N#HdtI+Eh;<+!7;n~o@##OLsBxBSNQ^>$VUaNTMYNLOByKD22P z>t&s%!fP>=&r%~oscyDw1bxpbZ5UNukKFLAu>X4 z!w1@Y8C(81^YgoLX$<22nPWB@Lq2EHpBBqkGdDdQ1FbzcqfWoEPP1Jf9|bx_wZIS9 zV?eu=o5x+y?i{p9QorT13%Iw)AegoU&iV8M}I5(XMf4Ud@`w6~6{#Wc{y2 zhzyP$HL_3;oP)ej7XWW(o)j}De`vx=?LZ+=m(UMs;{mSE2eV4Pbr!fS!~BFL2atB^ z=Q}YV>>)u~34G`Evqu1ktmA&++NPzG8{}!;Anot_kf+g)3O{fXDEevq) zAGe9i^>z5P>|M2HD(8jdugiRs6%04Y8J0*5_IcoZ&n$uc%eABWXjn#d!kYlkm~+Ah zB!YniiT>H^#|1AX7A?A}nRI=Yr7^EtWfjErVL${;C}2j-0B*k%vHz4BU&`+b%P%Jl zDWgQ}g-oHy7flam&&QHwkqV ztg(G`75&~l4!=Y2rEit7iKLTYCkkXJK;bAy3siB?oWj(|rUb)LF%0$?^D<+Zfa=%j zwc;h~ygQI5>RP59q%}#>aJVv5@w7-|*_2O*sYc>`V>O}(B@*0(5J8o%j;YBzP zoo)6P?72{sd|DXPRZ2%~jR>pXMe19hje@4EFFgvsTykj-kJY)StN!rzxx8PI8Gk&v z=R&$#?{?httACc@@OT2Rrx*ZL-Fe`uGreXMfd;OaM{US^C$WfhgdeYRz&=KRlB=bx zvos;HrsxD&{sxnP_a!LES57x7`uu{u>Ptf?kNPZA59iyY2R6hj-tZ;!8q-Ii~r_*OiD-9DZD90d+?=YQc3*n7V(Rj{Q#Lx7&p|N6Wd$6g~2^WQf` zQ;;?iC}eZd+hz#1^ID_heg2${vvR;0b#z!HqbK z?RX3yhHPS!O1Ap2wVfFl4bshSffLG|33eYG63g@~pIKw6$`vM8Fy8Hxhol61u%2Ga z$1y6cVgtLDFHixfH`YntKx8mQucz9d9irm*WvFiyoszHf>h8l{PnhclaDKMh8g+hkJ14-YKt zeU6orlPUu&oxq7IZ>dsktgT(S3{&a$hlK6R*mY1BJFU#6ViVLrEw;zR)(57nBo~e< zwXh8yRI5SA1jY%zq~?s`x4PYaIPjTFvda- z;u{+l(a=@=V_wDrY^cm5>3{7jkI8p^Or!o(FMp+Kr-T&PHVl>DFYRxUW*QBN*BVdqf5@>CJ}h2%t}?o&r6(MnEpG ziT^qZv+HhUl*TM*Zs~IUQ8EoU`kV$q9Kd(bF6$@|<6w{Wv%)Jyh{NSIy0;(#} zHG>4pzE7+vRl7`uuAY(nO(xltgvie_kQi}p8{L94LHNKSkZk{R9jE!f16Mm#QI&zZ>GMaq+rXqiX* zqp^jJz$Iz?=eAyZjqhJxllY%1WF85cYw##~QSUl(@o6NM>(z@!&n;qj)SLYkg(1cC zyQwhU*2FoWO$7`Ig&MtJPM&&R7COghLcpm@K)%|dBaklY+IuYonhJBj`%DD{ttAh+u{bYLX zgYTN1p%nZ5ao#WX`HwQErro(#13qOERNvLc)*R<*>H$9JPZP2bC*0Yd{UED2Pp|( zR$gQV9*y$fb7)i%yjQ7vvdM^^3-W0)xia4+k_hbS?F}>Y*<1VgYG(#tnU0H>?))Hu zYc-vEe$YBFRHoCzRi5JF$nE<>5a@RD^(?^R$8+CprV0PflfahIp?iji*MXrC_Xpw% z@|3X84oT^k1nvI%r*@4^958->kvS0?Y;^;2>0Pcr!vmMvDX#9b zSJ{hB`SrhmeEmR{vLl^*NV8Ji+Z!k1bhy}dOEvp~#AtoDi`L3>XI^9o-)K{9GS2%8 zn@{!AXH!)4SW`WE<{Y!5OVMD`I%;?S&L{3CWzjb>*|Z~+uoD9&Gy-uwmy?Ax2I5LO z;zT5UiCZl4-a0o3tayX4o0ps?I;hfYTU3g`U(AQKt zVwZqS|0cJYZN0{8j{oC3`2RJi;2@iTR{Agd&EUoSzW9)$SVr+CGyqHb3AHZq&|31A znfuazFg1jY9-VS=6(CusDzi^KjUZ%`ub4|DJCKhbf zGVe%39BvwH|H1v~e<4>D29#NxN;#r6f2x+>;%=j7S1eNzJCc~yclUgw7P|r|{0EDr z|Ao@#bl?Y$D2<_vk(d=NYK6U_^EMe>-11GtyfXEee%z0MYw$$K{0mm$#PbrmeI_G1 zGoMq;MJ@BBXP5_e1xhDayWrO}i1h!Lo1aT#wkmVeAoVHXg|-@@_(T~hhK@QFT2aK; zq#=X*Kk6V9qTknH?q|~NzeJ&rrXI75`{DaX2(f<&p@2jmzd)y)ByFoH7~kKx`b3d| z@_eF^UC>yQZuOt5PT2V9LD0WgB;DJ@N^oQXp_W-lwV;q=tlmIt8hW`v^!;llR6ZB`aov~K4Uqqf?;`$x;=2U^4c39~ zGqYPr|4o9s{O7`FgAv2W4aY)wK=> zjqkzOJ_>8m`VZ^vk3%&B{-ACW#_(VcIW)Jbe*No6tnMeXKXHKi>tO_P_x{79JyE zi?_PCdV#KYlBsPv440Ocf)G0?(fz&y&alRvB=}oay)P~X&uA|Cm|b_r zDgXoyI}$ljox|SqDY558%-~-Z%B#Ug4*!j4}^JlFcu-1GEb45q|^$=y`CC9s}CHr1((qv{C zCVh6U0ui@)1YeKs5gR(rtiK2LrF1SL5Qzd zsqRSh{UVRBGX!5~iH*#YgfLJxfA-AFSY9pul_yJi05jel(02^)IqVTfW zseNF5`q#^AJ%M+H*gO-2mxXk+ zB+61AWmP3FHKK&yDxb(3GF?{E`$H)2G{V5l% zH2$+f)r_(b?FPUTg4|92xm@aKu5aCiKY!7IXbYZ)mRQiyrGVQycAD@>LsIZ`(QVx_ zBmrWl^W)crmS_*Wo%f1 zqBPZvsH`sy4`sYeu*sb-q%ng#B|(LWEyd+Gl7A$%$}+os8@ei0%dyW_to+5Q%t}E> zny{eOrhNy^ucH0<#o0|II~Dsq(I>ZJR5Mf<%ERu;069)OwHQL5)xCSO{~Js)7tIc| z(L$Tn5E&N5x*f6O4dQ^VM@)TcS*e-H6K)BIcAGdB$tv9dz&Io3_jbv@qk&3{Nf#&0 z#Kfyq=I^eQR-quNbym*D8oI#wGFt)wq%G~n^DDw#BQNzgCyyLURWq!PHjW%u$H!eX z9CcIkol(z)IV@g@Ikrzy35Dq)yAQ!(Y5X8XMa6kS!fDGBi07L1wSBGWlu!wket;eKJCI;U3E>g#hnxc+G6! z)~tXN1A~AiFqReJa%gw0xVhSrdlZ0&!E<>tByV{Zc*Z~*+g1a`J_V^@$~kp(s_BeY zh9(y+D@SWVoqOSKarnv!KU)tAdF94i-YE&b4#P}Y8UY&h%{;~_|7=3_r#siC)fD-9<9bn+;A7F0C?E+N&geKrg4>&`zE4Kph2ScpigB#G=$oTbhuIBQK{H-)icaSL1>&bU&8eG#=4chsV@HMh`|CdxyA!o= zj1fs@f+uQrg{-X&UG)C->Zz(?vdeT4UncO4ETg#9oFYb?bR6_#UkVpQ-L}Iw<*H<- z8paXaip!o>{rTFAlpz*r_HgL%AaDIAI#ng(Y{BrTclsI+GJ=XPP;lkDs^*o108GZM zOX$TbdB-tpB)6?v=3ZEOt;%!0w#BJr*Y^SSG`%qKOqvof@=gq*g&gSUQJLMm#AQM& zVLvy{6UYLlzSbU{8=3gzQX3lw;@0}85~TS!%6`Vh#S=U%%?2_*oT|BN8T$4!2lS(Nx(*BoO!+pCDFo&J*uxB+R}sC5aDSmS$J0KS^+0 zIGXsZMp08Ra_?R2FS(xy?Q$lV=fE`45Ep6lv!U`d+%n50QzML@9LDK)_gcTVKv}a- zjX_e>|5VY%Ix7WAfs9sqd$0F?0KU5&TYHsschzEFhbEfuNCUIF&{G?RTQP59-|-T> zxIfeyB9J(SZAXM(&F8wl_9j)f<$<1Hg8V`ATpOm!a6^4%6*@@ojW1zu)*GUZq zA6~rAXX}{GWQ%mECxMIlI)VoWjW2X3_42$93wRa216e-h2!Hy&flwE90|zp!n(A>r z)FZBe^SL&RVZK+TqlhiKEO0DO{(2G0cNyrfzt&VyB>G8$AI$CSZG-t@oQ3Or?y1s>$n#P zE#PvD6tt^tz2^=927Pf`nQs97nU<#~E6FzB@{Tw3%rD#*=XbB?cih{Q%*XYsmN&*e zkhx+k9p$|}fkr;Qe}RFNn~v~U0zi*8p8HgXBhU1@1QId4#Ze=dy<5b>$_J7_6Q<{6 zcu4)z_j7NW1T}h=N4%gG5UE$enV8%5JpOFy6Y}HaN1@#J zIT4JW=jrmJd(TA#Od#Ul-}C)FvJ#nshMpBj_obFvlE6#0=+ z^?xA{RVtD@imn@#_{;hozz_yt{khWz11l)>lWDEFi|}*s;T0~y{OwiT0YsrFMBbD` zQl1VR4*J>(05vYGyTpu#H66lp*UK1~TzlC0CPd$b!Cesb>6gqO?s{=eee6~DQos+> zAzO4!dqEugFv7p-EfTa(A&AgN8OE060Gxz?1{Jiw#<@-Kb(A8*^Uh?C1!k3|9Q_xK z_JEeoz<*J+Eq1LtP9^2+QLI?n*iD543=s3+HR7ASM+JX~0H*akdVW>|bQp4H;|2o| zCcUM?Bh*3Npd9EFCsOoQgtCeN7=6uh7~sI|<{bHQH>u@Os#b;c-2rC(IB@fBXG>$4 zLtFZ+0R5trJ&q; zh#ktk&La@cL+2UBJ;y`nn=^d3A%K_BP~9aG1;PVZqFSLAY9ugpK+Nfujjx_{JW_U; zhXjbFOtAS$F3t9kO=d1nU9Tf)1Ihh*^;{+I$wJq1p?fTq8;4#$Z<`A9e!BAx01W-< zrvS&c#ZNk;Zakm4>7v|g<#mQ?N{kR)x@~%uhYN%koXEe%{~rL+qDW+t`7ko@+1D5^ z?k=Tcg>Mf_BGNamLzHVhsQ=Leu&tV>%i|jlBGx7;WiS>Ds74{gy!)mmHz5FtgyTk$ zQcf1Rx~(SKIG#5ou#dL|P_)ES|gA%oXtWGIb^rkg5;z-_r#Kjax)hoh6g93TIU z2k`rjv48KXDDltwT!E2nrbdTQu^k;o*vtukm!)+ZwpyqQXn~Q z?Pv^?9BYJM;NPs1yncJNJD#2FI1H!yAJaQoHu!woGa$xv4BqBte{vggtH*fC0_jR$$lhb8q-rMzLXMrbmapEWz zny@vnYr=8LMGQ-Ec@B2pO!EbO2)bn!YS>{LJmrEcS2}SWzQT@07 E4QBlfC;$Ke literal 0 HcmV?d00001 diff --git a/doc/components/clock_gating.md b/doc/components/clock_gating.md new file mode 100644 index 00000000..c24e0758 --- /dev/null +++ b/doc/components/clock_gating.md @@ -0,0 +1,84 @@ +# Clock Gating + +ROHD-HCL includes a generic clock gating component for enabling and disabling clocks to save power. The implementation supports multiple scenarios and use cases: + +- Easily control whether clock gating `isPresent` or not without modifying the implementation. +- Delay (or don't) controlled signals that are sampled in the gated clock domain, depending on your timing needs. +- Optionally use an override to force all clock gates to be enabled. +- Bring your own clock gating implementation and propagate the instantiation and any additionally required ports through an entire hierarchy without modifying any lower levels of the design. +- Automatically handle some tricky situations (e.g. keeping clocks enabled during reset for synchronous reset). + +![Diagram of the clock gating component](clock_gate.png) + +A very simple counter design is shown below with clock gating included via the component. + +```dart +class CounterWithSimpleClockGate extends Module { + Logic get count => output('count'); + + CounterWithSimpleClockGate({ + required Logic clk, + required Logic incr, + required Logic reset, + required ClockGateControlInterface cgIntf, + }) : super(name: 'clk_gated_counter') { + clk = addInput('clk', clk); + incr = addInput('incr', incr); + reset = addInput('reset', reset); + + // We clone the incoming interface, receiving all config information with it + cgIntf = ClockGateControlInterface.clone(cgIntf) + ..pairConnectIO(this, cgIntf, PairRole.consumer); + + // In this case, we want to enable the clock any time we're incrementing + final clkEnable = incr; + + // Build the actual clock gate component. + final clkGate = ClockGate( + clk, + enable: clkEnable, + reset: reset, + controlIntf: cgIntf, + ); + + final count = addOutput('count', width: 8); + count <= + flop( + // access the gated clock from the component + clkGate.gatedClk, + // by default, `controlled` signals are delayed by 1 cycle + count + clkGate.controlled(incr).zeroExtend(count.width), + reset: reset, + ); + } +} +``` + +Some important pieces to note here are: + +- The clock gate component is instantiated like any other component +- We pass it a `ClockGateControlInterface` which brings with it any potential custom control. When we punch ports for this design, we use the `clone` constructor, which carries said configuration information. +- We enable the clock any time `incr` is asserted to increment the counter. +- Use the gated clock on the downstream flop for the counter. +- Use a "controlled" version of `incr`, which by default is delayed by one cycle. + +The `ClockGateControlInterface` comes with an optional `enableOverride` which can force the clocks to always be enabled. It also contains a boolean `isPresent` which can control whether clock gating should be generated at all. Since configuration information is automatically carried down through the hierarchy, this means you *can turn on or off clock gating generation through an entire hierarchy without modifying your design*. + +Suppose now we wanted to add our own custom clock gating module implementation. This implementation may require some additional signals as well. When we pass a control interface we can provide some additional arguments to achieve this. For example: + +```dart +ClockGateControlInterface( + additionalPorts: [ + Port('anotherOverride'), + ], + gatedClockGenerator: (intf, clk, enable) => CustomClockGatingModule( + clk: clk, + en: enable, + anotherOverride: intf.port('anotherOverride'), + ).gatedClk, +); +``` + +Passing in an interface configured like this would mean that any consumers would automatically get the additional ports and new clock gating implementation. Our counter example could get this new method for clock gating and a new port without changing the design of the counter at all. + +An executable version of this example is available in `example/clock_gating_example.dart`. diff --git a/example/clock_gating_example.dart b/example/clock_gating_example.dart new file mode 100644 index 00000000..44ebeaca --- /dev/null +++ b/example/clock_gating_example.dart @@ -0,0 +1,176 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// clock_gating_example.dart +// Example of how to use clock gating. +// +// 2024 September 24 +// Author: Max Korbel + +// ignore_for_file: avoid_print + +import 'dart:async'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; +import 'package:rohd_vf/rohd_vf.dart'; + +/// A very simple counter that has clock gating internally. +class CounterWithSimpleClockGate extends Module { + Logic get count => output('count'); + + CounterWithSimpleClockGate({ + required Logic clk, + required Logic incr, + required Logic reset, + required ClockGateControlInterface cgIntf, + }) : super(name: 'clk_gated_counter') { + clk = addInput('clk', clk); + incr = addInput('incr', incr); + reset = addInput('reset', reset); + + // We clone the incoming interface, receiving all config information with it + cgIntf = ClockGateControlInterface.clone(cgIntf) + ..pairConnectIO(this, cgIntf, PairRole.consumer); + + // In this case, we want to enable the clock any time we're incrementing + final clkEnable = incr; + + // Build the actual clock gate component. + final clkGate = ClockGate( + clk, + enable: clkEnable, + reset: reset, + controlIntf: cgIntf, + delayControlledSignals: true, + ); + + final count = addOutput('count', width: 8); + count <= + flop( + // access the gated clock from the component + clkGate.gatedClk, + + // depending on configuration default, `controlled` signals are + // delayed by 1 cycle (in this case we enable it) + count + clkGate.controlled(incr).zeroExtend(count.width), + + reset: reset, + ); + } +} + +/// A reference to an external SystemVerilog clock-gating macro. +class CustomClockGateMacro extends Module with CustomSystemVerilog { + Logic get gatedClk => output('gatedClk'); + + CustomClockGateMacro({ + required Logic clk, + required Logic en, + required Logic override, + required Logic anotherOverride, + }) : super(name: 'custom_clock_gate_macro') { + // make sure ports match the SystemVerilog + clk = addInput('clk', clk); + en = addInput('en', en); + override = addInput('override', override); + anotherOverride = addInput('another_override', anotherOverride); + addOutput('gatedClk'); + + // simulation-only behavior + gatedClk <= clk & flop(~clk, en | override | anotherOverride); + } + + // define how to instantiate this custom SystemVerilog + @override + String instantiationVerilog(String instanceType, String instanceName, + Map inputs, Map outputs) => + '`CUSTOM_CLOCK_GATE(' + '${outputs['gatedClk']}, ' + '${inputs['clk']}, ' + '${inputs['en']}, ' + '${inputs['override']}, ' + '${inputs['another_override']}' + ')'; +} + +Future main({bool noPrint = false}) async { + // Build a custom version of the clock gating control interface which uses our + // custom macro. + final customClockGateControlIntf = ClockGateControlInterface( + hasEnableOverride: true, + additionalPorts: [ + // we add an additional override port, for example, which is passed + // automatically down the hierarchy + Port('anotherOverride'), + ], + gatedClockGenerator: (intf, clk, enable) => CustomClockGateMacro( + clk: clk, + en: enable, + override: intf.enableOverride!, + anotherOverride: intf.port('anotherOverride'), + ).gatedClk, + ); + + // Generate a simple clock. This will run along by itself as + // the Simulator goes. + final clk = SimpleClockGenerator(10).clk; + + // ... and some additional signals + final reset = Logic(); + final incr = Logic(); + + final counter = CounterWithSimpleClockGate( + clk: clk, + reset: reset, + incr: incr, + cgIntf: customClockGateControlIntf, + ); + + // build the module and attach a waveform viewer for debug + await counter.build(); + + // Let's see what this module looks like as SystemVerilog, so we can pass it + // to other tools. + final systemVerilogCode = counter.generateSynth(); + if (!noPrint) { + print(systemVerilogCode); + } + + // Now let's try simulating! + + // Attach a waveform dumper so we can see what happens. + if (!noPrint) { + WaveDumper(counter); + } + + // Start off with a disabled counter and asserting reset at the start. + incr.inject(0); + reset.inject(1); + + // leave overrides turned off + customClockGateControlIntf.enableOverride!.inject(0); + customClockGateControlIntf.port('anotherOverride').inject(0); + + Simulator.setMaxSimTime(1000); + unawaited(Simulator.run()); + + // wait a bit before dropping reset + await clk.waitCycles(3); + reset.inject(0); + + // wait a bit before raising incr + await clk.waitCycles(5); + incr.inject(1); + + // leave it high for a bit, then drop it + await clk.waitCycles(5); + incr.inject(0); + + // wait a little longer, then end the test + await clk.waitCycles(5); + await Simulator.endSimulation(); + + // Now we can review the waves to see how the gated clock does not toggle + // while gated! +} diff --git a/example/example.dart b/example/example.dart index ad4eabe0..51405e1a 100644 --- a/example/example.dart +++ b/example/example.dart @@ -12,7 +12,7 @@ import 'package:rohd/rohd.dart'; import 'package:rohd_hcl/rohd_hcl.dart'; -Future main() async { +Future main({bool noPrint = false}) async { // Build a module that rotates a 16-bit signal by an 8-bit signal, which // we guarantee will never see more than 10 as the rotate amount. final original = Logic(width: 16); @@ -23,11 +23,15 @@ Future main() async { // Do a quick little simulation with some inputs original.put(0x4321); rotateAmount.put(4); - print('Shifting ${original.value} by ${rotateAmount.value} ' - 'yields ${rotated.value}'); + if (!noPrint) { + print('Shifting ${original.value} by ${rotateAmount.value} ' + 'yields ${rotated.value}'); + } // Generate verilog for it and print it out await mod.build(); - print('Generating verilog...'); - print(mod.generateSynth()); + final sv = mod.generateSynth(); + if (!noPrint) { + print(sv); + } } diff --git a/lib/rohd_hcl.dart b/lib/rohd_hcl.dart index 8cf3540a..5b95a5fb 100644 --- a/lib/rohd_hcl.dart +++ b/lib/rohd_hcl.dart @@ -4,6 +4,7 @@ export 'src/arbiters/arbiters.dart'; export 'src/arithmetic/arithmetic.dart'; export 'src/binary_gray.dart'; +export 'src/clock_gating.dart'; export 'src/component_config/component_config.dart'; export 'src/count.dart'; export 'src/edge_detector.dart'; diff --git a/lib/src/clock_gating.dart b/lib/src/clock_gating.dart new file mode 100644 index 00000000..f056ca72 --- /dev/null +++ b/lib/src/clock_gating.dart @@ -0,0 +1,270 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// clock_gating.dart +// Clock gating. +// +// 2024 September 18 +// Author: Max Korbel + +import 'package:rohd/rohd.dart'; +// ignore: implementation_imports +import 'package:rohd/src/utilities/uniquifier.dart'; + +/// An [Interface] for controlling [ClockGate]s. +class ClockGateControlInterface extends PairInterface { + /// Whether an [enableOverride] is present on this interface. + final bool hasEnableOverride; + + /// If asserted, then clocks will be enabled regardless of the `enable` + /// signal. + /// + /// Presence is controlled by [hasEnableOverride]. + Logic? get enableOverride => tryPort('en_override'); + + /// Indicates whether clock gating logic [isPresent] or not. If it is not, + /// then no clock gating will occur and no clock gating logic will be + /// generated. + final bool isPresent; + + /// A default implementation for clock gating, effectively just an AND of the + /// clock and the enable signal, with an optional [enableOverride]. + static Logic defaultGenerateGatedClock( + ClockGateControlInterface intf, + Logic clk, + Logic enable, + ) => + clk & + flop( + ~clk, + [ + enable, + if (intf.hasEnableOverride) intf.enableOverride!, + ].swizzle().or()); + + /// A function that generates the gated clock signal based on the provided + /// `intf`, `clk`, and `enable` signals. + final Logic Function( + ClockGateControlInterface intf, + Logic clk, + Logic enable, + ) gatedClockGenerator; + + /// Constructs a [ClockGateControlInterface] with the provided arguments. + /// + /// If [isPresent] is false, then no clock gating will occur and no clock + /// gating logic will be generated. + /// + /// If [hasEnableOverride] is true, then an additional [enableOverride] port + /// will be generated. + /// + /// [additionalPorts] can optionally be added to this interface, which can be + /// useful in conjunction with a custom [gatedClockGenerator]. As the + /// interface is punched through hierarchies, any modules using this interface + /// will automatically include the [additionalPorts] and use the custom + /// [gatedClockGenerator] for clock gating logic. + ClockGateControlInterface({ + this.isPresent = true, + this.hasEnableOverride = false, + List? additionalPorts, + this.gatedClockGenerator = defaultGenerateGatedClock, + }) : super(portsFromProvider: [ + if (hasEnableOverride) Port('en_override'), + ...?additionalPorts, + ]); + + /// Creates a clone of [otherInterface] with the same configuration, including + /// any `additionalPorts` and `gatedClockGenerator` function. This should be + /// used to replicate interface configuration through hierarchies to carry + /// configuration information. + /// + /// If [isPresent] is provided, then it will override the [isPresent] value + /// from [otherInterface]. + /// + /// If a [gatedClockGenerator] is provided, then it will override the + /// [gatedClockGenerator] function from [otherInterface]. + ClockGateControlInterface.clone( + ClockGateControlInterface super.otherInterface, { + bool? isPresent, + Logic Function( + ClockGateControlInterface intf, + Logic clk, + Logic enable, + )? gatedClockGenerator, + }) : hasEnableOverride = otherInterface.hasEnableOverride, + isPresent = isPresent ?? otherInterface.isPresent, + gatedClockGenerator = + gatedClockGenerator ?? otherInterface.gatedClockGenerator, + super.clone(); +} + +/// A generic and configurable clock gating block. +class ClockGate extends Module { + /// An internal cache for controlled signals to avoid duplicating them. + final Map _controlledCache = {}; + + /// Returns a (potentially) delayed (by one cycle) version of [original] if + /// [delayControlledSignals] is true and the clock gating [isPresent]. This is + /// the signal that should be used as inputs to logic depending on the + /// [gatedClk]. + /// + /// If a [resetValue] is provided, then the signal will be reset to that value + /// when the clock gating is reset. + Logic controlled(Logic original, {dynamic resetValue}) { + if (!isPresent || !delayControlledSignals) { + return original; + } + + if (_controlledCache.containsKey(original)) { + return _controlledCache[original]!; + } else { + final o = super.addOutput( + _uniquifier.getUniqueName(initialName: '${original.name}_delayed')); + + _controlledCache[original] = o; + + o <= + flop( + _freeClk, + reset: _reset, + resetValue: resetValue, + super.addInput( + _uniquifier.getUniqueName(initialName: original.name), + original, + width: original.width, + ), + ); + + return o; + } + } + + /// A uniquifier for ports to ensure that they are unique as they punch via + /// [controlled]. + final _uniquifier = Uniquifier(); + + // override the addInput and addOutput functions for uniquification purposes + + @override + Logic addInput(String name, Logic x, {int width = 1}) { + _uniquifier.getUniqueName(initialName: name, reserved: true); + return super.addInput(name, x, width: width); + } + + @override + Logic addOutput(String name, {int width = 1}) { + _uniquifier.getUniqueName(initialName: name, reserved: true); + return super.addOutput(name, width: width); + } + + /// The gated clock output. + late final Logic gatedClk; + + /// Reset for all internal logic. + late final Logic? _reset; + + /// The enable signal provided as an input. + late final Logic _enable; + + /// The free clock signal provided as an input. + late final Logic _freeClk; + + /// The control interface for the clock gating, if provided. + late final ClockGateControlInterface? _controlIntf; + + /// Indicates whether the clock gating is present. If it is not present, then + /// the [gatedClk] is directly connected to the free clock and the + /// [controlled] signals are not delayed. + bool get isPresent => + _controlIntf?.isPresent ?? + // if no interface is provided, then _controlInterface is initialized with + // `isPresent` as true, so if this is null then there is no clock gating + false; + + /// Indicates whether the controlled signals are delayed by 1 cycle. If this + /// is false, or clock gating is not [isPresent], then the [controlled] + /// signals are not delayed. + final bool delayControlledSignals; + + /// Constructs a clock gating block where [enable] controls whether the + /// [gatedClk] is connected directly to the [freeClk] or is gated off. + /// + /// If [controlIntf] is provided, then the clock gating can be controlled + /// externally (for example whether the clock gating [isPresent] or using an + /// override signal to force clocks enabled). If [controlIntf] is not + /// provided, then the clock gating is always present. + /// + /// If [delayControlledSignals] is true, then any signals that are + /// [controlled] by the clock gating will be delayed by 1 cycle. This can be + /// helpful for timing purposes to avoid ungating the clock on the same cycle + /// as the signal is used. Using the [controlled] signals helps turn on or off + /// the delay across all applicable signals via a single switch: + /// [delayControlledSignals]. + /// + /// The [gatedClk] is automatically enabled during [reset] (if provided) so + /// that synchronous resets work properly, and the [enable] is extended for an + /// appropriate duration (if [delayControlledSignals]) to ensure proper + /// capture of data. + ClockGate( + Logic freeClk, { + required Logic enable, + Logic? reset, + ClockGateControlInterface? controlIntf, + this.delayControlledSignals = false, + super.name = 'clock_gate', + }) { + // if this clock gating is not intended to be present, then just do nothing + if (!(controlIntf?.isPresent ?? true)) { + _controlIntf = null; + gatedClk = freeClk; + return; + } + + _freeClk = addInput('freeClk', freeClk); + _enable = addInput('enable', enable); + + if (reset != null) { + _reset = addInput('reset', reset); + } else { + _reset = null; + } + + if (controlIntf == null) { + // if we are not provided an interface, make our own to use with default + // configuration + _controlIntf = ClockGateControlInterface(); + } else { + _controlIntf = ClockGateControlInterface.clone(controlIntf) + ..pairConnectIO(this, controlIntf, PairRole.consumer); + } + + gatedClk = addOutput('gatedClk'); + + _buildLogic(); + } + + /// Build the internal logic for handling enabling the gated clock. + void _buildLogic() { + var internalEnable = _enable; + + if (_reset != null) { + // we want to enable the clock during reset so that synchronous resets + // work properly + internalEnable |= _reset!; + } + + if (delayControlledSignals) { + // extra if there's a delay on the inputs relative to the enable + internalEnable |= flop( + _freeClk, + _enable, + reset: _reset, + resetValue: 1, // during reset, keep the clock enabled + ); + } + + gatedClk <= + _controlIntf! + .gatedClockGenerator(_controlIntf!, _freeClk, internalEnable); + } +} diff --git a/test/clock_gating_test.dart b/test/clock_gating_test.dart new file mode 100644 index 00000000..cae925c8 --- /dev/null +++ b/test/clock_gating_test.dart @@ -0,0 +1,275 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// clock_gating_test.dart +// Tests for clock gating. +// +// 2024 September 18 +// Author: Max Korbel + +import 'dart:async'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/src/clock_gating.dart'; +import 'package:rohd_vf/rohd_vf.dart'; +import 'package:test/test.dart'; + +class CustomClockGateMacro extends Module with CustomSystemVerilog { + Logic get gatedClk => output('gatedClk'); + + CustomClockGateMacro({ + required Logic clk, + required Logic en, + required Logic override, + required Logic anotherOverride, + }) : super(name: 'custom_clock_gate_macro') { + clk = addInput('clk', clk); + en = addInput('en', en); + override = addInput('override', override); + anotherOverride = addInput('another_override', anotherOverride); + + addOutput('gatedClk'); + + // simulation-only behavior + gatedClk <= clk & flop(~clk, en | override | anotherOverride); + } + + @override + String instantiationVerilog(String instanceType, String instanceName, + Map inputs, Map outputs) => + '`CUSTOM_CLOCK_GATE(' + '${outputs['gatedClk']}, ' + '${inputs['clk']}, ' + '${inputs['en']}, ' + '${inputs['override']}, ' + '${inputs['another_override']}' + ')'; +} + +class CustomClockGateControlInterface extends ClockGateControlInterface { + Logic get anotherOverride => port('anotherOverride'); + + CustomClockGateControlInterface({super.isPresent}) + : super( + hasEnableOverride: true, + additionalPorts: [ + Port('anotherOverride'), + ], + gatedClockGenerator: (intf, clk, enable) => CustomClockGateMacro( + clk: clk, + en: enable, + override: intf.enableOverride!, + anotherOverride: intf.port('anotherOverride'), + ).gatedClk); +} + +class CounterWithSimpleClockGate extends Module { + Logic get count => output('count'); + + /// A probe for clock gating. + late final ClockGate _clkGate; + + CounterWithSimpleClockGate(Logic clk, Logic incr, Logic reset, + {bool withDelay = true, ClockGateControlInterface? cgIntf}) + : super(name: 'clk_gated_counter') { + if (cgIntf != null) { + cgIntf = ClockGateControlInterface.clone(cgIntf) + ..pairConnectIO(this, cgIntf, PairRole.consumer); + } + + clk = addInput('clk', clk); + incr = addInput('incr', incr); + reset = addInput('reset', reset); + + final clkEnable = incr; + _clkGate = ClockGate( + clk, + enable: clkEnable, + reset: reset, + controlIntf: cgIntf, + delayControlledSignals: withDelay, + ); + + final count = addOutput('count', width: 8); + count <= + flop( + _clkGate.gatedClk, + count + _clkGate.controlled(incr).zeroExtend(count.width), + reset: reset, + ); + } +} + +void main() { + tearDown(() async { + await Simulator.reset(); + }); + + test('custom clock gating port', () async { + final cgIntf = CustomClockGateControlInterface(); + + cgIntf.enableOverride!.inject(0); + cgIntf.anotherOverride.inject(1); + + final clk = SimpleClockGenerator(10).clk; + final incr = Logic()..inject(0); + final reset = Logic(); + + final counter = CounterWithSimpleClockGate( + clk, + incr, + reset, + cgIntf: cgIntf, + ); + + await counter.build(); + + final sv = counter.generateSynth(); + expect(sv, contains('anotherOverride')); + expect(sv, contains('CUSTOM_CLOCK_GATE')); + + // ignore: invalid_use_of_protected_member + expect(counter.tryInput('anotherOverride'), isNotNull); + + Simulator.setMaxSimTime(500); + unawaited(Simulator.run()); + + reset.inject(1); + await clk.waitCycles(3); + reset.inject(0); + await clk.waitCycles(3); + + incr.inject(1); + await clk.waitCycles(5); + incr.inject(0); + await clk.waitCycles(5); + + expect(counter.count.value.toInt(), 5); + + cgIntf.anotherOverride.inject(1); + + await counter._clkGate.gatedClk.nextPosedge; + final t1 = Simulator.time; + await counter._clkGate.gatedClk.nextPosedge; + expect(Simulator.time, t1 + 10); + + cgIntf.anotherOverride.inject(0); + + unawaited(counter._clkGate.gatedClk.nextPosedge.then((_) { + fail('Expected a gated clock, no more toggles'); + })); + + await clk.waitCycles(5); + + await Simulator.endSimulation(); + }); + + group('basic clock gating', () { + final clockGatingTypes = { + 'none': () => null, + 'normal': ClockGateControlInterface.new, + 'normal not present': () => ClockGateControlInterface(isPresent: false), + 'override': () => ClockGateControlInterface(hasEnableOverride: true), + 'custom': CustomClockGateControlInterface.new, + }; + + for (final withDelay in [true, false]) { + for (final cgType in clockGatingTypes.entries) { + final hasOverride = cgType.value()?.hasEnableOverride ?? false; + for (final enOverride in [ + if (hasOverride) true, + false, + ]) { + test( + [ + cgType.key, + if (withDelay) 'with delay', + if (hasOverride) 'override: $enOverride', + ].join(', '), () async { + final cgIntf = cgType.value(); + + final overrideSignal = cgIntf is CustomClockGateControlInterface + ? (cgIntf.anotherOverride..inject(0)) + : cgIntf?.enableOverride; + cgIntf?.enableOverride?.inject(0); + + if (enOverride) { + overrideSignal?.inject(1); + } else { + overrideSignal?.inject(0); + } + + final clk = SimpleClockGenerator(10).clk; + final incr = Logic()..inject(0); + final reset = Logic(); + + final counter = CounterWithSimpleClockGate( + clk, + incr, + reset, + cgIntf: cgIntf, + withDelay: withDelay, + ); + + await counter.build(); + + // WaveDumper(counter); + + var clkToggleCount = 0; + counter._clkGate.gatedClk.posedge.listen((_) { + clkToggleCount++; + }); + + Simulator.setMaxSimTime(500); + unawaited(Simulator.run()); + + reset.inject(1); + await clk.waitCycles(3); + reset.inject(0); + await clk.waitCycles(3); + + incr.inject(1); + await clk.waitCycles(5); + incr.inject(0); + await clk.waitCycles(5); + + expect(counter.count.value.toInt(), 5); + + if (counter._clkGate.isPresent && !enOverride) { + if (counter._clkGate.delayControlledSignals) { + expect(clkToggleCount, lessThanOrEqualTo(7 + 4)); + } else { + expect(clkToggleCount, lessThanOrEqualTo(6 + 4)); + } + } else { + expect(clkToggleCount, greaterThanOrEqualTo(14)); + } + + if (hasOverride) { + if (cgIntf is CustomClockGateControlInterface) { + cgIntf.anotherOverride.inject(0); + } + + cgIntf!.enableOverride!.inject(1); + + await counter._clkGate.gatedClk.nextPosedge; + final t1 = Simulator.time; + await counter._clkGate.gatedClk.nextPosedge; + expect(Simulator.time, t1 + 10); + + cgIntf.enableOverride!.inject(0); + + unawaited(counter._clkGate.gatedClk.nextPosedge.then((_) { + fail('Expected a gated clock, no more toggles'); + })); + + await clk.waitCycles(5); + } + + await Simulator.endSimulation(); + }); + } + } + } + }); +} diff --git a/test/example_test.dart b/test/example_test.dart new file mode 100644 index 00000000..53b6ef20 --- /dev/null +++ b/test/example_test.dart @@ -0,0 +1,21 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// example_test.dart +// Tests that the examples run. +// +// 2024 September 24 +// Author: Max Korbel + +import 'package:test/test.dart'; + +import '../example/clock_gating_example.dart' as clock_gating_example; +import '../example/example.dart' as example; + +void main() { + test('examples run', () async { + for (final exMain in [example.main, clock_gating_example.main]) { + await exMain(noPrint: true); + } + }); +} From e7c90934da9389bc13ff4196b3a926ea791f6d62 Mon Sep 17 00:00:00 2001 From: soneryaldiz <56893713+soneryaldiz@users.noreply.github.com> Date: Wed, 23 Oct 2024 10:07:34 -0700 Subject: [PATCH 4/5] Fixed Point Value system (#99) --- .gitignore | 1 + doc/components/fixed_point.md | 7 + lib/src/arithmetic/arithmetic.dart | 1 + .../arithmetic/values/fixed_point_value.dart | 258 ++++++++++++++++++ .../values/fixed_point_value_test.dart | 247 +++++++++++++++++ 5 files changed, 514 insertions(+) create mode 100644 doc/components/fixed_point.md create mode 100644 lib/src/arithmetic/values/fixed_point_value.dart create mode 100644 test/arithmetic/values/fixed_point_value_test.dart diff --git a/.gitignore b/.gitignore index 04504631..26c5e66e 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ tmp* confapp/.vscode/* *tracker.json *tracker.log +*.sv # Exceptions !.vscode/extensions.json diff --git a/doc/components/fixed_point.md b/doc/components/fixed_point.md new file mode 100644 index 00000000..e5312c1e --- /dev/null +++ b/doc/components/fixed_point.md @@ -0,0 +1,7 @@ +# Fixed-Point Arithmetic + +Fixed-point binary representation of numbers is useful several applications including digital signal processing and embedded systems. As a first step towards enabling fixed-point components, we created a new value system `FixedPointValue` similar to `LogicValue`. + +## FixedPointValue + +A `FixedPointValue` represents a signed or unsigned fixed-point value following the Q notation (Qm.n format) as introduced by [Texas Instruments](https://www.ti.com/lit/ug/spru565b/spru565b.pdf). It comprises an optional sign, integer part and/or a fractional part. `FixedPointValue`s can be constructed from individual fields or from a Dart `double`, converted to Dart `double`, can be compared and can be operated on (+, -, *, /). diff --git a/lib/src/arithmetic/arithmetic.dart b/lib/src/arithmetic/arithmetic.dart index 6cfff87e..ca1237ae 100644 --- a/lib/src/arithmetic/arithmetic.dart +++ b/lib/src/arithmetic/arithmetic.dart @@ -13,3 +13,4 @@ export 'ones_complement_adder.dart'; export 'parallel_prefix_operations.dart'; export 'ripple_carry_adder.dart'; export 'sign_magnitude_adder.dart'; +export 'values/fixed_point_value.dart'; diff --git a/lib/src/arithmetic/values/fixed_point_value.dart b/lib/src/arithmetic/values/fixed_point_value.dart new file mode 100644 index 00000000..f4a22e4f --- /dev/null +++ b/lib/src/arithmetic/values/fixed_point_value.dart @@ -0,0 +1,258 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// fixed_point_value.dart +// Representation of fixed-point values. +// +// 2024 September 24 +// Authors: +// Soner Yaldiz + +import 'dart:math'; +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// An immutable representation of (un)signed fixed-point values following +/// Q notation (Qm.n format) as introduced by +/// (Texas Instruments)[https://www.ti.com/lit/ug/spru565b/spru565b.pdf]. +@immutable +class FixedPointValue implements Comparable { + /// The fixed point value bit storage in two's complement. + late final LogicValue value; + + /// [signed] indicates whether the representation is signed. + final bool signed; + + /// [m] is the number of bits reserved for the integer part. + final int m; + + /// [n] is the number of bits reserved for the fractional part. + final int n; + + /// Returns true if the number is negative. + bool isNegative() => signed & (value[-1] == LogicValue.one); + + /// Constructs [FixedPointValue] from sign, integer and fraction values. + FixedPointValue( + {required this.value, + required this.signed, + required this.m, + required this.n}) { + if (value == LogicValue.empty) { + throw RohdHclException('Zero width is not allowed.'); + } + final w = signed ? m + n + 1 : m + n; + if (w != value.width) { + throw RohdHclException('Width must be (sign) + m + n.'); + } + } + + /// Expands the bit width of integer and fractional parts. + LogicValue expandWidth({required bool sign, int m = 0, int n = 0}) { + if ((m < 0) | (n < 0)) { + throw RohdHclException('Input width must be non-negative.'); + } + if ((m > 0) & (m < this.m)) { + throw RohdHclException('Integer width is larger than input.'); + } + if ((n > 0) & (n < this.n)) { + throw RohdHclException('Fraction width is larger than input.'); + } + var newValue = value; + if (m >= this.m) { + if (signed) { + newValue = newValue.signExtend(newValue.width + m - this.m); + } else { + newValue = newValue.zeroExtend(newValue.width + m - this.m); + if (sign) { + newValue = newValue.zeroExtend(newValue.width + 1); + } + } + } + if (n > this.n) { + newValue = + newValue.reversed.zeroExtend(newValue.width + n - this.n).reversed; + } + return newValue; + } + + /// Returns a negative integer if `this` less than [other], + /// a positive integer if `this` greater than [other], + /// and zero if `this` and [other] are equal. + @override + int compareTo(Object other) { + if (other is! FixedPointValue) { + throw RohdHclException('Input must be of type FixedPointValue'); + } + if (!value.isValid | !other.value.isValid) { + throw RohdHclException('Inputs must be valid.'); + } + final s = signed | other.signed; + final m = max(this.m, other.m); + final n = max(this.n, other.n); + final val1 = expandWidth(sign: s, m: m, n: n); + final val2 = other.expandWidth(sign: s, m: m, n: n); + final comp = val1.compareTo(val2); + if (comp == 0) { + return comp; + } else if (!isNegative() & !other.isNegative()) { + return comp; + } else if (!isNegative() & other.isNegative()) { + return 1; + } else if (isNegative() & !other.isNegative()) { + return -1; + } else { + return -comp; + } + } + + /// Equal-to operation that returns a LogicValue. + LogicValue eq(FixedPointValue other) => + compareTo(other) == 0 ? LogicValue.one : LogicValue.zero; + + /// Not equal-to operation that returns a LogicValue. + LogicValue neq(FixedPointValue other) => + compareTo(other) != 0 ? LogicValue.one : LogicValue.zero; + + /// Less-than operation that returns a LogicValue. + LogicValue operator <(FixedPointValue other) => + compareTo(other) < 0 ? LogicValue.one : LogicValue.zero; + + /// Less-than operation that returns a LogicValue. + LogicValue operator <=(FixedPointValue other) => + compareTo(other) <= 0 ? LogicValue.one : LogicValue.zero; + + /// Less-than operation that returns a LogicValue. + LogicValue operator >(FixedPointValue other) => + compareTo(other) > 0 ? LogicValue.one : LogicValue.zero; + + /// Less-than operation that returns a LogicValue. + LogicValue operator >=(FixedPointValue other) => + compareTo(other) >= 0 ? LogicValue.one : LogicValue.zero; + + @override + int get hashCode => + value.hashCode ^ signed.hashCode ^ m.hashCode ^ n.hashCode; + + @override + bool operator ==(Object other) { + if (other is! FixedPointValue) { + return false; + } + return compareTo(other) == 0; + } + + /// Constructs [FixedPointValue] of a Dart [double] rounding away from zero. + factory FixedPointValue.ofDouble(double val, + {required bool signed, required int m, required int n}) { + if (!signed & (val < 0)) { + throw RohdHclException('Negative input not allowed with unsigned'); + } + final integerValue = (val * pow(2, n)).toInt(); + final w = signed ? 1 + m + n : m + n; + final v = LogicValue.ofInt(integerValue, w); + return FixedPointValue(value: v, signed: signed, m: m, n: n); + } + + /// Converts a fixed-point value to a Dart [double]. + double toDouble() { + if (m + n > 52) { + throw RohdHclException('Fixed-point value is too wide to convert.'); + } + if (!this.value.isValid) { + throw RohdHclException('Inputs must be valid.'); + } + BigInt number; + if (isNegative()) { + number = (~(this.value - 1)).toBigInt(); + } else { + number = this.value.toBigInt(); + } + final value = number.toDouble() / pow(2, n).toDouble(); + return isNegative() ? -value : value; + } + + /// Addition operation that returns a FixedPointValue. + /// The result is signed if one of the operands is signed. + /// The result integer has the max integer width of the operands plus one. + /// The result fraction has the max fractional width of the operands. + FixedPointValue operator +(FixedPointValue other) { + if (!value.isValid | !other.value.isValid) { + throw RohdHclException('Inputs must be valid.'); + } + final s = signed | other.signed; + final nr = max(n, other.n); + final mr = max(m, other.m) + 1; + final val1 = expandWidth(sign: s, m: mr, n: nr); + final val2 = other.expandWidth(sign: s, m: mr, n: nr); + return FixedPointValue(value: val1 + val2, signed: s, m: mr, n: nr); + } + + /// Subtraction operation that returns a FixedPointValue. + /// The result is always signed. + /// The result integer has the max integer width of the operands plus one. + /// The result fraction has the max fractional width of the operands. + FixedPointValue operator -(FixedPointValue other) { + if (!value.isValid | !other.value.isValid) { + throw RohdHclException('Inputs must be valid.'); + } + const s = true; + final nr = max(n, other.n); + final mr = max(m, other.m) + 1; + final val1 = expandWidth(sign: s, m: mr, n: nr); + final val2 = other.expandWidth(sign: s, m: mr, n: nr); + return FixedPointValue(value: val1 - val2, signed: s, m: mr, n: nr); + } + + /// Multiplication operation that returns a FixedPointValue. + /// The result is signed if one of the operands is signed. + /// The result fraction width is the sum of fraction widths of operands. + FixedPointValue operator *(FixedPointValue other) { + if (!value.isValid | !other.value.isValid) { + throw RohdHclException('Inputs must be valid.'); + } + final s = signed | other.signed; + final mr = s ? m + other.m + 1 : m + other.m; + final nr = n + other.n; + final tr = mr + nr; + final val1 = expandWidth(sign: s, m: tr - n); + final val2 = other.expandWidth(sign: s, m: tr - other.n); + return FixedPointValue(value: val1 * val2, signed: s, m: mr, n: nr); + } + + /// Division operation that returns a FixedPointValue. + /// The result is signed if one of the operands is signed. + /// The result integer width is the sum of dividend integer width and divisor + /// fraction width. The result fraction width is the sum of dividend fraction + /// width and divisor integer width. + FixedPointValue operator /(FixedPointValue other) { + if (!value.isValid | !other.value.isValid) { + throw RohdHclException('Inputs must be valid.'); + } + final s = signed | other.signed; + // extend integer width for max negative number + final m1 = s ? m + 1 : m; + final m2 = s ? other.m + 1 : other.m; + final mr = m1 + other.n; + final nr = n + m2; + final tr = mr + nr; + var val1 = expandWidth(sign: s, m: m1, n: tr - m1); + var val2 = other.expandWidth(sign: s, m: tr - other.n); + // Convert to positive as needed + if (s) { + if (val1[-1] == LogicValue.one) { + val1 = ~(val1 - 1); + } + if (val2[-1] == LogicValue.one) { + val2 = ~(val2 - 1); + } + } + var val = val1 / val2; + // Convert to negative as needed + if (isNegative() != other.isNegative()) { + val = (~val) + 1; + } + return FixedPointValue(value: val, signed: s, m: mr, n: nr); + } +} diff --git a/test/arithmetic/values/fixed_point_value_test.dart b/test/arithmetic/values/fixed_point_value_test.dart new file mode 100644 index 00000000..f19a3c25 --- /dev/null +++ b/test/arithmetic/values/fixed_point_value_test.dart @@ -0,0 +1,247 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// fixed_point_value_test.dart +// Tests of fixed-point value representation +// +// 2024 September 24 +// Authors: +// Soner Yaldiz + +import 'dart:math'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; +import 'package:test/test.dart'; + +void main() { + test('Constructor smoke', () { + final corners = [ + // value, signed, m, n, expected width + (LogicValue.ofInt(15, 8), true, 4, 3, 8), + (LogicValue.ofInt(15, 7), false, 4, 3, 7), + (LogicValue.filled(64, LogicValue.one), false, 0, 64, 64), + (LogicValue.filled(128, LogicValue.one), false, 128, 0, 128), + ]; + for (var c = 0; c < corners.length; c++) { + final fxp = FixedPointValue( + value: corners[c].$1, + signed: corners[c].$2, + m: corners[c].$3, + n: corners[c].$4); + expect(corners[c].$1, fxp.value); + expect(fxp.signed, corners[c].$2); + expect(fxp.value.width, corners[c].$5); + } + }); + + test('expandWidth', () { + final corners = [ + // value, signed, m, n, sign, m, n, result + ('01111111', true, 4, 3, true, 4, 3, '01111111'), + ('01111111', true, 4, 3, true, 6, 4, '00011111110'), + ('10000111', true, 4, 3, true, 6, 4, '11100001110'), + ('1111111', false, 4, 3, false, 6, 4, '0011111110'), + ('0111', true, 0, 3, true, 0, 3, '0111'), + ('0111', true, 0, 3, true, 0, 5, '011100'), + ('0111', true, 0, 3, true, 2, 3, '000111'), + ('1000', true, 0, 3, true, 2, 3, '111000'), + ('0111', true, 3, 0, true, 3, 0, '0111'), + ('0111', true, 3, 0, true, 4, 0, '00111'), + ('0111', true, 3, 0, true, 3, 1, '01110'), + ('1100', true, 3, 0, true, 4, 2, '1110000'), + ]; + for (var c = 0; c < corners.length; c++) { + final fxp = FixedPointValue( + value: LogicValue.ofString(corners[c].$1), + signed: corners[c].$2, + m: corners[c].$3, + n: corners[c].$4); + final value = fxp.expandWidth( + sign: corners[c].$5, m: corners[c].$6, n: corners[c].$7); + expect(value, LogicValue.ofString(corners[c].$8), + reason: value.bitString); + } + }); + + test('compareTo', () { + final corners = [ + // value, sign, m, n, value, sign, m, n, result + // pos pos + ('00111', true, 2, 2, '0001110', true, 3, 3, 0), + ('00111', true, 2, 2, '0000110', true, 3, 3, greaterThan(0)), + ('00111', true, 2, 2, '0010110', true, 3, 3, lessThan(0)), + ('0111', false, 2, 2, '0001110', true, 3, 3, 0), + ('0111', false, 2, 2, '0000110', true, 3, 3, greaterThan(0)), + ('0111', false, 2, 2, '0010110', true, 3, 3, lessThan(0)), + ('01111', true, 2, 2, '1000000', true, 3, 3, greaterThan(0)), + ('11000', true, 2, 2, '1111000', true, 3, 3, greaterThan(0)), + ('11110', true, 2, 2, '1111000', true, 3, 3, lessThan(0)), + ('10000', true, 2, 2, '0111000', true, 3, 3, lessThan(0)), + ]; + for (var c = 0; c < corners.length; c++) { + final fxp1 = FixedPointValue( + value: LogicValue.ofString(corners[c].$1), + signed: corners[c].$2, + m: corners[c].$3, + n: corners[c].$4); + final fxp2 = FixedPointValue( + value: LogicValue.ofString(corners[c].$5), + signed: corners[c].$6, + m: corners[c].$7, + n: corners[c].$8); + expect(fxp1.compareTo(fxp2), corners[c].$9); + } + }); + + test('ofDouble toDouble', () { + final corners = [ + // value, m, n, double + ('00000000', 4, 3, 0.0), + ('11111111', 7, 0, -1.0), + ('00011010', 4, 3, 3.25), + ('11110010', 4, 3, -1.75), + ('1000', 0, 3, -1.0), + ('10000', 1, 3, -2.0), + ('1100', 0, 3, -0.5), + ]; + for (var c = 0; c < corners.length; c++) { + final number = corners[c].$4; + final fxp = FixedPointValue.ofDouble(number, + signed: true, m: corners[c].$2, n: corners[c].$3); + expect(fxp.value.bitString, corners[c].$1); + expect(fxp.toDouble(), number); + } + corners + ..clear() + ..addAll([ + // value, m, n, double + ('00000000', 5, 3, 0.0), + ('00001001', 5, 3, 1.125), + ('11111111', 5, 3, 31.875), + ('11111111', 8, 0, pow(2, 8).toDouble() - 1), + ]); + for (var c = 0; c < corners.length; c++) { + final number = corners[c].$4; + final fxp = FixedPointValue.ofDouble(number, + signed: false, m: corners[c].$2, n: corners[c].$3); + expect(fxp.value.bitString, corners[c].$1); + expect(fxp.toDouble(), number); + } + // Exhaustive unsigned + for (var i = 0; i < pow(2, 4); i++) { + for (var m = 0; m < 5; m++) { + final n = 4 - m; + final fxp = FixedPointValue( + value: LogicValue.ofInt(i, 4), signed: false, m: m, n: n); + expect(fxp.value.width, 4); + expect(fxp.toDouble(), i / pow(2, n)); + } + } + }); + + test('Comparison operators', () { + expect( + FixedPointValue.ofDouble(14.432, signed: false, m: 4, n: 2) + .eq(FixedPointValue.ofDouble(14.432, signed: false, m: 4, n: 2)), + LogicValue.one); + expect( + FixedPointValue.ofDouble(14.432, signed: false, m: 4, n: 2) + .neq(FixedPointValue.ofDouble(14.432, signed: false, m: 4, n: 2)), + LogicValue.zero); + expect( + FixedPointValue.ofDouble(13.454, signed: false, m: 4, n: 2) > + (FixedPointValue.ofDouble(14, signed: false, m: 4, n: 2)), + LogicValue.zero); + expect( + FixedPointValue.ofDouble(13.454, signed: false, m: 4, n: 2) >= + (FixedPointValue.ofDouble(14, signed: false, m: 4, n: 2)), + LogicValue.zero); + expect( + FixedPointValue.ofDouble(13.454, signed: false, m: 4, n: 2) < + (FixedPointValue.ofDouble(14, signed: false, m: 4, n: 2)), + LogicValue.one); + expect( + FixedPointValue.ofDouble(13.454, signed: false, m: 4, n: 2) <= + (FixedPointValue.ofDouble(14, signed: false, m: 4, n: 2)), + LogicValue.one); + expect( + FixedPointValue.ofDouble(14, signed: false, m: 4, n: 2) <= + (FixedPointValue.ofDouble(14, signed: false, m: 4, n: 2)), + LogicValue.one); + expect( + FixedPointValue.ofDouble(14, signed: false, m: 4, n: 2) >= + (FixedPointValue.ofDouble(14, signed: false, m: 4, n: 2)), + LogicValue.one); + }); + + test('Math', () { + const w = 4; + FixedPointValue fxp; + FixedPointValue fxp1; + FixedPointValue fxp2; + for (var i1 = 0; i1 < pow(2, w); i1++) { + for (var i2 = 1; i2 < pow(2, w); i2++) { + for (var m1 = 0; m1 < w; m1++) { + for (var m2 = 0; m2 < w; m2++) { + for (var s1 = 0; s1 < 2; s1++) { + for (var s2 = 0; s2 < 2; s2++) { + final n1 = s1 == 0 ? w - m1 - 1 : w - m1; + final n2 = s2 == 0 ? w - m2 - 1 : w - m2; + fxp1 = FixedPointValue( + value: LogicValue.ofInt(i1, w), + signed: s1 == 0, + m: m1, + n: n1); + fxp2 = FixedPointValue( + value: LogicValue.ofInt(i2, w), + signed: s2 == 0, + m: m2, + n: n2); + + // add + fxp = fxp1 + fxp2; + expect(fxp.toDouble(), fxp1.toDouble() + fxp2.toDouble(), + reason: '+'); + expect(fxp.n, max(n1, n2)); + expect(fxp.m, max(m1, m2) + 1); + + // subtract + fxp = fxp1 - fxp2; + expect(fxp.toDouble(), fxp1.toDouble() - fxp2.toDouble(), + reason: '-'); + expect(fxp.n, max(n1, n2)); + expect(fxp.m, max(m1, m2) + 1); + + // multiply + fxp = fxp1 * fxp2; + expect(fxp.toDouble(), fxp1.toDouble() * fxp2.toDouble(), + reason: '${fxp1.toDouble()}*${fxp2.toDouble()}'); + expect(fxp.n, n1 + n2); + expect(fxp.m, s1 + s2 == 2 ? m1 + m2 : m1 + m2 + 1); + + // divide + fxp = fxp1 / fxp2; + final q = s1 + s2 == 2 ? n1 + m2 : n1 + m2 + 1; + double expectedValue; + if (i1 == 0) { + expectedValue = 0; + } else { + expectedValue = + ((fxp1.toDouble() / fxp2.toDouble()).abs() * pow(2, q)) + .floor() / + pow(2, q); + if (fxp1.toDouble() / fxp2.toDouble() < 0) { + expectedValue = -expectedValue; + } + } + expect(fxp.toDouble(), expectedValue, + reason: + '${fxp1.toDouble()}/${fxp2.toDouble()} = $expectedValue'); + } + } + } + } + } + } + }); +} From 8bf5d234bce211c1c8a6355256ccffc6b4212be3 Mon Sep 17 00:00:00 2001 From: Max Korbel Date: Thu, 24 Oct 2024 11:11:50 -0700 Subject: [PATCH 5/5] Refactoring of `FloatingPointValue` constructors for ease of use and extension (#110) Co-authored-by: Desmond A. Kirkpatrick --- .../floating_point/floating_point.dart | 2 +- .../floating_point/floating_point_logic.dart | 48 +- .../floating_point_16_value.dart | 82 +++ .../floating_point_32_value.dart | 89 +++ .../floating_point_64_value.dart | 81 +++ .../floating_point_8_value.dart | 159 +++++ .../floating_point_bf16_value.dart | 82 +++ .../floating_point_tf32_value.dart | 81 +++ .../floating_point_value.dart | 572 +++++------------- .../floating_point_values.dart | 10 + .../floating_point_adder_round_test.dart | 10 +- .../floating_point_adder_simple_test.dart | 68 +-- .../floating_point_value_test.dart | 101 ++-- 13 files changed, 889 insertions(+), 496 deletions(-) create mode 100644 lib/src/arithmetic/floating_point/floating_point_values/floating_point_16_value.dart create mode 100644 lib/src/arithmetic/floating_point/floating_point_values/floating_point_32_value.dart create mode 100644 lib/src/arithmetic/floating_point/floating_point_values/floating_point_64_value.dart create mode 100644 lib/src/arithmetic/floating_point/floating_point_values/floating_point_8_value.dart create mode 100644 lib/src/arithmetic/floating_point/floating_point_values/floating_point_bf16_value.dart create mode 100644 lib/src/arithmetic/floating_point/floating_point_values/floating_point_tf32_value.dart rename lib/src/arithmetic/floating_point/{ => floating_point_values}/floating_point_value.dart (53%) create mode 100644 lib/src/arithmetic/floating_point/floating_point_values/floating_point_values.dart diff --git a/lib/src/arithmetic/floating_point/floating_point.dart b/lib/src/arithmetic/floating_point/floating_point.dart index de0b3b40..d7cce900 100644 --- a/lib/src/arithmetic/floating_point/floating_point.dart +++ b/lib/src/arithmetic/floating_point/floating_point.dart @@ -4,4 +4,4 @@ export 'floating_point_adder_round.dart'; export 'floating_point_adder_simple.dart'; export 'floating_point_logic.dart'; -export 'floating_point_value.dart'; +export 'floating_point_values/floating_point_values.dart'; diff --git a/lib/src/arithmetic/floating_point/floating_point_logic.dart b/lib/src/arithmetic/floating_point/floating_point_logic.dart index fcdca310..1d9d4f6c 100644 --- a/lib/src/arithmetic/floating_point/floating_point_logic.dart +++ b/lib/src/arithmetic/floating_point/floating_point_logic.dart @@ -11,8 +11,7 @@ // import 'package:rohd/rohd.dart'; -import 'package:rohd_hcl/src/arithmetic/floating_point/floating_point_value.dart'; -import 'package:rohd_hcl/src/exceptions.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; /// Flexible floating point logic representation class FloatingPoint extends LogicStructure { @@ -84,19 +83,38 @@ class FloatingPoint64 extends FloatingPoint { mantissaWidth: FloatingPoint64Value.mantissaWidth); } -/// Eight-bit floating point representation for deep learning -class FloatingPoint8 extends FloatingPoint { - /// Calculate mantissa width and sanitize - static int _calculateMantissaWidth(int exponentWidth) { - final mantissaWidth = 7 - exponentWidth; - if (!FloatingPoint8Value.isLegal(exponentWidth, mantissaWidth)) { - throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); - } else { - return mantissaWidth; - } - } +/// Eight-bit floating point representation for deep learning: E4M3 +class FloatingPoint8E4M3 extends FloatingPoint { + /// Construct an 8-bit floating point number + FloatingPoint8E4M3() + : super( + mantissaWidth: FloatingPoint8E4M3Value.mantissaWidth, + exponentWidth: FloatingPoint8E4M3Value.exponentWidth); +} +/// Eight-bit floating point representation for deep learning: E5M2 +class FloatingPoint8E5M2 extends FloatingPoint { /// Construct an 8-bit floating point number - FloatingPoint8({required super.exponentWidth}) - : super(mantissaWidth: _calculateMantissaWidth(exponentWidth)); + FloatingPoint8E5M2() + : super( + mantissaWidth: FloatingPoint8E5M2Value.mantissaWidth, + exponentWidth: FloatingPoint8E5M2Value.exponentWidth); +} + +/// Sixteen-bit BF16 floating point representation +class FloatingPointBF16 extends FloatingPoint { + /// Construct a BF16 16-bit floating point number + FloatingPointBF16() + : super( + mantissaWidth: FloatingPointBF16Value.mantissaWidth, + exponentWidth: FloatingPointBF16Value.exponentWidth); +} + +/// Sixteen-bit floating point representation +class FloatingPoint16 extends FloatingPoint { + /// Construct a 16-bit floating point number + FloatingPoint16() + : super( + mantissaWidth: FloatingPoint16Value.mantissaWidth, + exponentWidth: FloatingPoint16Value.exponentWidth); } diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_16_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_16_value.dart new file mode 100644 index 00000000..b027cf97 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_16_value.dart @@ -0,0 +1,82 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_fp16_value.dart +// Implementation of FP16 Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of an FP16 floating-point value. +class FloatingPoint16Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 5; + + /// The mantissa width + static const int mantissaWidth = 10; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// Constructor for a single precision floating point value + FloatingPoint16Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// Return the [FloatingPoint16Value] representing the constant specified + factory FloatingPoint16Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPoint16Value.ofLogicValue( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + .value); + + /// [FloatingPoint16Value] constructor from string representation of + /// individual bitfields + FloatingPoint16Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPoint16Value] constructor from spaced string representation of + /// individual bitfields + FloatingPoint16Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPoint16Value] constructor from a single string representing + /// space-separated bitfields + FloatingPoint16Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPoint16Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPoint16Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPoint16Value] constructor from a set of [int]s of the binary + /// representation + FloatingPoint16Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPoint16Value] from a host double + factory FloatingPoint16Value.ofDouble(double inDouble) { + final fpv = FloatingPointValue.ofDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + return FloatingPoint16Value.ofLogicValue(fpv.value); + } + + /// Construct a [FloatingPoint16Value] from a Logic word + factory FloatingPoint16Value.ofLogicValue(LogicValue val) => + FloatingPointValue.buildOfLogicValue( + FloatingPoint16Value.new, exponentWidth, mantissaWidth, val); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_32_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_32_value.dart new file mode 100644 index 00000000..b0a1a32e --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_32_value.dart @@ -0,0 +1,89 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_32_value.dart +// Implementation of 32-bit Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'dart:typed_data'; +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of a single-precision floating-point value. +class FloatingPoint32Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 8; + + /// The mantissa width + static const int mantissaWidth = 23; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// Constructor for a single precision floating point value + FloatingPoint32Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// Return the [FloatingPoint32Value] representing the constant specified + factory FloatingPoint32Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPoint32Value.ofLogicValue( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + .value); + + /// [FloatingPoint32Value] constructor from string representation of + /// individual bitfields + FloatingPoint32Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPoint32Value] constructor from spaced string representation of + /// individual bitfields + FloatingPoint32Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPoint32Value] constructor from a single string representing + /// space-separated bitfields + FloatingPoint32Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPoint32Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPoint32Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPoint32Value] constructor from a set of [int]s of the binary + /// representation + FloatingPoint32Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPoint32Value] from a host double + factory FloatingPoint32Value.ofDouble(double inDouble) { + final byteData = ByteData(4)..setFloat32(0, inDouble); + final accum = byteData.buffer + .asUint8List() + .map((b) => LogicValue.ofInt(b, 32)) + .reduce((accum, v) => (accum << 8) | v); + + return FloatingPoint32Value( + sign: accum[-1], + exponent: accum.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: accum.slice(mantissaWidth - 1, 0)); + } + + /// Construct a [FloatingPoint32Value] from a Logic word + factory FloatingPoint32Value.ofLogicValue(LogicValue val) => + FloatingPointValue.buildOfLogicValue( + FloatingPoint32Value.new, exponentWidth, mantissaWidth, val); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_64_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_64_value.dart new file mode 100644 index 00000000..05a59e83 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_64_value.dart @@ -0,0 +1,81 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_64_value.dart +// Implementation of 64-bit Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'dart:typed_data'; + +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of a double-precision floating-point value. +class FloatingPoint64Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 11; + + /// The mantissa width + static const int mantissaWidth = 52; + + /// Constructor for a double precision floating point value + FloatingPoint64Value( + {required super.sign, required super.mantissa, required super.exponent}); + + /// Return the [FloatingPoint64Value] representing the constant specified + factory FloatingPoint64Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPoint64Value.ofLogicValue( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + .value); + + /// [FloatingPoint64Value] constructor from string representation of + /// individual bitfields + FloatingPoint64Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPoint64Value] constructor from spaced string representation of + /// individual bitfields + FloatingPoint64Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPoint64Value] constructor from a single string representing + /// space-separated bitfields + FloatingPoint64Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPoint64Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPoint64Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPoint64Value] constructor from a set of [int]s of the binary + /// representation + FloatingPoint64Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPoint64Value] from a host double + factory FloatingPoint64Value.ofDouble(double inDouble) { + final byteData = ByteData(8)..setFloat64(0, inDouble); + final accum = byteData.buffer + .asUint8List() + .map((b) => LogicValue.ofInt(b, 64)) + .reduce((accum, v) => (accum << 8) | v); + + return FloatingPoint64Value( + sign: accum[-1], + exponent: accum.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: accum.slice(mantissaWidth - 1, 0)); + } + + /// Construct a [FloatingPoint32Value] from a Logic word + factory FloatingPoint64Value.ofLogicValue(LogicValue val) => + FloatingPointValue.buildOfLogicValue( + FloatingPoint64Value.new, exponentWidth, mantissaWidth, val); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_8_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_8_value.dart new file mode 100644 index 00000000..594f4149 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_8_value.dart @@ -0,0 +1,159 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_8_value.dart +// Implementation of 8-bit Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'dart:math'; +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// The E4M3 representation of a 8-bit floating point value as defined in +/// [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433). +class FloatingPoint8E4M3Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 4; + + /// The mantissa width + static const int mantissaWidth = 3; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// The maximum value representable by the E4M3 format + static double get maxValue => 448.toDouble(); + + /// The minimum value representable by the E4M3 format + static double get minValue => pow(2, -9).toDouble(); + + /// Constructor for a double precision floating point value + FloatingPoint8E4M3Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// [FloatingPoint8E4M3Value] constructor from string representation of + /// individual bitfields + FloatingPoint8E4M3Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPoint8E4M3Value] constructor from spaced string representation of + /// individual bitfields + FloatingPoint8E4M3Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPoint8E4M3Value] constructor from a single string representing + /// space-separated bitfields + FloatingPoint8E4M3Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPoint8E4M3Value] constructor from a set of [BigInt]s of the + /// binary representation + FloatingPoint8E4M3Value.ofBigInts(super.exponent, super.mantissa, + {super.sign}) + : super.ofBigInts(); + + /// [FloatingPoint8E4M3Value] constructor from a set of [int]s of the binary + /// representation + FloatingPoint8E4M3Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPoint8E4M3Value] from a host double + factory FloatingPoint8E4M3Value.ofDouble(double inDouble) { + if ((inDouble.abs() > maxValue) | + ((inDouble != 0) & (inDouble.abs() < minValue))) { + throw RohdHclException('Number exceeds E4M3 range'); + } + final fpv = FloatingPointValue.ofDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + return FloatingPoint8E4M3Value( + sign: fpv.sign, exponent: fpv.exponent, mantissa: fpv.mantissa); + } + + /// Construct a [FloatingPoint8E4M3Value] from a Logic word + factory FloatingPoint8E4M3Value.ofLogicValue(LogicValue val) => + FloatingPointValue.buildOfLogicValue( + FloatingPoint8E4M3Value.new, exponentWidth, mantissaWidth, val); +} + +/// The E5M2 representation of a 8-bit floating point value as defined in +/// [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433). +class FloatingPoint8E5M2Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 5; + + /// The mantissa width + static const int mantissaWidth = 2; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// The maximum value representable by the E5M2 format + static double get maxValue => 57344.toDouble(); + + /// The minimum value representable by the E5M2 format + static double get minValue => pow(2, -16).toDouble(); + + /// Constructor for a double precision floating point value + FloatingPoint8E5M2Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// [FloatingPoint8E5M2Value] constructor from string representation of + /// individual bitfields + FloatingPoint8E5M2Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPoint8E5M2Value] constructor from spaced string representation of + /// individual bitfields + FloatingPoint8E5M2Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPoint8E5M2Value] constructor from a single string representing + /// space-separated bitfields + FloatingPoint8E5M2Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPoint8E5M2Value] constructor from a set of [BigInt]s of the + /// binary representation + FloatingPoint8E5M2Value.ofBigInts(super.exponent, super.mantissa, + {super.sign}) + : super.ofBigInts(); + + /// [FloatingPoint8E5M2Value] constructor from a set of [int]s of the binary + /// representation + FloatingPoint8E5M2Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPoint8E5M2Value] from a host double + factory FloatingPoint8E5M2Value.ofDouble(double inDouble) { + if ((inDouble.abs() > maxValue) | + ((inDouble != 0) & (inDouble.abs() < minValue))) { + throw RohdHclException('Number exceeds E5M2 range'); + } + final fpv = FloatingPointValue.ofDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + return FloatingPoint8E5M2Value( + sign: fpv.sign, exponent: fpv.exponent, mantissa: fpv.mantissa); + } + + /// Construct a [FloatingPoint8E5M2Value] from a Logic word + factory FloatingPoint8E5M2Value.ofLogicValue(LogicValue val) => + FloatingPointValue.buildOfLogicValue( + FloatingPoint8E5M2Value.new, exponentWidth, mantissaWidth, val); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_bf16_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_bf16_value.dart new file mode 100644 index 00000000..fdf0a6af --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_bf16_value.dart @@ -0,0 +1,82 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_bf16_value.dart +// Implementation of BF16 Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of a BF16 floating-point value. +class FloatingPointBF16Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 8; + + /// The mantissa width + static const int mantissaWidth = 7; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// Constructor for a single precision floating point value + FloatingPointBF16Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// Return the [FloatingPointBF16Value] representing the constant specified + factory FloatingPointBF16Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPointBF16Value.ofLogicValue( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + .value); + + /// [FloatingPointBF16Value] constructor from string representation of + /// individual bitfields + FloatingPointBF16Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPointBF16Value] constructor from spaced string representation of + /// individual bitfields + FloatingPointBF16Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPointBF16Value] constructor from a single string representing + /// space-separated bitfields + FloatingPointBF16Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPointBF16Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPointBF16Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPointBF16Value] constructor from a set of [int]s of the binary + /// representation + FloatingPointBF16Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPointBF16Value] from a host double + factory FloatingPointBF16Value.ofDouble(double inDouble) { + final fpv = FloatingPointValue.ofDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + + return FloatingPointBF16Value.ofLogicValue(fpv.value); + } + + /// Construct a [FloatingPointBF16Value] from a Logic word + factory FloatingPointBF16Value.ofLogicValue(LogicValue val) => + FloatingPointValue.buildOfLogicValue( + FloatingPointBF16Value.new, exponentWidth, mantissaWidth, val); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_tf32_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_tf32_value.dart new file mode 100644 index 00000000..5f024f05 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_tf32_value.dart @@ -0,0 +1,81 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +// floating_point_tf32_value.dart +// Implementation of TF32 Floating-Point value representations. +// +// 2024 October 15 +// Authors: +// Max Korbel +// Desmond A Kirkpatrick + +import 'package:meta/meta.dart'; +import 'package:rohd/rohd.dart'; +import 'package:rohd_hcl/rohd_hcl.dart'; + +/// A representation of a TF32 floating-point value. +class FloatingPointTF32Value extends FloatingPointValue { + /// The exponent width + static const int exponentWidth = 8; + + /// The mantissa width + static const int mantissaWidth = 10; + + @override + @protected + int get constrainedExponentWidth => exponentWidth; + + @override + @protected + int get constrainedMantissaWidth => mantissaWidth; + + /// Constructor for a single precision floating point value + FloatingPointTF32Value( + {required super.sign, required super.exponent, required super.mantissa}); + + /// Return the [FloatingPointTF32Value] representing the constant specified + factory FloatingPointTF32Value.getFloatingPointConstant( + FloatingPointConstants constantFloatingPoint) => + FloatingPointTF32Value.ofLogicValue( + FloatingPointValue.getFloatingPointConstant( + constantFloatingPoint, exponentWidth, mantissaWidth) + .value); + + /// [FloatingPointTF32Value] constructor from string representation of + /// individual bitfields + FloatingPointTF32Value.ofBinaryStrings( + super.sign, super.exponent, super.mantissa) + : super.ofBinaryStrings(); + + /// [FloatingPointTF32Value] constructor from spaced string representation of + /// individual bitfields + FloatingPointTF32Value.ofSpacedBinaryString(super.fp) + : super.ofSpacedBinaryString(); + + /// [FloatingPointTF32Value] constructor from a single string representing + /// space-separated bitfields + FloatingPointTF32Value.ofString(String fp, {super.radix}) + : super.ofString(fp, exponentWidth, mantissaWidth); + + /// [FloatingPointTF32Value] constructor from a set of [BigInt]s of the binary + /// representation + FloatingPointTF32Value.ofBigInts(super.exponent, super.mantissa, {super.sign}) + : super.ofBigInts(); + + /// [FloatingPointTF32Value] constructor from a set of [int]s of the binary + /// representation + FloatingPointTF32Value.ofInts(super.exponent, super.mantissa, {super.sign}) + : super.ofInts(); + + /// Numeric conversion of a [FloatingPointTF32Value] from a host double + factory FloatingPointTF32Value.ofDouble(double inDouble) { + final fpv = FloatingPointValue.ofDouble(inDouble, + exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); + return FloatingPointTF32Value.ofLogicValue(fpv.value); + } + + /// Construct a [FloatingPointTF32Value] from a Logic word + factory FloatingPointTF32Value.ofLogicValue(LogicValue val) => + FloatingPointValue.buildOfLogicValue( + FloatingPointTF32Value.new, exponentWidth, mantissaWidth, val); +} diff --git a/lib/src/arithmetic/floating_point/floating_point_value.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_value.dart similarity index 53% rename from lib/src/arithmetic/floating_point/floating_point_value.dart rename to lib/src/arithmetic/floating_point/floating_point_values/floating_point_value.dart index a3ca64d6..999a17ea 100644 --- a/lib/src/arithmetic/floating_point/floating_point_value.dart +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_value.dart @@ -10,7 +10,6 @@ // Desmond A Kirkpatrick { final int _maxExp; final int _minExp; - /// Factory (static) constructor of a [FloatingPointValue] from - /// sign, mantissa and exponent - factory FloatingPointValue( + /// A mapping from a `({exponentWidth, mantissaWidth})` record to a + /// constructor for a specific FloatingPointValue subtype. This map is used by + /// the [FloatingPointValue.withMappedSubtype] constructor to select the + /// appropriate constructor for a given set of widths. + /// + /// By default, this is populated with available subtypes from ROHD-HCL, but + /// it can be overridden or extended based on the user's needs. + static Map< + ({int exponentWidth, int mantissaWidth}), + FloatingPointValue Function( + {required LogicValue sign, + required LogicValue exponent, + required LogicValue mantissa})> subtypeConstructorMap = { + ( + exponentWidth: FloatingPoint32Value.exponentWidth, + mantissaWidth: FloatingPoint32Value.mantissaWidth + ): FloatingPoint32Value.new, + ( + exponentWidth: FloatingPoint64Value.exponentWidth, + mantissaWidth: FloatingPoint64Value.mantissaWidth + ): FloatingPoint64Value.new, + (exponentWidth: 4, mantissaWidth: 3): FloatingPoint8E4M3Value.new, + (exponentWidth: 5, mantissaWidth: 2): FloatingPoint8E5M2Value.new, + (exponentWidth: 5, mantissaWidth: 10): FloatingPoint16Value.new, + (exponentWidth: 8, mantissaWidth: 7): FloatingPointBF16Value.new, + (exponentWidth: 8, mantissaWidth: 10): FloatingPointTF32Value.new, + }; + + /// Constructor for a [FloatingPointValue] with a sign, exponent, and + /// mantissa. + @protected + FloatingPointValue( + {required this.sign, required this.exponent, required this.mantissa}) + : value = [sign, exponent, mantissa].swizzle(), + _bias = computeBias(exponent.width), + _minExp = computeMinExponent(exponent.width), + _maxExp = computeMaxExponent(exponent.width) { + if (sign.width != 1) { + throw RohdHclException('FloatingPointValue: sign width must be 1'); + } + if (constrainedMantissaWidth != null && + mantissa.width != constrainedMantissaWidth) { + throw RohdHclException('FloatingPointValue: mantissa width must be ' + '$constrainedMantissaWidth'); + } + if (constrainedExponentWidth != null && + exponent.width != constrainedExponentWidth) { + throw RohdHclException('FloatingPointValue: exponent width must be ' + '$constrainedExponentWidth'); + } + } + + /// Constructs a [FloatingPointValue] with a sign, exponent, and mantissa + /// using one of the builders provided from [subtypeConstructorMap] if + /// available, otherwise using the default constructor. + factory FloatingPointValue.withMappedSubtype( {required LogicValue sign, required LogicValue exponent, required LogicValue mantissa}) { - if (exponent.width == FloatingPoint32Value.exponentWidth && - mantissa.width == FloatingPoint32Value.mantissaWidth) { - return FloatingPoint32Value( - sign: sign, mantissa: mantissa, exponent: exponent); - } else if (exponent.width == FloatingPoint64Value._exponentWidth && - mantissa.width == FloatingPoint64Value._mantissaWidth) { - return FloatingPoint64Value( - sign: sign, mantissa: mantissa, exponent: exponent); - } else { - return FloatingPointValue.withConstraints( - sign: sign, mantissa: mantissa, exponent: exponent); - } - } + final key = (exponentWidth: exponent.width, mantissaWidth: mantissa.width); - /// [FloatingPointValue] constructor from a binary string representation of - /// individual bitfields - factory FloatingPointValue.ofBinaryStrings( - String sign, String exponent, String mantissa) { - if (sign.length != 1) { - throw RohdHclException('Sign string must be of length 1'); + if (subtypeConstructorMap.containsKey(key)) { + return subtypeConstructorMap[key]!( + sign: sign, exponent: exponent, mantissa: mantissa); } return FloatingPointValue( - sign: LogicValue.of(sign), - exponent: LogicValue.of(exponent), - mantissa: LogicValue.of(mantissa)); + sign: sign, exponent: exponent, mantissa: mantissa); } + /// Converts this [FloatingPointValue] to a [FloatingPointValue] with the same + /// sign, exponent, and mantissa using the constructor provided in + /// [subtypeConstructorMap] if available, otherwise using the default + /// constructor. + FloatingPointValue toMappedSubtype() => FloatingPointValue.withMappedSubtype( + sign: sign, exponent: exponent, mantissa: mantissa); + + /// [constrainedMantissaWidth] is the hard-coded mantissa width of the + /// sub-class of this floating-point value + @protected + int? get constrainedMantissaWidth => null; + + /// [constrainedExponentWidth] is the hard-coded exponent width of the + /// sub-class of this floating-point value + @protected + int? get constrainedExponentWidth => null; + + /// [FloatingPointValue] constructor from a binary string representation of + /// individual bitfields + FloatingPointValue.ofBinaryStrings( + String sign, String exponent, String mantissa) + : this( + sign: LogicValue.of(sign), + exponent: LogicValue.of(exponent), + mantissa: LogicValue.of(mantissa)); + /// [FloatingPointValue] constructor from a single binary string representing /// space-separated bitfields - factory FloatingPointValue.ofSeparatedBinaryStrings(String fp) { - final s = fp.split(' '); - if (s.length != 3) { - throw RohdHclException('FloatingPointValue requires three strings ' - 'to initialize'); - } - return FloatingPointValue.ofBinaryStrings(s[0], s[1], s[2]); - } + FloatingPointValue.ofSpacedBinaryString(String fp) + : this.ofBinaryStrings( + fp.split(' ')[0], fp.split(' ')[1], fp.split(' ')[2]); /// [FloatingPointValue] constructor from a radix-encoded string /// representation and the size of the exponent and mantissa - factory FloatingPointValue.ofString( - String fp, int exponentWidth, int mantissaWidth, - {int radix = 2}) { + FloatingPointValue.ofString(String fp, int exponentWidth, int mantissaWidth, + {int radix = 2}) + : this.ofBinaryStrings( + _extractBinaryStrings(fp, exponentWidth, mantissaWidth, radix).sign, + _extractBinaryStrings(fp, exponentWidth, mantissaWidth, radix) + .exponent, + _extractBinaryStrings(fp, exponentWidth, mantissaWidth, radix) + .mantissa); + + /// Helper function for extracting binary strings from a longer + /// binary string and the known exponent and mantissa widths. + static ({String sign, String exponent, String mantissa}) + _extractBinaryStrings( + String fp, int exponentWidth, int mantissaWidth, int radix) { final binaryFp = LogicValue.ofBigInt( BigInt.parse(fp, radix: radix), exponentWidth + mantissaWidth + 1) .bitString; - final (sign, exponent, mantissa) = ( - binaryFp.substring(0, 1), - binaryFp.substring(1, 1 + exponentWidth), - binaryFp.substring(1 + exponentWidth, 1 + exponentWidth + mantissaWidth) + return ( + sign: binaryFp.substring(0, 1), + exponent: binaryFp.substring(1, 1 + exponentWidth), + mantissa: binaryFp.substring( + 1 + exponentWidth, 1 + exponentWidth + mantissaWidth) ); - return FloatingPointValue.ofBinaryStrings(sign, exponent, mantissa); } + // TODO(desmonddak): toRadixString() would be useful, not limited to binary + /// [FloatingPointValue] constructor from a set of [BigInt]s of the binary /// representation and the size of the exponent and mantissa - factory FloatingPointValue.ofBigInts(BigInt exponent, BigInt mantissa, - {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) { - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - - return FloatingPointValue( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } + FloatingPointValue.ofBigInts(BigInt exponent, BigInt mantissa, + {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) + : this( + sign: LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + exponent: LogicValue.ofBigInt(exponent, exponentWidth), + mantissa: LogicValue.ofBigInt(mantissa, mantissaWidth)); /// [FloatingPointValue] constructor from a set of [int]s of the binary /// representation and the size of the exponent and mantissa - factory FloatingPointValue.ofInts(int exponent, int mantissa, - {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) { - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth), - LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth) - ); - - return FloatingPointValue( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } - - /// Constructor enabling subclasses. - FloatingPointValue.withConstraints( - {required this.sign, - required this.exponent, - required this.mantissa, - int? mantissaWidth, - int? exponentWidth}) - : value = [sign, exponent, mantissa].swizzle(), - _bias = computeBias(exponent.width), - _minExp = computeMinExponent(exponent.width), - _maxExp = computeMaxExponent(exponent.width) { - if (sign.width != 1) { - throw RohdHclException('FloatingPointValue: sign width must be 1'); - } - if (mantissaWidth != null && mantissa.width != mantissaWidth) { - throw RohdHclException( - 'FloatingPointValue: mantissa width must be $mantissaWidth'); - } - if (exponentWidth != null && exponent.width != exponentWidth) { - throw RohdHclException( - 'FloatingPointValue: exponent width must be $exponentWidth'); + FloatingPointValue.ofInts(int exponent, int mantissa, + {int exponentWidth = 0, int mantissaWidth = 0, bool sign = false}) + : this( + sign: LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), + exponent: LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth), + mantissa: + LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth)); + + /// Construct a [FloatingPointValue] from a [LogicValue] + factory FloatingPointValue.ofLogicValue( + int exponentWidth, int mantissaWidth, LogicValue val) => + buildOfLogicValue( + FloatingPointValue.new, exponentWidth, mantissaWidth, val); + + /// A helper function for [FloatingPointValue.ofLogicValue] and base classes + /// which performs some width checks and slicing. + @protected + static T buildOfLogicValue( + T Function( + {required LogicValue sign, + required LogicValue exponent, + required LogicValue mantissa}) + constructor, + int exponentWidth, + int mantissaWidth, + LogicValue val, + ) { + final expectedWidth = 1 + exponentWidth + mantissaWidth; + if (val.width != expectedWidth) { + throw RohdHclException('Width of $val must be $expectedWidth'); } - } - /// Construct a [FloatingPointValue] from a Logic word - factory FloatingPointValue.fromLogic( - int exponentWidth, int mantissaWidth, LogicValue val) { - final sign = (val[-1] == LogicValue.one); - final exponent = - val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt(); - final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt(); - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - return FloatingPointValue( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); + return constructor( + sign: val[-1], + exponent: val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth), + mantissa: val.slice(mantissaWidth - 1, 0)); } /// Return the [FloatingPointValue] representing the constant specified @@ -313,18 +363,25 @@ class FloatingPointValue implements Comparable { } /// Convert from double using its native binary representation - factory FloatingPointValue.fromDouble(double inDouble, + factory FloatingPointValue.ofDouble(double inDouble, {required int exponentWidth, required int mantissaWidth, FloatingPointRoundingMode roundingMode = FloatingPointRoundingMode.roundNearestEven}) { if ((exponentWidth == 8) && (mantissaWidth == 23)) { - return FloatingPoint32Value.fromDouble(inDouble); + // TODO(desmonddak): handle rounding mode for 32 bit? + return FloatingPoint32Value.ofDouble(inDouble); } else if ((exponentWidth == 11) && (mantissaWidth == 52)) { - return FloatingPoint64Value.fromDouble(inDouble); + return FloatingPoint64Value.ofDouble(inDouble); } - final fp64 = FloatingPoint64Value.fromDouble(inDouble); + if (roundingMode != FloatingPointRoundingMode.roundNearestEven && + roundingMode != FloatingPointRoundingMode.truncate) { + throw UnimplementedError( + 'Only roundNearestEven or truncate is supported for this width'); + } + + final fp64 = FloatingPoint64Value.ofDouble(inDouble); final exponent64 = fp64.exponent; var expVal = (exponent64.toInt() - fp64.bias) + @@ -392,12 +449,13 @@ class FloatingPointValue implements Comparable { /// Convert a floating point number into a [FloatingPointValue] /// representation. This form performs NO ROUNDING. - factory FloatingPointValue.fromDoubleIter(double inDouble, + @internal + factory FloatingPointValue.ofDoubleUnrounded(double inDouble, {required int exponentWidth, required int mantissaWidth}) { if ((exponentWidth == 8) && (mantissaWidth == 23)) { - return FloatingPoint32Value.fromDouble(inDouble); + return FloatingPoint32Value.ofDouble(inDouble); } else if ((exponentWidth == 11) && (mantissaWidth == 52)) { - return FloatingPoint64Value.fromDouble(inDouble); + return FloatingPoint64Value.ofDouble(inDouble); } var doubleVal = inDouble; @@ -526,9 +584,9 @@ class FloatingPointValue implements Comparable { (mantissa == other.mantissa); } + // TODO(desmonddak): figure out the difference with Infinity /// Return true if the represented floating point number is considered /// NaN or 'Not a Number' due to overflow - // TODO(desmonddak): figure out the difference with Infinity bool isNaN() { if ((exponent.width == 4) & (mantissa.width == 3)) { // FP8 E4M3 does not support infinities @@ -595,7 +653,7 @@ class FloatingPointValue implements Comparable { 'multiplicand must have the same mantissa and exponent widths'); } - return FloatingPointValue.fromDouble(op(toDouble(), other.toDouble()), + return FloatingPointValue.ofDouble(op(toDouble(), other.toDouble()), mantissaWidth: mantissa.width, exponentWidth: exponent.width); } @@ -625,297 +683,3 @@ class FloatingPointValue implements Comparable { FloatingPointValue abs() => FloatingPointValue( sign: LogicValue.zero, exponent: exponent, mantissa: mantissa); } - -/// A representation of a single precision floating point value -class FloatingPoint32Value extends FloatingPointValue { - /// The exponent width - static const int exponentWidth = 8; - - /// The mantissa width - static const int mantissaWidth = 23; - - /// Constructor for a single precision floating point value - FloatingPoint32Value( - {required super.sign, required super.exponent, required super.mantissa}) - : super.withConstraints( - mantissaWidth: mantissaWidth, exponentWidth: exponentWidth); - - /// Return the [FloatingPoint32Value] representing the constant specified - factory FloatingPoint32Value.getFloatingPointConstant( - FloatingPointConstants constantFloatingPoint) => - FloatingPointValue.getFloatingPointConstant( - constantFloatingPoint, exponentWidth, mantissaWidth) - as FloatingPoint32Value; - - /// [FloatingPoint32Value] constructor from string representation of - /// individual bitfields - factory FloatingPoint32Value.ofStrings( - String sign, String exponent, String mantissa) => - FloatingPoint32Value( - sign: LogicValue.of(sign), - exponent: LogicValue.of(exponent), - mantissa: LogicValue.of(mantissa)); - - /// [FloatingPoint32Value] constructor from a single string representing - /// space-separated bitfields - factory FloatingPoint32Value.ofString(String fp) { - final s = fp.split(' '); - assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}'); - return FloatingPoint32Value.ofStrings(s[0], s[1], s[2]); - } - - /// [FloatingPoint32Value] constructor from a set of [BigInt]s of the binary - /// representation - factory FloatingPoint32Value.ofBigInts(BigInt exponent, BigInt mantissa, - {bool sign = false}) { - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - - return FloatingPoint32Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } - - /// [FloatingPoint32Value] constructor from a set of [int]s of the binary - /// representation - factory FloatingPoint32Value.ofInts(int exponent, int mantissa, - {bool sign = false}) { - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(BigInt.from(exponent), exponentWidth), - LogicValue.ofBigInt(BigInt.from(mantissa), mantissaWidth) - ); - - return FloatingPoint32Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } - - /// Numeric conversion of a [FloatingPoint32Value] from a host double - factory FloatingPoint32Value.fromDouble(double inDouble) { - final byteData = ByteData(4) - ..setFloat32(0, inDouble) - ..buffer.asUint8List(); - final bytes = byteData.buffer.asUint8List(); - final lv = bytes.map((b) => LogicValue.ofInt(b, 32)); - - final accum = lv.reduce((accum, v) => (accum << 8) | v); - - final sign = accum[-1]; - final exponent = - accum.slice(exponentWidth + mantissaWidth - 1, mantissaWidth); - final mantissa = accum.slice(mantissaWidth - 1, 0); - - return FloatingPoint32Value( - sign: sign, exponent: exponent, mantissa: mantissa); - } - - /// Construct a [FloatingPoint32Value] from a Logic word - factory FloatingPoint32Value.fromLogic(LogicValue val) { - final sign = (val[-1] == LogicValue.one); - final exponent = - val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth); - final mantissa = val.slice(mantissaWidth - 1, 0); - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - exponent, - mantissa - ); - return FloatingPoint32Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } -} - -/// A representation of a double precision floating point value -class FloatingPoint64Value extends FloatingPointValue { - static const int _exponentWidth = 11; - static const int _mantissaWidth = 52; - - /// return the exponent width - static int get exponentWidth => _exponentWidth; - - /// return the mantissa width - static int get mantissaWidth => _mantissaWidth; - - /// Constructor for a double precision floating point value - FloatingPoint64Value( - {required super.sign, required super.mantissa, required super.exponent}) - : super.withConstraints( - exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - - /// Return the [FloatingPoint64Value] representing the constant specified - factory FloatingPoint64Value.getFloatingPointConstant( - FloatingPointConstants constantFloatingPoint) => - FloatingPointValue.getFloatingPointConstant( - constantFloatingPoint, _exponentWidth, _mantissaWidth) - as FloatingPoint64Value; - - /// [FloatingPoint64Value] constructor from string representation of - /// individual bitfields - factory FloatingPoint64Value.ofStrings( - String sign, String exponent, String mantissa) => - FloatingPoint64Value( - sign: LogicValue.of(sign), - exponent: LogicValue.of(exponent), - mantissa: LogicValue.of(mantissa)); - - /// [FloatingPoint64Value] constructor from a single string representing - /// space-separated bitfields - factory FloatingPoint64Value.ofString(String fp) { - final s = fp.split(' '); - assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}'); - return FloatingPoint64Value.ofStrings(s[0], s[1], s[2]); - } - - /// [FloatingPoint64Value] constructor from a set of [BigInt]s of the binary - /// representation - factory FloatingPoint64Value.ofBigInts(BigInt exponent, BigInt mantissa, - {bool sign = false}) => - FloatingPointValue.ofBigInts(exponent, mantissa, - sign: sign, - exponentWidth: exponentWidth, - mantissaWidth: mantissaWidth) as FloatingPoint64Value; - - /// [FloatingPoint64Value] constructor from a set of [int]s of the binary - /// representation - factory FloatingPoint64Value.ofInts(int exponent, int mantissa, - {bool sign = false}) => - FloatingPointValue.ofInts(exponent, mantissa, - sign: sign, - exponentWidth: exponentWidth, - mantissaWidth: mantissaWidth) as FloatingPoint64Value; - - /// Numeric conversion of a [FloatingPoint64Value] from a host double - factory FloatingPoint64Value.fromDouble(double inDouble) { - final byteData = ByteData(8) - ..setFloat64(0, inDouble) - ..buffer.asUint8List(); - final bytes = byteData.buffer.asUint8List(); - final lv = bytes.map((b) => LogicValue.ofInt(b, 64)); - - final accum = lv.reduce((accum, v) => (accum << 8) | v); - - final sign = accum[-1]; - final exponent = - accum.slice(_exponentWidth + _mantissaWidth - 1, _mantissaWidth); - final mantissa = accum.slice(_mantissaWidth - 1, 0); - - return FloatingPoint64Value( - sign: sign, mantissa: mantissa, exponent: exponent); - } - - /// Construct a [FloatingPoint32Value] from a Logic word - factory FloatingPoint64Value.fromLogic(LogicValue val) { - final sign = (val[-1] == LogicValue.one); - final exponent = - val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt(); - final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt(); - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - return FloatingPoint64Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } -} - -/// A representation of a 8-bit floating point value as defined in -/// [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433). -class FloatingPoint8Value extends FloatingPointValue { - /// The exponent width - late final int exponentWidth; - - /// The mantissa width - late final int mantissaWidth; - - static double get _e4m3max => 448.toDouble(); - static double get _e5m2max => 57344.toDouble(); - static double get _e4m3min => pow(2, -9).toDouble(); - static double get _e5m2min => pow(2, -16).toDouble(); - - /// Return if the exponent and mantissa widths match E4M3 or E5M2 - static bool isLegal(int exponentWidth, int mantissaWidth) { - if (((exponentWidth == 4) & (mantissaWidth == 3)) | - ((exponentWidth == 5) & (mantissaWidth == 2))) { - return true; - } else { - return false; - } - } - - /// Constructor for a double precision floating point value - FloatingPoint8Value( - {required super.sign, required super.mantissa, required super.exponent}) - : super.withConstraints() { - exponentWidth = exponent.width; - mantissaWidth = mantissa.width; - if (!isLegal(exponentWidth, mantissaWidth)) { - throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); - } - } - - /// [FloatingPoint8Value] constructor from string representation of - /// individual bitfields - factory FloatingPoint8Value.ofStrings( - String sign, String exponent, String mantissa) => - FloatingPoint8Value( - sign: LogicValue.of(sign), - exponent: LogicValue.of(exponent), - mantissa: LogicValue.of(mantissa)); - - /// [FloatingPoint8Value] constructor from a single string representing - /// space-separated bitfields - factory FloatingPoint8Value.ofString(String fp) { - final s = fp.split(' '); - assert(s.length == 3, 'Wrong FloatingPointValue string length ${s.length}'); - return FloatingPoint8Value.ofStrings(s[0], s[1], s[2]); - } - - /// Construct a [FloatingPoint8Value] from a Logic word - factory FloatingPoint8Value.fromLogic(LogicValue val, int exponentWidth) { - if (val.width != 8) { - throw RohdHclException('Width must be 8'); - } - - final mantissaWidth = 7 - exponentWidth; - if (!isLegal(exponentWidth, mantissaWidth)) { - throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); - } - - final sign = (val[-1] == LogicValue.one); - final exponent = - val.slice(exponentWidth + mantissaWidth - 1, mantissaWidth).toBigInt(); - final mantissa = val.slice(mantissaWidth - 1, 0).toBigInt(); - final (signLv, exponentLv, mantissaLv) = ( - LogicValue.ofBigInt(sign ? BigInt.one : BigInt.zero, 1), - LogicValue.ofBigInt(exponent, exponentWidth), - LogicValue.ofBigInt(mantissa, mantissaWidth) - ); - return FloatingPoint8Value( - sign: signLv, exponent: exponentLv, mantissa: mantissaLv); - } - - /// Numeric conversion of a [FloatingPoint8Value] from a host double - factory FloatingPoint8Value.fromDouble(double inDouble, - {required int exponentWidth}) { - final mantissaWidth = 7 - exponentWidth; - if (!isLegal(exponentWidth, mantissaWidth)) { - throw RohdHclException('FloatingPoint8 must follow E4M3 or E5M2'); - } - if (exponentWidth == 4) { - if ((inDouble > _e4m3max) | (inDouble < _e4m3min)) { - throw RohdHclException('Number exceeds E4M3 range'); - } - } else if (exponentWidth == 5) { - if ((inDouble > _e5m2max) | (inDouble < _e5m2min)) { - throw RohdHclException('Number exceeds E5M2 range'); - } - } - final fpv = FloatingPointValue.fromDouble(inDouble, - exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); - return FloatingPoint8Value( - sign: fpv.sign, exponent: fpv.exponent, mantissa: fpv.mantissa); - } -} diff --git a/lib/src/arithmetic/floating_point/floating_point_values/floating_point_values.dart b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_values.dart new file mode 100644 index 00000000..145a0f74 --- /dev/null +++ b/lib/src/arithmetic/floating_point/floating_point_values/floating_point_values.dart @@ -0,0 +1,10 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +export 'floating_point_16_value.dart'; +export 'floating_point_32_value.dart'; +export 'floating_point_64_value.dart'; +export 'floating_point_8_value.dart'; +export 'floating_point_bf16_value.dart'; +export 'floating_point_tf32_value.dart'; +export 'floating_point_value.dart'; diff --git a/test/arithmetic/floating_point/floating_point_adder_round_test.dart b/test/arithmetic/floating_point/floating_point_adder_round_test.dart index fe9eeb9c..20eaab59 100644 --- a/test/arithmetic/floating_point/floating_point_adder_round_test.dart +++ b/test/arithmetic/floating_point/floating_point_adder_round_test.dart @@ -25,7 +25,7 @@ void main() { fa.put(fva); fb.put(fvb); - final expectedNoRound = FloatingPointValue.fromDoubleIter( + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( fva.toDouble() + fvb.toDouble(), exponentWidth: eWidth, mantissaWidth: mWidth); @@ -65,7 +65,7 @@ void main() { fa.put(fva); fb.put(fvb); // No rounding - final expected = FloatingPointValue.fromDoubleIter( + final expected = FloatingPointValue.ofDoubleUnrounded( fva.toDouble() + fvb.toDouble(), exponentWidth: eWidth, mantissaWidth: mWidth); @@ -259,7 +259,7 @@ void main() { fa.put(fva); fb.put(fvb); - final expectedNoRound = FloatingPointValue.fromDoubleIter( + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( fva.toDouble() + fvb.toDouble(), exponentWidth: eWidth, mantissaWidth: mWidth); @@ -304,7 +304,7 @@ void main() { fa.put(fva); fb.put(fvb); - final expectedNoRound = FloatingPointValue.fromDoubleIter( + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( fva.toDouble() + fvb.toDouble(), exponentWidth: eWidth, mantissaWidth: mWidth); @@ -345,7 +345,7 @@ void main() { exponentWidth: eWidth, mantissaWidth: mWidth); fa.put(fva); fb.put(fvb); - final expectedNoRound = FloatingPointValue.fromDoubleIter( + final expectedNoRound = FloatingPointValue.ofDoubleUnrounded( fva.toDouble() + fvb.toDouble(), exponentWidth: eWidth, mantissaWidth: mWidth); diff --git a/test/arithmetic/floating_point/floating_point_adder_simple_test.dart b/test/arithmetic/floating_point/floating_point_adder_simple_test.dart index 4147274e..67ce71ef 100644 --- a/test/arithmetic/floating_point/floating_point_adder_simple_test.dart +++ b/test/arithmetic/floating_point/floating_point_adder_simple_test.dart @@ -17,10 +17,10 @@ import 'package:test/test.dart'; void main() { test('FP: basic adder test', () { final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(3.25).value); + ..put(FloatingPoint32Value.ofDouble(3.25).value); final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(1.5).value); - final out = FloatingPoint32Value.fromDouble(3.25 + 1.5); + ..put(FloatingPoint32Value.ofDouble(1.5).value); + final out = FloatingPoint32Value.ofDouble(3.25 + 1.5); final adder = FloatingPointAdderSimple(fp1, fp2); final fpSuper = adder.sum.floatingPointValue; @@ -32,10 +32,10 @@ void main() { test('FP: small numbers adder test', () { final val = pow(2.0, -23).toDouble(); final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pow(2.0, -23).toDouble()).value); + ..put(FloatingPoint32Value.ofDouble(pow(2.0, -23).toDouble()).value); final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pow(2.0, -23).toDouble()).value); - final out = FloatingPoint32Value.fromDouble(val + val); + ..put(FloatingPoint32Value.ofDouble(pow(2.0, -23).toDouble()).value); + final out = FloatingPoint32Value.ofDouble(val + val); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -50,10 +50,10 @@ void main() { for (final pair in input) { final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$1).value); + ..put(FloatingPoint32Value.ofDouble(pair.$1).value); final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$2).value); - final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); + ..put(FloatingPoint32Value.ofDouble(pair.$2).value); + final out = FloatingPoint32Value.ofDouble(pair.$1 + pair.$2); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -66,10 +66,10 @@ void main() { test('FP: basic adder test', () { final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(3.25).value); + ..put(FloatingPoint32Value.ofDouble(3.25).value); final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(1.5).value); - final out = FloatingPoint32Value.fromDouble(3.25 + 1.5); + ..put(FloatingPoint32Value.ofDouble(1.5).value); + final out = FloatingPoint32Value.ofDouble(3.25 + 1.5); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -92,7 +92,7 @@ void main() { FloatingPointConstants.smallestPositiveSubnormal) .negate() .value); - final out = FloatingPoint32Value.fromDouble(val - val); + final out = FloatingPoint32Value.ofDouble(val - val); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -105,10 +105,10 @@ void main() { test('FP: adder carry numbers test', () { final val = pow(2.5, -12).toDouble(); final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pow(2.5, -12).toDouble()).value); + ..put(FloatingPoint32Value.ofDouble(pow(2.5, -12).toDouble()).value); final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pow(2.5, -12).toDouble()).value); - final out = FloatingPoint32Value.fromDouble(val + val); + ..put(FloatingPoint32Value.ofDouble(pow(2.5, -12).toDouble()).value); + final out = FloatingPoint32Value.ofDouble(val + val); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -123,10 +123,10 @@ void main() { for (final pair in input) { final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$1).value); + ..put(FloatingPoint32Value.ofDouble(pair.$1).value); final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$2).value); - final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); + ..put(FloatingPoint32Value.ofDouble(pair.$2).value); + final out = FloatingPoint32Value.ofDouble(pair.$1 + pair.$2); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -144,10 +144,10 @@ void main() { for (final pair in input) { final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$1).value); + ..put(FloatingPoint32Value.ofDouble(pair.$1).value); final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$2).value); - final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); + ..put(FloatingPoint32Value.ofDouble(pair.$2).value); + final out = FloatingPoint32Value.ofDouble(pair.$1 + pair.$2); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -169,7 +169,7 @@ void main() { .negate() .value); - final out = FloatingPoint32Value.fromDouble( + final out = FloatingPoint32Value.ofDouble( fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble()); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -194,7 +194,7 @@ void main() { final outDouble = fp1.floatingPointValue.toDouble() + fp2.floatingPointValue.toDouble(); - final out = FloatingPointValue.fromDoubleIter(outDouble, + final out = FloatingPointValue.ofDoubleUnrounded(outDouble, exponentWidth: ew, mantissaWidth: mw); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -207,15 +207,15 @@ void main() { const mw = 5; final fp1 = FloatingPoint(exponentWidth: ew, mantissaWidth: mw) - ..put(FloatingPointValue.fromDouble(pair.$1, + ..put(FloatingPointValue.ofDouble(pair.$1, exponentWidth: ew, mantissaWidth: mw) .value); final fp2 = FloatingPoint(exponentWidth: ew, mantissaWidth: mw) - ..put(FloatingPointValue.fromDouble(pair.$2, + ..put(FloatingPointValue.ofDouble(pair.$2, exponentWidth: ew, mantissaWidth: mw) .value); - final out = FloatingPointValue.fromDouble(pair.$1 + pair.$2, + final out = FloatingPointValue.ofDouble(pair.$1 + pair.$2, exponentWidth: ew, mantissaWidth: mw); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -247,10 +247,10 @@ void main() { for (final pair in input) { final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$1).value); + ..put(FloatingPoint32Value.ofDouble(pair.$1).value); final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$2).value); - final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); + ..put(FloatingPoint32Value.ofDouble(pair.$2).value); + final out = FloatingPoint32Value.ofDouble(pair.$1 + pair.$2); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -264,10 +264,10 @@ void main() { const pair = (9.0, -3.75); { final fp1 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$1).value); + ..put(FloatingPoint32Value.ofDouble(pair.$1).value); final fp2 = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(pair.$2).value); - final out = FloatingPoint32Value.fromDouble(pair.$1 + pair.$2); + ..put(FloatingPoint32Value.ofDouble(pair.$2).value); + final out = FloatingPoint32Value.ofDouble(pair.$1 + pair.$2); final adder = FloatingPointAdderSimple(fp1, fp2); @@ -299,7 +299,7 @@ void main() { fa.put(fva); fb.put(fvb); // fromDoubleIter does not round like '+' would - final expected = FloatingPointValue.fromDoubleIter( + final expected = FloatingPointValue.ofDoubleUnrounded( fva.toDouble() + fvb.toDouble(), exponentWidth: fpv.exponent.width, mantissaWidth: fpv.mantissa.width); diff --git a/test/arithmetic/floating_point/floating_point_value_test.dart b/test/arithmetic/floating_point/floating_point_value_test.dart index 252d3326..83b23634 100644 --- a/test/arithmetic/floating_point/floating_point_value_test.dart +++ b/test/arithmetic/floating_point/floating_point_value_test.dart @@ -28,7 +28,7 @@ void main() { final mantStr = mantissa.bitString; final fp = FloatingPointValue.ofBinaryStrings(signStr, expStr, mantStr); final dbl = fp.toDouble(); - final fp2 = FloatingPointValue.fromDouble(dbl, + final fp2 = FloatingPointValue.ofDouble(dbl, exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); if (fp != fp2) { if (fp.isNaN() != fp2.isNaN()) { @@ -50,12 +50,13 @@ void main() { final mantStr = (mantissa << i).bitString; final fp = FloatingPointValue.ofBinaryStrings(signStr, expStr, mantStr); expect(fp.toString(), '$signStr $expStr $mantStr'); - final fp2 = FloatingPointValue.fromDouble(fp.toDouble(), + final fp2 = FloatingPointValue.ofDouble(fp.toDouble(), exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); expect(fp2, equals(fp)); } } }); + test('FPV: indirect subnormal conversion no rounding', () { const signStr = '0'; for (var exponentWidth = 2; exponentWidth < 12; exponentWidth++) { @@ -67,13 +68,14 @@ void main() { final fp = FloatingPointValue.ofBinaryStrings(signStr, expStr, mantStr); expect(fp.toString(), '$signStr $expStr $mantStr'); - final fp2 = FloatingPointValue.fromDoubleIter(fp.toDouble(), + final fp2 = FloatingPointValue.ofDoubleUnrounded(fp.toDouble(), exponentWidth: exponentWidth, mantissaWidth: mantissaWidth); expect(fp2, equals(fp)); } } } }); + test('FPV: round trip 32', () { final values = [ FloatingPoint32Value.getFloatingPointConstant( @@ -91,10 +93,11 @@ void main() { FloatingPointConstants.largestNormal) ]; for (final fp in values) { - final fp2 = FloatingPoint32Value.fromDouble(fp.toDouble()); + final fp2 = FloatingPoint32Value.ofDouble(fp.toDouble()); expect(fp2, equals(fp)); } }); + test('FPV: round trip 64', () { final values = [ FloatingPoint64Value.getFloatingPointConstant( @@ -112,24 +115,26 @@ void main() { FloatingPointConstants.largestNormal) ]; for (final fp in values) { - final fp2 = FloatingPoint64Value.fromDouble(fp.toDouble()); + final fp2 = FloatingPoint64Value.ofDouble(fp.toDouble()); expect(fp2, equals(fp)); } }); + test('FloatingPointValue string conversion', () { const str = '0 10000001 01000100000000000000000'; // 5.0625 - final fp = FloatingPoint32Value.ofString(str); + final fp = FloatingPoint32Value.ofSpacedBinaryString(str); expect(fp.toString(), str); expect(fp.toDouble(), 5.0625); }); + test('FPV: simple 32', () { final values = [0.15625, 12.375, -1.0, 0.25, 0.375]; for (final val in values) { - final fp = FloatingPoint32Value.fromDouble(val); + final fp = FloatingPoint32Value.ofDouble(val); assert(val == fp.toDouble(), 'mismatch'); expect(fp.toDouble(), val); - final fpSuper = FloatingPointValue.fromDouble(val, - exponentWidth: 8, mantissaWidth: 23); + final fpSuper = + FloatingPointValue.ofDouble(val, exponentWidth: 8, mantissaWidth: 23); assert(val == fpSuper.toDouble(), 'mismatch'); expect(fpSuper.toDouble(), val); } @@ -138,10 +143,10 @@ void main() { test('FPV: simple 64', () { final values = [0.15625, 12.375, -1.0, 0.25, 0.375]; for (final val in values) { - final fp = FloatingPoint64Value.fromDouble(val); + final fp = FloatingPoint64Value.ofDouble(val); assert(val == fp.toDouble(), 'mismatch'); expect(fp.toDouble(), val); - final fpSuper = FloatingPointValue.fromDouble(val, + final fpSuper = FloatingPointValue.ofDouble(val, exponentWidth: 11, mantissaWidth: 52); assert(val == fpSuper.toDouble(), 'mismatch'); expect(fpSuper.toDouble(), val); @@ -159,18 +164,17 @@ void main() { for (var c = 0; c < corners.length; c++) { final val = corners[c][1] as double; final str = corners[c][0] as String; - final fp = FloatingPointValue.fromDouble(val, - exponentWidth: 4, mantissaWidth: 3); + final fp = + FloatingPointValue.ofDouble(val, exponentWidth: 4, mantissaWidth: 3); expect(val, fp.toDouble()); expect(str, fp.toString()); - final fp8 = FloatingPointValue.fromDouble(val, - exponentWidth: 4, mantissaWidth: 3); + final fp8 = FloatingPoint8E4M3Value.ofDouble(val); expect(val, fp8.toDouble()); expect(str, fp8.toString()); } }); - test('FP8: E5M2', () { + test('FPV8: E5M2', () { final corners = [ ['0 00000 00', 0.toDouble()], ['0 11110 11', 57344.toDouble()], @@ -181,81 +185,104 @@ void main() { for (var c = 0; c < corners.length; c++) { final val = corners[c][1] as double; final str = corners[c][0] as String; - final fp = FloatingPointValue.fromDouble(val, - exponentWidth: 5, mantissaWidth: 2); + final fp = + FloatingPointValue.ofDouble(val, exponentWidth: 5, mantissaWidth: 2); expect(val, fp.toDouble()); expect(str, fp.toString()); - final fp8 = FloatingPointValue.fromDouble(val, - exponentWidth: 5, mantissaWidth: 2); + final fp8 = FloatingPoint8E5M2Value.ofDouble(val); expect(val, fp8.toDouble()); expect(str, fp8.toString()); } }); test('FPV: setting and getting from a signal', () { - final fp = FloatingPoint32() - ..put(FloatingPoint32Value.fromDouble(1.5).value); + final fp = FloatingPoint32()..put(FloatingPoint32Value.ofDouble(1.5).value); expect(fp.floatingPointValue.toDouble(), 1.5); final fp2 = FloatingPoint64() - ..put(FloatingPoint64Value.fromDouble(1.5).value); + ..put(FloatingPoint64Value.ofDouble(1.5).value); expect(fp2.floatingPointValue.toDouble(), 1.5); - final fp8e4m3 = FloatingPoint8(exponentWidth: 4) - ..put(FloatingPoint8Value.fromDouble(1.5, exponentWidth: 4).value); + final fp8e4m3 = FloatingPoint8E4M3() + ..put(FloatingPoint8E4M3Value.ofDouble(1.5).value); expect(fp8e4m3.floatingPointValue.toDouble(), 1.5); - final fp8e5m2 = FloatingPoint8(exponentWidth: 5) - ..put(FloatingPoint8Value.fromDouble(1.5, exponentWidth: 5).value); + final fp8e5m2 = FloatingPoint8E5M2() + ..put(FloatingPoint8E5M2Value.ofDouble(1.5).value); expect(fp8e5m2.floatingPointValue.toDouble(), 1.5); }); test('FPV: round nearest even Guard and Sticky', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '0000100000000000000000000000000000000000000000000001'); final fpRound = FloatingPointValue.ofBinaryStrings('0', '1000', '0001'); final val = fp64.toDouble(); final fpConvert = - FloatingPointValue.fromDouble(val, exponentWidth: 4, mantissaWidth: 4); + FloatingPointValue.ofDouble(val, exponentWidth: 4, mantissaWidth: 4); expect(fpConvert, equals(fpRound)); }); + test('FPV: round nearest even Guard and Round', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '0000110000000000000000000000000000000000000000000000'); final fpRound = FloatingPointValue.ofBinaryStrings('0', '1000', '0001'); final val = fp64.toDouble(); final fpConvert = - FloatingPointValue.fromDouble(val, exponentWidth: 4, mantissaWidth: 4); + FloatingPointValue.ofDouble(val, exponentWidth: 4, mantissaWidth: 4); expect(fpConvert, equals(fpRound)); }); + test('FPV: rounding nearest even increment', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '0001100000000000000000000000000000000000000000000000'); final fpRound = FloatingPointValue.ofBinaryStrings('0', '1000', '0010'); final val = fp64.toDouble(); final fpConvert = - FloatingPointValue.fromDouble(val, exponentWidth: 4, mantissaWidth: 4); + FloatingPointValue.ofDouble(val, exponentWidth: 4, mantissaWidth: 4); expect(fpConvert, equals(fpRound)); }); + test('FPV: rounding nearest even increment carry into exponent', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '1111100000000000000000000000000000000000000000000000'); final fpRound = FloatingPointValue.ofBinaryStrings('0', '1001', '0000'); final val = fp64.toDouble(); final fpConvert = - FloatingPointValue.fromDouble(val, exponentWidth: 4, mantissaWidth: 4); + FloatingPointValue.ofDouble(val, exponentWidth: 4, mantissaWidth: 4); expect(fpConvert, equals(fpRound)); }); + test('FPV: rounding nearest even truncate', () { - final fp64 = FloatingPoint64Value.ofStrings('0', '10000000000', + final fp64 = FloatingPoint64Value.ofBinaryStrings('0', '10000000000', '0010100000000000000000000000000000000000000000000000'); final fpTrunc = FloatingPointValue.ofBinaryStrings('0', '1000', '0010'); final val = fp64.toDouble(); final fpConvert = - FloatingPointValue.fromDouble(val, exponentWidth: 4, mantissaWidth: 4); + FloatingPointValue.ofDouble(val, exponentWidth: 4, mantissaWidth: 4); expect(fpConvert, equals(fpTrunc)); }); + + test('mapped subtype constructor', () { + final fp = FloatingPointValue.withMappedSubtype( + sign: LogicValue.zero, + exponent: LogicValue.ofString('10101'), + mantissa: LogicValue.ofString('10'), + ); + + expect(fp, isA()); + }); + + test('mapped subtype conversion', () { + final fp = FloatingPointValue( + sign: LogicValue.zero, + exponent: LogicValue.ofString('10101'), + mantissa: LogicValue.ofString('10'), + ); + + expect(fp, isNot(isA())); + expect(fp.toMappedSubtype(), isA()); + }); }