Skip to content

Commit

Permalink
Fix 2.0-beta bugs (PaddlePaddle#183)
Browse files Browse the repository at this point in the history
* fix 2.0-beta bugs

* update pretreained path

* add extract_weight.py
  • Loading branch information
LielinJiang authored Feb 26, 2021
1 parent 60066eb commit 130bd7f
Show file tree
Hide file tree
Showing 19 changed files with 249 additions and 318 deletions.
31 changes: 16 additions & 15 deletions configs/animeganv2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ model:
gan_criterion:
name: GANLoss
gan_mode: lsgan
# use your trained path
pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight: 300.
d_adv_weight: 300.
Expand Down Expand Up @@ -47,21 +48,21 @@ dataset:
test:
name: SingleDataset
dataroot: data/animedataset/test/HR_photo
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transforms:
- name: ResizeToScale
size: [256, 256]
scale: 32
interpolation: bilinear
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: ResizeToScale
size: [256, 256]
scale: 32
interpolation: bilinear
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]

lr_scheduler:
name: LinearDecay
Expand Down
61 changes: 25 additions & 36 deletions configs/dcgan_mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,33 @@ model:

dataset:
train:
name: SingleDataset
dataroot: data/mnist/train
name: CommonVisionDataset
dataset_name: MNIST
num_workers: 0
batch_size: 128
preprocess:
- name: LoadImageFromFile
key: A
- name: Transfroms
input_keys: [A]
pipeline:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
return_label: False
transforms:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
test:
name: SingleDataset
dataroot: data/mnist/test
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
name: CommonVisionDataset
dataset_name: MNIST
num_workers: 0
batch_size: 128
return_label: False
transforms:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]

lr_scheduler:
name: LinearDecay
Expand Down
2 changes: 1 addition & 1 deletion configs/pix2pix_facades.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ dataset:
preprocess:
- name: LoadImageFromFile
key: pair
- name: Transforms
- name: Transforms
input_keys: [A, B]
pipeline:
- name: Resize
Expand Down
15 changes: 15 additions & 0 deletions docs/en_US/tutorials/styleganv2.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ train model
python tools/main.py -c configs/stylegan_v2_256_ffhq.yaml
```

### Inference

When you finish training, you need to use ``tools/extract_weight.py`` to extract the corresponding weights.
```
python tools/extract_weight.py output_dir/YOUR_TRAINED_WEIGHT.pdparams --net-name gen_ema --output YOUR_WEIGHT_PATH.pdparams
```

Then use ``applications/tools/styleganv2.py`` to get results
```
python tools/styleganv2.py --output_path stylegan01 --weight_path YOUR_WEIGHT_PATH.pdparams --size 256
```

Note: ``--size`` should be same with your config file.


## Results

Random Samples:
Expand Down
51 changes: 49 additions & 2 deletions docs/zh_CN/tutorials/styleganv2.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,56 @@ python -u tools/styleganv2.py \
- n_col: 采样的图片的列数
- cpu: 是否使用cpu推理,若不使用,请在命令中去除

### 训练(TODO)
### 训练

#### 准备数据集
你可以从[这里](https://drive.google.com/drive/folders/1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP)下载对应的数据集

为了方便,我们提供了[images256x256.tar](https://paddlegan.bj.bcebos.com/datasets/images256x256.tar)

目前的配置文件默认数据集的结构如下:
```
PaddleGAN
├── data
├── ffhq
├──images1024x1024
├── 00000.png
├── 00001.png
├── 00002.png
├── 00003.png
├── 00004.png
├──images256x256
├── 00000.png
├── 00001.png
├── 00002.png
├── 00003.png
├── 00004.png
├──custom_data
├── img0.png
├── img1.png
├── img2.png
├── img3.png
├── img4.png
...
```

启动训练
```
python tools/main.py -c configs/stylegan_v2_256_ffhq.yaml
```

### 推理

训练结束后,需要使用 ``tools/extract_weight.py`` 来提取对应的权重给``applications/tools/styleganv2.py``来进行推理.
```
python tools/extract_weight.py output_dir/YOUR_TRAINED_WEIGHT.pdparams --net-name gen_ema --output stylegan_config_f.pdparams
```

```
python tools/styleganv2.py --output_path stylegan01 --weight_path YOUR_WEIGHT_PATH.pdparams --size 256
```

未来还将添加训练脚本方便用户训练出更多类型的 StyleGAN V2 图像生成器。
注意: ``--size`` 这个参数要和配置文件中的参数保持一致.


## 生成结果展示
Expand Down
2 changes: 1 addition & 1 deletion ppgan/datasets/animeganv2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .image_folder import ImageFolder

from .builder import DATASETS
from .transforms.builder import build_transforms
from .preprocess.builder import build_transforms


@DATASETS.register()
Expand Down
2 changes: 1 addition & 1 deletion ppgan/datasets/common_vision_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .builder import DATASETS
from .base_dataset import BaseDataset
from .transforms.builder import build_transforms
from .preprocess.builder import build_transforms


@DATASETS.register()
Expand Down
12 changes: 12 additions & 0 deletions ppgan/datasets/preprocess/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,15 @@ def build_preprocess(cfg):

preproccess = Compose(preproccess)
return preproccess


def build_transforms(cfg):
transforms = []

for trans_cfg in cfg:
temp_trans_cfg = copy.deepcopy(trans_cfg)
name = temp_trans_cfg.pop('name')
transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg))

transforms = Compose(transforms)
return transforms
71 changes: 71 additions & 0 deletions ppgan/datasets/preprocess/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,74 @@ def _apply_image(self, image):
image = image + normed_noise
image = np.clip(image, 0., 1.)
return image


@TRANSFORMS.register()
class Add(T.BaseTransform):
def __init__(self, value, keys=None):
"""Initialize Add Transform
Parameters:
value (List[int]) -- the [r,g,b] value will add to image by pixel wise.
"""
super().__init__(keys=keys)
self.value = value

def _get_params(self, inputs):
params = {}
params['value'] = self.value
return params

def _apply_image(self, image):
return np.clip(image + self.params['value'], 0, 255).astype('uint8')
# return custom_F.add(image, self.params['value'])


@TRANSFORMS.register()
class ResizeToScale(T.BaseTransform):
def __init__(self,
size: int,
scale: int,
interpolation='bilinear',
keys=None):
"""Initialize ResizeToScale Transform
Parameters:
size (List[int]) -- the minimum target size
scale (List[int]) -- the stride scale
interpolation (Optional[str]) -- interpolation method
"""
super().__init__(keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
self.scale = scale
self.interpolation = interpolation

def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
hw = image.shape[:2]
params = {}
params['taget_size'] = self.reduce_to_scale(hw, self.size[::-1],
self.scale)
return params

@staticmethod
def reduce_to_scale(img_hw, min_hw, scale):
im_h, im_w = img_hw
if im_h <= min_hw[0]:
im_h = min_hw[0]
else:
x = im_h % scale
im_h = im_h - x

if im_w < min_hw[1]:
im_w = min_hw[1]
else:
y = im_w % scale
im_w = im_w - y
return (im_h, im_w)

def _apply_image(self, image):
return F.resize(image, self.params['taget_size'], self.interpolation)
1 change: 0 additions & 1 deletion ppgan/datasets/transforms/__init__.py

This file was deleted.

58 changes: 0 additions & 58 deletions ppgan/datasets/transforms/builder.py

This file was deleted.

Loading

0 comments on commit 130bd7f

Please sign in to comment.