Skip to content

Commit

Permalink
Rewrite ReadBitsInt{Be,Le}(), change param type from uint8 to int
Browse files Browse the repository at this point in the history
See kaitai-io/kaitai_struct#949

Changing the parameter type is potentially a breaking change, because
calls like ReadBitsIntBe(uint8(15)) will no longer work. But this
will not be a problem for KSC-generated code, it could only break
manually-written (or manually-patched) code.

`ReadBitsInt(n uint8)` is intentionally left as it was, since it exists
only to maintain backward compatibility and will be removed at some point.
  • Loading branch information
generalmimon committed Mar 1, 2022
1 parent dfb7177 commit a5c5c1e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 31 deletions.
62 changes: 33 additions & 29 deletions kaitai/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type Stream struct {
buf [8]byte

// Number of bits remaining in "bits" for sequential calls to ReadBitsInt
bitsLeft uint8
bitsLeft int
bits uint64
}

Expand Down Expand Up @@ -297,13 +297,17 @@ func (k *Stream) AlignToByte() {
}

// ReadBitsIntBe reads n-bit integer in big-endian byte order and returns it as uint64.
func (k *Stream) ReadBitsIntBe(n uint8) (res uint64, err error) {
bitsNeeded := int(n) - int(k.bitsLeft)
func (k *Stream) ReadBitsIntBe(n int) (res uint64, err error) {
res = 0

bitsNeeded := n - k.bitsLeft
k.bitsLeft = -bitsNeeded & 7 // `-bitsNeeded mod 8`

if bitsNeeded > 0 {
// 1 bit => 1 byte
// 8 bits => 1 byte
// 9 bits => 2 bytes
bytesNeeded := ((bitsNeeded - 1) / 8) + 1
bytesNeeded := ((bitsNeeded - 1) / 8) + 1 // `ceil(bitsNeeded / 8)`
if bytesNeeded > 8 {
return res, fmt.Errorf("ReadBitsIntBe(%d): more than 8 bytes requested", n)
}
Expand All @@ -312,20 +316,17 @@ func (k *Stream) ReadBitsIntBe(n uint8) (res uint64, err error) {
return res, err
}
for i := 0; i < bytesNeeded; i++ {
k.bits <<= 8
k.bits |= uint64(k.buf[i])
k.bitsLeft += 8
res = res<<8 | uint64(k.buf[i])
}

newBits := res
res = res>>k.bitsLeft | k.bits<<bitsNeeded
k.bits = newBits // will be masked at the end of the function
} else {
res = k.bits >> -bitsNeeded // shift unneeded bits out
}

// raw mask with required number of 1s, starting from lowest bit
var mask uint64 = (1 << n) - 1
// shift "bits" to align the highest bits with the mask & derive the result
shiftBits := k.bitsLeft - n
res = (k.bits >> shiftBits) & mask
// clear top bits that we've just read => AND with 1s
k.bitsLeft -= n
mask = (1 << k.bitsLeft) - 1
var mask uint64 = (1 << k.bitsLeft) - 1 // `bitsLeft` is in range 0..7
k.bits &= mask

return res, err
Expand All @@ -335,18 +336,19 @@ func (k *Stream) ReadBitsIntBe(n uint8) (res uint64, err error) {
//
// Deprecated: Use ReadBitsIntBe instead.
func (k *Stream) ReadBitsInt(n uint8) (res uint64, err error) {
return k.ReadBitsIntBe(n)
return k.ReadBitsIntBe(int(n))
}

// ReadBitsIntLe reads n-bit integer in little-endian byte order and returns it as uint64.
func (k *Stream) ReadBitsIntLe(n uint8) (res uint64, err error) {
bitsNeeded := int(n) - int(k.bitsLeft)
var bitsLeft uint64 = uint64(k.bitsLeft)
func (k *Stream) ReadBitsIntLe(n int) (res uint64, err error) {
res = 0
bitsNeeded := n - k.bitsLeft

if bitsNeeded > 0 {
// 1 bit => 1 byte
// 8 bits => 1 byte
// 9 bits => 2 bytes
bytesNeeded := ((bitsNeeded - 1) / 8) + 1
bytesNeeded := ((bitsNeeded - 1) / 8) + 1 // `ceil(bitsNeeded / 8)`
if bytesNeeded > 8 {
return res, fmt.Errorf("ReadBitsIntLe(%d): more than 8 bytes requested", n)
}
Expand All @@ -355,19 +357,21 @@ func (k *Stream) ReadBitsIntLe(n uint8) (res uint64, err error) {
return res, err
}
for i := 0; i < bytesNeeded; i++ {
k.bits |= uint64(k.buf[i]) << bitsLeft
bitsLeft += 8
res |= uint64(k.buf[i]) << (i * 8)
}

newBits := res >> bitsNeeded
res = res<<k.bitsLeft | k.bits
k.bits = newBits
} else {
res = k.bits
k.bits >>= n
}

// raw mask with required number of 1s, starting from lowest bit
var mask uint64 = (1 << n) - 1
// derive reading result
res = k.bits & mask
// remove bottom bits that we've just read by shifting
k.bits >>= n
k.bitsLeft = uint8(bitsLeft) - n
k.bitsLeft = -bitsNeeded & 7 // `-bitsNeeded mod 8`

var mask uint64 = (1 << n) - 1 // unlike some other languages, no problem with this in Go
res &= mask
return res, err
}

Expand Down
4 changes: 2 additions & 2 deletions kaitai/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ func TestStream_AlignToByte(t *testing.T) {

func TestStream_ReadBitsIntBe(t *testing.T) {
type args struct {
totalBitsNeeded uint8
totalBitsNeeded int
}
tests := []struct {
name string
Expand Down Expand Up @@ -771,7 +771,7 @@ func TestStream_ReadBitsArray(t *testing.T) {

func TestStream_ReadBitsIntLe(t *testing.T) {
type args struct {
n uint8
n int
}
tests := []struct {
name string
Expand Down

0 comments on commit a5c5c1e

Please sign in to comment.