Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement task handling for master server's service #2245

Merged
merged 4 commits into from
May 25, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions doc/design/cluster_train/master_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ A dataset is a list of files in *RecordIO* format. A RecordIO file consists of c

## Task Queue

As mentioned in [distributed training design doc](./README.md), a *task* is a data shard that the master server assigns to the trainer process to train on. A task consists of one or multiple *blocks* from one or multiple files. The master server maintains *task queues* to track the training progress.
As mentioned in [distributed training design doc](./README.md), a *task* is a data shard that the master server assigns to the trainer process to train on. A task consists of one or multiple *chunks* from one or multiple files. The master server maintains *task queues* to track the training progress.

### Task Queue Creation

Expand All @@ -21,23 +21,23 @@ As mentioned in [distributed training design doc](./README.md), a *task* is a da
func (m *RPCServer) ReportDataset(Paths []string, dummy *int) error {
}
```
1. The master server will scan through each RecordIO file to generate the *block index* and know how many blocks does each file have. A block can be referenced by the file path and the index of the block within the file. The block index is in memory data structure that enables fast access to each block, and the index of the block with the file is an integer start from 0, representing the n-th block within the file.
1. The master server will scan through each RecordIO file to generate the *chunk index* and know how many chunks does each file have. A chunk can be referenced by the file path and the index of the chunk within the file. The chunk index is in memory data structure that enables fast access to each chunk, and the index of the chunk with the file is an integer start from 0, representing the n-th chunk within the file.

The definition of the block is:
The definition of the chunk is:
```go
type Block struct {
Idx int // index of the block within the file
type Chunk struct {
Idx int // index of the chunk within the file
Path string
Index recordio.Index // block index
Index recordio.Index // chunk index
}
```
1. Blocks are grouped into tasks, and tasks are filled into the todo queue. The pending queue and the done queue are initialized with no element.
1. Chunks are grouped into tasks, and tasks are filled into the todo queue. The pending queue and the done queue are initialized with no element.

The definition of the task is:
```go
type Task struct {
Index int
Blocks []Block
Chunks []Chunk
}
```

Expand Down
93 changes: 93 additions & 0 deletions paddle/go/cmd/master/master.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package main

import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"path/filepath"
"strconv"
"strings"
"time"

"github.com/namsral/flag"

"github.com/PaddlePaddle/Paddle/paddle/go/master"
"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
)

func main() {
port := flag.Int("port", 8080, "port of the master server.")
dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.")
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
flag.Parse()

if *dataset == "" {
panic("no dataset specified.")
}

if *faultTolerance {
panic("fault tolernance not implemented.")
}

var chunks []master.Chunk
var paths []string
ss := strings.Split(*dataset, ",")
fmt.Println(ss)
for _, s := range ss {
match, err := filepath.Glob(s)
if err != nil {
panic(err)
}
paths = append(paths, match...)
}

if len(paths) == 0 {
panic("no valid datset specified.")
}

idx := 0
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
panic(err)
}

index, err := recordio.LoadIndex(f)
if err != nil {
panic(err)
}
f.Close()

count := index.NumChunks()
for i := 0; i < count; i++ {
chunk := master.Chunk{
Idx: idx,
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
}

s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
err := rpc.Register(s)
if err != nil {
panic(err)
}

rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil {
panic(err)
}

err = http.Serve(l, nil)
if err != nil {
panic(err)
}
}
5 changes: 3 additions & 2 deletions paddle/go/cmd/pserver/pserver.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package main

import (
"flag"
"net"
"net/http"
"net/rpc"
"strconv"

"github.com/namsral/flag"

"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)

func main() {
port := flag.Int("p", 0, "port of the pserver")
port := flag.Int("port", 0, "port of the pserver")
flag.Parse()

s := pserver.NewService()
Expand Down
178 changes: 178 additions & 0 deletions paddle/go/master/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package master

import (
"errors"
"log"
"sync"
"time"

"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
)

const (
targetTaskCount = 300
)

// errors
var (
ErrNoMoreTask = errors.New("no more task for current pass")
ErrPendingTaskNotFound = errors.New("pending task not found")
)

// Service is the master server service.
type Service struct {
timeoutDur time.Duration
timeoutMax int

mu sync.Mutex
taskQueues taskQueues
}

// Recover recovers service state from etcd.
func Recover() (*Service, error) {
// TODO(helin): recover from snapshot state from etcd.
return nil, nil
}

func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0
if chunksPerTask <= 0 {
chunksPerTask = 1
}

var result []taskEntry
var cur taskEntry
for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.ID = id
id++
result = append(result, cur)
cur.Task.Chunks = nil
}

cur.Task.Chunks = append(cur.Task.Chunks, c)
}

if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id
id++
result = append(result, cur)
}

return result
}

// NewService creates a new service.
func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service {
s := &Service{}
s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax
s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry)
s.taskQueues.Todo = partition(chunks, chunksPerTask)
return s
}

// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
Idx int // index of the chunk within the file
Path string
Index recordio.Index // block index
}

// Task is the basic unit of data instances assigned to trainers.
type Task struct {
ID int
Chunks []Chunk
}

type taskEntry struct {
Epoch int
NumTimeout int
Task Task
}

type taskQueues struct {
Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry
Failed []Task
}

// *must* be called with s.mu being held.
func (s *Service) snapshot() error {
// TODO(helin): snapshot state on etcd.
return nil
}

// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
s.mu.Lock()
defer s.mu.Unlock()

if len(s.taskQueues.Todo) == 0 {
return ErrNoMoreTask
}

t := s.taskQueues.Todo[0]
t.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t
err := s.snapshot()
if err != nil {
return err
}

time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() {
return func() {
s.mu.Lock()
defer s.mu.Unlock()

t, ok := s.taskQueues.Pending[taskID]
if !ok {
return
}

if t.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check.
return
}

defer func() {
err := s.snapshot()
if err != nil {
log.Println(err)
}
}()

delete(s.taskQueues.Pending, t.Task.ID)

t.NumTimeout++
if t.NumTimeout > s.timeoutMax {
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return
}

s.taskQueues.Todo = append(s.taskQueues.Todo, t)
}
}(t.Task.ID, t.Epoch))
return nil
}

// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.mu.Lock()
defer s.mu.Unlock()

t, ok := s.taskQueues.Pending[taskID]
if !ok {
return ErrPendingTaskNotFound
}

// task finished, reset timeout
t.NumTimeout = 0
s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID)
return s.snapshot()
}
37 changes: 37 additions & 0 deletions paddle/go/master/service_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package master

import "testing"

func TestPartitionCount(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 5)
if len(ts) != 20 {
t.Error(len(ts))
}

cs = make([]Chunk, 101)
ts = partition(cs, 5)
if len(ts) != 21 {
t.Error(len(ts))
}

ts = partition(cs, 1)
if len(ts) != 101 {
t.Error(len(ts))
}

ts = partition(cs, 0)
if len(ts) != 101 {
t.Error(len(ts))
}
}

func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 20)
for i := range ts {
if ts[i].Task.ID != i {
t.Error(ts[i], i)
}
}
}