diff --git a/base58.go b/base58.go index e1b3467..193a87d 100644 --- a/base58.go +++ b/base58.go @@ -94,34 +94,21 @@ func FastBase58DecodingAlphabet(str string, alphabet *Alphabet) ([]byte, error) return nil, fmt.Errorf("zero length string") } - var ( - t, c uint64 - zmask uint32 - zcount int - - b58u = []rune(str) - b58sz = len(b58u) - - outisz = (b58sz + 3) >> 2 - binu = make([]byte, (b58sz+3)*3) - bytesleft = b58sz & 3 - - zero = rune(alphabet.encode[0]) - ) + zero := alphabet.encode[0] + b58sz := len(str) - if bytesleft > 0 { - zmask = 0xffffffff << uint32(bytesleft*8) - } else { - bytesleft = 4 + var zcount int + for i := 0; i < b58sz && str[i] == zero; i++ { + zcount++ } - var outi = make([]uint32, outisz) + var t, c uint64 - for i := 0; i < b58sz && b58u[i] == zero; i++ { - zcount++ - } + // the 32bit algo stretches the result up to 2 times + binu := make([]byte, 2*((b58sz*406/555)+1)) + outi := make([]uint32, (b58sz+3)/4) - for _, r := range b58u { + for _, r := range str { if r > 127 { return nil, fmt.Errorf("high-bit set on invalid digit") } @@ -131,39 +118,37 @@ func FastBase58DecodingAlphabet(str string, alphabet *Alphabet) ([]byte, error) c = uint64(alphabet.decode[r]) - for j := outisz - 1; j >= 0; j-- { + for j := len(outi) - 1; j >= 0; j-- { t = uint64(outi[j])*58 + c - c = (t >> 32) & 0x3f + c = t >> 32 outi[j] = uint32(t & 0xffffffff) } - - if c > 0 { - return nil, fmt.Errorf("output number too big (carry to the next int32)") - } - - if outi[0]&zmask != 0 { - return nil, fmt.Errorf("output number too big (last int32 filled too far)") - } } - var j, cnt int - for j, cnt = 0, 0; j < outisz; j++ { - for mask := byte(bytesleft-1) * 8; mask <= 0x18; mask, cnt = mask-8, cnt+1 { - binu[cnt] = byte(outi[j] >> mask) - } - if j == 0 { - bytesleft = 4 // because it could be less than 4 the first time through + // initial mask depends on b58sz, on further loops it always starts at 24 bits + mask := (uint(b58sz%4) * 8) + if mask == 0 { + mask = 32 + } + mask -= 8 + + outLen := 0 + for j := 0; j < len(outi); j++ { + for mask < 32 { // loop relies on uint overflow + binu[outLen] = byte(outi[j] >> mask) + mask -= 8 + outLen++ } + mask = 24 } - for n, v := range binu { - if v > 0 { - start := n - zcount - if start < 0 { - start = 0 - } - return binu[start:cnt], nil + // find the most significant byte post-decode, if any + for msb := zcount; msb < len(binu); msb++ { + if binu[msb] > 0 { + return binu[msb-zcount : outLen], nil } } - return binu[:cnt], nil + + // it's all zeroes + return binu[:outLen], nil }