Skip to content

Commit

Permalink
add code for classification
Browse files Browse the repository at this point in the history
  • Loading branch information
cheerss committed Aug 1, 2021
1 parent 2853ab5 commit f20000d
Show file tree
Hide file tree
Showing 21 changed files with 2,760 additions and 0 deletions.
239 changes: 239 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import os
import yaml
from yacs.config import CfgNode as CN

_C = CN()

# Base config files
_C.BASE = ['']

# -----------------------------------------------------------------------------
# Data settings
# -----------------------------------------------------------------------------
_C.DATA = CN()
# Batch size for a single GPU, could be overwritten by command line argument
_C.DATA.BATCH_SIZE = 128
# Path to dataset, could be overwritten by command line argument
_C.DATA.DATA_PATH = ''
# Dataset name
_C.DATA.DATASET = 'imagenet'
# Input image size
_C.DATA.IMG_SIZE = 224
# Interpolation to resize image (random, bilinear, bicubic)
_C.DATA.INTERPOLATION = 'bicubic'
# Use zipped dataset instead of folder dataset
# could be overwritten by command line argument
_C.DATA.ZIP_MODE = False
# Cache Data in Memory, could be overwritten by command line argument
_C.DATA.CACHE_MODE = 'part'
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
_C.DATA.PIN_MEMORY = True
# Number of data loading threads
_C.DATA.NUM_WORKERS = 8

# -----------------------------------------------------------------------------
# Model settings
# -----------------------------------------------------------------------------
_C.MODEL = CN()
# Model type
_C.MODEL.TYPE = 'cross-scale'
# Model name
_C.MODEL.NAME = 'tiny_patch4_group7_224'
# Checkpoint to resume, could be overwritten by command line argument
_C.MODEL.RESUME = ''
_C.MODEL.FROM_PRETRAIN = ''
# Number of classes, overwritten in data preparation
_C.MODEL.NUM_CLASSES = 1000
# Dropout rate
_C.MODEL.DROP_RATE = 0.0
# Drop path rate
_C.MODEL.DROP_PATH_RATE = 0.1
# Label Smoothing
_C.MODEL.LABEL_SMOOTHING = 0.1

# CrossFormer parameters
_C.MODEL.CROS = CN()
_C.MODEL.CROS.PATCH_SIZE = [4, 8, 16, 32]
_C.MODEL.CROS.MERGE_SIZE = [[2, 4], [2,4], [2, 4]]
_C.MODEL.CROS.IN_CHANS = 3
_C.MODEL.CROS.EMBED_DIM = 48
_C.MODEL.CROS.DEPTHS = [2, 2, 6, 2]
_C.MODEL.CROS.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.CROS.GROUP_SIZE = [7, 7, 7, 7]
_C.MODEL.CROS.MLP_RATIO = 4.
_C.MODEL.CROS.QKV_BIAS = True
_C.MODEL.CROS.QK_SCALE = None
_C.MODEL.CROS.APE = False
_C.MODEL.CROS.PATCH_NORM = True

# -----------------------------------------------------------------------------
# Training settings
# -----------------------------------------------------------------------------
_C.TRAIN = CN()
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.EPOCHS = 300
_C.TRAIN.WARMUP_EPOCHS = 20
_C.TRAIN.WEIGHT_DECAY = 0.05
_C.TRAIN.BASE_LR = 5e-4
_C.TRAIN.WARMUP_LR = 5e-7
_C.TRAIN.MIN_LR = 5e-7
# Clip gradient norm
_C.TRAIN.CLIP_GRAD = 5.0
# Auto resume from latest checkpoint
_C.TRAIN.AUTO_RESUME = True
# Gradient accumulation steps
# could be overwritten by command line argument
_C.TRAIN.ACCUMULATION_STEPS = 0
# Whether to use gradient checkpointing to save memory
# could be overwritten by command line argument
_C.TRAIN.USE_CHECKPOINT = False

# LR scheduler
_C.TRAIN.LR_SCHEDULER = CN()
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
# Epoch interval to decay LR, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
# LR decay rate, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1

# Optimizer
_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.NAME = 'adamw'
# Optimizer Epsilon
_C.TRAIN.OPTIMIZER.EPS = 1e-8
# Optimizer Betas
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
# SGD momentum
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9

# -----------------------------------------------------------------------------
# Augmentation settings
# -----------------------------------------------------------------------------
_C.AUG = CN()
# Color jitter factor
_C.AUG.COLOR_JITTER = 0.4
# Use AutoAugment policy. "v0" or "original"
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
# Random erase prob
_C.AUG.REPROB = 0.25
# Random erase mode
_C.AUG.REMODE = 'pixel'
# Random erase count
_C.AUG.RECOUNT = 1
# Mixup alpha, mixup enabled if > 0
_C.AUG.MIXUP = 0.8
# Cutmix alpha, cutmix enabled if > 0
_C.AUG.CUTMIX = 1.0
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
_C.AUG.CUTMIX_MINMAX = None
# Probability of performing mixup or cutmix when either/both is enabled
_C.AUG.MIXUP_PROB = 1.0
# Probability of switching to cutmix when both mixup and cutmix enabled
_C.AUG.MIXUP_SWITCH_PROB = 0.5
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
_C.AUG.MIXUP_MODE = 'batch'

# -----------------------------------------------------------------------------
# Testing settings
# -----------------------------------------------------------------------------
_C.TEST = CN()
# Whether to use center crop when testing
_C.TEST.CROP = True

# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
# overwritten by command line argument
_C.AMP_OPT_LEVEL = ''
# Path to output folder, overwritten by command line argument
_C.OUTPUT = ''
# Tag of experiment, overwritten by command line argument
_C.TAG = 'default'
# Frequency to save checkpoint
_C.SAVE_FREQ = 1000
# Frequency to logging info
_C.PRINT_FREQ = 10
# Fixed random seed
_C.SEED = 0
# Perform evaluation only, overwritten by command line argument
_C.EVAL_MODE = False
# Test throughput only, overwritten by command line argument
_C.THROUGHPUT_MODE = False
# local rank for DistributedDataParallel, given by command line argument
_C.LOCAL_RANK = 0


def _update_config_from_file(config, cfg_file):
config.defrost()
with open(cfg_file, 'r') as f:
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)

for cfg in yaml_cfg.setdefault('BASE', ['']):
if cfg:
_update_config_from_file(
config, os.path.join(os.path.dirname(cfg_file), cfg)
)
print('=> merge config from {}'.format(cfg_file))
config.merge_from_file(cfg_file)
config.freeze()


def update_config(config, args):
_update_config_from_file(config, args.cfg)

config.defrost()
if args.opts:
config.merge_from_list(args.opts)

# merge from specific arguments
if args.batch_size:
config.DATA.BATCH_SIZE = args.batch_size
if args.data_path:
config.DATA.DATA_PATH = args.data_path
if args.zip:
config.DATA.ZIP_MODE = True
if args.cache_mode:
config.DATA.CACHE_MODE = args.cache_mode
if args.resume:
config.MODEL.RESUME = args.resume
if args.accumulation_steps:
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
if args.use_checkpoint:
config.TRAIN.USE_CHECKPOINT = True
if args.amp_opt_level:
config.AMP_OPT_LEVEL = args.amp_opt_level
if args.output:
config.OUTPUT = args.output
if args.tag:
config.TAG = args.tag
if args.eval:
config.EVAL_MODE = True
if args.num_workers >= 0:
config.DATA.NUM_WORKERS = args.num_workers
if args.throughput:
config.THROUGHPUT_MODE = True

# if args.patch_size:
# config.MODEL.CROS.PATCH_SIZE = args.patch_size

config.MODEL.CROS.MLP_RATIO = args.mlp_ratio
# config.MODEL.MERGE_SIZE_AFTER = [args.merge_size_after1, args.merge_size_after2, args.merge_size_after3, []]
config.DATA.DATASET = args.data_set
config.TRAIN.WARMUP_EPOCHS = args.warmup_epochs
# set local rank for distributed training
config.LOCAL_RANK = args.local_rank
# output folder
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)

config.freeze()


def get_config(args):
"""Get a yacs CfgNode object with default values."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
config = _C.clone()
update_config(config, args)

return config
11 changes: 11 additions & 0 deletions configs/base_patch4_group7_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
MODEL:
TYPE: cross-scale
NAME: cros_base_patch4_group7_224
DROP_PATH_RATE: 0.3
CROS:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
GROUP_SIZE: [ 7, 7, 7, 7 ]
PATCH_SIZE: [4, 8, 16, 32]
MERGE_SIZE: [[2, 4], [2,4], [2, 4]]
11 changes: 11 additions & 0 deletions configs/large_patch4_group7_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
MODEL:
TYPE: cross-scale
NAME: cros_base_patch4_group7_224
DROP_PATH_RATE: 0.5
CROS:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
GROUP_SIZE: [ 7, 7, 7, 7 ]
PATCH_SIZE: [4, 8, 16, 32]
MERGE_SIZE: [[2, 4], [2,4], [2, 4]]
11 changes: 11 additions & 0 deletions configs/small_patch4_group7_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
MODEL:
TYPE: cross-scale
NAME: cros_tiny_patch4_group7_224
DROP_PATH_RATE: 0.2
CROS:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 6, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
GROUP_SIZE: [ 7, 7, 7, 7 ]
PATCH_SIZE: [4, 8, 16, 32]
MERGE_SIZE: [[2, 4], [2,4], [2, 4]]
11 changes: 11 additions & 0 deletions configs/tiny_patch4_group7_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
MODEL:
TYPE: cross-scale
NAME: cros_tiny_patch4_group7_224
DROP_PATH_RATE: 0.1
CROS:
EMBED_DIM: 64
DEPTHS: [ 1, 1, 8, 6 ]
NUM_HEADS: [ 2, 4, 8, 16 ]
GROUP_SIZE: [ 7, 7, 7, 7 ]
PATCH_SIZE: [4, 8, 16, 32]
MERGE_SIZE: [[2, 4], [2,4], [2, 4]]
1 change: 1 addition & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .build import build_loader
Loading

0 comments on commit f20000d

Please sign in to comment.