From c65a9af4ed0245ba8b6c1b712606e4c764316ef2 Mon Sep 17 00:00:00 2001 From: magic_zhang <2460171714@qq.com> Date: Wed, 17 Mar 2021 13:06:13 +0800 Subject: [PATCH] Update loss.py fix bug --- atss_core/modeling/rpn/atss/loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/atss_core/modeling/rpn/atss/loss.py b/atss_core/modeling/rpn/atss/loss.py index 31e7925..0c69a40 100644 --- a/atss_core/modeling/rpn/atss/loss.py +++ b/atss_core/modeling/rpn/atss/loss.py @@ -294,10 +294,11 @@ def __call__(self, box_cls, box_regression, centerness, targets, anchors): reg_targets_flatten = reg_targets_flatten[pos_inds] anchors_flatten = anchors_flatten[pos_inds] centerness_flatten = centerness_flatten[pos_inds] - centerness_targets = self.compute_centerness_targets(reg_targets_flatten, anchors_flatten) - sum_centerness_targets_avg_per_gpu = reduce_sum(centerness_targets.sum()).item() / float(num_gpus) + if pos_inds.numel() > 0: + centerness_targets = self.compute_centerness_targets(reg_targets_flatten, anchors_flatten) + sum_centerness_targets_avg_per_gpu = reduce_sum(centerness_targets.sum()).item() / float(num_gpus) reg_loss = self.GIoULoss(box_regression_flatten, reg_targets_flatten, anchors_flatten, weight=centerness_targets) / sum_centerness_targets_avg_per_gpu centerness_loss = self.centerness_loss_func(centerness_flatten, centerness_targets) / num_pos_avg_per_gpu