-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is validation loss computed and output ? #310
Comments
The decode head loss and aux head loss is being calculated on the validation dataset. May be you can find their average to calculate the total loss. |
@rubeea So, do you mean that the validation loss is implicitly computed in the code, but it is not output anywhere on tensorboard, standard output, log.json etc ? Should I modify a little bit the code so that I can get the value of the validation loss ? The two kinds of losses you mentioned (decode.loss_seg and aux.loss_seg) appear on tensorboard in the train tab, but I cannot find them in the validation tab. Only evaluation scores (aAcc, mAcc, and mIoU) appear in the validation tab. Possibly I am doing something stupid or misunderstanding something miserably. I am still confused, but It seems that I should first learn the meaning of the losses and their implementation in the code. Thank you again. |
Hi, Actually you are right those are indeed the training data losses while the metrics are being computed on the validation dataset. Kindly report the solution here if you find a workaround. Thanks :) |
Hi, As per the mmsegmantation docs(https://mmsegmentation.readthedocs.io/en/latest/tutorials/customize_runtime.html), the validation loss can be calculated by setting the workflow to: [('train', 1), ('val', 1)] instead of just [('train', 1)]. However, I get a dataloaders error when I attempt to set the workflow to [('train', 1)]. Do you meet a similar error as well? If yes, any idea on how it can be resolved? |
Hello @rubeea Sorry for late reply. I have been crazily busy this week. Thank you for your comment. Yes, I have changed the workflow to include val. For the original problem in this issue, that is, to output validation loss in the tensorboard, To show loss in tensorboard, we need the key 'log_vars' in the output dictionary. This key exists in train output (in def train_step), but not in val output. That is why the val loss is not shown in tensorboard, I suppose.
I slightly changed the name by adding prefix 'val_' in the keys, otherwise I think the val loss is not distinguished from train loss in the tensorboard. In my case, this workaround worked and the val loss is shown on the tensorboard. (One unsatisfactory point is that the val loss is shown in 'train' tab... This is ugly but is not a problem practically.) |
I have not yet fully understand how to handle validation loss. But a workaround to show validation loss is found as above, so let me close this issue. |
Thank you for your comments and help. I'll definitely post if I find a way around to display the losses in a separate validation lab. |
Hi @tetsu-kikuchi, Did you encounter this error when using the workflow as workflow= [('train', 1), ('val', 1)]. If not can you kindly share your config file here so that I can understand what mistake I am making. Error: 7 frames /usr/local/lib/python3.6/dist-packages/mmcv/runner/iter_based_runner.py in run(self, data_loaders, workflow, max_iters, **kwargs) /usr/local/lib/python3.6/dist-packages/mmcv/runner/iter_based_runner.py in val(self, data_loader, **kwargs) /usr/local/lib/python3.6/dist-packages/mmcv/parallel/data_parallel.py in val_step(self, *inputs, **kwargs) /content/pldu_mmsegmentation/mmseg/models/segmentors/base.py in val_step(self, data_batch, **kwargs) /usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) /usr/local/lib/python3.6/dist-packages/mmcv/runner/fp16_utils.py in new_func(*args, **kwargs) /content/pldu_mmsegmentation/mmseg/models/segmentors/base.py in forward(self, img, img_metas, return_loss, **kwargs) TypeError: forward_train() missing 1 required positional argument: 'gt_semantic_seg' Thanks in advance for your help. |
Hi @rubeea python tools/train.py configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_80k_cityscapes.py --load-from checkpoints/deeplabv3plus_r50-d8_512x1024_80k_cityscapes_20200606_114049-f9fb496d.pth configs/base/models/deeplabv3plus_r50-d8.py
configs/base/schedules/schedule_80k.py
configs/base/default_runtime.py
configs/base/datasets/cityscapes.py
Note that I have customized some of the codes. By the way, your error message indicates that 'gt_semantic_seg' is not properly read. |
Hey @tetsu-kikuchi
for building the datasets for the workflow= [('train'),('val')] and train_segmentor in mmseg/apis/train.py for training as follows:
Moreover, if validate=True in the above function and the workflow is set to workflow=[('train')] only, I believe the statistics (loss, accuracy etc.) are calculated on the validation dataset and not the train dataset. Is that correct? Because if this is the case then I don't think there is a need to change the workflow to [('train'),('val')]. Thanks in advance. |
Hi @rubeea https://github.com/open-mmlab/mmsegmentation/blob/master/tools/train.py#L137
https://github.com/open-mmlab/mmsegmentation/blob/master/tools/train.py#L152
(To be precise, I downloaded the code on the last December, and use it by a slight customization for myself)
Sorry, I rarely set workflow=[('train')] only, so I do not know this point well. |
Hi @tetsu-kikuchi, Thanks for your response. As suspected, you are using the train pipeline for the validation dataset as well |
* Use ONNX / Core ML compatible method to broadcast. Unfortunately `tile` could not be used either, it's still not compatible with ONNX. See open-mmlab#284. * Add comment about why broadcast_to is not used. Also, apply style to changed files. * Make sure broadcast remains in same device.
* resolve comments * update changelog * remove redundant code * update
Thank you for your great work. I'd like to ask you a small question.
While I can find evaluation scores such as mIoU, I cannot find validation loss anywhere (on tensorboard, standard output, log.json etc.).
I used the following config.
I set
where 1 epoch = 300 iterations.
Thanks for any help.
The text was updated successfully, but these errors were encountered: