Skip to content

Commit

Permalink
Added label_smooth_eps=0.1 for [net] layer for Label Smoothing for Cl…
Browse files Browse the repository at this point in the history
…assifier
  • Loading branch information
AlexeyAB committed Dec 8, 2019
1 parent 318919e commit 2a873f3
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 10 deletions.
2 changes: 2 additions & 0 deletions include/darknet.h
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ typedef struct network {
int flip; // horizontal flip 50% probability augmentaiont for classifier training (default = 1)
int blur;
int mixup;
float label_smooth_eps;
int letter_box;
float angle;
float aspect;
Expand Down Expand Up @@ -813,6 +814,7 @@ typedef struct load_args {
int flip;
int blur;
int mixup;
float label_smooth_eps;
float angle;
float aspect;
float saturation;
Expand Down
1 change: 1 addition & 0 deletions src/classifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
args.hue = net.hue;
args.size = net.w > net.h ? net.w : net.h;

args.label_smooth_eps = net.label_smooth_eps;
args.mixup = net.mixup;
if (dont_show && show_imgs) show_imgs = 2;
args.show_imgs = show_imgs;
Expand Down
44 changes: 35 additions & 9 deletions src/data.c
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,32 @@ void fill_truth(char *path, char **labels, int k, float *truth)
}
}

void fill_truth_smooth(char *path, char **labels, int k, float *truth, float label_smooth_eps)
{
int i;
memset(truth, 0, k * sizeof(float));
int count = 0;
for (i = 0; i < k; ++i) {
if (strstr(path, labels[i])) {
truth[i] = (1 - label_smooth_eps);
++count;
}
else {
truth[i] = label_smooth_eps / (k - 1);
}
}
if (count != 1) {
printf("Too many or too few labels: %d, %s\n", count, path);
count = 0;
for (i = 0; i < k; ++i) {
if (strstr(path, labels[i])) {
printf("\t label %d: %s \n", count, labels[i]);
count++;
}
}
}
}

void fill_hierarchy(float *truth, int k, tree *hierarchy)
{
int j;
Expand Down Expand Up @@ -548,12 +574,12 @@ void fill_hierarchy(float *truth, int k, tree *hierarchy)
}
}

matrix load_labels_paths(char **paths, int n, char **labels, int k, tree *hierarchy)
matrix load_labels_paths(char **paths, int n, char **labels, int k, tree *hierarchy, float label_smooth_eps)
{
matrix y = make_matrix(n, k);
int i;
for(i = 0; i < n && labels; ++i){
fill_truth(paths[i], labels, k, y.vals[i]);
fill_truth_smooth(paths[i], labels, k, y.vals[i], label_smooth_eps);
if(hierarchy){
fill_hierarchy(y.vals[i], k, hierarchy);
}
Expand Down Expand Up @@ -1336,7 +1362,7 @@ void *load_thread(void *ptr)
if (a.type == OLD_CLASSIFICATION_DATA){
*a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
} else if (a.type == CLASSIFICATION_DATA){
*a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.flip, a.min, a.max, a.w, a.h, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.mixup, a.blur, a.show_imgs);
*a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.flip, a.min, a.max, a.w, a.h, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.mixup, a.blur, a.show_imgs, a.label_smooth_eps);
} else if (a.type == SUPER_DATA){
*a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale);
} else if (a.type == WRITING_DATA){
Expand Down Expand Up @@ -1432,7 +1458,7 @@ data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int
data d = {0};
d.shallow = 0;
d.X = load_image_paths(paths, n, w, h);
d.y = load_labels_paths(paths, n, labels, k, 0);
d.y = load_labels_paths(paths, n, labels, k, 0, 0);
if(m) free(paths);
return d;
}
Expand Down Expand Up @@ -1481,21 +1507,21 @@ data load_data_super(char **paths, int n, int m, int w, int h, int scale)
return d;
}

data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs)
data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs, float label_smooth_eps)
{
char **paths_stored = paths;
if(m) paths = get_random_paths(paths, n, m);
data d = {0};
d.shallow = 0;
d.X = load_image_augment_paths(paths, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
d.y = load_labels_paths(paths, n, labels, k, hierarchy);
d.y = load_labels_paths(paths, n, labels, k, hierarchy, label_smooth_eps);

if (mixup && rand_int(0, 1)) {
char **paths_mix = get_random_paths(paths_stored, n, m);
data d2 = { 0 };
d2.shallow = 0;
d2.X = load_image_augment_paths(paths_mix, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
d2.y = load_labels_paths(paths_mix, n, labels, k, hierarchy);
d2.y = load_labels_paths(paths_mix, n, labels, k, hierarchy, label_smooth_eps);
free(paths_mix);

data d3 = { 0 };
Expand All @@ -1505,12 +1531,12 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h
if (mixup >= 3) {
char **paths_mix3 = get_random_paths(paths_stored, n, m);
d3.X = load_image_augment_paths(paths_mix3, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
d3.y = load_labels_paths(paths_mix3, n, labels, k, hierarchy);
d3.y = load_labels_paths(paths_mix3, n, labels, k, hierarchy, label_smooth_eps);
free(paths_mix3);

char **paths_mix4 = get_random_paths(paths_stored, n, m);
d4.X = load_image_augment_paths(paths_mix4, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure);
d4.y = load_labels_paths(paths_mix4, n, labels, k, hierarchy);
d4.y = load_labels_paths(paths_mix4, n, labels, k, hierarchy, label_smooth_eps);
free(paths_mix4);
}

Expand Down
3 changes: 2 additions & 1 deletion src/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo
data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure);
matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure);
data load_data_super(char **paths, int n, int m, int w, int h, int scale);
data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs);
data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs, float label_smooth_eps);
data load_go(char *filename);

box_label *read_boxes(char *filename, int *n);
Expand All @@ -116,6 +116,7 @@ data *split_data(data d, int part, int total);
data concat_data(data d1, data d2);
data concat_datas(data *d, int n);
void fill_truth(char *path, char **labels, int k, float *truth);
void fill_truth_smooth(char *path, char **labels, int k, float *truth, float label_smooth_eps);
#ifdef __cplusplus
}

Expand Down
1 change: 1 addition & 0 deletions src/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ void parse_net_options(list *options, network *net)
else if (cutmix) net->mixup = 2;
else if (mosaic) net->mixup = 3;
net->letter_box = option_find_int_quiet(options, "letter_box", 0);
net->label_smooth_eps = option_find_float_quiet(options, "label_smooth_eps", 0.0f);

net->angle = option_find_float_quiet(options, "angle", 0);
net->aspect = option_find_float_quiet(options, "aspect", 1);
Expand Down

0 comments on commit 2a873f3

Please sign in to comment.