diff --git a/ext4/dmverity/dmverity.go b/ext4/dmverity/dmverity.go index 8aaaf10272..f952fbde87 100644 --- a/ext4/dmverity/dmverity.go +++ b/ext4/dmverity/dmverity.go @@ -22,6 +22,8 @@ const ( MerkleTreeBufioSize = 1024 * 1024 // 1MB // RecommendedVHDSizeGB is the recommended size in GB for VHDs, which is not a hard limit. RecommendedVHDSizeGB = 128 * 1024 * 1024 * 1024 + // VeritySignature is a value written to dm-verity super-block. + VeritySignature = "verity" ) var salt = bytes.Repeat([]byte{0}, 32) @@ -30,6 +32,7 @@ var ( ErrSuperBlockReadFailure = errors.New("failed to read dm-verity super block") ErrSuperBlockParseFailure = errors.New("failed to parse dm-verity super block") ErrRootHashReadFailure = errors.New("failed to read dm-verity root hash") + ErrNotVeritySuperBlock = errors.New("invalid dm-verity super-block signature") ) type dmveritySuperblock struct { @@ -133,7 +136,7 @@ func NewDMVeritySuperblock(size uint64) *dmveritySuperblock { SaltSize: uint16(len(salt)), } - copy(superblock.Signature[:], "verity") + copy(superblock.Signature[:], VeritySignature) copy(superblock.Algorithm[:], "sha256") copy(superblock.Salt[:], salt) @@ -173,7 +176,7 @@ func ReadDMVerityInfo(vhdPath string, offsetInBytes int64) (*VerityInfo, error) block := make([]byte, blockSize) if s, err := vhd.Read(block); err != nil || s != blockSize { if err != nil { - return nil, errors.Wrapf(ErrSuperBlockReadFailure, "%s", err) + return nil, errors.Wrapf(err, "%s", ErrSuperBlockReadFailure) } return nil, errors.Wrapf(ErrSuperBlockReadFailure, "unexpected bytes read: expected=%d, actual=%d", blockSize, s) } @@ -181,13 +184,15 @@ func ReadDMVerityInfo(vhdPath string, offsetInBytes int64) (*VerityInfo, error) dmvSB := &dmveritySuperblock{} b := bytes.NewBuffer(block) if err := binary.Read(b, binary.LittleEndian, dmvSB); err != nil { - return nil, errors.Wrapf(ErrSuperBlockParseFailure, "%s", err) + return nil, errors.Wrapf(err, "%s", ErrSuperBlockParseFailure) + } + if string(bytes.Trim(dmvSB.Signature[:], "\x00")[:]) != VeritySignature { + return nil, ErrNotVeritySuperBlock } - // read the merkle tree root if s, err := vhd.Read(block); err != nil || s != blockSize { if err != nil { - return nil, errors.Wrapf(ErrRootHashReadFailure, "%s", err) + return nil, errors.Wrapf(err, "%s", ErrRootHashReadFailure) } return nil, errors.Wrapf(ErrRootHashReadFailure, "unexpected bytes read: expected=%d, actual=%d", blockSize, s) } diff --git a/ext4/dmverity/dmverity_test.go b/ext4/dmverity/dmverity_test.go new file mode 100644 index 0000000000..54c2b0dd51 --- /dev/null +++ b/ext4/dmverity/dmverity_test.go @@ -0,0 +1,102 @@ +package dmverity + +import ( + "bytes" + "encoding/binary" + "io" + "io/ioutil" + "math/rand" + "os" + "strings" + "testing" + "unsafe" + + "github.com/pkg/errors" +) + +func tempFileWithContentLength(t *testing.T, length int) *os.File { + tmpFile, err := ioutil.TempFile("", "") + if err != nil { + t.Fatalf("failed to create temp file") + } + defer tmpFile.Close() + + content := make([]byte, length) + if _, err := rand.Read(content); err != nil { + t.Fatalf("failed to write random bytes to buffer") + } + if _, err := tmpFile.Write(content); err != nil { + t.Fatalf("failed to write random bytes to temp file") + } + return tmpFile +} + +func writeDMVeritySuperBlock(filename string) (*os.File, error) { + out, err := os.OpenFile(filename, os.O_RDWR, 0777) + if err != nil { + return nil, err + } + defer out.Close() + fsSize, err := out.Seek(0, io.SeekEnd) + if err != nil { + return nil, err + } + sb := NewDMVeritySuperblock(uint64(fsSize)) + if err := binary.Write(out, binary.LittleEndian, sb); err != nil { + return nil, err + } + sbSize := int(unsafe.Sizeof(*sb)) + padding := bytes.Repeat([]byte{0}, blockSize-(sbSize%blockSize)) + if _, err = out.Write(padding); err != nil { + return nil, err + } + return out, nil +} + +func TestInvalidReadEOF(t *testing.T) { + tmpFile := tempFileWithContentLength(t, blockSize) + _, err := ReadDMVerityInfo(tmpFile.Name(), blockSize) + if err == nil { + t.Fatalf("no error returned") + } + if errors.Cause(err) != io.EOF { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestInvalidReadNotEnoughBytes(t *testing.T) { + tmpFile := tempFileWithContentLength(t, blockSize+blockSize/2) + _, err := ReadDMVerityInfo(tmpFile.Name(), blockSize) + if err == nil { + t.Fatalf("no error returned") + } + if errors.Cause(err) != ErrSuperBlockReadFailure || !strings.Contains(err.Error(), "unexpected bytes read") { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestNotVeritySuperBlock(t *testing.T) { + tmpFile := tempFileWithContentLength(t, 2*blockSize) + _, err := ReadDMVerityInfo(tmpFile.Name(), blockSize) + if err == nil { + t.Fatalf("no error returned") + } + if err != ErrNotVeritySuperBlock { + t.Fatalf("expected %q, got %q", ErrNotVeritySuperBlock, err) + } +} + +func TestNoMerkleTree(t *testing.T) { + tmpFile := tempFileWithContentLength(t, blockSize) + targetFile, err := writeDMVeritySuperBlock(tmpFile.Name()) + if err != nil { + t.Fatalf("failed to write dm-verity super-block: %s", err) + } + _, err = ReadDMVerityInfo(targetFile.Name(), blockSize) + if err == nil { + t.Fatalf("no error returned") + } + if errors.Cause(err) != io.EOF || !strings.Contains(err.Error(), "failed to read dm-verity root hash") { + t.Fatalf("expected %q, got %q", io.EOF, err) + } +} diff --git a/internal/guest/storage/devicemapper/targets.go b/internal/guest/storage/devicemapper/targets.go index 798f5795b2..0d1422a1cf 100644 --- a/internal/guest/storage/devicemapper/targets.go +++ b/internal/guest/storage/devicemapper/targets.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/errors" "go.opencensus.io/trace" + "github.com/Microsoft/hcsshim/ext4/dmverity" "github.com/Microsoft/hcsshim/internal/guest/prot" "github.com/Microsoft/hcsshim/internal/oc" ) @@ -62,7 +63,7 @@ func CreateVerityTarget(ctx context.Context, devPath, devName string, verityInfo verityTarget := Target{ SectorStart: 0, LengthInBlocks: dmBlocks, - Type: "verity", + Type: dmverity.VeritySignature, Params: fmt.Sprintf("%d %s %s %s", verityInfo.Version, devices, blkInfo, hashes), } diff --git a/test/vendor/github.com/Microsoft/hcsshim/ext4/dmverity/dmverity.go b/test/vendor/github.com/Microsoft/hcsshim/ext4/dmverity/dmverity.go index 8aaaf10272..f952fbde87 100644 --- a/test/vendor/github.com/Microsoft/hcsshim/ext4/dmverity/dmverity.go +++ b/test/vendor/github.com/Microsoft/hcsshim/ext4/dmverity/dmverity.go @@ -22,6 +22,8 @@ const ( MerkleTreeBufioSize = 1024 * 1024 // 1MB // RecommendedVHDSizeGB is the recommended size in GB for VHDs, which is not a hard limit. RecommendedVHDSizeGB = 128 * 1024 * 1024 * 1024 + // VeritySignature is a value written to dm-verity super-block. + VeritySignature = "verity" ) var salt = bytes.Repeat([]byte{0}, 32) @@ -30,6 +32,7 @@ var ( ErrSuperBlockReadFailure = errors.New("failed to read dm-verity super block") ErrSuperBlockParseFailure = errors.New("failed to parse dm-verity super block") ErrRootHashReadFailure = errors.New("failed to read dm-verity root hash") + ErrNotVeritySuperBlock = errors.New("invalid dm-verity super-block signature") ) type dmveritySuperblock struct { @@ -133,7 +136,7 @@ func NewDMVeritySuperblock(size uint64) *dmveritySuperblock { SaltSize: uint16(len(salt)), } - copy(superblock.Signature[:], "verity") + copy(superblock.Signature[:], VeritySignature) copy(superblock.Algorithm[:], "sha256") copy(superblock.Salt[:], salt) @@ -173,7 +176,7 @@ func ReadDMVerityInfo(vhdPath string, offsetInBytes int64) (*VerityInfo, error) block := make([]byte, blockSize) if s, err := vhd.Read(block); err != nil || s != blockSize { if err != nil { - return nil, errors.Wrapf(ErrSuperBlockReadFailure, "%s", err) + return nil, errors.Wrapf(err, "%s", ErrSuperBlockReadFailure) } return nil, errors.Wrapf(ErrSuperBlockReadFailure, "unexpected bytes read: expected=%d, actual=%d", blockSize, s) } @@ -181,13 +184,15 @@ func ReadDMVerityInfo(vhdPath string, offsetInBytes int64) (*VerityInfo, error) dmvSB := &dmveritySuperblock{} b := bytes.NewBuffer(block) if err := binary.Read(b, binary.LittleEndian, dmvSB); err != nil { - return nil, errors.Wrapf(ErrSuperBlockParseFailure, "%s", err) + return nil, errors.Wrapf(err, "%s", ErrSuperBlockParseFailure) + } + if string(bytes.Trim(dmvSB.Signature[:], "\x00")[:]) != VeritySignature { + return nil, ErrNotVeritySuperBlock } - // read the merkle tree root if s, err := vhd.Read(block); err != nil || s != blockSize { if err != nil { - return nil, errors.Wrapf(ErrRootHashReadFailure, "%s", err) + return nil, errors.Wrapf(err, "%s", ErrRootHashReadFailure) } return nil, errors.Wrapf(ErrRootHashReadFailure, "unexpected bytes read: expected=%d, actual=%d", blockSize, s) }