-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix pserver checkpoint #5102
Fix pserver checkpoint #5102
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,12 +17,11 @@ package pserver | |
import ( | ||
"bufio" | ||
"bytes" | ||
"crypto/md5" | ||
"encoding/gob" | ||
"encoding/hex" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"hash/crc32" | ||
"io/ioutil" | ||
"os" | ||
"path" | ||
|
@@ -40,7 +39,7 @@ type ElementType int | |
|
||
// ErrCheckpointNotFound indicates that the pserver checkpoint could | ||
// not be found. | ||
var ErrCheckpointNotFound = errors.New("checkpoint not found") | ||
var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd") | ||
|
||
// RPC error message. | ||
const ( | ||
|
@@ -76,7 +75,7 @@ type ParameterWithConfig struct { | |
type checkpointMeta struct { | ||
UUID string `json:"uuid"` | ||
Path string `json:"path"` | ||
MD5 string `json:"md5"` | ||
CRC32 uint32 `json:"crc32"` | ||
Timestamp int64 `json:"timestamp"` | ||
} | ||
|
||
|
@@ -92,7 +91,7 @@ type Service struct { | |
idx int | ||
checkpointInterval time.Duration | ||
checkpointPath string | ||
client *EtcdClient | ||
client KVStore | ||
|
||
mu sync.Mutex | ||
optMap map[string]*optimizer | ||
|
@@ -104,7 +103,12 @@ type parameterCheckpoint struct { | |
State []byte | ||
} | ||
|
||
func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) { | ||
type KVStore interface { | ||
GetKey(key string, timeout time.Duration) ([]byte, error) | ||
PutKey(key string, value []byte, timeout time.Duration, withLease bool) error | ||
} | ||
|
||
func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) { | ||
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second) | ||
if err != nil { | ||
return | ||
|
@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) { | |
} | ||
|
||
// LoadCheckpoint loads checkpoint from file. | ||
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) { | ||
func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) { | ||
log.Info("Loading checkpoint", "pserver index", idx) | ||
defer traceTime(time.Now(), "load checkpoint") | ||
|
||
|
@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) { | |
return nil, err | ||
} | ||
|
||
// TODO(helin): change MD5 to CRC since CRC is better for file | ||
// checksum in our use case (emphasize speed over security). | ||
h := md5.New() | ||
md5 := hex.EncodeToString(h.Sum(content)) | ||
if md5 != cpMeta.MD5 { | ||
crc32 := crc32.ChecksumIEEE(content) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious why md5 will cause the error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same question... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @typhoonzero @Yancey1989 @dzhwinter The problem was I think we used the md5 package incorrectly that generated a long string, causes etcd write error with "message too large". To show to error: package main
import (
"crypto/md5"
"encoding/hex"
"fmt"
)
func main() {
h := md5.New()
md5 := hex.EncodeToString(h.Sum([]byte("hello this is some string")))
// Output: 68656c6c6f207468697320697320736f6d6520737472696e67d41d8cd98f00b204e9800998ecf8427e
fmt.Println(md5)
} The output is not a typical MD5 string, rather a very long one. I think the correct way to get the MD5 string is here: package main
import (
"crypto/md5"
"fmt"
)
func main() {
data := []byte("These pretzels are making me thirsty.")
sum := fmt.Sprintf("%x", md5.Sum(data))
// Output: b0804ec967f48520697662a204f5fe72
fmt.Printf(sum)
} The reason to switch to CRC32 is because it's faster, better for checksum, MD5 is slower, better for defending cracking. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one more question. Why we change MD5 to CRC, concern the speed? |
||
if crc32 != cpMeta.CRC32 { | ||
return nil, errors.New(WrongChecksum) | ||
} | ||
|
||
|
@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) { | |
if err = dec.Decode(&cp); err != nil { | ||
return nil, err | ||
} | ||
|
||
return cp, nil | ||
} | ||
|
||
// NewService creates a new service, will bypass etcd registration if no | ||
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint. | ||
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { | ||
func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) { | ||
s := &Service{ | ||
idx: idx, | ||
checkpointInterval: interval, | ||
|
@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient | |
} | ||
s.optMap[p.Param.Name] = newOptimizer(p, item.State) | ||
} | ||
close(s.initialized) | ||
} | ||
return s, nil | ||
} | ||
|
@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error { | |
for range t { | ||
err := s.checkpoint() | ||
if err != nil { | ||
log.Error("finish init params error", log.Ctx{"error": err}) | ||
log.Error("checkpoint error", log.Ctx{"error": err}) | ||
} | ||
} | ||
}() | ||
|
@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { | |
parameter.Name = name | ||
parameter.ElementType = opt.elementType | ||
parameter.Content = opt.GetWeights() | ||
|
||
log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType) | ||
return nil | ||
} | ||
|
@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) { | |
|
||
oldMeta, err := loadMeta(s.client, s.idx) | ||
if err == ErrCheckpointNotFound { | ||
log.Info("Do not have existing checkpoint.") | ||
log.Info("old meta not found, skip removing old meta") | ||
err = nil | ||
} else if err == nil { | ||
log.Info("removing old meta") | ||
if oldMeta.Path != "" { | ||
rmErr := os.Remove(oldMeta.Path) | ||
if rmErr != nil { | ||
// log error, but still treat checkpoint as | ||
// successful. | ||
log.Error("remove old meta file error", log.Ctx{"error": rmErr}) | ||
} | ||
} | ||
} | ||
|
||
if err != nil { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we would add some log here, sorry this is out of this PR code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The caller will handle it, here is one caller code: err := s.checkpoint()
if err != nil {
log.Error("checkpoint error", log.Ctx{"error": err})
} I think in general the outer most caller should handle the error (either log or do something else), because it has the most information. If everyone prints log, it will be duplicating. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got it, thanks :) |
||
return | ||
} | ||
|
||
h := md5.New() | ||
md5 := hex.EncodeToString(h.Sum(buf.Bytes())) | ||
crc32 := crc32.ChecksumIEEE(buf.Bytes()) | ||
cpMeta := checkpointMeta{ | ||
UUID: id, | ||
Timestamp: time.Now().UnixNano(), | ||
MD5: md5, | ||
CRC32: crc32, | ||
Path: p, | ||
} | ||
|
||
|
@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) { | |
return | ||
} | ||
|
||
if oldMeta.Path != "" { | ||
rmErr := os.Remove(oldMeta.Path) | ||
if rmErr != nil { | ||
// log error, but still treat checkpoint as | ||
// successful. | ||
log.Error("remove old meta file error", log.Ctx{"error": rmErr}) | ||
} | ||
} | ||
|
||
return | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package pserver | ||
|
||
import ( | ||
"bytes" | ||
"encoding/binary" | ||
"fmt" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
const testDir = "./test_data" | ||
|
||
type myKV struct { | ||
m map[string][]byte | ||
} | ||
|
||
func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) { | ||
if m.m == nil { | ||
m.m = make(map[string][]byte) | ||
} | ||
return m.m[key], nil | ||
} | ||
|
||
func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error { | ||
if m.m == nil { | ||
m.m = make(map[string][]byte) | ||
} | ||
m.m[key] = value | ||
return nil | ||
} | ||
|
||
func TestCheckpoint(t *testing.T) { | ||
kv := &myKV{} | ||
s, err := NewService(0, time.Hour, testDir, kv, nil) | ||
assert.Nil(t, err) | ||
err = s.checkpoint() | ||
assert.Nil(t, err) | ||
_, err = LoadCheckpoint(kv, 0) | ||
assert.Nil(t, err) | ||
} | ||
|
||
func float32ToByte(f float32) []byte { | ||
var buf bytes.Buffer | ||
err := binary.Write(&buf, binary.LittleEndian, f) | ||
if err != nil { | ||
fmt.Println("binary.Write failed:", err) | ||
} | ||
return buf.Bytes() | ||
} | ||
|
||
func TestCheckpointWithData(t *testing.T) { | ||
kv := &myKV{} | ||
s, err := NewService(0, time.Hour, testDir, kv, nil) | ||
assert.Nil(t, err) | ||
|
||
var content []byte | ||
for i := 0; i < 50000; i++ { | ||
content = append(content, float32ToByte(float32(i))...) | ||
} | ||
|
||
p1 := Parameter{Name: "p1", ElementType: 1, Content: content} | ||
err = s.InitParam(ParameterWithConfig{Param: p1}, nil) | ||
assert.Nil(t, err) | ||
|
||
err = s.FinishInitParams(0, nil) | ||
assert.Nil(t, err) | ||
|
||
var p2 Parameter | ||
err = s.GetParam(p1.Name, &p2) | ||
assert.Nil(t, err) | ||
assert.Equal(t, p1, p2) | ||
|
||
err = s.checkpoint() | ||
assert.Nil(t, err) | ||
cp, err := LoadCheckpoint(kv, 0) | ||
assert.Nil(t, err) | ||
s1, err := NewService(0, time.Hour, testDir, kv, cp) | ||
assert.Nil(t, err) | ||
|
||
var p3 Parameter | ||
err = s1.GetParam(p1.Name, &p3) | ||
assert.Nil(t, err) | ||
assert.Equal(t, p1, p3) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Else output error log?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! Done.