Skip to content

Commit

Permalink
Optimize the time cost to go into training (alibaba#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
rickycao-qy authored Aug 10, 2020
1 parent 2d3b416 commit fd4041d
Show file tree
Hide file tree
Showing 40 changed files with 513 additions and 370 deletions.
46 changes: 11 additions & 35 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

93 changes: 88 additions & 5 deletions packages/core/src/types/data/common.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import events from 'events';

import { generateId } from '../../utils/public';
import { Statistic } from '../other';

/**
Expand Down Expand Up @@ -44,11 +47,91 @@ export interface Sample {
/**
* The data loader to used to load dataset.
*/
export interface DataLoader {
len: () => Promise<number>;
getItem: (id: number) => Promise<Sample>;
next?: () => Promise<Sample>;
nextBatch?: (batchSize: number) => Promise<Array<Sample>>;
export abstract class DataLoader {
private event = new events.EventEmitter();
private fetchIndex = 0;
private id = generateId();
public processIndex = -1;

/**
* Data-access plugin developer needs to implement these three abstract function
* which is to notify the length of data, how to get and set the specific index of data
*/
abstract async len(): Promise<number>;
abstract async getItem(id: number): Promise<Sample>;
abstract async setItem(id: number, sample: Sample): Promise<void>;

notifyProcess(): void {
this.event.emit(this.id);
}

getFetchIndex(): number {
return this.fetchIndex;
}

/**
* iterate over dataset. Get next single sample
* Override Forbidden
*/
async next(): Promise<Sample> {
// reset index of data fetched to beginning when it reaches the end
if (this.fetchIndex >= await this.len()) {
this.fetchIndex = 0;
}

// if the data fetched has already been processed, return it
if (this.fetchIndex < this.processIndex || this.processIndex === -1) {
return this.getItem(this.fetchIndex++);
}

// if data fetched not already processed, wait util this is finished
return new Promise((resolve) => {
this.event.on(this.id, async () => {
if (this.fetchIndex < this.processIndex) {
const data = await this.getItem(this.fetchIndex++);
this.event.removeAllListeners(this.id);
resolve(data);
}
});
});
}

/**
* iterate over dataset. Get next batch of data
* Override Forbidden
*/
async nextBatch(batchSize: number): Promise<Sample[]> {
const dataLen = await this.len();

if (this.fetchIndex >= dataLen) {
this.fetchIndex = 0;
}

if (this.fetchIndex + batchSize >= dataLen) {
batchSize = dataLen - this.fetchIndex - 1;
}

if (this.fetchIndex + batchSize < this.processIndex) {
const result = [];
for (let i = this.fetchIndex; i < this.fetchIndex + batchSize; i++) {
result.push(this.getItem(i));
}
return Promise.all(result);
}

return new Promise((resolve) => {
this.event.on(this.id, async () => {
if (this.fetchIndex + batchSize < this.processIndex) {
const result = [];
for (let i = this.fetchIndex; i < this.fetchIndex + batchSize; i++) {
result.push(this.getItem(i));
}
this.event.removeAllListeners(this.id);
resolve(await Promise.all(result));
}
});
});
}
}

/**
Expand Down
48 changes: 48 additions & 0 deletions packages/core/src/types/data/common_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { DataLoader, Sample } from './common';

class LocalDataLoader extends DataLoader {
async len(): Promise<number> {
return 10;
}

async getItem(): Promise<Sample> {
return {
data: Math.random(),
label: Math.random
};
}

setItem(): Promise<void> {
return null;
}
}

describe('test dataloder', () => {
it('should return data immediately when data is ready', async () => {
const dataLoader = new LocalDataLoader();
dataLoader.processIndex = 5;
const nextData = await dataLoader.next();
const nextDataBatch = await dataLoader.nextBatch(2);
expect(nextData?.data).not.toBeNull();
expect(nextData?.label).not.toBeNull();
expect(nextDataBatch?.length).toBe(2);
});

it('should wait until data is processed', async () => {
const dataLoader = new LocalDataLoader();
dataLoader.processIndex = 1;
setTimeout(() => {
dataLoader.processIndex = 5;
dataLoader.notifyProcess();
}, 1000);
const nextDataBatch = await dataLoader.nextBatch(3);
expect(nextDataBatch?.length).toBe(3);
}, 3000);

it('should read data from beginning when it reaches end', async () => {
const dataLoader = new LocalDataLoader();
dataLoader.processIndex = 5;
await dataLoader.nextBatch(4);
expect(dataLoader.getFetchIndex()).toBe(0);
});
});
4 changes: 2 additions & 2 deletions packages/core/src/types/data/csv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ export interface CsvMetadata extends Metadata {
};
}

export interface CsvDataLoader extends DataLoader {
getItem: (id: number) => Promise<CsvSample>;
export abstract class CsvDataLoader extends DataLoader {
abstract getItem(id: number): Promise<CsvSample>;
}

export interface CsvDataset extends UniDataset {
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/types/data/image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ export interface ImageSample {
label: ImageLabel;
}

export interface ImageDataLoader extends DataLoader {
getItem: (id: number) => Promise<ImageSample>;
export abstract class ImageDataLoader extends DataLoader {
abstract getItem(id: number): Promise<ImageSample>;
}

export interface ImageDataset extends UniDataset {
Expand Down
27 changes: 26 additions & 1 deletion packages/core/src/types/plugins.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,32 @@ export interface DataAccessType extends PipcookPlugin {
* @param args The arguments from pipeline config file.
*/
export interface DataProcessType extends PipcookPlugin {
(data: Sample, metadata: Metadata, args: ArgsType): Promise<void>;
(data: Sample, metadata: Metadata, args: ArgsType): Promise<Sample>;
}

/**
* Similar to `DataProcessType`, but this type targets on the whole dataset rather than a sample.
* This plugin will be convenient when it comes to process data that requires information from the whole dataset. I.E. corpus construction, average data among the dataset ...
*
* @example
*
* ```js
* const getCorpus = async (dataset: UniDataset, metadata: Metadata, args?: ArgsType): Promise<void> => {
* const corpus: Set<string> = new Set();
* for (const data of dataset) {
* for (const word of (data.data.split(" "))) {
* corpus.add(word);
* }
* }
* metadata.corpus = corpus;
* }
* ```
*
* @param dataset The dataset of which you loaded in `dataCollect` & `dataAccess`
* @param args The arguments from pipeline config file.
*/
export interface DatasetProcessType extends PipcookPlugin {
(dataset: UniDataset, args: ArgsType): Promise<void>;
}

/**
Expand Down
22 changes: 13 additions & 9 deletions packages/costa/src/client/entry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,21 @@ async function emitStart(message: PluginMessage): Promise<void> {
if (pkg.pipcook.category === 'dataProcess') {
// in "dataProcess" plugin, we need to do process them in one by one.
const [ dataset, args ] = pluginArgs.map(deserializeArg) as [ UniDataset, any ];
const loaders = [ dataset.trainLoader, dataset.validationLoader, dataset.testLoader ]
[ dataset.trainLoader, dataset.validationLoader, dataset.testLoader ]
.filter((loader: DataLoader) => loader != null)
.map(async (loader: DataLoader) => {
const len = await loader.len();
// FIXME(Yorkie): in parallel?
for (let i = 0; i < len; i++) {
const sample = await loader.getItem(i);
await fn(sample, dataset.metadata, args);
}
.forEach(async (loader: DataLoader) => {
process.nextTick(async () => {
const len = await loader.len();
loader.processIndex = 0;
for (let i = 0; i < len; i++) {
let sample = await loader.getItem(i);
sample = await fn(sample, dataset.metadata, args);
await loader.setItem(i, sample);
loader.processIndex = i + 1;
loader.notifyProcess();
}
});
});
await Promise.all(loaders);
recv(PluginOperator.WRITE);
return;
}
Expand Down
15 changes: 11 additions & 4 deletions packages/plugins/data-access/coco-data-access/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import {
ImageDataLoader,
ImageLabel,
convertPascal2CocoFileOutput,
shuffle
shuffle,
ImageSample
} from '@pipcook/pipcook-core';
import glob from 'glob-promise';
import * as path from 'path';
Expand All @@ -23,23 +24,29 @@ interface DataPair {
label: ImageLabel;
}

class DataLoader implements ImageDataLoader {
class DataLoader extends ImageDataLoader {
dataPairs!: DataPair[];
constructor(dataPairs: DataPair[]) {
super();
shuffle(dataPairs);
this.dataPairs = dataPairs;
}

async len() {
async len(): Promise<number> {
return this.dataPairs.length;
}

async getItem(id: number) {
async getItem(id: number): Promise<ImageSample> {
return {
data: this.dataPairs[id].image,
label: this.dataPairs[id].label
};
}

async setItem(id: number, sample: ImageSample): Promise<void> {
this.dataPairs[id].image = sample.data;
this.dataPairs[id].label = sample.label;
}
}

/**
Expand Down
11 changes: 8 additions & 3 deletions packages/plugins/data-access/csv-data-access/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import * as fs from 'fs';
import * as path from 'path';
import parse from 'csv-parse/lib/sync';

class DataLoader implements CsvDataLoader {
class DataLoader extends CsvDataLoader {
records!: CsvSample[];

constructor(csvPath: string, labelColumn: string) {
super();
const records = parse(fs.readFileSync(csvPath), {
columns: true
});
Expand All @@ -25,13 +26,17 @@ class DataLoader implements CsvDataLoader {
});
}

async len() {
async len(): Promise<number> {
return this.records.length;
}

async getItem(id: number) {
async getItem(id: number): Promise<CsvSample> {
return this.records[id];
}

async setItem(id: number, sample: CsvSample): Promise<void> {
this.records[id] = sample;
}
}

/**
Expand Down
Loading

0 comments on commit fd4041d

Please sign in to comment.