Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pserver checkpoint #5102

Merged
merged 2 commits into from
Oct 26, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go/cmd/pserver/pserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func main() {
cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil {
if err == pserver.ErrCheckpointNotFound {
log.Info("Could not find the pserver checkpoint.")
log.Info("load checkpoint error", "error", err)
} else {
panic(err)
}
Expand Down Expand Up @@ -99,7 +99,7 @@ func main() {
candy.Must(err)

go func() {
log.Info("starting pserver", log.Ctx{"port": *port})
log.Info("serving pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil)
candy.Must(err)
}()
Expand Down
6 changes: 5 additions & 1 deletion go/pserver/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0])
}

var cptr (*C.uchar)
if len(c) > 0 {
cptr = (*C.uchar)(&c[0])
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Else output error log?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! Done.

o.config = c
o.opt = C.paddle_create_optimizer(
(*C.uchar)(&c[0]),
cptr,
C.int(len(c)),
C.paddle_element_type(p.ElementType),
cbuffer,
Expand Down
58 changes: 31 additions & 27 deletions go/pserver/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ package pserver
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
"os"
"path"
Expand All @@ -40,7 +39,7 @@ type ElementType int

// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")
var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")

// RPC error message.
const (
Expand Down Expand Up @@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type checkpointMeta struct {
UUID string `json:"uuid"`
Path string `json:"path"`
MD5 string `json:"md5"`
CRC32 uint32 `json:"crc32"`
Timestamp int64 `json:"timestamp"`
}

Expand All @@ -92,7 +91,7 @@ type Service struct {
idx int
checkpointInterval time.Duration
checkpointPath string
client *EtcdClient
client KVStore

mu sync.Mutex
optMap map[string]*optimizer
Expand All @@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State []byte
}

func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
type KVStore interface {
GetKey(key string, timeout time.Duration) ([]byte, error)
PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
}

func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
if err != nil {
return
Expand All @@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
}

// LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
log.Info("Loading checkpoint", "pserver index", idx)
defer traceTime(time.Now(), "load checkpoint")

Expand All @@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return nil, err
}

// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
crc32 := crc32.ChecksumIEEE(content)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why md5 will cause the error?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same question...

Copy link
Contributor Author

@helinwang helinwang Oct 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@typhoonzero @Yancey1989 @dzhwinter

The problem was I think we used the md5 package incorrectly that generated a long string, causes etcd write error with "message too large". To show to error:

package main

import (
	"crypto/md5"
	"encoding/hex"
	"fmt"
)

func main() {
	h := md5.New()
	md5 := hex.EncodeToString(h.Sum([]byte("hello this is some string")))
	// Output: 68656c6c6f207468697320697320736f6d6520737472696e67d41d8cd98f00b204e9800998ecf8427e
	fmt.Println(md5)
}

The output is not a typical MD5 string, rather a very long one.

I think the correct way to get the MD5 string is here:

package main

import (
	"crypto/md5"
	"fmt"
)

func main() {
	data := []byte("These pretzels are making me thirsty.")
	sum := fmt.Sprintf("%x", md5.Sum(data))
	// Output: b0804ec967f48520697662a204f5fe72
	fmt.Printf(sum)
}

The reason to switch to CRC32 is because it's faster, better for checksum, MD5 is slower, better for defending cracking.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more question. Why we change MD5 to CRC, concern the speed?

if crc32 != cpMeta.CRC32 {
return nil, errors.New(WrongChecksum)
}

Expand All @@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if err = dec.Decode(&cp); err != nil {
return nil, err
}

return cp, nil
}

// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: interval,
Expand All @@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
close(s.initialized)
}
return s, nil
}
Expand Down Expand Up @@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for range t {
err := s.checkpoint()
if err != nil {
log.Error("finish init params error", log.Ctx{"error": err})
log.Error("checkpoint error", log.Ctx{"error": err})
}
}
}()
Expand Down Expand Up @@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()

log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
return nil
}
Expand Down Expand Up @@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {

oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
log.Info("Do not have existing checkpoint.")
log.Info("old meta not found, skip removing old meta")
err = nil
} else if err == nil {
log.Info("removing old meta")
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}
}

if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we would add some log here, sorry this is out of this PR code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The caller will handle it, here is one caller code:

err := s.checkpoint()
  if err != nil {
    log.Error("checkpoint error", log.Ctx{"error": err})
  }

I think in general the outer most caller should handle the error (either log or do something else), because it has the most information. If everyone prints log, it will be duplicating.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it, thanks :)

return
}

h := md5.New()
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
crc32 := crc32.ChecksumIEEE(buf.Bytes())
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
MD5: md5,
CRC32: crc32,
Path: p,
}

Expand All @@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return
}

if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Error("remove old meta file error", log.Ctx{"error": rmErr})
}
}

return
}
86 changes: 86 additions & 0 deletions go/pserver/service_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package pserver

import (
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

const testDir = "./test_data"

type myKV struct {
m map[string][]byte
}

func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) {
if m.m == nil {
m.m = make(map[string][]byte)
}
return m.m[key], nil
}

func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
if m.m == nil {
m.m = make(map[string][]byte)
}
m.m[key] = value
return nil
}

func TestCheckpoint(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)
err = s.checkpoint()
assert.Nil(t, err)
_, err = LoadCheckpoint(kv, 0)
assert.Nil(t, err)
}

func float32ToByte(f float32) []byte {
var buf bytes.Buffer
err := binary.Write(&buf, binary.LittleEndian, f)
if err != nil {
fmt.Println("binary.Write failed:", err)
}
return buf.Bytes()
}

func TestCheckpointWithData(t *testing.T) {
kv := &myKV{}
s, err := NewService(0, time.Hour, testDir, kv, nil)
assert.Nil(t, err)

var content []byte
for i := 0; i < 50000; i++ {
content = append(content, float32ToByte(float32(i))...)
}

p1 := Parameter{Name: "p1", ElementType: 1, Content: content}
err = s.InitParam(ParameterWithConfig{Param: p1}, nil)
assert.Nil(t, err)

err = s.FinishInitParams(0, nil)
assert.Nil(t, err)

var p2 Parameter
err = s.GetParam(p1.Name, &p2)
assert.Nil(t, err)
assert.Equal(t, p1, p2)

err = s.checkpoint()
assert.Nil(t, err)
cp, err := LoadCheckpoint(kv, 0)
assert.Nil(t, err)
s1, err := NewService(0, time.Hour, testDir, kv, cp)
assert.Nil(t, err)

var p3 Parameter
err = s1.GetParam(p1.Name, &p3)
assert.Nil(t, err)
assert.Equal(t, p1, p3)
}
4 changes: 0 additions & 4 deletions go/pserver/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {

wg.Wait()
}

func TestCheckpointSpeed(t *testing.T) {
//TODO(zhihong): test speed
}