-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
55 changed files
with
6,754 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
.nfs* | ||
*.pyc | ||
.dumbo.json | ||
.DS_Store | ||
.*.swp | ||
*.pth | ||
**/__pycache__/** | ||
.ipynb_checkpoints/ | ||
datasets/data/ | ||
experiment-* | ||
*.tmp | ||
*.pkl | ||
**/.mypy_cache/* | ||
.mypy_cache/* | ||
not_tracked_dir/ | ||
output/ | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2021 Hust Visual Learning Team | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import argparse | ||
import datetime | ||
import json | ||
import random | ||
import time | ||
from pathlib import Path | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import argparse | ||
import datetime | ||
import json | ||
import random | ||
import time | ||
from pathlib import Path | ||
import os | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import DataLoader, DistributedSampler | ||
|
||
import datasets | ||
import util.misc as utils | ||
from models import build_model as build_yolos_model | ||
from datasets import build_dataset, get_coco_api_from_dataset | ||
|
||
# from timm.scheduler import create_scheduler | ||
# from new_models import build_model | ||
from util.scheduler import create_scheduler | ||
from datasets.coco_eval import CocoEvaluator | ||
from util import box_ops | ||
import torch.nn.functional as F | ||
|
||
|
||
@torch.no_grad() | ||
def get_val_json(data_loader, base_ds, device, output_dir, args): | ||
jdict = [] | ||
for samples, targets in data_loader: | ||
# samples = samples.to(device) | ||
# import pdb;pdb.set_trace() | ||
targets = [{k: v for k, v in t.items()} for t in targets] | ||
for target in targets: | ||
labels = target["labels"].tolist() | ||
for label in labels: | ||
jdict.append({"category_id": int(label)}) | ||
|
||
output_json = os.path.join(output_dir, "coco-valsplit-cls-dist.json") | ||
with open(output_json, "w") as f: | ||
json.dump(jdict, f) | ||
|
||
# for target, output in zip(targets, results): | ||
# jdict | ||
print("%s done" % output_json) | ||
return | ||
|
||
|
||
def get_args_parser(): | ||
parser = argparse.ArgumentParser("Set YOLOS", add_help=False) | ||
parser.add_argument("--lr", default=1e-4, type=float) | ||
parser.add_argument("--lr_backbone", default=1e-5, type=float) | ||
parser.add_argument("--batch_size", default=2, type=int) | ||
parser.add_argument("--weight_decay", default=1e-4, type=float) | ||
parser.add_argument("--epochs", default=150, type=int) | ||
parser.add_argument("--eval_size", default=800, type=int) | ||
|
||
parser.add_argument( | ||
"--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm" | ||
) | ||
|
||
# scheduler | ||
# Learning rate schedule parameters | ||
parser.add_argument( | ||
"--sched", | ||
default="warmupcos", | ||
type=str, | ||
metavar="SCHEDULER", | ||
help='LR scheduler (default: "step", options:"step", "warmupcos"', | ||
) | ||
## step | ||
parser.add_argument("--lr_drop", default=100, type=int) | ||
## warmupcosine | ||
|
||
# parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', | ||
# help='learning rate noise on/off epoch percentages') | ||
# parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', | ||
# help='learning rate noise limit percent (default: 0.67)') | ||
# parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', | ||
# help='learning rate noise std-dev (default: 1.0)') | ||
parser.add_argument( | ||
"--warmup-lr", | ||
type=float, | ||
default=1e-6, | ||
metavar="LR", | ||
help="warmup learning rate (default: 1e-6)", | ||
) | ||
parser.add_argument( | ||
"--min-lr", | ||
type=float, | ||
default=1e-7, | ||
metavar="LR", | ||
help="lower lr bound for cyclic schedulers that hit 0 (1e-5)", | ||
) | ||
parser.add_argument( | ||
"--warmup-epochs", | ||
type=int, | ||
default=0, | ||
metavar="N", | ||
help="epochs to warmup LR, if scheduler supports", | ||
) | ||
parser.add_argument( | ||
"--decay-rate", | ||
"--dr", | ||
type=float, | ||
default=0.1, | ||
metavar="RATE", | ||
help="LR decay rate (default: 0.1)", | ||
) | ||
|
||
# * model setting | ||
parser.add_argument( | ||
"--det_token_num", | ||
default=100, | ||
type=int, | ||
help="Number of det token in the deit backbone", | ||
) | ||
parser.add_argument( | ||
"--backbone_name", | ||
default="tiny", | ||
type=str, | ||
help="Name of the deit backbone to use", | ||
) | ||
parser.add_argument( | ||
"--pre_trained", | ||
default="", | ||
help="set imagenet pretrained model path if not train yolos from scatch", | ||
) | ||
parser.add_argument( | ||
"--init_pe_size", nargs="+", type=int, help="init pe size (h,w)" | ||
) | ||
parser.add_argument("--mid_pe_size", nargs="+", type=int, help="mid pe size (h,w)") | ||
# * Matcher | ||
parser.add_argument( | ||
"--set_cost_class", | ||
default=1, | ||
type=float, | ||
help="Class coefficient in the matching cost", | ||
) | ||
parser.add_argument( | ||
"--set_cost_bbox", | ||
default=5, | ||
type=float, | ||
help="L1 box coefficient in the matching cost", | ||
) | ||
parser.add_argument( | ||
"--set_cost_giou", | ||
default=2, | ||
type=float, | ||
help="giou box coefficient in the matching cost", | ||
) | ||
# * Loss coefficients | ||
|
||
parser.add_argument("--dice_loss_coef", default=1, type=float) | ||
parser.add_argument("--bbox_loss_coef", default=5, type=float) | ||
parser.add_argument("--giou_loss_coef", default=2, type=float) | ||
parser.add_argument( | ||
"--eos_coef", | ||
default=0.1, | ||
type=float, | ||
help="Relative classification weight of the no-object class", | ||
) | ||
|
||
# dataset parameters | ||
parser.add_argument("--dataset_file", default="coco") | ||
parser.add_argument("--coco_path", type=str) | ||
parser.add_argument("--coco_panoptic_path", type=str) | ||
parser.add_argument("--remove_difficult", action="store_true") | ||
|
||
parser.add_argument( | ||
"--output_dir", default="", help="path where to save, empty for no saving" | ||
) | ||
parser.add_argument( | ||
"--device", default="cuda", help="device to use for training / testing" | ||
) | ||
parser.add_argument("--seed", default=42, type=int) | ||
parser.add_argument("--resume", default="", help="resume from checkpoint") | ||
parser.add_argument( | ||
"--start_epoch", default=0, type=int, metavar="N", help="start epoch" | ||
) | ||
parser.add_argument("--eval", action="store_true") | ||
parser.add_argument("--num_workers", default=2, type=int) | ||
|
||
# distributed training parameters | ||
parser.add_argument( | ||
"--world_size", default=1, type=int, help="number of distributed processes" | ||
) | ||
parser.add_argument( | ||
"--dist_url", default="env://", help="url used to set up distributed training" | ||
) | ||
return parser | ||
|
||
|
||
def main(args): | ||
utils.init_distributed_mode(args) | ||
# print("git:\n {}\n".format(utils.get_sha())) | ||
|
||
print(args) | ||
|
||
device = torch.device(args.device) | ||
|
||
# fix the seed for reproducibility | ||
seed = args.seed + utils.get_rank() | ||
torch.manual_seed(seed) | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
# import pdb;pdb.set_trace() | ||
|
||
dataset_train = build_dataset(image_set="train", args=args) | ||
dataset_val = build_dataset(image_set="val", args=args) | ||
# import pdb;pdb.set_trace() | ||
if args.distributed: | ||
sampler_train = DistributedSampler(dataset_train) | ||
sampler_val = DistributedSampler(dataset_val, shuffle=False) | ||
else: | ||
sampler_train = torch.utils.data.RandomSampler(dataset_train) | ||
sampler_val = torch.utils.data.SequentialSampler(dataset_val) | ||
|
||
batch_sampler_train = torch.utils.data.BatchSampler( | ||
sampler_train, args.batch_size, drop_last=True | ||
) | ||
|
||
data_loader_train = DataLoader( | ||
dataset_train, | ||
batch_sampler=batch_sampler_train, | ||
collate_fn=utils.collate_fn, | ||
num_workers=args.num_workers, | ||
) | ||
data_loader_val = DataLoader( | ||
dataset_val, | ||
args.batch_size, | ||
sampler=sampler_val, | ||
drop_last=False, | ||
collate_fn=utils.collate_fn, | ||
num_workers=args.num_workers, | ||
) | ||
|
||
base_ds = get_coco_api_from_dataset(dataset_val) | ||
|
||
output_dir = Path(args.output_dir) | ||
|
||
get_val_json(data_loader_val, base_ds, device, args.output_dir, args) | ||
|
||
return | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
"Get YOLOS pred json file", parents=[get_args_parser()] | ||
) | ||
args = parser.parse_args() | ||
if args.output_dir: | ||
Path(args.output_dir).mkdir(parents=True, exist_ok=True) | ||
main(args) |
Oops, something went wrong.