Skip to content

Commit

Permalink
Add imdb split info (PaddlePaddle#312)
Browse files Browse the repository at this point in the history
* add split info

* add more docstring for imdb

Co-authored-by: Zeyu Chen <[email protected]>
  • Loading branch information
joey12300 and ZeyuChen authored Apr 27, 2021
1 parent 0274828 commit e73a02d
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 135 deletions.
1 change: 1 addition & 0 deletions paddlenlp/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
from .wmt14ende import *
from .couplet import *
from .yahoo_answer_100k import *
from .imdb import *
11 changes: 10 additions & 1 deletion paddlenlp/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,25 @@

class Imdb(DatasetBuilder):
"""
Subsets of IMDb data are available for access to customers for personal and non-commercial use.
Each dataset is contained in a gzipped, tab-separated-values (TSV) formatted file in the UTF-8 character set.
The first line in each file contains headers that describe what is in each column.
Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.
"""
URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
META_INFO = collections.namedtuple('META_INFO', ('data_dir', 'md5'))
SPLITS = {
'train': META_INFO(os.path.join('aclImdb', 'train'), None),
'test': META_INFO(os.path.join('aclImdb', 'test'), None),
}

def _get_data(self, mode, **kwargs):
"""Downloads dataset."""
default_root = os.path.join(DATA_HOME, self.__class__.__name__)
data_dir = os.path.join(default_root, "aclImdb", mode)
filename, _ = self.SPLITS[mode]
data_dir = os.path.join(default_root, filename)
if not os.path.exists(data_dir):
path = get_path_from_url(self.URL, default_root, self.MD5)
return data_dir
Expand Down
Empty file.
101 changes: 0 additions & 101 deletions tests/dataset/experimental/test_imdb.py

This file was deleted.

91 changes: 58 additions & 33 deletions tests/dataset/test_imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,90 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
import unittest
from paddlenlp.datasets import Imdb
from paddlenlp.datasets import load_dataset

from common_test import CpuCommonTest
import util
import unittest


def get_examples(mode='train'):
examples = {
'train':
('I loved this movie since I was 7 and I saw it on the opening day '
'It was so touching and beautiful I strongly recommend seeing for '
'all Its a movie to watch with your family by farbr br My MPAA rating '
'PG13 for thematic elements prolonged scenes of disastor nuditysexuality '
'and some language', 1),
'test':
('Felix in Hollywood is a great film The version I viewed was very well '
'restored which is sometimes a problem with these silent era animated films '
'It has some of Hollywoods most famous stars making cameo animated '
'appearances A must for any silent film or animation enthusiast', 1)
}
return examples[mode]


class TestImdbTrainSet(CpuCommonTest):
def setUp(self):
self.config['mode'] = 'train'
np.random.seed(102)

def test_training_set(self):
expected_text, expected_label = (
'Its a good movie maybe I like it because it was filmed here '
'in PR The actors did a good performance and not only did the '
'girls be girlish but they were good in fighting so it was awsome '
'The guy is cute too so its a good match if you want to the guy '
'or the girls', 1)
expected_len = 25000
self.config['path_or_read_func'] = 'imdb'
self.config['splits'] = 'train'

train_ds = Imdb(**self.config)
def test_train_set(self):
expected_len = 25000
expected_text, expected_label = get_examples(self.config['splits'])
train_ds = load_dataset(**self.config)
self.check_output_equal(len(train_ds), expected_len)
self.check_output_equal(expected_text, train_ds[14][0])
self.check_output_equal(expected_label, train_ds[14][1])
self.check_output_equal(expected_text, train_ds[36]['text'])
self.check_output_equal(expected_label, train_ds[36]['label'])


class TestImdbTestSet(CpuCommonTest):
def setUp(self):
self.config['mode'] = 'test'
np.random.seed(102)
self.config['path_or_read_func'] = 'imdb'
self.config['splits'] = 'test'

def test_test_set(self):
expected_text, expected_label = (
'This is one of the great ones It works so beautifully that '
'you hardly notice the miscasting of then 37 year old Dana '
'Andrews as the drugstore soda jerk who goes to war and comes '
'back four years later when he would have been at most 25 But '
'then who else should have played him', 1)
expected_len = 25000

test_ds = Imdb(**self.config)
expected_text, expected_label = get_examples(self.config['splits'])
test_ds = load_dataset(**self.config)
self.check_output_equal(len(test_ds), expected_len)
self.check_output_equal(expected_text, test_ds[2][0])
self.check_output_equal(expected_label, test_ds[2][1])
self.check_output_equal(expected_text, test_ds[23]['text'])
self.check_output_equal(expected_label, test_ds[23]['label'])


class TestImdbTrainTestSet(CpuCommonTest):
def setUp(self):
self.config['path_or_read_func'] = 'imdb'
self.config['splits'] = ['train', 'test']

def test_train_set(self):
expected_ds_num = 2
expected_len = 25000
expected_train_text, expected_train_label = get_examples('train')
expected_test_text, expected_test_label = get_examples('test')
ds = load_dataset(**self.config)

self.check_output_equal(len(ds), expected_ds_num)
self.check_output_equal(len(ds[0]), expected_len)
self.check_output_equal(len(ds[1]), expected_len)

self.check_output_equal(expected_train_text, ds[0][36]['text'])
self.check_output_equal(expected_train_label, ds[0][36]['label'])
self.check_output_equal(expected_test_text, ds[1][23]['text'])
self.check_output_equal(expected_test_label, ds[1][23]['label'])


class TestImdbWrongMode(CpuCommonTest):
class TestImdbNoSplitDataFiles(CpuCommonTest):
def setUp(self):
# valid mode is 'train' and 'test', wrong mode would raise an error
self.config['mode'] = 'wrong'
self.config['path_or_read_func'] = 'imdb'

@util.assert_raises
def test_wrong_set(self):
Imdb(**self.config)
def test_no_split_datafiles(self):
load_dataset(**self.config)


if __name__ == "__main__":
Expand Down

0 comments on commit e73a02d

Please sign in to comment.