-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2245 from helinwang/master_server
implement task handling for master server's service
- Loading branch information
Showing
5 changed files
with
319 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package main | ||
|
||
import ( | ||
"fmt" | ||
"net" | ||
"net/http" | ||
"net/rpc" | ||
"os" | ||
"path/filepath" | ||
"strconv" | ||
"strings" | ||
"time" | ||
|
||
"github.com/namsral/flag" | ||
|
||
"github.com/PaddlePaddle/Paddle/paddle/go/master" | ||
"github.com/PaddlePaddle/Paddle/paddle/go/recordio" | ||
) | ||
|
||
func main() { | ||
port := flag.Int("port", 8080, "port of the master server.") | ||
dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.") | ||
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") | ||
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") | ||
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") | ||
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") | ||
flag.Parse() | ||
|
||
if *dataset == "" { | ||
panic("no dataset specified.") | ||
} | ||
|
||
if *faultTolerance { | ||
panic("fault tolernance not implemented.") | ||
} | ||
|
||
var chunks []master.Chunk | ||
var paths []string | ||
ss := strings.Split(*dataset, ",") | ||
fmt.Println(ss) | ||
for _, s := range ss { | ||
match, err := filepath.Glob(s) | ||
if err != nil { | ||
panic(err) | ||
} | ||
paths = append(paths, match...) | ||
} | ||
|
||
if len(paths) == 0 { | ||
panic("no valid datset specified.") | ||
} | ||
|
||
idx := 0 | ||
for _, path := range paths { | ||
f, err := os.Open(path) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
index, err := recordio.LoadIndex(f) | ||
if err != nil { | ||
panic(err) | ||
} | ||
f.Close() | ||
|
||
count := index.NumChunks() | ||
for i := 0; i < count; i++ { | ||
chunk := master.Chunk{ | ||
Idx: idx, | ||
Path: path, | ||
Index: *index.ChunkIndex(i), | ||
} | ||
chunks = append(chunks, chunk) | ||
} | ||
} | ||
|
||
s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) | ||
err := rpc.Register(s) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
rpc.HandleHTTP() | ||
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
err = http.Serve(l, nil) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
package master | ||
|
||
import ( | ||
"errors" | ||
"log" | ||
"sync" | ||
"time" | ||
|
||
"github.com/PaddlePaddle/Paddle/paddle/go/recordio" | ||
) | ||
|
||
const ( | ||
targetTaskCount = 300 | ||
) | ||
|
||
// errors | ||
var ( | ||
ErrNoMoreTask = errors.New("no more task for current pass") | ||
ErrPendingTaskNotFound = errors.New("pending task not found") | ||
) | ||
|
||
// Service is the master server service. | ||
type Service struct { | ||
timeoutDur time.Duration | ||
timeoutMax int | ||
|
||
mu sync.Mutex | ||
taskQueues taskQueues | ||
} | ||
|
||
// Recover recovers service state from etcd. | ||
func Recover() (*Service, error) { | ||
// TODO(helin): recover from snapshot state from etcd. | ||
return nil, nil | ||
} | ||
|
||
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { | ||
id := 0 | ||
if chunksPerTask <= 0 { | ||
chunksPerTask = 1 | ||
} | ||
|
||
var result []taskEntry | ||
var cur taskEntry | ||
for i, c := range chunks { | ||
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { | ||
cur.Task.ID = id | ||
id++ | ||
result = append(result, cur) | ||
cur.Task.Chunks = nil | ||
} | ||
|
||
cur.Task.Chunks = append(cur.Task.Chunks, c) | ||
} | ||
|
||
if len(cur.Task.Chunks) > 0 { | ||
cur.Task.ID = id | ||
id++ | ||
result = append(result, cur) | ||
} | ||
|
||
return result | ||
} | ||
|
||
// NewService creates a new service. | ||
func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { | ||
s := &Service{} | ||
s.timeoutDur = timeoutDur | ||
s.timeoutMax = timeoutMax | ||
s.taskQueues = taskQueues{} | ||
s.taskQueues.Pending = make(map[int]taskEntry) | ||
s.taskQueues.Todo = partition(chunks, chunksPerTask) | ||
return s | ||
} | ||
|
||
// Chunk is a chunk of data consisted of several data instances. | ||
type Chunk struct { | ||
Idx int // index of the chunk within the file | ||
Path string | ||
Index recordio.Index // block index | ||
} | ||
|
||
// Task is the basic unit of data instances assigned to trainers. | ||
type Task struct { | ||
ID int | ||
Chunks []Chunk | ||
} | ||
|
||
type taskEntry struct { | ||
Epoch int | ||
NumTimeout int | ||
Task Task | ||
} | ||
|
||
type taskQueues struct { | ||
Todo []taskEntry | ||
Pending map[int]taskEntry // map from task ID to task entry | ||
Done []taskEntry | ||
Failed []Task | ||
} | ||
|
||
// *must* be called with s.mu being held. | ||
func (s *Service) snapshot() error { | ||
// TODO(helin): snapshot state on etcd. | ||
return nil | ||
} | ||
|
||
// GetTask gets a new task from the service. | ||
func (s *Service) GetTask(dummy int, task *Task) error { | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
|
||
if len(s.taskQueues.Todo) == 0 { | ||
return ErrNoMoreTask | ||
} | ||
|
||
t := s.taskQueues.Todo[0] | ||
t.Epoch++ | ||
s.taskQueues.Todo = s.taskQueues.Todo[1:] | ||
s.taskQueues.Pending[t.Task.ID] = t | ||
err := s.snapshot() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() { | ||
return func() { | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
|
||
t, ok := s.taskQueues.Pending[taskID] | ||
if !ok { | ||
return | ||
} | ||
|
||
if t.Epoch != epoch { | ||
// new epoch, task launched after the | ||
// schedule of this timeout check. | ||
return | ||
} | ||
|
||
defer func() { | ||
err := s.snapshot() | ||
if err != nil { | ||
log.Println(err) | ||
} | ||
}() | ||
|
||
delete(s.taskQueues.Pending, t.Task.ID) | ||
|
||
t.NumTimeout++ | ||
if t.NumTimeout > s.timeoutMax { | ||
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) | ||
return | ||
} | ||
|
||
s.taskQueues.Todo = append(s.taskQueues.Todo, t) | ||
} | ||
}(t.Task.ID, t.Epoch)) | ||
return nil | ||
} | ||
|
||
// TaskFinished tell the service that a task is finished. | ||
func (s *Service) TaskFinished(taskID int, dummy *int) error { | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
|
||
t, ok := s.taskQueues.Pending[taskID] | ||
if !ok { | ||
return ErrPendingTaskNotFound | ||
} | ||
|
||
// task finished, reset timeout | ||
t.NumTimeout = 0 | ||
s.taskQueues.Done = append(s.taskQueues.Done, t) | ||
delete(s.taskQueues.Pending, taskID) | ||
return s.snapshot() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package master | ||
|
||
import "testing" | ||
|
||
func TestPartitionCount(t *testing.T) { | ||
cs := make([]Chunk, 100) | ||
ts := partition(cs, 5) | ||
if len(ts) != 20 { | ||
t.Error(len(ts)) | ||
} | ||
|
||
cs = make([]Chunk, 101) | ||
ts = partition(cs, 5) | ||
if len(ts) != 21 { | ||
t.Error(len(ts)) | ||
} | ||
|
||
ts = partition(cs, 1) | ||
if len(ts) != 101 { | ||
t.Error(len(ts)) | ||
} | ||
|
||
ts = partition(cs, 0) | ||
if len(ts) != 101 { | ||
t.Error(len(ts)) | ||
} | ||
} | ||
|
||
func TestPartionIndex(t *testing.T) { | ||
cs := make([]Chunk, 100) | ||
ts := partition(cs, 20) | ||
for i := range ts { | ||
if ts[i].Task.ID != i { | ||
t.Error(ts[i], i) | ||
} | ||
} | ||
} |