Skip to content

Commit

Permalink
Modify shift to fixed 1, 4, 8, and 16 shifts only (#91)
Browse files Browse the repository at this point in the history
* hw: add shift modifications

* hw: update alu pe

* test: add hv shift
  • Loading branch information
rgantonio authored Oct 21, 2024
1 parent 3648c90 commit bc4a8ab
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
17 changes: 15 additions & 2 deletions rtl/encoder/hv_alu_pe.sv
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module hv_alu_pe #(
parameter int unsigned HVDimension = 512,
parameter int unsigned NumOps = 4,
parameter int unsigned NumOpsWidth = $clog2(NumOps),
parameter int unsigned MaxShiftAmt = 128,
parameter int unsigned MaxShiftAmt = 4,
parameter int unsigned ShiftWidth = $clog2(MaxShiftAmt)
)(
// Inputs
Expand All @@ -26,11 +26,24 @@ module hv_alu_pe #(

//---------------------------
// Logic for shifting
//
// Shifting modes:
// 0: 1 shift
// 1: 4 shift
// 2: 8 shift
// 3: 16 shift
//---------------------------
logic [HVDimension-1:0] circular_shift_res;

// Doing selections through wire slicing
// rather than actual shifts to optimize synthesis
always_comb begin
circular_shift_res = (A_i >> shift_amt_i) | (A_i << (HVDimension - shift_amt_i));
case (shift_amt_i)
default: circular_shift_res = {A_i[ 0], A_i[ HVDimension-1:1]};
1: circular_shift_res = {A_i[ 3:0], A_i[ HVDimension-1:4]};
2: circular_shift_res = {A_i[ 7:0], A_i[ HVDimension-1:8]};
3: circular_shift_res = {A_i[15:0], A_i[HVDimension-1:16]};
endcase
end

//---------------------------
Expand Down
27 changes: 22 additions & 5 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def gen_randint(max_val):

# For the ALU output
def hv_alu_out(hv_a, hv_b, shift_amt, hv_dim, op):
mask_val = 2**hv_dim - 1
2**hv_dim - 1

if op == 1:
result = hv_a
Expand All @@ -205,10 +205,11 @@ def hv_alu_out(hv_a, hv_b, shift_amt, hv_dim, op):
elif op == 3:
# Workaround because github CI fails
# At shifting more than 64 bits
if isinstance(hv_a, np.ndarray):
result = np.roll(hv_a, shift_amt)
else:
result = (hv_a >> shift_amt) | (hv_a << (hv_dim - shift_amt)) & mask_val
# if isinstance(hv_a, np.ndarray):
# result = np.roll(hv_a, shift_amt)
# else:
# result = (hv_a >> shift_amt) | (hv_a << (hv_dim - shift_amt)) & mask_val
result = shift_hv(hv_a, shift_amt, hv_dim)
else:
result = hv_a ^ hv_b
return result
Expand Down Expand Up @@ -245,6 +246,22 @@ def hvlist2num(hv_list):
return hv_num


# Shift modes based on shift values
def shift_hv(hv_a, shift_amt, hv_dim):
mask_val = 2**hv_dim - 1

if shift_amt == 0:
result = (hv_a >> 1) | (hv_a << (hv_dim - 1)) & mask_val
elif shift_amt == 1:
result = (hv_a >> 4) | (hv_a << (hv_dim - 4)) & mask_val
elif shift_amt == 2:
result = (hv_a >> 8) | (hv_a << (hv_dim - 8)) & mask_val
else:
result = (hv_a >> 16) | (hv_a << (hv_dim - 16)) & mask_val

return result


"""
Set of functions for the encoding module
"""
Expand Down

0 comments on commit bc4a8ab

Please sign in to comment.