From 0b861350c2cf7a2741b258cb864ecb49d34f2bd0 Mon Sep 17 00:00:00 2001 From: Giuseppe Scrivano Date: Tue, 20 Feb 2024 14:46:35 +0100 Subject: [PATCH] machine: add sparse file writer Signed-off-by: Giuseppe Scrivano --- pkg/machine/compression/sparse_file_writer.go | 133 ++++++++++++++++++ .../compression/sparse_file_writer_test.go | 109 ++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 pkg/machine/compression/sparse_file_writer.go create mode 100644 pkg/machine/compression/sparse_file_writer_test.go diff --git a/pkg/machine/compression/sparse_file_writer.go b/pkg/machine/compression/sparse_file_writer.go new file mode 100644 index 0000000000..d8e071a3ca --- /dev/null +++ b/pkg/machine/compression/sparse_file_writer.go @@ -0,0 +1,133 @@ +package compression + +import ( + "bytes" + "errors" + "io" +) + +type state int + +const ( + zerosThreshold = 1024 + + stateData = iota + stateZeros +) + +type WriteSeekCloser interface { + io.Closer + io.WriteSeeker +} + +type sparseWriter struct { + state state + file WriteSeekCloser + zeros int64 + lastIsZero bool +} + +func NewSparseWriter(file WriteSeekCloser) *sparseWriter { + return &sparseWriter{ + file: file, + state: stateData, + zeros: 0, + lastIsZero: false, + } +} + +func (sw *sparseWriter) createHole() error { + zeros := sw.zeros + if zeros == 0 { + return nil + } + sw.zeros = 0 + sw.lastIsZero = true + _, err := sw.file.Seek(zeros, io.SeekCurrent) + return err +} + +func findFirstNotZero(b []byte) int { + for i, v := range b { + if v != 0 { + return i + } + } + return -1 +} + +// Write writes data to the file, creating holes for long sequences of zeros. +func (sw *sparseWriter) Write(data []byte) (int, error) { + written, current := 0, 0 + totalLen := len(data) + for current < len(data) { + switch sw.state { + case stateData: + nextZero := bytes.IndexByte(data[current:], 0) + if nextZero < 0 { + _, err := sw.file.Write(data[written:]) + sw.lastIsZero = false + return totalLen, err + } else { + current += nextZero + sw.state = stateZeros + } + case stateZeros: + nextNonZero := findFirstNotZero(data[current:]) + if nextNonZero < 0 { + // finish with a zero, flush any data and keep track of the zeros + if written != current { + if _, err := sw.file.Write(data[written:current]); err != nil { + return -1, err + } + sw.lastIsZero = false + } + sw.zeros += int64(len(data) - current) + return totalLen, nil + } + // do not bother with too short sequences + if sw.zeros == 0 && nextNonZero < zerosThreshold { + sw.state = stateData + current += nextNonZero + continue + } + if written != current { + if _, err := sw.file.Write(data[written:current]); err != nil { + return -1, err + } + sw.lastIsZero = false + } + sw.zeros += int64(nextNonZero) + current += nextNonZero + if err := sw.createHole(); err != nil { + return -1, err + } + written = current + } + } + return totalLen, nil +} + +// Close closes the SparseWriter's underlying file. +func (sw *sparseWriter) Close() error { + if sw.file == nil { + return errors.New("file is already closed") + } + if err := sw.createHole(); err != nil { + sw.file.Close() + return err + } + if sw.lastIsZero { + if _, err := sw.file.Seek(-1, io.SeekCurrent); err != nil { + sw.file.Close() + return err + } + if _, err := sw.file.Write([]byte{0}); err != nil { + sw.file.Close() + return err + } + } + err := sw.file.Close() + sw.file = nil + return err +} diff --git a/pkg/machine/compression/sparse_file_writer_test.go b/pkg/machine/compression/sparse_file_writer_test.go new file mode 100644 index 0000000000..70a7c75b61 --- /dev/null +++ b/pkg/machine/compression/sparse_file_writer_test.go @@ -0,0 +1,109 @@ +package compression + +import ( + "bytes" + "errors" + "io" + "testing" +) + +type memorySparseFile struct { + buffer bytes.Buffer + pos int64 +} + +func (m *memorySparseFile) Seek(offset int64, whence int) (int64, error) { + var newPos int64 + switch whence { + case io.SeekStart: + newPos = offset + case io.SeekCurrent: + newPos = m.pos + offset + case io.SeekEnd: + newPos = int64(m.buffer.Len()) + offset + default: + return 0, errors.New("unsupported seek whence") + } + + if newPos < 0 { + return 0, errors.New("negative position is not allowed") + } + + m.pos = newPos + return newPos, nil +} + +func (m *memorySparseFile) Write(b []byte) (n int, err error) { + if int64(m.buffer.Len()) < m.pos { + padding := make([]byte, m.pos-int64(m.buffer.Len())) + _, err := m.buffer.Write(padding) + if err != nil { + return 0, err + } + } + + m.buffer.Next(int(m.pos) - m.buffer.Len()) + + n, err = m.buffer.Write(b) + m.pos += int64(n) + return n, err +} + +func (m *memorySparseFile) Close() error { + return nil +} + +func testInputWithWriteLen(t *testing.T, input []byte, chunkSize int) { + m := &memorySparseFile{} + sparseWriter := NewSparseWriter(m) + + for i := 0; i < len(input); i += chunkSize { + end := i + chunkSize + if end > len(input) { + end = len(input) + } + _, err := sparseWriter.Write(input[i:end]) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + } + err := sparseWriter.Close() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !bytes.Equal(input, m.buffer.Bytes()) { + t.Fatalf("Incorrect output") + } +} + +func testInput(t *testing.T, inputBytes []byte) { + currentLen := 1 + for { + testInputWithWriteLen(t, inputBytes, currentLen) + currentLen <<= 1 + if currentLen > len(inputBytes) { + break + } + } +} + +func TestSparseWriter(t *testing.T) { + testInput(t, []byte("hello")) + testInput(t, append(make([]byte, 100), []byte("hello")...)) + testInput(t, []byte("")) + + // add "hello" at the beginning + largeInput := make([]byte, 1024*1024) + copy(largeInput, []byte("hello")) + testInput(t, largeInput) + + // add "hello" at the end + largeInput = make([]byte, 1024*1024) + copy(largeInput[1024*1024-5:], []byte("hello")) + testInput(t, largeInput) + + // add "hello" in the middle + largeInput = make([]byte, 1024*1024) + copy(largeInput[len(largeInput)/2:], []byte("hello")) + testInput(t, largeInput) +}