我是靠谱客的博主 冷酷小霸王,这篇文章主要介绍detectron2 在训练过程中输出 validation loss(验证集的损失),现在分享给大家,希望可以做个参考。

写在前面的话

该问题在 GitHub的 detectron2 的 issues 上被提出,有人解决了(如下图所示)
提示一下,去 GitHub 上的 issues 搜索问题,尽量找【closed】标签的,这些基本都是有解决方法的问题。
这里只做个记录,仅供学习使用

参考GitHub链接:
How do I compute validation loss during training?

这个实现的很巧妙,直接把训练集的替换成验证集,用原本的训练集的计算loss的方法做计算

在这里插入图片描述

添加的包

复制代码
1
2
3
4
5
from detectron2.engine import HookBase from detectron2.data import build_detection_train_loader import detectron2.utils.comm as comm

功能函数

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class ValidationLoss(HookBase): def __init__(self, cfg, DATASETS_VAL_NAME):#多加一个DATASETS_VAL_NAME参数(小改动) super().__init__() self.cfg = cfg.clone() self.cfg.DATASETS.TRAIN = DATASETS_VAL_NAME## self._loader = iter(build_detection_train_loader(self.cfg)) def after_step(self): data = next(self._loader) with torch.no_grad(): loss_dict = self.trainer.model(data) losses = sum(loss_dict.values()) assert torch.isfinite(losses).all(), loss_dict loss_dict_reduced = {"val_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items()} losses_reduced = sum(loss for loss in loss_dict_reduced.values()) if comm.is_main_process(): self.trainer.storage.put_scalars(total_val_loss=losses_reduced, **loss_dict_reduced)

使用方法

复制代码
1
2
3
4
5
6
7
8
9
10
11
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) trainer = Trainer(cfg) val_loss = ValidationLoss(cfg, cfg.DATASETS.VAL) ##多加的参数 trainer.register_hooks([val_loss]) # swap the order of PeriodicWriter and ValidationLoss trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1] trainer.resume_or_load(resume=False) trainer.train()

实现效果

total_val_loss
val_loss_cls
val_loss_box_reg

在这里插入图片描述

最后

以上就是冷酷小霸王最近收集整理的关于detectron2 在训练过程中输出 validation loss(验证集的损失)的全部内容,更多相关detectron2内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(64)

评论列表共有 0 条评论

立即
投稿
返回
顶部