We use python files as configs. You can find all the provided configs under $MMAction2/configs
.
We follow the style below to name config files. Contributors are advised to follow the same style.
{model}_[model setting]_{backbone}_[misc]_{data setting}_[gpu x batch_per_gpu]_{schedule}_{dataset}_{modality}
{xxx}
is required field and [yyy]
is optional.
{model}
: model type, e.g.tsn
,i3d
, etc.[model setting]
: specific setting for some models.{backbone}
: backbone type, e.g.r50
(ResNet-50), etc.[misc]
: miscellaneous setting/plugins of model, e.g.dense
,320p
,video
, etc.{data setting}
: frame sample setting in{clip_len}x{frame_interval}x{num_clips}
format.[gpu x batch_per_gpu]
: GPUs and samples per GPU.{schedule}
: training schedule, e.g.20e
means 20 epochs.{dataset}
: dataset name, e.g.kinetics400
,mmit
, etc.{modality}
: frame modality, e.g.rgb
,flow
, etc.
Please refer to the corresponding pages for config file structure for different tasks.
We incorporate modular design into our config system, which is convenient to conduct various experiments.
-
An Example of BMN
To help the users have a basic idea of a complete config structure and the modules in an action localization system, we make brief comments on the config of BMN as the following. For more detailed usage and alternative for per parameter in each module, please refer to the API documentation.
# model settings model = dict( # Config of the model type='BMN', # Type of the localizer temporal_dim=100, # Total frames selected for each video boundary_ratio=0.5, # Ratio for determining video boundaries num_samples=32, # Number of samples for each proposal num_samples_per_bin=3, # Number of bin samples for each sample feat_dim=400, # Dimension of feature soft_nms_alpha=0.4, # Soft NMS alpha soft_nms_low_threshold=0.5, # Soft NMS low threshold soft_nms_high_threshold=0.9, # Soft NMS high threshold post_process_top_k=100) # Top k proposals in post process # model training and testing settings train_cfg = None # Config of training hyperparameters for BMN test_cfg = dict(average_clips='score') # Config for testing hyperparameters for BMN # dataset settings dataset_type = 'ActivityNetDataset' # Type of dataset for training, valiation and testing data_root = 'data/activitynet_feature_cuhk/csv_mean_100/' # Root path to data for training data_root_val = 'data/activitynet_feature_cuhk/csv_mean_100/' # Root path to data for validation and testing ann_file_train = 'data/ActivityNet/anet_anno_train.json' # Path to the annotation file for training ann_file_val = 'data/ActivityNet/anet_anno_val.json' # Path to the annotation file for validation ann_file_test = 'data/ActivityNet/anet_anno_test.json' # Path to the annotation file for testing train_pipeline = [ # List of training pipeline steps dict(type='LoadLocalizationFeature'), # Load localization feature pipeline dict(type='GenerateLocalizationLabels'), # Generate localization labels pipeline dict( # Config of Collect type='Collect', # Collect pipeline that decides which keys in the data should be passed to the localizer keys=['raw_feature', 'gt_bbox'], # Keys of input meta_name='video_meta', # Meta name meta_keys=['video_name']), # Meta keys of input dict( # Config of ToTensor type='ToTensor', # Convert other types to tensor type pipeline keys=['raw_feature']), # Keys to be converted from image to tensor dict( # Config of ToDataContainer type='ToDataContainer', # Pipeline to convert the data to DataContainer fields=[dict(key='gt_bbox', stack=False, cpu_only=True)]) # Required fields to be converted with keys and attributes ] val_pipeline = [ # List of validation pipeline steps dict(type='LoadLocalizationFeature'), # Load localization feature pipeline dict(type='GenerateLocalizationLabels'), # Generate localization labels pipeline dict( # Config of Collect type='Collect', # Collect pipeline that decides which keys in the data should be passed to the localizer keys=['raw_feature', 'gt_bbox'], # Keys of input meta_name='video_meta', # Meta name meta_keys=[ 'video_name', 'duration_second', 'duration_frame', 'annotations', 'feature_frame' ]), # Meta keys of input dict( # Config of ToTensor type='ToTensor', # Convert other types to tensor type pipeline keys=['raw_feature']), # Keys to be converted from image to tensor dict( # Config of ToDataContainer type='ToDataContainer', # Pipeline to convert the data to DataContainer fields=[dict(key='gt_bbox', stack=False, cpu_only=True)]) # Required fields to be converted with keys and attributes ] test_pipeline = [ # List of testing pipeline steps dict(type='LoadLocalizationFeature'), # Load localization feature pipeline dict( # Config of Collect type='Collect', # Collect pipeline that decides which keys in the data should be passed to the localizer keys=['raw_feature'], # Keys of input meta_name='video_meta', # Meta name meta_keys=[ 'video_name', 'duration_second', 'duration_frame', 'annotations', 'feature_frame' ]), # Meta keys of input dict( # Config of ToTensor type='ToTensor', # Convert other types to tensor type pipeline keys=['raw_feature']), # Keys to be converted from image to tensor ] data = dict( # Config of data videos_per_gpu=8, # Batch size of each single GPU workers_per_gpu=8, # Workers to pre-fetch data for each single GPU train_dataloader=dict( # Additional config of train dataloader drop_last=True), # Whether to drop out the last batch of data in training val_dataloader=dict( # Additional config of validation dataloader videos_per_gpu=1), # Batch size of each single GPU during evaluation test_dataloader=dict( # Additional config of test dataloader videos_per_gpu=2), # Batch size of each single GPU during testing test=dict( # Testing dataset config type=dataset_type, ann_file=ann_file_test, pipeline=test_pipeline, data_prefix=data_root_val), val=dict( # Validation dataset config type=dataset_type, ann_file=ann_file_val, pipeline=val_pipeline, data_prefix=data_root_val), train=dict( # Training dataset config type=dataset_type, ann_file=ann_file_train, pipeline=train_pipeline, data_prefix=data_root)) # optimizer optimizer = dict( # Config used to build optimizer, support (1). All the optimizers in PyTorch # whose arguments are also the same as those in PyTorch. (2). Custom optimizers # which are builed on `constructor`, referring to "tutorials/5_new_modules.md" # for implementation. type='Adam', # Type of optimizer, refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/optimizer/default_constructor.py#L13 for more details lr=0.001, # Learning rate, see detail usages of the parameters in the documentaion of PyTorch weight_decay=0.0001) # Weight decay of Adam optimizer_config = dict( # Config used to build the optimizer hook grad_clip=None) # Most of the methods do not use gradient clip # learning policy lr_config = dict( # Learning rate scheduler config used to register LrUpdater hook policy='step', # Policy of scheduler, also support CosineAnnealing, Cyclic, etc. Refer to details of supported LrUpdater from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py#L9 step=7) # Steps to decay the learning rate total_epochs = 9 # Total epochs to train the model checkpoint_config = dict( # Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation interval=1) # Interval to save checkpoint evaluation = dict( # Config of evaluation during training interval=1, # Interval to perform evaluation metrics=['AR@AN']) # Metrics to be performed log_config = dict( # Config to register logger hook interval=50, # Interval to print the log hooks=[ # Hooks to be implemented during training dict(type='TextLoggerHook'), # The logger used to record the training process # dict(type='TensorboardLoggerHook'), # The Tensorboard logger is also supported ]) # runtime settings dist_params = dict(backend='nccl') # Parameters to setup distributed training, the port can also be set log_level = 'INFO' # The level of logging work_dir = './work_dirs/bmn_400x100_2x8_9e_activitynet_feature/' # Directory to save the model checkpoints and logs for the current experiments load_from = None # load models as a pre-trained model from a given path. This will not resume training resume_from = None # Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved workflow = [('train', 1)] # Workflow for # runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once output_config = dict( # Config of localization ouput out=f'{work_dir}/results.json', # Path to output file output_format='json') # File format of output file
We incorporate modular design into our config system, which is convenient to conduct various experiments.
-
An Example of TSN
To help the users have a basic idea of a complete config structure and the modules in an action recognition system, we make brief comments on the config of TSN as the following. For more detailed usage and alternative for per parameter in each module, please refer to the API documentation.
# model settings model = dict( # Config of the model type='Recognizer2D', # Type of the recognizer backbone=dict( # Dict for backbone type='ResNet', # Name of the backbone pretrained='torchvision://resnet50', # The url/site of the pretrained model depth=50, # Depth of ResNet model norm_eval=False), # Whether to set BN layers to eval mode when training cls_head=dict( # Dict for classification head type='TSNHead', # Name of classification head num_classes=400, # Number of classes to be classified. in_channels=2048, # The input channels of classification head. spatial_type='avg', # Type of pooling in spatial dimension consensus=dict(type='AvgConsensus', dim=1), # Config of consensus module dropout_ratio=0.4, # Probability in dropout layer init_std=0.01)) # Std value for linear layer initiation # model training and testing settings train_cfg = None # Config of training hyperparameters for TSN test_cfg = dict(average_clips=None) # Config for testing hyperparameters for TSN. Here we define clip averaging method in it # dataset settings dataset_type = 'RawframeDataset' # Type of dataset for training, valiation and testing data_root = 'data/kinetics400/rawframes_train/' # Root path to data for training data_root_val = 'data/kinetics400/rawframes_val/' # Root path to data for validation and testing ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt' # Path to the annotation file for training ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt' # Path to the annotation file for validation ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt' # Path to the annotation file for testing img_norm_cfg = dict( # Config of image normalition used in data pipeline mean=[123.675, 116.28, 103.53], # Mean values of different channels to normalize std=[58.395, 57.12, 57.375], # Std values of different channels to normalize to_bgr=False) # Whether to convert channels from RGB to BGR train_pipeline = [ # List of training pipeline steps dict( # Config of SampleFrames type='SampleFrames', # Sample frames pipeline, sampling frames from video clip_len=1, # Frames of each sampled output clip frame_interval=1, # Temporal interval of adjacent sampled frames num_clips=3), # Number of clips to be sampled dict( # Config of RawFrameDecode type='RawFrameDecode'), # Load and decode Frames pipeline, picking raw frames with given indices dict( # Config of Resize type='Resize', # Resize pipeline scale=(-1, 256)), # The scale to resize images dict( # Config of MultiScaleCrop type='MultiScaleCrop', # Multi scale crop pipeline, cropping images with a list of randomly selected scales input_size=224, # Input size of the network scales=(1, 0.875, 0.75, 0.66), # Scales of weight and height to be selected random_crop=False, # Whether to randomly sample cropping bbox max_wh_scale_gap=1), # Maximum gap of w and h scale levels dict( # Config of Resize type='Resize', # Resize pipeline scale=(224, 224), # The scale to resize images keep_ratio=False), # Whether to resize with changing the aspect ratio dict( # Config of Flip type='Flip', # Flip Pipeline flip_ratio=0.5), # Probability of implementing flip dict( # Config of Normalize type='Normalize', # Normalize pipeline **img_norm_cfg), # Config of image normalization dict( # Config of FormatShape type='FormatShape', # Format shape pipeline, Format final image shape to the given input_format input_format='NCHW'), # Final image shape format dict( # Config of Collect type='Collect', # Collect pipeline that decides which keys in the data should be passed to the recognizer keys=['imgs', 'label'], # Keys of input meta_keys=[]), # Meta keys of input dict( # Config of ToTensor type='ToTensor', # Convert other types to tensor type pipeline keys=['imgs', 'label']) # Keys to be converted from image to tensor ] val_pipeline = [ # List of validation pipeline steps dict( # Config of SampleFrames type='SampleFrames', # Sample frames pipeline, sampling frames from video clip_len=1, # Frames of each sampled output clip frame_interval=1, # Temporal interval of adjacent sampled frames num_clips=3, # Number of clips to be sampled test_mode=True), # Whether to set test mode in sampling dict( # Config of RawFrameDecode type='RawFrameDecode'), # Load and decode Frames pipeline, picking raw frames with given indices dict( # Config of Resize type='Resize', # Resize pipeline scale=(-1, 256)), # The scale to resize images dict( # Config of CenterCrop type='CenterCrop', # Center crop pipeline, cropping the center area from images crop_size=224), # The size to crop images dict( # Config of Flip type='Flip', # Flip pipeline flip_ratio=0), # Probability of implementing flip dict( # Config of Normalize type='Normalize', # Normalize pipeline **img_norm_cfg), # Config of image normalization dict( # Config of FormatShape type='FormatShape', # Format shape pipeline, Format final image shape to the given input_format input_format='NCHW'), # Final image shape format dict( # Config of Collect type='Collect', # Collect pipeline that decides which keys in the data should be passed to the recognizer keys=['imgs', 'label'], # Keys of input meta_keys=[]), # Meta keys of input dict( # Config of ToTensor type='ToTensor', # Convert other types to tensor type pipeline keys=['imgs']) # Keys to be converted from image to tensor ] test_pipeline = [ # List of testing pipeline steps dict( # Config of SampleFrames type='SampleFrames', # Sample frames pipeline, sampling frames from video clip_len=1, # Frames of each sampled output clip frame_interval=1, # Temporal interval of adjacent sampled frames num_clips=25, # Number of clips to be sampled test_mode=True), # Whether to set test mode in sampling dict( # Config of RawFrameDecode type='RawFrameDecode'), # Load and decode Frames pipeline, picking raw frames with given indices dict( # Config of Resize type='Resize', # Resize pipeline scale=(-1, 256)), # The scale to resize images dict( # Config of CenterCrop type='TenCrop', # Center crop pipeline, cropping the center area from images crop_size=224), # The size to crop images dict( # Config of Flip type='Flip', # Flip pipeline flip_ratio=0), # Probability of implementing flip dict( # Config of Normalize type='Normalize', # Normalize pipeline **img_norm_cfg), # Config of image normalization dict( # Config of FormatShape type='FormatShape', # Format shape pipeline, Format final image shape to the given input_format input_format='NCHW'), # Final image shape format dict( # Config of Collect type='Collect', # Collect pipeline that decides which keys in the data should be passed to the recognizer keys=['imgs', 'label'], # Keys of input meta_keys=[]), # Meta keys of input dict( # Config of ToTensor type='ToTensor', # Convert other types to tensor type pipeline keys=['imgs']) # Keys to be converted from image to tensor ] data = dict( # Config of data videos_per_gpu=32, # Batch size of each single GPU workers_per_gpu=4, # Workers to pre-fetch data for each single GPU train_dataloader=dict( # Additional config of train dataloader drop_last=True), # Whether to drop out the last batch of data in training val_dataloader=dict( # Additional config of validation dataloader videos_per_gpu=1), # Batch size of each single GPU during evaluation test_dataloader=dict( # Additional config of test dataloader videos_per_gpu=2), # Batch size of each single GPU during testing train=dict( # Training dataset config type=dataset_type, ann_file=ann_file_train, data_prefix=data_root, pipeline=train_pipeline), val=dict( # Validation dataset config type=dataset_type, ann_file=ann_file_val, data_prefix=data_root_val, pipeline=val_pipeline), test=dict( # Testing dataset config type=dataset_type, ann_file=ann_file_test, data_prefix=data_root_val, pipeline=test_pipeline)) # optimizer optimizer = dict( # Config used to build optimizer, support (1). All the optimizers in PyTorch # whose arguments are also the same as those in PyTorch. (2). Custom optimizers # which are builed on `constructor`, referring to "tutorials/5_new_modules.md" # for implementation. type='SGD', # Type of optimizer, refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/optimizer/default_constructor.py#L13 for more details lr=0.01, # Learning rate, see detail usages of the parameters in the documentaion of PyTorch momentum=0.9, # Momentum, weight_decay=0.0001) # Weight decay of SGD optimizer_config = dict( # Config used to build the optimizer hook grad_clip=dict(max_norm=40, norm_type=2)) # Use gradient clip # learning policy lr_config = dict( # Learning rate scheduler config used to register LrUpdater hook policy='step', # Policy of scheduler, also support CosineAnnealing, Cyclic, etc. Refer to details of supported LrUpdater from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py#L9 step=[40, 80]) # Steps to decay the learning rate total_epochs = 100 # Total epochs to train the model checkpoint_config = dict( # Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation interval=5) # Interval to save checkpoint evaluation = dict( # Config of evaluation during training interval=5, # Interval to perform evaluation metrics=['top_k_accuracy', 'mean_class_accuracy'], # Metrics to be performed topk=(1, 5)) # K value for `top_k_accuracy` metric log_config = dict( # Config to register logger hook interval=20, # Interval to print the log hooks=[ # Hooks to be implemented during training dict(type='TextLoggerHook'), # The logger used to record the training process # dict(type='TensorboardLoggerHook'), # The Tensorboard logger is also supported ]) # runtime settings dist_params = dict(backend='nccl') # Parameters to setup distributed training, the port can also be set log_level = 'INFO' # The level of logging work_dir = './work_dirs/tsn_r50_1x1x3_100e_kinetics400_rgb/' # Directory to save the model checkpoints and logs for the current experiments load_from = None # load models as a pre-trained model from a given path. This will not resume training resume_from = None # Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved workflow = [('train', 1)] # Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once
We incorporate modular design into our config system, which is convenient to conduct various experiments.
To help the users have a basic idea of a complete config structure and the modules in an spatio-temporal action detection system, we make brief comments on the config of FastRCNN as the following. For more detailed usage and alternative for per parameter in each module, please refer to the API documentation.
# model setting
model = dict( # Config of the model
type='FastRCNN', # Type of the detector
backbone=dict( # Dict for backbone
type='ResNet3dSlowOnly', # Name of the backbone
depth=50, # Depth of ResNet model
pretrained=None, # The url/site of the pretrained model
pretrained2d=False, # If the pretrained model is 2D
lateral=False, # If the backbone is with lateral connections
num_stages=4, # Stages of ResNet model
conv1_kernel=(1, 7, 7), # Conv1 kernel size
conv1_stride_t=1, # Conv1 temporal stride
pool1_stride_t=1, # Pool1 temporal stride
spatial_strides=(1, 2, 2, 1)), # The spatial stride for each ResNet stage
roi_head=dict( # Dict for roi_head
type='AVARoIHead', # Name of the roi_head
bbox_roi_extractor=dict( # Dict for bbox_roi_extractor
type='SingleRoIExtractor3D', # Name of the bbox_roi_extractor
roi_layer_type='RoIAlign', # Type of the RoI op
output_size=8, # Output feature size of the RoI op
with_temporal_pool=True), # If temporal dim is pooled
bbox_head=dict( # Dict for bbox_head
type='BBoxHeadAVA', # Name of the bbox_head
in_channels=2048, # Number of channels of the input feature
num_classes=81, # Number of action classes + 1
multilabel=True, # If the dataset is multilabel
dropout_ratio=0.5))) # The dropout ratio used
# model training and testing settings
train_cfg = dict( # Training config of FastRCNN
rcnn=dict( # Dict for rcnn training config
assigner=dict( # Dict for assigner
type='MaxIoUAssignerAVA', # Name of the assigner
pos_iou_thr=0.9, # IoU threshold for positive examples, > pos_iou_thr -> positive
neg_iou_thr=0.9, # IoU threshold for negative examples, < neg_iou_thr -> negative
min_pos_iou=0.9), # Minimum acceptable IoU for positive examples
sampler=dict( # Dict for sample
type='RandomSampler', # Name of the sampler
num=32, # Batch Size of the sampler
pos_fraction=1, # Positive bbox fraction of the sampler
neg_pos_ub=-1, # Upper bound of the ratio of num negative to num positive
add_gt_as_proposals=True), # Add gt bboxes as proposals
pos_weight=1.0, # Loss weight of positive examples
debug=False)) # Debug mode
test_cfg = dict( # Testing config of FastRCNN
rcnn=dict( # Dict for rcnn testing config
action_thr=0.00)) # The threshold of an action
# dataset settings
dataset_type = 'AVADataset' # Type of dataset for training, valiation and testing
data_root = 'data/ava/rawframes' # Root path to data
anno_root = 'data/ava/annotations' # Root path to annotations
ann_file_train = f'{anno_root}/ava_train_v2.1.csv' # Path to the annotation file for training
ann_file_val = f'{anno_root}/ava_val_v2.1.csv' # Path to the annotation file for validation
exclude_file_train = f'{anno_root}/ava_train_excluded_timestamps_v2.1.csv' # Path to the exclude annotation file for training
exclude_file_val = f'{anno_root}/ava_val_excluded_timestamps_v2.1.csv' # Path to the exclude annotation file for validation
label_file = f'{anno_root}/ava_action_list_v2.1_for_activitynet_2018.pbtxt' # Path to the label file
proposal_file_train = f'{anno_root}/ava_dense_proposals_train.FAIR.recall_93.9.pkl' # Path to the human detection proposals for training examples
proposal_file_val = f'{anno_root}/ava_dense_proposals_val.FAIR.recall_93.9.pkl' # Path to the human detection proposals for validation examples
img_norm_cfg = dict( # Config of image normalition used in data pipeline
mean=[123.675, 116.28, 103.53], # Mean values of different channels to normalize
std=[58.395, 57.12, 57.375], # Std values of different channels to normalize
to_bgr=False) # Whether to convert channels from RGB to BGR
train_pipeline = [ # List of training pipeline steps
dict( # Config of SampleFrames
type='AVASampleFrames', # Sample frames pipeline, sampling frames from video
clip_len=4, # Frames of each sampled output clip
frame_interval=16) # Temporal interval of adjacent sampled frames
dict( # Config of RawFrameDecode
type='RawFrameDecode'), # Load and decode Frames pipeline, picking raw frames with given indices
dict( # Config of RandomRescale
type='RandomRescale', # Randomly rescale the shortedge by a given range
scale_range=(256, 320)), # The shortedge size range of RandomRescale
dict( # Config of RandomCrop
type='RandomCrop', # Randomly crop a patch with the given size
size=256), # The size of the cropped patch
dict( # Config of Flip
type='Flip', # Flip Pipeline
flip_ratio=0.5), # Probability of implementing flip
dict( # Config of Normalize
type='Normalize', # Normalize pipeline
**img_norm_cfg), # Config of image normalization
dict( # Config of FormatShape
type='FormatShape', # Format shape pipeline, Format final image shape to the given input_format
input_format='NCTHW', # Final image shape format
collapse=True), # Collapse the dim N if N == 1
dict( # Config of Rename
type='Rename', # Rename keys
mapping=dict(imgs='img')), # The old name to new name mapping
dict( # Config of ToTensor
type='ToTensor', # Convert other types to tensor type pipeline
keys=['img', 'proposals', 'gt_bboxes', 'gt_labels']), # Keys to be converted from image to tensor
dict( # Config of ToDataContainer
type='ToDataContainer', # Convert other types to DataContainer type pipeline
fields=[ # Fields to convert to DataContainer
dict( # Dict of fields
key=['proposals', 'gt_bboxes', 'gt_labels'], # Keys to Convert to DataContainer
stack=False)]), # Whether to stack these tensor
dict( # Config of Collect
type='Collect', # Collect pipeline that decides which keys in the data should be passed to the recognizer
keys=['img', 'proposals', 'gt_bboxes', 'gt_labels'], # Keys of input
meta_keys=['scores', 'entity_ids']), # Meta keys of input
]
val_pipeline = [ # List of validation pipeline steps
dict( # Config of SampleFrames
type='AVASampleFrames', # Sample frames pipeline, sampling frames from video
clip_len=4, # Frames of each sampled output clip
frame_interval=16) # Temporal interval of adjacent sampled frames
dict( # Config of RawFrameDecode
type='RawFrameDecode'), # Load and decode Frames pipeline, picking raw frames with given indices
dict( # Config of Resize
type='Resize', # Resize pipeline
scale=(-1, 256)), # The scale to resize images
dict( # Config of Normalize
type='Normalize', # Normalize pipeline
**img_norm_cfg), # Config of image normalization
dict( # Config of FormatShape
type='FormatShape', # Format shape pipeline, Format final image shape to the given input_format
input_format='NCTHW', # Final image shape format
collapse=True), # Collapse the dim N if N == 1
dict( # Config of Rename
type='Rename', # Rename keys
mapping=dict(imgs='img')), # The old name to new name mapping
dict( # Config of ToTensor
type='ToTensor', # Convert other types to tensor type pipeline
keys=['img', 'proposals']), # Keys to be converted from image to tensor
dict( # Config of ToDataContainer
type='ToDataContainer', # Convert other types to DataContainer type pipeline
fields=[ # Fields to convert to DataContainer
dict( # Dict of fields
key=['proposals'], # Keys to Convert to DataContainer
stack=False)]), # Whether to stack these tensor
dict( # Config of Collect
type='Collect', # Collect pipeline that decides which keys in the data should be passed to the recognizer
keys=['img', 'proposals'], # Keys of input
meta_keys=['scores', 'entity_ids'], # Meta keys of input
nested=True) # Whether to wrap the data in a nested list
]
data = dict( # Config of data
videos_per_gpu=16, # Batch size of each single GPU
workers_per_gpu=4, # Workers to pre-fetch data for each single GPU
val_dataloader=dict( # Additional config of validation dataloader
videos_per_gpu=1), # Batch size of each single GPU during evaluation
test_dataloader=dict( # Additional config of testing dataloader
videos_per_gpu=1), # Batch size of each single GPU during testing
train=dict( # Training dataset config
type=dataset_type,
ann_file=ann_file_train,
exclude_file=exclude_file_train,
pipeline=train_pipeline,
label_file=label_file,
proposal_file=proposal_file_train,
person_det_score_thr=0.9,
data_prefix=data_root),
val=dict( # Validation dataset config
type=dataset_type,
ann_file=ann_file_val,
exclude_file=exclude_file_val,
pipeline=val_pipeline,
label_file=label_file,
proposal_file=proposal_file_val,
person_det_score_thr=0.9,
data_prefix=data_root))
data['test'] = data['val'] # Set test_dataset as val_dataset
# optimizer
optimizer = dict(
# Config used to build optimizer, support (1). All the optimizers in PyTorch
# whose arguments are also the same as those in PyTorch. (2). Custom optimizers
# which are builed on `constructor`, referring to "tutorials/5_new_modules.md"
# for implementation.
type='SGD', # Type of optimizer, refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/optimizer/default_constructor.py#L13 for more details
lr=0.2, # Learning rate, see detail usages of the parameters in the documentaion of PyTorch (for 8gpu)
momentum=0.9, # Momentum,
weight_decay=0.00001) # Weight decay of SGD
optimizer_config = dict( # Config used to build the optimizer hook
grad_clip=dict(max_norm=40, norm_type=2)) # Use gradient clip
lr_config = dict( # Learning rate scheduler config used to register LrUpdater hook
policy='step', # Policy of scheduler, also support CosineAnnealing, Cyclic, etc. Refer to details of supported LrUpdater from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py#L9
step=[40, 80], # Steps to decay the learning rate
warmup='linear', # Warmup strategy
warmup_by_epoch=True, # Warmup_iters indicates iter num or epoch num
warmup_iters=5, # Number of iters or epochs for warmup
warmup_ratio=0.1) # The initial learning rate is warmup_ratio * lr
total_epochs = 20 # Total epochs to train the model
checkpoint_config = dict( # Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation
interval=1) # Interval to save checkpoint
workflow = [('train', 1)] # Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once
evaluation = dict( # Config of evaluation during training
interval=1) # Interval to perform evaluation
log_config = dict( # Config to register logger hook
interval=20, # Interval to print the log
hooks=[ # Hooks to be implemented during training
dict(type='TextLoggerHook'), # The logger used to record the training process
])
# runtime settings
dist_params = dict(backend='nccl') # Parameters to setup distributed training, the port can also be set
log_level = 'INFO' # The level of logging
work_dir = ('./work_dirs/ava/' # Directory to save the model checkpoints and logs for the current experiments
'slowonly_kinetics_pretrained_r50_4x16x1_20e_ava_rgb')
load_from = ('https://download.openmmlab.com/mmaction/recognition/slowonly/' # load models as a pre-trained model from a given path. This will not resume training
'slowonly_r50_4x16x1_256e_kinetics400_rgb/'
'slowonly_r50_4x16x1_256e_kinetics400_rgb_20200704-a69556c6.pth')
resume_from = None # Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved
Some intermediate variables are used in the config files, like train_pipeline
/val_pipeline
/test_pipeline
,
ann_file_train
/ann_file_val
/ann_file_test
, img_norm_cfg
etc.
For Example, we would like to first define train_pipeline
/val_pipeline
/test_pipeline
and pass them into data
.
Thus, train_pipeline
/val_pipeline
/test_pipeline
are intermediate variable.
we also define ann_file_train
/ann_file_val
/ann_file_test
and data_root
/data_root_val
to provide data pipeline some
basic information.
In addition, we use img_norm_cfg
as intermediate variables to construct data augmentation components.
...
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train'
data_root_val = 'data/kinetics400/rawframes_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_rawframes.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_rawframes.txt'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
train_pipeline = [
dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.8),
random_crop=False,
max_wh_scale_gap=0),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs', 'label'])
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=1,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=32,
frame_interval=2,
num_clips=10,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='Flip', flip_ratio=0),
dict(type='Normalize', **img_norm_cfg),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['imgs'])
]
data = dict(
videos_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=test_pipeline))