Skip to content

Commit

Permalink
add code 9.6 in d2lzh
Browse files Browse the repository at this point in the history
  • Loading branch information
ShusenTang committed Oct 31, 2019
1 parent 73ccd30 commit b4c1391
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions code/d2lzh_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,58 @@ def MultiBoxDetection_one(c_p, l_p, anc, nms_threshold = 0.5):



# ################################# 9.6 ############################
class PikachuDetDataset(torch.utils.data.Dataset):
"""皮卡丘检测数据集类"""
def __init__(self, data_dir, part, image_size=(256, 256)):
assert part in ["train", "val"]
self.image_size = image_size
self.image_dir = os.path.join(data_dir, part, "images")

with open(os.path.join(data_dir, part, "label.json")) as f:
self.label = json.load(f)

self.transform = torchvision.transforms.Compose([
# 将 PIL 图片转换成位于[0.0, 1.0]的floatTensor, shape (C x H x W)
torchvision.transforms.ToTensor()])

def __len__(self):
return len(self.label)

def __getitem__(self, index):
image_path = str(index + 1) + ".png"

cls = self.label[image_path]["class"]
label = np.array([cls] + self.label[image_path]["loc"],
dtype="float32")[None, :]

PIL_img = Image.open(os.path.join(self.image_dir, image_path)
).convert('RGB').resize(self.image_size)
img = self.transform(PIL_img)

sample = {
"label": label, # shape: (1, 5) [class, xmin, ymin, xmax, ymax]
"image": img # shape: (3, *image_size)
}

return sample

def load_data_pikachu(batch_size, edge_size=256, data_dir = '../../data/pikachu'):
"""edge_size:输出图像的宽和高"""
image_size = (edge_size, edge_size)
train_dataset = PikachuDetDataset(data_dir, 'train', image_size)
val_dataset = PikachuDetDataset(data_dir, 'val', image_size)


train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=4)

val_iter = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
shuffle=False, num_workers=4)
return train_iter, val_iter




# ############################# 10.7 ##########################
def read_imdb(folder='train', data_root="/S1/CSCL/tangss/Datasets/aclImdb"):
Expand Down

0 comments on commit b4c1391

Please sign in to comment.