Skip to content

Commit

Permalink
Merge pull request #34 from jargh/main
Browse files Browse the repository at this point in the history
Improvements to ARM small Karatsuba muls
  • Loading branch information
jargh authored Jun 10, 2022
2 parents 2530c31 + ac6a59a commit 5162347
Show file tree
Hide file tree
Showing 26 changed files with 1,992 additions and 1,127 deletions.
339 changes: 196 additions & 143 deletions arm/curve25519/bignum_mul_p25519.S
Original file line number Diff line number Diff line change
Expand Up @@ -29,149 +29,208 @@
.text
.balign 4

// ---------------------------------------------------------------------------
// Macro computing [c,b,a] := [b,a] + (x - y) * (w - z), adding with carry
// to the [b,a] components but leaving CF aligned with the c term, which is
// a sign bitmask for (x - y) * (w - z). Continued add-with-carry operations
// with [c,...,c] will continue the carry chain correctly starting from
// the c position if desired to add to a longer term of the form [...,b,a].
//
// c,h,l,t should all be different and t,h should not overlap w,z.
// ---------------------------------------------------------------------------

#define muldiffnadd(b,a,x,y,w,z) \
subs t, x, y; \
cneg t, t, cc; \
csetm c, cc; \
subs h, w, z; \
cneg h, h, cc; \
mul l, t, h; \
umulh h, t, h; \
cinv c, c, cc; \
adds xzr, c, #1; \
eor l, l, c; \
adcs a, a, l; \
eor h, h, c; \
adcs b, b, h
#define z x0
#define x x1
#define y x2

#define a0 x3
#define a1 x4
#define a2 x5
#define a3 x6
#define b0 x7
#define b1 x8
#define b2 x9
#define b3 x10

#define s0 x11
#define s1 x12
#define s2 x13
#define s3 x14
#define s4 x15

#define m x15
#define q x15

#define t0 x11
#define t1 x16
#define t2 x12
#define t3 x13
#define t4 x14
#define t5 x15

#define u0 x11
#define u1 x16
#define u2 x1
#define u3 x2
#define u4 x12
#define u5 x13
#define u6 x14
#define u7 x15

#define c x17
#define h x19
#define l x20
#define t x21
#define d x21
#define b0 x5
#define b1 x6

#define u0 x7
#define u1 x8
#define u2 x9
#define u3 x10
#define u4 x11
#define u5 x12
#define u6 x13
#define u7 x14

#define t x15

#define sgn x16
#define ysgn x17

// These are aliases to registers used elsewhere including input pointers.
// By the time they are used this does not conflict with other uses.

#define m0 y
#define m1 ysgn
#define m2 t
#define m3 x
#define u u2

// For the reduction stages, again aliasing other things but not the u's

#define c x3
#define h x4
#define l x5
#define d x6
#define q x17

S2N_BN_SYMBOL(bignum_mul_p25519):

// Save additional registers to use

stp x19, x20, [sp, #-16]!
stp x21, x22, [sp, #-16]!

// Load operands

ldp a0, a1, [x1]
ldp b0, b1, [x2]
ldp a2, a3, [x1, #16]
ldp b2, b3, [x2, #16]

// First accumulate all the "simple" products as [s4,s3,s2,s1,s0]

mul s0, a0, b0
mul s1, a1, b1
mul s2, a2, b2
mul s3, a3, b3

umulh m, a0, b0
adds s1, s1, m
umulh m, a1, b1
adcs s2, s2, m
umulh m, a2, b2
adcs s3, s3, m
umulh m, a3, b3
adc s4, m, xzr

// Multiply by B + 1 to get [t5;t4;t3;t2;t1;t0] where t0 == s0

adds t1, s1, s0
adcs t2, s2, s1
adcs t3, s3, s2
adcs t4, s4, s3
adc t5, xzr, s4

// Multiply by B^2 + 1 to get [u6;u5;u4;u3;u2;u1;-]. Note that
// u0 == t0 == s0 and u1 == t1

adds u2, t2, t0
adcs u3, t3, t1
adcs u4, t4, t2
adcs u5, t5, t3
adcs u6, xzr, t4
adc u7, xzr, t5

// Now add in all the "complicated" terms.

muldiffnadd(u6,u5, a2,a3, b3,b2)
adc u7, u7, c

muldiffnadd(u2,u1, a0,a1, b1,b0)
adcs u3, u3, c
adcs u4, u4, c
adcs u5, u5, c
adcs u6, u6, c
adc u7, u7, c

muldiffnadd(u5,u4, a1,a3, b3,b1)
adcs u6, u6, c
adc u7, u7, c

muldiffnadd(u3,u2, a0,a2, b2,b0)
adcs u4, u4, c
adcs u5, u5, c
adcs u6, u6, c
adc u7, u7, c

muldiffnadd(u4,u3, a0,a3, b3,b0)
adcs u5, u5, c
adcs u6, u6, c
adc u7, u7, c
muldiffnadd(u4,u3, a1,a2, b2,b1)
adcs u5, u5, c
adcs u6, u6, c
adc u7, u7, c
// Multiply the low halves using Karatsuba 2x2->4 to get [u3,u2,u1,u0]

ldp a0, a1, [x]
ldp b0, b1, [y]

mul u0, a0, b0
umulh u1, a0, b0
mul u2, a1, b1
umulh u3, a1, b1

subs a1, a1, a0
cneg a1, a1, cc
csetm sgn, cc

adds u2, u2, u1
adc u3, u3, xzr

subs a0, b0, b1
cneg a0, a0, cc
cinv sgn, sgn, cc

mul t, a1, a0
umulh a0, a1, a0

adds u1, u0, u2
adcs u2, u2, u3
adc u3, u3, xzr

adds xzr, sgn, #1
eor t, t, sgn
adcs u1, t, u1
eor a0, a0, sgn
adcs u2, a0, u2
adc u3, u3, sgn

// Multiply the high halves using Karatsuba 2x2->4 to get [u7,u6,u5,u4]

ldp a0, a1, [x, #16]
ldp b0, b1, [y, #16]

mul u4, a0, b0
umulh u5, a0, b0
mul u6, a1, b1
umulh u7, a1, b1

subs a1, a1, a0
cneg a1, a1, cc
csetm sgn, cc

adds u6, u6, u5
adc u7, u7, xzr

subs a0, b0, b1
cneg a0, a0, cc
cinv sgn, sgn, cc

mul t, a1, a0
umulh a0, a1, a0

adds u5, u4, u6
adcs u6, u6, u7
adc u7, u7, xzr

adds xzr, sgn, #1
eor t, t, sgn
adcs u5, t, u5
eor a0, a0, sgn
adcs u6, a0, u6
adc u7, u7, sgn

// Compute sgn,[a1,a0] = x_hi - x_lo
// and ysgn,[b1,b0] = y_lo - y_hi
// sign-magnitude differences

ldp a0, a1, [x, #16]
ldp t, sgn, [x]
subs a0, a0, t
sbcs a1, a1, sgn
csetm sgn, cc

ldp t, ysgn, [y]
subs b0, t, b0
sbcs b1, ysgn, b1
csetm ysgn, cc

eor a0, a0, sgn
subs a0, a0, sgn
eor a1, a1, sgn
sbc a1, a1, sgn

eor b0, b0, ysgn
subs b0, b0, ysgn
eor b1, b1, ysgn
sbc b1, b1, ysgn

// Save the correct sign for the sub-product

eor sgn, ysgn, sgn

// Add H' = H + L_top, still in [u7,u6,u5,u4]

adds u4, u4, u2
adcs u5, u5, u3
adcs u6, u6, xzr
adc u7, u7, xzr

// Now compute the mid-product as [m3,m2,m1,m0]

mul m0, a0, b0
umulh m1, a0, b0
mul m2, a1, b1
umulh m3, a1, b1

subs a1, a1, a0
cneg a1, a1, cc
csetm u, cc

adds m2, m2, m1
adc m3, m3, xzr

subs b1, b0, b1
cneg b1, b1, cc
cinv u, u, cc

mul b0, a1, b1
umulh b1, a1, b1

adds m1, m0, m2
adcs m2, m2, m3
adc m3, m3, xzr

adds xzr, u, #1
eor b0, b0, u
adcs m1, b0, m1
eor b1, b1, u
adcs m2, b1, m2
adc m3, m3, u

// Accumulate the positive mid-terms as [u7,u6,u5,u4,u3,u2]

adds u2, u4, u0
adcs u3, u5, u1
adcs u4, u6, u4
adcs u5, u7, u5
adcs u6, u6, xzr
adc u7, u7, xzr

// Add in the sign-adjusted complex term

adds xzr, sgn, #1
eor m0, m0, sgn
adcs u2, m0, u2
eor m1, m1, sgn
adcs u3, m1, u3
eor m2, m2, sgn
adcs u4, m2, u4
eor m3, m3, sgn
adcs u5, m3, u5
adcs u6, u6, sgn
adc u7, u7, sgn

// Now we have the full 8-digit product 2^256 * h + l where
// h = [u7,u6,u5,u4] and l = [u3,u2,u1,u0]
Expand Down Expand Up @@ -249,12 +308,6 @@ S2N_BN_SYMBOL(bignum_mul_p25519):

stp u0, u1, [x0]
stp u2, u3, [x0, #16]

// Restore regs and return

ldp x21, x22, [sp], #16
ldp x19, x20, [sp], #16

ret

#if defined(__linux__) && defined(__ELF__)
Expand Down
Loading

0 comments on commit 5162347

Please sign in to comment.