From 4bd07f81a91b5548ded4610b0ce20b6e08108d07 Mon Sep 17 00:00:00 2001 From: Pierre Seznec Date: Fri, 16 Jun 2023 10:49:32 +0200 Subject: [PATCH] add u256 multiplication --- assembly/__tests__/u256.spec.ts | 71 +++++++++++++++++++++++++++++++++ assembly/globals.ts | 66 +++++++++++++++++++++++++++++- assembly/integer/u256.ts | 32 +++++++++++++++ 3 files changed, 167 insertions(+), 2 deletions(-) diff --git a/assembly/__tests__/u256.spec.ts b/assembly/__tests__/u256.spec.ts index 346b41c..9d33f6d 100644 --- a/assembly/__tests__/u256.spec.ts +++ b/assembly/__tests__/u256.spec.ts @@ -1,3 +1,4 @@ +import { u128 } from '../integer'; import { u256 } from '../integer/u256'; describe("String Conversion", () => { @@ -348,4 +349,74 @@ describe("Basic Operations", () => { var r = new u256(3, 4, 6, 8); expect(a - b).toStrictEqual(r); }); + + it("Should multiply two u256 numbers", () => { + var a = u256.from(3); + var b = u256.from(3); + var r = u256.from(9); + + expect(a * b).toStrictEqual(r); + expect(b * a).toStrictEqual(r); + }); + + it("Should multiply two u256 numbers", () => { + var a = u256.from(43545453452); + var b = u256.from(2353454354); + var r = new u256(10248516654965971928, 5); + + expect(a * b).toStrictEqual(r); + expect(b * a).toStrictEqual(r); + }); + + it("Should multiply two random u256 numbers", () => { + let i = 0; + while(i< 20) { + const randa = Math.trunc(Math.random() * 0xFFFF) + const a = u256.from(randa); + const randb = Math.trunc(Math.random() * 0xFFFF) + const b = u256.from(randb); + const r = new u256(randa * randb); + + expect(a * b).toStrictEqual(r); + expect(b * a).toStrictEqual(r); + i++; + } + }); + + + it("Should multiply two u256 numbers - 2", () => { + var a = u256.from(new u128(14083847773837265618, 6692605942)); + var b = u256.from(new u128(18444665141527514289, 5354084802)); + var r = new u256(5659639222556316466, 4474720309748468391, 17386035696907167262, 1); + + expect(a * b).toStrictEqual(r); + expect(b * a).toStrictEqual(r); + }); + + it("Should multiply u256 numbers by 1", () => { + var a = u256.Max; + var b = u256.One; + var r = a; + expect(a * b).toStrictEqual(r); + expect(b * a).toStrictEqual(r); + }); + + it("Should multiply u256 numbers by 0", () => { + var a = new u256(5656466, 447478468391, 17386907167262, 1); + var b = u256.Zero; + var r = b; + expect(a * b).toStrictEqual(r); + expect(b * a).toStrictEqual(r); + }); + + it("Should multiply two u256 numbers with overflow", () => { + var a = new u256(0, 0, 1); + expect(a * a).toStrictEqual(u256.Zero); + }); + + it("Should multiply two u256 numbers with overflow - 2", () => { + var a = new u256(1, 0, 1); + expect(a * a).toStrictEqual(new u256(1, 0, 2)); + }); + }); diff --git a/assembly/globals.ts b/assembly/globals.ts index d2cb68c..f0ade75 100644 --- a/assembly/globals.ts +++ b/assembly/globals.ts @@ -1,3 +1,4 @@ +import { u256 } from './integer'; import { u128 } from './integer/u128'; // used for returning quotient and reminder from __divmod128 @@ -7,6 +8,8 @@ import { u128 } from './integer/u128'; // used for returning low and high part of __mulq64, __multi3 etc @lazy export var __res128_hi: u64 = 0; +// used for returning 0 or 1 +@lazy export var __carry: u64 = 0; /** * Convert 128-bit unsigned integer to 64-bit float @@ -75,14 +78,73 @@ export function __umulq64(a: u64, b: u64): u64 { var uv = u * v; var w0 = uv & 0xFFFFFFFF; - uv = a * v + (uv >> 32); + uv = a * v + (uv >> 32); var w1 = uv >> 32; - uv = u * b + (uv & 0xFFFFFFFF); + uv = u * b + (uv & 0xFFFFFFFF); __res128_hi = a * b + w1 + (uv >> 32); return (uv << 32) | w0; } +// __umul64Hop computes (hi * 2^64 + lo) = z + (x * y) +// @ts-ignore: decorator +@inline +export function __umul64Hop(z: u64, x: u64, y: u64): u64 { + var lo = __umulq64(x, y); + lo = __uadd64(lo, z); + var hi = __res128_hi +__carry; + __res128_hi = hi; + return lo +} + +// __umul64Step computes (hi * 2^64 + lo) = z + (x * y) + carry. +// @ts-ignore: decorator +@inline +export function __umul64Step(z: u64, x: u64, y: u64, carry: u64): u64 { + var lo = __umulq64(x, y) + lo = __uadd64(lo, carry); + var hi = __uadd64(__res128_hi, 0, __carry); + lo = __uadd64(lo, z); + hi += __carry; + __res128_hi = hi; + return lo +} + +// __uadd64 returns the sum with carry of x, y and carry: sum = x + y + carry. +// The carry input must be 0 or 1; otherwise the behavior is undefined. +// The carryOut output is guaranteed to be 0 or 1. +// @ts-ignore: decorator +@inline +export function __uadd64(x: u64, y: u64, carry: u64 = 0): u64 { + var sum = x + y + carry + // // The sum will overflow if both top bits are set (x & y) or if one of them + // // is (x | y), and a carry from the lower place happened. If such a carry + // // happens, the top bit will be 1 + 0 + 1 = 0 (& ~sum). + __carry = ((x & y) | ((x | y) & ~sum)) >>> 63 + return sum; + +} + +// @ts-ignore: decorator +@global +export function __mul256(x0: u64, x1: u64, x2: u64, x3: u64, y0: u64, y1: u64, y2: u64, y3: u64): u256 { + var lo1 = __umulq64(x0, y0); + var res1 = __umul64Hop(__res128_hi, x1, y0); + var res2 = __umul64Hop(__res128_hi, x2, y0); + var res3 = x3 * y0 + __res128_hi; + + var lo2 = __umul64Hop(res1, x0, y1); + res2 = __umul64Step(res2, x1, y1, __res128_hi); + res3 += x2 * y1 + __res128_hi; + + var hi1 = __umul64Hop(res2, x0, y2); + res3 += x1 * y2 + __res128_hi + + var hi2 = __umul64Hop(res3, x0, y3); + + return new u256(lo1, lo2, hi1, hi2); +} + // @ts-ignore: decorator @global export function __multi3(al: u64, ah: u64, bl: u64, bh: u64): u64 { diff --git a/assembly/integer/u256.ts b/assembly/integer/u256.ts index 4c2a0ac..7126df4 100644 --- a/assembly/integer/u256.ts +++ b/assembly/integer/u256.ts @@ -1,6 +1,7 @@ import { i128 } from './i128'; import { u128 } from './u128'; import { u256toDecimalString } from "../utils"; +import { __mul256 } from '../globals'; @lazy const HEX_CHARS = '0123456789abcdef'; @@ -141,6 +142,31 @@ export class u256 { return new u256(value, mask, mask, mask); } + /** + * Create 256-bit unsigned integer from generic type T + * @param value + * @returns 256-bit unsigned integer + */ + @inline + static from(value: T): u256 { + if (value instanceof bool) return u256.fromU64(value); + else if (value instanceof i8) return u256.fromI64(value); + else if (value instanceof u8) return u256.fromU64(value); + else if (value instanceof i16) return u256.fromI64(value); + else if (value instanceof u16) return u256.fromU64(value); + else if (value instanceof i32) return u256.fromI64(value); + else if (value instanceof u32) return u256.fromU64(value); + else if (value instanceof i64) return u256.fromI64(value); + else if (value instanceof u64) return u256.fromU64(value); + else if (value instanceof f32) return u256.fromF64(value); + else if (value instanceof f64) return u256.fromF64(value); + else if (value instanceof u128) return u256.fromU128(value); + else if (value instanceof u256) return u256.fromU256(value); + else if (value instanceof u8[]) return u256.fromBytes(value); + else if (value instanceof Uint8Array) return u256.fromBytes(value); + else throw new TypeError("Unsupported generic type"); + } + // TODO // static fromString(str: string): u256 @@ -431,6 +457,12 @@ export class u256 { return !u256.lt(a, b); } + // mul: u256 x u256 = u256 + @inline @operator('*') + static mul(a: u256, b: u256): u256 { + return __mul256(a.lo1, a.lo2, a.hi1, a.hi2, b.lo1, b.lo2, b.hi1, b.hi2) + } + @inline static popcnt(value: u256): i32 { var count = popcnt(value.lo1);