diff --git a/cv32e40p_fpu_manifest.flist b/cv32e40p_fpu_manifest.flist index 92ddce332..a4f4d6bf3 100644 --- a/cv32e40p_fpu_manifest.flist +++ b/cv32e40p_fpu_manifest.flist @@ -83,6 +83,13 @@ ${DESIGN_RTL_DIR}/vendor/pulp_platform_fpnew/src/fpnew_opgroup_fmt_slice.sv ${DESIGN_RTL_DIR}/vendor/pulp_platform_fpnew/src/fpnew_opgroup_multifmt_slice.sv ${DESIGN_RTL_DIR}/vendor/pulp_platform_fpnew/src/fpnew_opgroup_block.sv ${DESIGN_RTL_DIR}/vendor/pulp_platform_fpnew/src/fpnew_top.sv +${DESIGN_RTL_DIR}/vendor/bf16_acc/bf16fp32conv.sv +${DESIGN_RTL_DIR}/vendor/bf16_acc/fp32bf16conv.sv +${DESIGN_RTL_DIR}/vendor/bf16_acc/bf16_conv_top.sv +${DESIGN_RTL_DIR}/vendor/bf16_acc/bf16_maxmin.sv +${DESIGN_RTL_DIR}/vendor/bf16_acc/lzc.sv +${DESIGN_RTL_DIR}/vendor/bf16_acc/bf16_fma.sv +${DESIGN_RTL_DIR}/vendor/bf16_acc/bf16_accelerator_top.sv ${DESIGN_RTL_DIR}/cv32e40p_fp_wrapper.sv ${DESIGN_RTL_DIR}/cv32e40p_top.sv diff --git a/rtl/cv32e40p_decoder.sv b/rtl/cv32e40p_decoder.sv index d03027bae..449bdbf17 100644 --- a/rtl/cv32e40p_decoder.sv +++ b/rtl/cv32e40p_decoder.sv @@ -33,7 +33,7 @@ module cv32e40p_decoder parameter COREV_PULP = 1, // PULP ISA Extension (including PULP specific CSRs and hardware loop, excluding cv.elw) parameter COREV_CLUSTER = 0, // PULP ISA Extension cv.elw (need COREV_PULP = 1) parameter A_EXTENSION = 0, - parameter FPU = 0, + parameter FPU = 1, parameter FPU_ADDMUL_LAT = 0, parameter FPU_OTHERS_LAT = 0, parameter ZFINX = 0, @@ -1135,13 +1135,20 @@ module cv32e40p_decoder fpu_op = cv32e40p_fpu_pkg::F2F; fp_op_group = CONV; // bits [22:20] used, other bits must be 0 - if (instr_rdata_i[24:23]) illegal_insn_o = 1'b1; + // if (instr_rdata_i[24:23]) illegal_insn_o = 1'b1; // check source format unique case (instr_rdata_i[22:20]) // Only process instruction if corresponding extension is active (static) 3'b000: begin - if (!(C_RVF && (C_XF16 || C_XF16ALT || C_XF8))) illegal_insn_o = 1'b1; - fpu_src_fmt_o = cv32e40p_fpu_pkg::FP32; + if(~C_BF16) begin + if (!(C_RVF && (C_XF16 || C_XF16ALT || C_XF8))) illegal_insn_o = 1'b1; + fpu_src_fmt_o = cv32e40p_fpu_pkg::FP32; + end + else begin + fpu_src_fmt_o = cv32e40p_fpu_pkg::BF16; + // set dst format to FP32 + fpu_dst_fmt_o = cv32e40p_fpu_pkg::FP32; + end end 3'b001: begin if (~C_RVD) illegal_insn_o = 1'b1; @@ -1152,8 +1159,15 @@ module cv32e40p_decoder fpu_src_fmt_o = cv32e40p_fpu_pkg::FP16; end 3'b110: begin - if (~C_XF16ALT) illegal_insn_o = 1'b1; - fpu_src_fmt_o = cv32e40p_fpu_pkg::FP16ALT; + if(~C_BF16) begin + if (~C_XF16ALT) illegal_insn_o = 1'b1; + fpu_src_fmt_o = cv32e40p_fpu_pkg::FP16ALT; + end + else begin + fpu_src_fmt_o = cv32e40p_fpu_pkg::FP32; + // set dst format to FP32 + fpu_dst_fmt_o = cv32e40p_fpu_pkg::BF16; + end end 3'b011: begin if (~C_XF8) illegal_insn_o = 1'b1; @@ -1185,6 +1199,31 @@ module cv32e40p_decoder // set dst format to FP32 fpu_dst_fmt_o = cv32e40p_fpu_pkg::FP32; end + // bf16_min/max + 5'b01101, 5'b01111: begin + if(C_BF16)begin + unique case (instr_rdata_i[28]) + 1'b0: begin + fpu_op = cv32e40p_fpu_pkg::MINMAX; + fp_rnd_mode_o = 3'b000; // min + fp_op_group = NONCOMP; + check_fprm = 1'b0; // instruction encoded in rm + fpu_dst_fmt_o = cv32e40p_fpu_pkg::BF16; + fpu_src_fmt_o = cv32e40p_fpu_pkg::BF16; + end + 1'b1: begin + fpu_op = cv32e40p_fpu_pkg::MINMAX; + fp_rnd_mode_o = 3'b001; // max + fp_op_group = NONCOMP; + check_fprm = 1'b0; // instruction encoded in rm + fpu_dst_fmt_o = cv32e40p_fpu_pkg::BF16; + fpu_src_fmt_o = cv32e40p_fpu_pkg::BF16; + end + endcase + end + else + illegal_insn_o = 1; + end // feq/flt/fle.fmt - FP Comparisons 5'b10100: begin fpu_op = cv32e40p_fpu_pkg::CMP; diff --git a/rtl/include/cv32e40p_fpu_pkg.sv b/rtl/include/cv32e40p_fpu_pkg.sv index fe2415325..8cbecfb23 100644 --- a/rtl/include/cv32e40p_fpu_pkg.sv +++ b/rtl/include/cv32e40p_fpu_pkg.sv @@ -50,7 +50,7 @@ package cv32e40p_fpu_pkg; // *NOTE:* Add new formats only at the end of the enumeration for backwards compatibilty! - localparam int unsigned NUM_FP_FORMATS = 5; // change me to add formats + localparam int unsigned NUM_FP_FORMATS = 6; // change me to add formats localparam int unsigned FP_FORMAT_BITS = $clog2(NUM_FP_FORMATS); // FP formats @@ -59,7 +59,8 @@ package cv32e40p_fpu_pkg; FP64 = 'd1, FP16 = 'd2, FP8 = 'd3, - FP16ALT = 'd4 + FP16ALT = 'd4, + BF16 = 'd5 // add new formats here } fp_format_e; diff --git a/rtl/include/cv32e40p_pkg.sv b/rtl/include/cv32e40p_pkg.sv index 319e790b6..5d8d3ef32 100644 --- a/rtl/include/cv32e40p_pkg.sv +++ b/rtl/include/cv32e40p_pkg.sv @@ -45,7 +45,7 @@ package cv32e40p_pkg; parameter OPCODE_JAL = 7'h6f; parameter OPCODE_AUIPC = 7'h17; parameter OPCODE_LUI = 7'h37; - parameter OPCODE_OP_FP = 7'h53; + parameter OPCODE_OP_FP = 7'h53; // to be used for bf16_min, bf16_max, bf16_fp32_conv, and fp32_bf16_conv instr parameter OPCODE_OP_FMADD = 7'h43; parameter OPCODE_OP_FNMADD = 7'h4f; parameter OPCODE_OP_FMSUB = 7'h47; @@ -771,6 +771,7 @@ package cv32e40p_pkg; parameter bit C_XF16 = 1'b0; // Is half-precision float extension (Xf16) enabled parameter bit C_XF16ALT = 1'b0; // Is alternative half-precision float extension (Xf16alt) enabled parameter bit C_XF8 = 1'b0; // Is quarter-precision float extension (Xf8) enabled + parameter bit C_BF16 = 1'b1; // Is half-precision float extension (bf16) enabled parameter bit C_XFVEC = 1'b0; // Is vectorial float extension (Xfvec) enabled // Latency of FP operations: 0 = no pipe registers, 1 = 1 pipe register etc. @@ -790,6 +791,7 @@ package cv32e40p_pkg; C_RVF ? 32 : // F ext. C_XF16 ? 16 : // Xf16 ext. C_XF16ALT ? 16 : // Xf16alt ext. + C_BF16 ? 16 : // bf16 ext. C_XF8 ? 8 : // Xf8 ext. 0; // Unused in case of no FP diff --git a/rtl/vendor/bf16_acc/bf16_accelerator_top.sv b/rtl/vendor/bf16_acc/bf16_accelerator_top.sv new file mode 100644 index 000000000..ebeeb18bf --- /dev/null +++ b/rtl/vendor/bf16_acc/bf16_accelerator_top.sv @@ -0,0 +1,132 @@ +`timescale 1ns / 1ps +////////////////////////////////////////////////////////////////////////////////// +// Company: +// Engineer: +// +// Create Date: 12/06/2023 03:56:06 PM +// Design Name: +// Module Name: acc_top +// Project Name: +// Target Devices: +// Tool Versions: +// Description: +// +// Dependencies: +// +// Revision: +// Revision 0.01 - File Created +// Additional Comments: +// +////////////////////////////////////////////////////////////////////////////////// + + +module bf16_accelerator_top( + input logic clk, + input logic reset, + input logic enable, // Enable signal for the accelerator + input logic [31:0] operand_a, // First operand + input logic [15:0] operand_b, // Second operand + input logic [31:0] operand_c, // Third operand for FMA operations + input logic [3:0] operation, // Operation type + output logic [31:0] result, // Result of the operation + output logic [3:0] fpcsr, // Floating-point control and status register + output logic valid // Output valid signal +); + +// Internal enable signals for submodules +logic conv_enable, maxmin_enable, addmul_enable; + +// Internal result and FPCSR signals from submodules +logic [15:0] maxmin_result; +logic [31:0] conv_result, addmul_result; +logic [3:0] conv_fpcsr, maxmin_fpcsr, addmul_fpcsr; + + //Instantiate the conversion module +bf16_conversion bf16_fp32_conversion_inst ( + .clk(clk), + .reset(reset), + .enable(conv_enable), + .operation(operation), // Pass the operation code + .operand(operand_a), // Pass the operand + .result(conv_result), // Receive the result + .fpcsr(conv_fpcsr) // Receive the FPCSR status + ); + +// Instantiate the max/min module + bf16_minmax maxmin_module ( + .clk(clk), + .reset(reset), + .enable(maxmin_enable), + .operand_a(operand_a[15:0]), + .operand_b(operand_b[15:0]), + .operation(operation), + .result(maxmin_result), + .fpcsr(maxmin_fpcsr) +); + +// Instantiate the add/mul module +bf16_fma_op addmul_module ( + .clk(clk), + .reset(reset), + .en(addmul_enable), + .op_a(operand_a), + .op_b(operand_b), + .op_c(operand_c), + .oper(operation), + .result(addmul_result), + .fpcsr(addmul_fpcsr) +); + + + +// // Conversion Operations +// 4'b0000: conv_enable = 1; // BF16 to FP32 Conversion +// 4'b0001: conv_enable = 1; // FP32 to BF16 Conversion + +// // Max/Min Operations +// 4'b0010: maxmin_enable = 1; // Max +// 4'b0011: maxmin_enable = 1; // Min + +// // Add/Mul Operations +// 4'b0100: addmul_enable = 1; // Add +// 4'b0101: addmul_enable = 1; // Mul +// 4'b0110: addmul_enable = 1; // Sub +// 4'b0111: addmul_enable = 1; // Fused Multiply-Add (FMADD) +// 4'b1000: addmul_enable = 1; // Fused Multiply-Subtract (FMSUB) +// 4'b1001: addmul_enable = 1; // Fused Negative Multiply-Add (FMNADD) +// 4'b1010: addmul_enable = 1; // Fused Negative Multiply-Subtract (FMNSUB) + + + assign conv_enable = !operation[3] & !operation[2] & !operation[1]; + assign maxmin_enable = !operation[3] & !operation[2] & operation[1]; + assign addmul_enable = operation[3] | operation[2]; + assign result = ({32{conv_enable}} & conv_result) | ({32{maxmin_enable}} & maxmin_result) | ({32{addmul_enable}} & addmul_result); + assign fpcsr = ({32{conv_enable}} & conv_fpcsr) | ({32{maxmin_enable}} & maxmin_fpcsr) | ({32{addmul_enable}} & addmul_fpcsr); + assign valid = enable && (conv_enable || maxmin_enable || addmul_enable); +// Result and FPCSR aggregation +//always @(posedge clk) begin +// valid = enable && (conv_enable || maxmin_enable || addmul_enable); +// conv_enable = !operation[3] & !operation[2] & !operation[1]; +// maxmin_enable = !operation[3] & !operation[2] & operation[1]; +// addmul_enable = operation[3] | operation[2]; +// if (!reset) begin +// if (conv_enable) begin +// result = conv_result; +// fpcsr = conv_fpcsr; +// end else if (maxmin_enable) begin +// result = maxmin_result; +// fpcsr = maxmin_fpcsr; +// end else if (addmul_enable) begin +// result = addmul_result; +// fpcsr = addmul_fpcsr; +// end +// end else begin +// result = 32'h0; +// fpcsr = 4'h0; +// end +//end + + + +endmodule + diff --git a/rtl/vendor/bf16_acc/bf16_conv_top.sv b/rtl/vendor/bf16_acc/bf16_conv_top.sv new file mode 100644 index 000000000..09847e3d0 --- /dev/null +++ b/rtl/vendor/bf16_acc/bf16_conv_top.sv @@ -0,0 +1,108 @@ +`timescale 1ns / 1ps +////////////////////////////////////////////////////////////////////////////////// +// Company: +// Engineer: +// +// Create Date: 12/06/2023 06:19:26 PM +// Design Name: +// Module Name: conversion_bf16 +// Project Name: +// Target Devices: +// Tool Versions: +// Description: +// +// Dependencies: +// +// Revision: +// Revision 0.01 - File Created +// Additional Comments: +// +////////////////////////////////////////////////////////////////////////////////// + + +module bf16_conversion( + input logic clk, + input logic reset, + input logic enable, // Enable signal for conversion operations + input logic [3:0] operation, // 4-bit operation code + input logic [31:0] operand, // Universal operand for both BF16 and FP32 + output logic [31:0] result, // Result, either BF16 or FP32 + output logic [3:0] fpcsr // Floating Point Control and Status Register +); + + // Define operation codes for conversions + localparam BF16_TO_FP32_OP = 4'b0000; + localparam FP32_TO_BF16_OP = 4'b0001; + + wire bf16tofp32_en; + wire fp32tobf16_en; + + + + // Internal signals for the submodules + wire [15:0] bf16_result; + wire [31:0] fp32_result; + wire [3:0] bf16_fpcsr; + wire [3:0] fp32_fpcsr; + + assign bf16tofp32_en = enable & (operation == BF16_TO_FP32_OP); + assign fp32tobf16_en = enable & (operation == FP32_TO_BF16_OP); + + + // Instantiate bf16_to_fp32 module + bf16_to_fp32 bf16_to_fp32_inst ( + .clk(clk), + .reset(reset), + .instruction_enable(bf16tofp32_en), + .operand_a(operand[15:0]), + .result(fp32_result), + .fpcsr(bf16_fpcsr) + ); + + // Instantiate fp32_to_bf16 module + fp32_to_bf16 fp32_to_bf16_inst ( + .clk(clk), + .reset(reset), + .instruction_enable(fp32tobf16_en), + .operand_a(operand), + .result(bf16_result), + .fpcsr(fp32_fpcsr) + ); + + // Logic to select the appropriate output based on operation + + assign result = (operation == BF16_TO_FP32_OP) ? fp32_result : + (operation == FP32_TO_BF16_OP) ? {16'h0000, bf16_result} : + 32'h00000000; + + assign fpcsr = (operation == BF16_TO_FP32_OP) ? bf16_fpcsr : + (operation == FP32_TO_BF16_OP) ? fp32_fpcsr : + 4'b0000; + + + +// always @(posedge clk ) begin +// if(enable)begin + + +// case (operation) +// BF16_TO_FP32_OP: begin +// result = fp32_result; +// fpcsr = bf16_fpcsr; +// end +// FP32_TO_BF16_OP: begin +// result = {16'h0000, bf16_result}; // Zero-extend BF16 result to 32 bits +// fpcsr = fp32_fpcsr; +// end +// default: begin +// result = 32'h00000000; +// fpcsr = 4'b0000; +// end +// endcase +// end +// end + +endmodule + + + diff --git a/rtl/vendor/bf16_acc/bf16_fma.sv b/rtl/vendor/bf16_acc/bf16_fma.sv new file mode 100644 index 000000000..8d8b47671 --- /dev/null +++ b/rtl/vendor/bf16_acc/bf16_fma.sv @@ -0,0 +1,504 @@ +`timescale 1ns / 1ps +////////////////////////////////////////////////////////////////////////////////// +// Company: +// Engineer: +// +// Create Date: 01/22/2024 03:56:17 PM +// Design Name: +// Module Name: bf16_fma +// Project Name: +// Target Devices: +// Tool Versions: +// Description: +// +// Dependencies: +// +// Revision: +// Revision 0.01 - File Created +// Additional Comments: +// +////////////////////////////////////////////////////////////////////////////////// + + +`timescale 1ns / 1ps + + + +module bf16_fma_op( + input logic clk, + input logic reset, + input logic en, + input logic [15:0] op_a, + input logic [15:0] op_b, + input logic [15:0] op_c, + input logic [3:0] oper, // Operation code + output logic [15:0] result, + output logic [3:0] fpcsr +); + + // bfloat16 specifications + localparam EXP_BITS = 8; + localparam MAN_BITS = 7; + localparam TOTAL_MAN_BITS = 2 * MAN_BITS + 16 + 2; // Total bits for extended mantissa + localparam BIAS = 127; + + // Decompose operands + logic [EXP_BITS-1:0] exp_a, exp_b, exp_c; + logic [MAN_BITS:0] man_a, man_b; // Including implicit bit + logic [MAN_BITS-1:0] man_c; +// logic operand_a.sign, operand_b.sign, operand_c.sign; + logic effective_subtraction; + + + + // Product and Sum variables + logic [TOTAL_MAN_BITS-1:0] aligned_product_mantissa; // Extended product mantissa + logic [TOTAL_MAN_BITS-1:0] exp_aligned_product_mantissa; + logic [2*MAN_BITS+1:0] product_mantissa; + logic [TOTAL_MAN_BITS-1:0] aligned_addend_mantissa; + logic [TOTAL_MAN_BITS:0] sum_mantissa; // In case of overflow + logic [TOTAL_MAN_BITS+1:0] aligned_sum_mantissa; // For ground bit + logic [MAN_BITS:0] result_mantissa; // MSB for overflow + logic [EXP_BITS:0] product_exp, aligned_addend_exp, sum_exp; + logic product_sign, sum_sign; + logic [4:0] count; + logic [5:0] lzc_cnt; + logic [5:0] lzc_cnt_one; + logic [TOTAL_MAN_BITS+1:0] aligned_lzc_mantissa; + logic [EXP_BITS:0] sum_exp_lzc; + logic [TOTAL_MAN_BITS:0] current_mantissa; + logic [EXP_BITS:0] final_exp; + logic [EXP_BITS:0] final_final_exp; + logic final_sign; + logic [15:0] final_result_regular; + + //pipelined_variables + logic enable; + logic [3:0] operation; + logic [3:0] oper_one; + logic [TOTAL_MAN_BITS-1:0] aligned_product_mantissa_one; + logic [EXP_BITS:0] product_exp_one; + logic product_sign_one; + logic result_is_special_one; + logic [15:0] special_result_one; + logic invalid_one; + logic is_zero_c_one; + logic is_sub_c_one; + logic [TOTAL_MAN_BITS+1:0] aligned_sum_mantissa_one; + logic [EXP_BITS:0] sum_exp_one; + logic is_zero_c_two; + logic is_sub_c_two; + logic result_is_special_two; + logic [15:0] special_result_two; + logic [15:0] result_regular_one; + + + + + // Rounding variables + logic guard_bit, round_bit, sticky_bit; + reg [4:0] i; + + // Special case handling + logic is_nan_a, is_nan_b, is_nan_c; + logic is_sub_a, is_sub_b, is_sub_c; + logic is_inf_a, is_inf_b, is_inf_c; + logic is_zero_a, is_zero_b, is_zero_c; + + logic result_is_special; + logic [15:0] result_special; + logic [15:0] result_regular; + logic [15:0] result_o; + logic invalid; + logic overflow; + logic underflow; + logic inexact; + logic clk_one; + + + + //Type definition + typedef struct packed { + logic sign; + logic [EXP_BITS-1:0] exponent; + logic [MAN_BITS-1:0] mantissa; + } fp_t; + + + assign clk_one = clk & enable; + + fp_t operand_a, operand_b, operand_c; + fp_t oper_a, oper_b, oper_c; + //pipelined fp_t variables + fp_t operand_c_one; + + + always @(posedge clk_one or posedge reset) begin + + if(reset) begin + oper_a <= 0; + oper_b <= 0; + oper_c <= 0; + oper_one <= 0; + end + else begin + oper_a <= op_a; + oper_b <= op_b; + oper_c <= op_c; + oper_one <= oper; + enable <= en; +// +// operand_a = op_a; +// operand_b = op_b; +// operand_c = op_c; + + // exp_a = operand_a.exponent; + // exp_b = operand_b.exponent; + // exp_c = operand_c.exponent; + // man_a = {1'b1, operand_a.mantissa}; // Include implicit bit + // man_b = {1'b1, operand_b.mantissa}; + // man_c = operand_c.mantissa; + // i = 0; + // fpcsr = 0; + + // Adjust operands based on the operation + + + end + end + + always_comb begin + + operand_a = oper_a; + operand_b = oper_b; + operand_c = oper_c; + operation = oper_one; + + case (operation) + 4'b0111: ; // FMADD: Do nothing + 4'b1000: operand_c.sign = ~operand_c.sign; // FMSUB: Invert sign of operand C + 4'b1010: operand_a.sign = ~operand_a.sign; // FNMSUB: Invert sign of operand A + 4'b1001: begin // FNMADD: Invert sign of operands A and C + operand_a.sign = ~operand_a.sign; + operand_c.sign = ~operand_c.sign; + end + 4'b0100: begin // ADD: Set operand A to +1.0 + // exp_a = BIAS; + // man_a = {1'b1, {MAN_BITS{1'b0}}}; + // operand_a.sign = 1'b0; + operand_a = '{sign: 1'b0, exponent: BIAS, mantissa: '0}; + end + 4'b0110: begin // SUB: Set operand A to +1.0, invert sign of operand C + // exp_a = BIAS; + // man_a = {1'b1, {MAN_BITS{1'b0}}}; + // operand_a.sign = 1'b0; + operand_c.sign = ~operand_c.sign; + operand_a = '{sign: 1'b0, exponent: BIAS, mantissa: '0}; + end + 4'b0101: begin // MUL: Set operand C to +0.0 + // exp_c = 0; + // man_c = {MAN_BITS{1'b0}}; + // operand_c.sign = 1'b0; + operand_c = '{sign: 1'b0, exponent: '0, mantissa: '0}; + end + default: ; // Other operations: no change + endcase + + man_a = {1'b1, operand_a.mantissa}; // Include implicit bit + man_b = {1'b1, operand_b.mantissa}; + man_c = operand_c.mantissa; + + is_nan_a = operand_a.exponent == 8'hFF && man_a[6:0] != 0; + is_nan_b = operand_b.exponent == 8'hFF && man_b[6:0] != 0; + is_nan_c = operand_c.exponent == 8'hFF && man_c[6:0] != 0; + + is_inf_a = operand_a.exponent == 8'hFF && man_a[6:0] == 0; + is_inf_b = operand_b.exponent == 8'hFF && man_b[6:0] == 0; + is_inf_c = operand_c.exponent == 8'hFF && man_c[6:0] == 0; + + is_zero_a = operand_a.exponent == 0 && man_a[6:0] == 0; + is_zero_b = operand_b.exponent == 0 && man_b[6:0] == 0; + is_zero_c = operand_c.exponent == 0 && man_c[6:0] == 0; + + is_sub_a = operand_a.exponent == 0 && man_a[6:0] != 0; + is_sub_b = operand_b.exponent == 0 && man_b[6:0] != 0; + is_sub_c = operand_c.exponent == 0 && man_c[6:0] != 0; + + + end + + + // Detect special cases + always_comb begin + + //result_is_special = 1'b0; + invalid = 1'b0; + + // Determine if the result should be special (NaN or Infinity or irect jump to operand_c because of 0 operand value) + result_is_special = is_nan_a || is_nan_b || is_nan_c || + (is_inf_a && is_zero_b) || (is_inf_b && is_zero_a) || + ((is_inf_a || is_inf_b) && is_inf_c && effective_subtraction) || (is_inf_a || is_inf_b || is_inf_c) || (is_sub_a || is_sub_b || is_zero_a || is_zero_b) ; + if (is_nan_a || is_nan_b || is_nan_c || ((is_inf_a && is_zero_b) || (is_inf_b && is_zero_a) || ((is_inf_a || is_inf_b) && is_inf_c && effective_subtraction))) begin + result_special = 16'h7FC0; // Canonical qNaN for bfloat16 + invalid = 1; // NV flag for NaN or invalid operation + end else if (is_inf_a || is_inf_b || is_inf_c) begin + if (is_inf_a || is_inf_b) begin + // Result is infinity with the sign of the product + result_special = {operand_a.sign ^ operand_b.sign, 8'hFF, 7'b0}; // Infinity with sign of the product + end else if (is_inf_c) begin + // Result is infinity with the sign of the addend (= operand_c) + result_special = {operand_c.sign, 8'hFF, 7'b0}; // Infinity with sign of the addend + end + end + else if (is_sub_a || is_sub_b || is_zero_a || is_zero_b) + begin + result_special = operand_c; + end +// else begin +// result_is_special =1'b0; +// end + end + + + always_comb begin + // FMA computation logic + + if (operand_a == 16'h3f80) //operand_a is 1 + begin + aligned_product_mantissa = {man_b,{(TOTAL_MAN_BITS - MAN_BITS - 1){1'b0}}}; + product_exp = operand_b.exponent; + end + else if (operand_b == 16'h3f80) //operand_b is 1 + begin + aligned_product_mantissa = {man_a,{(TOTAL_MAN_BITS - MAN_BITS - 1){1'b0}}}; + product_exp = operand_a.exponent; + end + + else begin + // Calculate product of a and b with extended mantissa + product_mantissa = man_a * man_b; + product_exp = operand_a.exponent + operand_b.exponent - BIAS; + if (product_mantissa[2*MAN_BITS+1] == 1) begin + product_exp = product_exp + 1; + end + else begin + product_mantissa = product_mantissa << 1; + end + aligned_product_mantissa = {product_mantissa, 16'b0}; + end + //product_exp = exp_a + exp_b - BIAS; + product_sign = operand_a.sign ^ operand_b.sign; + + end + + always @(posedge clk_one) begin + + aligned_product_mantissa_one <= aligned_product_mantissa; + product_exp_one <= product_exp; + product_sign_one <= product_sign; + operand_c_one <= operand_c; + special_result_one <= result_special; + result_is_special_one <= result_is_special; + is_zero_c_one <= is_zero_c; + is_sub_c_one <= is_sub_c; + invalid_one <= invalid; + + + end + + + + always_comb begin + + if (is_zero_c_one || is_sub_c_one) begin + result_regular = {product_sign_one,product_exp_one[7:0],aligned_product_mantissa_one[TOTAL_MAN_BITS-1:TOTAL_MAN_BITS-7]}; + end + + + // Align addend (operand_c) with product + aligned_addend_exp = operand_c_one.exponent; + aligned_addend_mantissa = {1'b1, operand_c_one.mantissa, {(TOTAL_MAN_BITS - MAN_BITS - 1){1'b0}}}; // Extend addend mantissa + exp_aligned_product_mantissa = aligned_product_mantissa_one; + // Align addend exponent with product exponent + if (aligned_addend_exp < product_exp_one) begin + aligned_addend_mantissa = aligned_addend_mantissa >> (product_exp_one - aligned_addend_exp); + aligned_addend_exp = product_exp_one; + end else if (aligned_addend_exp > product_exp_one) begin + exp_aligned_product_mantissa = aligned_product_mantissa_one >> (aligned_addend_exp - product_exp_one); + //product_exp_one = aligned_addend_exp; + end + + // Determine if operation is effectively a subtraction + effective_subtraction = (product_sign_one != operand_c_one.sign); + + // Add/Subtract product and addend + if (effective_subtraction) begin + if (exp_aligned_product_mantissa >= aligned_addend_mantissa) begin + sum_mantissa = exp_aligned_product_mantissa - aligned_addend_mantissa; + sum_sign = product_sign_one; + end else begin + sum_mantissa = aligned_addend_mantissa - exp_aligned_product_mantissa; + sum_sign = operand_c_one.sign; + end + end else begin + sum_mantissa = exp_aligned_product_mantissa + aligned_addend_mantissa; + sum_sign = product_sign_one; + end + sum_exp = aligned_addend_exp; + aligned_sum_mantissa = {sum_mantissa, 1'b0}; + + // Normalize result + if (aligned_sum_mantissa[TOTAL_MAN_BITS+1] == 1'b1) begin + sum_exp = sum_exp + 1; + aligned_sum_mantissa = aligned_sum_mantissa >> 1; + end + else if (aligned_sum_mantissa == 0) begin + sum_exp = 0; + end + + end + + + always @(posedge clk_one) + begin + aligned_sum_mantissa_one <= aligned_sum_mantissa; + final_sign <= sum_sign; + //current_mantissa_one <= current_mantissa; + sum_exp_one <= sum_exp; + special_result_two <= special_result_one; + result_is_special_two <= result_is_special_one; + is_zero_c_two <= is_zero_c_one; + is_sub_c_two <= is_sub_c_one; + result_regular_one <= result_regular; + + + end + // assign aligned_lzc_mantissa = (!aligned_sum_mantissa_one[TOTAL_MAN_BITS]) ? (aligned_sum_mantissa_one << (lzc_cnt + 1)) : aligned_sum_mantissa_one; + //assign final_final_exp = (!aligned_sum_mantissa_one[TOTAL_MAN_BITS]) ? (sum_exp_one - (lzc_cnt + 1)) : sum_exp_one; + + clz #(TOTAL_MAN_BITS) u_clz( + .ref_vector(aligned_sum_mantissa_one[TOTAL_MAN_BITS-1:0]),.dout(lzc_cnt) + ); + + + + always @(lzc_cnt) begin + final_exp = sum_exp_one; + //current_mantissa = aligned_sum_mantissa[TOTAL_MAN_BITS:0]; + lzc_cnt_one = lzc_cnt + 1; + aligned_lzc_mantissa=aligned_sum_mantissa_one; + if (!aligned_sum_mantissa_one[TOTAL_MAN_BITS]) begin + aligned_lzc_mantissa = aligned_sum_mantissa_one << (lzc_cnt_one); + final_exp = sum_exp_one - (lzc_cnt_one); + end + + + + + //always_comb begin + + + //else begin +// while (i < TOTAL_MAN_BITS && sum_exp > 0 && !aligned_sum_mantissa[TOTAL_MAN_BITS]) begin +// aligned_sum_mantissa = aligned_sum_mantissa << 1; +// sum_exp = sum_exp - 1; +// i = i + 1; +// end + +// end + +// count = 0; + +// for (i = 0; i < TOTAL_MAN_BITS && sum_exp > 0 && !aligned_sum_mantissa[TOTAL_MAN_BITS]; i = i + 1) begin +// aligned_sum_mantissa = aligned_sum_mantissa << 1; +// sum_exp = sum_exp - 1; +// count = count + 1; +// end + + //final_exp = sum_exp_one; + final_result_regular = result_regular_one; + //end + //end + + + +// while (sum_mantissa[TOTAL_MAN_BITS - 1] == 0 && sum_exp > 0 ) begin +// sum_mantissa = sum_mantissa << 1; +// sum_exp = sum_exp - 1; +// end + + //always_comb begin + guard_bit = aligned_lzc_mantissa[TOTAL_MAN_BITS - 8]; + round_bit = aligned_lzc_mantissa[TOTAL_MAN_BITS - 9]; + sticky_bit = |aligned_lzc_mantissa[TOTAL_MAN_BITS - 10:0]; // OR all bits below round bit + result_mantissa = aligned_lzc_mantissa[TOTAL_MAN_BITS - 1:TOTAL_MAN_BITS - 7]; + if (guard_bit && (round_bit || sticky_bit)) begin + if (result_mantissa == 7'h7F) begin + final_exp = final_exp + 1; // Increment exponent due to mantissa overflow + result_mantissa = 0; // Reset mantissa to 0 because of overflow + end else begin + result_mantissa = result_mantissa + 1; // Increment mantissa + end + end +// // Round (Round to Nearest, ties to Even) +// round_bit = aligned_sum_mantissa[0]; +// sticky_bit = |aligned_sum_mantissa[TOTAL_MAN_BITS - 1:0]; // OR all bits below round bit +// result_mantissa = aligned_sum_mantissa[TOTAL_MAN_BITS - 1:TOTAL_MAN_BITS - 7]; +// if (round_bit && (aligned_sum_mantissa[MAN_BITS] || sticky_bit)) begin +// result_mantissa = result_mantissa + 1; +// if (result_mantissa[MAN_BITS]) begin +// sum_exp = sum_exp + 1; +// end +// end + + // Handle overflow and underflow + if (final_exp >= 2**EXP_BITS) begin + //fpcsr[2] + overflow = 1'b1; + final_result_regular = {final_sign, 8'hFF, 7'h00}; // Infinity + end else if (final_exp <= 0) begin + //fpcsr[1] + underflow = 1'b1; + inexact = 1'b1; + final_result_regular = {final_sign, 8'h00, 7'h00}; // Zero (subnormals flushed to zero) + end else begin + final_result_regular = {final_sign, final_exp[7:0], result_mantissa[6:0]}; + end + + // Set inexact flag if any of the lower bits were non-zero + //fpcsr[0] + inexact = guard_bit || round_bit || sticky_bit; + //invalid = 0; // No invalid operation in simple FMA + + end + + + + always_comb begin + if (result_is_special_two) begin + result_o = special_result_two; + end + else begin + result_o = final_result_regular; + end + end + + always @(posedge clk_one or posedge reset) begin + if(reset) begin + result <= 0; + fpcsr <= 0; + end + else begin + + result <= result_o; + fpcsr <= {invalid_one, overflow, underflow, inexact}; + + end + end + + + + + + +endmodule \ No newline at end of file diff --git a/rtl/vendor/bf16_acc/bf16_maxmin.sv b/rtl/vendor/bf16_acc/bf16_maxmin.sv new file mode 100644 index 000000000..fbc537331 --- /dev/null +++ b/rtl/vendor/bf16_acc/bf16_maxmin.sv @@ -0,0 +1,107 @@ +`timescale 1ns / 1ps +////////////////////////////////////////////////////////////////////////////////// +// Company: +// Engineer: +// +// Create Date: 12/02/2023 09:08:55 PM +// Design Name: +// Module Name: bf16minmax +// Project Name: +// Target Devices: +// Tool Versions: +// Description: +// +// Dependencies: +// +// Revision: +// Revision 0.01 - File Created +// Additional Comments: +// +////////////////////////////////////////////////////////////////////////////////// + +module bf16_minmax( + input logic clk, + input logic reset, + input logic enable, + input logic [15:0] operand_a, // BF16 operand A + input logic [15:0] operand_b, // BF16 operand B + input logic [3:0] operation, // Operation select: 0 for min, 1 for max + output logic [15:0] result, // BF16 result + output logic [3:0] fpcsr // Floating Point Control and Status Register +); + + // Decompose BF16 operands + logic operand_a_sign, operand_b_sign; + logic [7:0] operand_a_exp, operand_b_exp; + logic [6:0] operand_a_man, operand_b_man; + logic operand_a_nan, operand_b_nan; + logic numerical_comparison; + logic operand_a_smaller; + logic select_a; + + assign operand_a_sign = operand_a[15]; + assign operand_a_exp = operand_a[14:7]; + assign operand_a_man = operand_a[6:0]; + assign operand_b_sign = operand_b[15]; + assign operand_b_exp = operand_b[14:7]; + assign operand_b_man = operand_b[6:0]; + +// // Numerical (absolute) comparison +// assign numerical_comparison = operand_a < operand_b; + +// // Determine which operand is smaller or larger +// assign operand_a_smaller = numerical_comparison ^ (operand_a_sign || operand_b_sign); + +// // Operation: 0011 for min, 0010 for max +// assign select_a = (operation == 4'b0011) ? operand_a_smaller : !operand_a_smaller; + +// // Check for NaN +// assign operand_a_nan = (operand_a_exp == 8'hFF) && (operand_a_man != 0); +// assign operand_b_nan = (operand_b_exp == 8'hFF) && (operand_b_man != 0); + + always @(posedge clk ) begin + // Reset FPCSR flags + if (reset) begin + result = 16'b0; + fpcsr = 4'b0000; + end + else if (enable) begin + + // Check for NaN + operand_a_nan = (operand_a_exp == 8'hFF) && (operand_a_man != 0); + operand_b_nan = (operand_b_exp == 8'hFF) && (operand_b_man != 0); + + if (operand_a_nan && operand_b_nan) begin + // Both operands are NaN, return canonical NaN + result = 16'h7FC0; // Canonical NaN in BF16 + fpcsr[3] = 1; // Invalid flag set + end else if (operand_a_nan) begin + // Operand A is NaN, return B + result = operand_b; + fpcsr[3] = 1; // Invalid flag set + end else if (operand_b_nan) begin + // Operand B is NaN, return A + result = operand_a; + fpcsr[3] = 1; // Invalid flag set + end + end + numerical_comparison = operand_a < operand_b; + + // Determine which operand is smaller or larger + operand_a_smaller = numerical_comparison ^ (operand_a_sign || operand_b_sign); + + // Operation: 0011 for min, 0010 for max + select_a = (operation == 4'b0011) ? operand_a_smaller : !operand_a_smaller; + + + + if (select_a) begin + // Select A for min or max based on comparison + result = operand_a; + end else begin + // Select B for min or max based on comparison + result = operand_b; + end + + end +endmodule diff --git a/rtl/vendor/bf16_acc/bf16fp32conv.sv b/rtl/vendor/bf16_acc/bf16fp32conv.sv new file mode 100644 index 000000000..6c1a11944 --- /dev/null +++ b/rtl/vendor/bf16_acc/bf16fp32conv.sv @@ -0,0 +1,64 @@ +module bf16_to_fp32( + input logic clk, + input logic reset, + input logic instruction_enable, // Enable signal for specific instruction type + input logic [15:0] operand_a, // BF16 input + output logic [31:0] result, // FP32 output + output logic [3:0] fpcsr // Floating Point Control and Status Register +); + + // Internal variables + logic operand_a_sign; + logic [7:0] operand_a_exp; + logic [6:0] operand_a_man; + logic operand_a_inf, operand_a_zero, operand_a_nan; + + // Only execute logic if enabled for this instruction + always @(posedge clk) begin + if (reset) begin + result = 0; + fpcsr = 0; + end else begin + // Decompose operand + operand_a_sign = operand_a[15]; + operand_a_exp = operand_a[14:7]; + operand_a_man = operand_a[6:0]; + + // Special case flags + operand_a_inf = (operand_a_exp == 8'hFF) && (operand_a_man == 0); + operand_a_zero = (operand_a_exp == 0) && (operand_a_man == 0); + operand_a_nan = (operand_a_exp == 8'hFF) && (operand_a_man != 0); + + // Handle special cases and conversion + if (operand_a_inf) begin + result = {operand_a_sign, 8'hFF, 23'h000000}; // Infinity + end else if (operand_a_zero) begin + result = {operand_a_sign, 8'h00, 23'h000000}; // Zero + end else if (operand_a_nan) begin + result = {1'b0, 8'hFF, {1'b1, 22'h00000}}; // NaN + end else begin + result = convert_to_fp32(operand_a_sign, operand_a_exp, operand_a_man); + end + + // Update fpcsr + fpcsr[3] = operand_a_nan; // Invalid operation flag + fpcsr[2] = 0; // Overflow flag + fpcsr[1] = 0; // Underflow flag + fpcsr[0] = 0; // Inexact flag + end + end + + function automatic [31:0] convert_to_fp32( + input logic sign, + input logic [7:0] exp, + input logic [6:0] man + ); + logic [7:0] new_exp; + logic [22:0] new_man; + + new_exp = exp; // Directly use exponent from BF16 + new_man = {man, 16'h0000}; // Zero-extend mantissa from BF16 to FP32 + + convert_to_fp32 = {sign, new_exp, new_man}; // Assemble FP32 number + endfunction +endmodule \ No newline at end of file diff --git a/rtl/vendor/bf16_acc/fp32bf16conv.sv b/rtl/vendor/bf16_acc/fp32bf16conv.sv new file mode 100644 index 000000000..597594321 --- /dev/null +++ b/rtl/vendor/bf16_acc/fp32bf16conv.sv @@ -0,0 +1,97 @@ +module fp32_to_bf16( + input logic clk, + input logic reset, + input logic instruction_enable, // Enable signal for specific instruction type + input logic [31:0] operand_a, // FP32 input + output logic [15:0] result, // BF16 output + output logic [3:0] fpcsr // Floating Point Control and Status Register +); + + // Internal variables + logic operand_a_sign; + logic [8:0] operand_a_exp; + logic [22:0] operand_a_man; + logic operand_a_inf, operand_a_zero, operand_a_nan, operand_a_subnormal; + + // Only execute logic if enabled for this instruction + always @(posedge clk) begin + if (reset) begin + result = 0; + fpcsr = 0; + end else begin + // Decompose operand + operand_a_sign = operand_a[31]; + operand_a_exp = {1'b0, operand_a[30:23]}; + operand_a_man = operand_a[22:0]; + + // Special case flags + operand_a_inf = (operand_a_exp == 9'h0FF) && (operand_a_man == 0); + operand_a_zero = (operand_a_exp == 0) && (operand_a_man == 0); + operand_a_nan = (operand_a_exp == 9'h0FF) && (operand_a_man != 0); + operand_a_subnormal = (operand_a_exp == 0) && (operand_a_man != 0); + + // Handle special cases and conversion + if (operand_a_inf) begin + result = {operand_a_sign, 8'hFF, 7'h00}; // Infinity + //fpcsr[2] <= 1; // Set overflow flag + end else if (operand_a_zero) begin + result = {operand_a_sign, 8'h00, 7'h00}; // Zero + end else if (operand_a_nan) begin + result = {1'b0, 8'hFF, 7'hc0}; // NaN + end else begin + result = convert_to_bf16(operand_a_sign, operand_a_exp, operand_a_man); + end + + // Update fpcsr + fpcsr[3] = operand_a_nan; // Invalid operation flag + fpcsr[2] = 0; // Overflow flag + fpcsr[1] = 0; // Underflow flag + fpcsr[0] = 0; // Inexact flag + end + end + + function automatic [15:0] convert_to_bf16( + input logic sign, + input logic [9:0] exp, + input logic [22:0] man + ); + logic [9:0] new_exp; + logic [6:0] new_man; + logic rounding_bit, sticky_bit, guard_bit, round_up; + + // Alignment for BF16 mantissa (truncate 16 LSBs, keep guard bit) + guard_bit = man[15]; + rounding_bit = man[14]; + sticky_bit = |man[13:0]; // OR all truncated bits for sticky bit + + // Check for rounding using RNE + round_up = guard_bit & (sticky_bit | rounding_bit | new_man[0]); + + new_exp = exp; // Adjust exponent from FP32 to BF16 bias + new_man = man[22:16]; // Truncate mantissa to BF16 precision + + // Apply rounding + if (round_up) begin + // Check for overflow in mantissa before incrementing + if (new_man == 7'h7F) begin + new_exp = new_exp + 1; // Increment exponent due to mantissa overflow + new_man = 0; // Reset mantissa to 0 because of overflow + end else begin + new_man = new_man + 1; // Increment mantissa + end + end + + // Check for exponent overflow or underflow + if (new_exp >= 9'h0FF) begin + new_exp = 9'h0FF; // Cap at largest normal value + end else if (new_exp <= 0) begin + new_exp = 0; // Subnormals and zero + new_man = 0; + end + + // Set inexact if any LSBs are truncated or guard, round, sticky bits are set + fpcsr[0] <= guard_bit | rounding_bit | sticky_bit; + + convert_to_bf16 = {sign, new_exp[7:0], new_man[6:0]}; // Assemble BF16 number + endfunction +endmodule diff --git a/rtl/vendor/bf16_acc/lzc.sv b/rtl/vendor/bf16_acc/lzc.sv new file mode 100644 index 000000000..b4ff0c9eb --- /dev/null +++ b/rtl/vendor/bf16_acc/lzc.sv @@ -0,0 +1,52 @@ +`timescale 1ns / 1ps +////////////////////////////////////////////////////////////////////////////////// +// Company: +// Engineer: +// +// Create Date: 01/29/2024 04:52:37 PM +// Design Name: +// Module Name: lzc +// Project Name: +// Target Devices: +// Tool Versions: +// Description: +// +// Dependencies: +// +// Revision: +// Revision 0.01 - File Created +// Additional Comments: +// +////////////////////////////////////////////////////////////////////////////////// + + +module clz (ref_vector, dout); + parameter REF_VECTOR_WIDTH=32; + localparam DOUT_WIDTH = $clog2(REF_VECTOR_WIDTH)+1; + localparam DOUT_LR_WIDTH = DOUT_WIDTH-1; + input [REF_VECTOR_WIDTH-1:0] ref_vector; + + output [DOUT_WIDTH-1:0] dout; + + wire [DOUT_LR_WIDTH-1:0] dout_r; + wire [DOUT_LR_WIDTH-1:0] dout_l; + + wire [REF_VECTOR_WIDTH/2-1:0] ref_vector_r; + wire [REF_VECTOR_WIDTH/2-1:0] ref_vector_l; + + generate + if (REF_VECTOR_WIDTH == 2) + assign dout = (ref_vector == 2'b00) ? 'd2 : + (ref_vector == 2'b01) ? 'd1 : 0; + else begin + assign ref_vector_l = ref_vector[REF_VECTOR_WIDTH-1:REF_VECTOR_WIDTH/2]; + assign ref_vector_r = ref_vector[REF_VECTOR_WIDTH/2-1:0]; + clz #(REF_VECTOR_WIDTH/2) u_nv_clz_l(ref_vector_l, dout_l); + clz #(REF_VECTOR_WIDTH/2) u_nv_clz_r(ref_vector_r, dout_r); + assign dout = (~dout_l[DOUT_LR_WIDTH-1]) ? {dout_l [DOUT_LR_WIDTH-1] & dout_r [DOUT_LR_WIDTH-1], 1'b0 , dout_l[DOUT_LR_WIDTH-2:0]} : + {dout_l [DOUT_LR_WIDTH-1] & dout_r [DOUT_LR_WIDTH-1], ~dout_r[DOUT_LR_WIDTH-1], dout_r[DOUT_LR_WIDTH-2:0]}; + end + + endgenerate + +endmodule \ No newline at end of file