From b1f3d10b328f91b271bc797264e8ec00220118a2 Mon Sep 17 00:00:00 2001 From: Evgeniy Date: Sat, 6 May 2023 12:00:14 +0300 Subject: [PATCH 1/3] implement region layer as a copy of yolo layer --- demo_darknet2onnx.py | 2 +- tool/darknet2pytorch.py | 30 +++---- tool/region_layer.py | 180 ++++++++++++++++++++++++++++++++++++++++ tool/utils.py | 4 +- 4 files changed, 196 insertions(+), 20 deletions(-) create mode 100644 tool/region_layer.py diff --git a/demo_darknet2onnx.py b/demo_darknet2onnx.py index 3bdccf8f..09badbc7 100644 --- a/demo_darknet2onnx.py +++ b/demo_darknet2onnx.py @@ -46,7 +46,7 @@ def detect(session, image_src, namesfile): outputs = session.run(None, {input_name: img_in}) - boxes = post_processing(img_in, 0.4, 0.6, outputs) + boxes = post_processing(img_in, 0.25, 0.45, outputs) class_names = load_class_names(namesfile) plot_boxes_cv2(image_src, boxes[0], savename='predictions_onnx.jpg', class_names=class_names) diff --git a/tool/darknet2pytorch.py b/tool/darknet2pytorch.py index adf9d8f1..6eb2a0c2 100644 --- a/tool/darknet2pytorch.py +++ b/tool/darknet2pytorch.py @@ -1,7 +1,7 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np -from tool.region_loss import RegionLoss +from tool.region_layer import RegionLayer from tool.yolo_layer import YoloLayer from tool.config import * from tool.torch_utils import * @@ -208,12 +208,8 @@ def forward(self, x): x = x1 * x2 outputs[ind] = x elif block['type'] == 'region': - continue - if self.loss: - self.loss = self.loss + self.models[ind](x) - else: - self.loss = self.models[ind](x) - outputs[ind] = None + boxes = self.models[ind](x) + out_boxes.append(boxes) elif block['type'] == 'yolo': # if self.training: # pass @@ -392,19 +388,19 @@ def create_network(self, blocks): out_strides.append(prev_stride) models.append(model) elif block['type'] == 'region': - loss = RegionLoss() + region = RegionLayer() anchors = block['anchors'].split(',') - loss.anchors = [float(i) for i in anchors] - loss.num_classes = int(block['classes']) - loss.num_anchors = int(block['num']) - loss.anchor_step = len(loss.anchors) // loss.num_anchors - loss.object_scale = float(block['object_scale']) - loss.noobject_scale = float(block['noobject_scale']) - loss.class_scale = float(block['class_scale']) - loss.coord_scale = float(block['coord_scale']) + region.anchors = [float(i) for i in anchors] + region.num_classes = int(block['classes']) + region.num_anchors = int(block['num']) + region.anchor_step = len(region.anchors) // region.num_anchors + region.object_scale = float(block['object_scale']) + region.noobject_scale = float(block['noobject_scale']) + region.class_scale = float(block['class_scale']) + region.coord_scale = float(block['coord_scale']) out_filters.append(prev_filters) out_strides.append(prev_stride) - models.append(loss) + models.append(region) elif block['type'] == 'yolo': yolo_layer = YoloLayer() anchors = block['anchors'].split(',') diff --git a/tool/region_layer.py b/tool/region_layer.py new file mode 100644 index 00000000..24f59752 --- /dev/null +++ b/tool/region_layer.py @@ -0,0 +1,180 @@ +import torch.nn as nn +import torch.nn.functional as F +from tool.torch_utils import * + +def region_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, only_objectness=1, + validation=False): + # Output would be invalid if it does not satisfy this assert + # assert (output.size(1) == (5 + num_classes) * num_anchors) + + # print(output.size()) + + # Slice the second dimension (channel) of output into: + # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ] + # And then into + # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ] + # batch = output.size(0) + # H = output.size(2) + # W = output.size(3) + + bxy_list = [] + bwh_list = [] + det_confs_list = [] + cls_confs_list = [] + + for i in range(num_anchors): + begin = i * (5 + num_classes) + end = (i + 1) * (5 + num_classes) + + bxy_list.append(output[:, begin : begin + 2]) + bwh_list.append(output[:, begin + 2 : begin + 4]) + det_confs_list.append(output[:, begin + 4 : begin + 5]) + cls_confs_list.append(output[:, begin + 5 : end]) + + # Shape: [batch, num_anchors * 2, H, W] + bxy = torch.cat(bxy_list, dim=1) + # Shape: [batch, num_anchors * 2, H, W] + bwh = torch.cat(bwh_list, dim=1) + + # Shape: [batch, num_anchors, H, W] + det_confs = torch.cat(det_confs_list, dim=1) + # Shape: [batch, num_anchors * H * W] + det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3)) + + # Shape: [batch, num_anchors * num_classes, H, W] + cls_confs = torch.cat(cls_confs_list, dim=1) + # Shape: [batch, num_anchors, num_classes, H * W] + cls_confs = cls_confs.view(output.size(0), num_anchors, num_classes, output.size(2) * output.size(3)) + # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes] + cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(output.size(0), num_anchors * output.size(2) * output.size(3), num_classes) + + # Apply sigmoid(), exp() and softmax() to slices + # + bxy = torch.sigmoid(bxy) # * scale_x_y - 0.5 * (scale_x_y - 1) + bwh = torch.exp(bwh) + det_confs = torch.sigmoid(det_confs) + cls_confs = torch.sigmoid(cls_confs) + + # Prepare C-x, C-y, P-w, P-h (None of them are torch related) + grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(3) - 1, output.size(3)), axis=0).repeat(output.size(2), 0), axis=0), axis=0) + grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(2) - 1, output.size(2)), axis=1).repeat(output.size(3), 1), axis=0), axis=0) + # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1) + # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W) + + anchor_w = [] + anchor_h = [] + for i in range(num_anchors): + anchor_w.append(anchors[i * 2]) + anchor_h.append(anchors[i * 2 + 1]) + + device = None + cuda_check = output.is_cuda + if cuda_check: + device = output.get_device() + + bx_list = [] + by_list = [] + bw_list = [] + bh_list = [] + + # Apply C-x, C-y, P-w, P-h + for i in range(num_anchors): + ii = i * 2 + # Shape: [batch, 1, H, W] + bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32) + # Shape: [batch, 1, H, W] + by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32) + # Shape: [batch, 1, H, W] + bw = bwh[:, ii : ii + 1] * anchor_w[i] + # Shape: [batch, 1, H, W] + bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i] + + bx_list.append(bx) + by_list.append(by) + bw_list.append(bw) + bh_list.append(bh) + + + ######################################## + # Figure out bboxes from slices # + ######################################## + + # Shape: [batch, num_anchors, H, W] + bx = torch.cat(bx_list, dim=1) + # Shape: [batch, num_anchors, H, W] + by = torch.cat(by_list, dim=1) + # Shape: [batch, num_anchors, H, W] + bw = torch.cat(bw_list, dim=1) + # Shape: [batch, num_anchors, H, W] + bh = torch.cat(bh_list, dim=1) + + # normalize coordinates to [0, 1] + # bx /= output.size(3) + # by /= output.size(2) + # Shape: [batch, 2 * num_anchors, H, W] + bx_bw = torch.cat((bx, bw), dim=1) + # Shape: [batch, 2 * num_anchors, H, W] + by_bh = torch.cat((by, bh), dim=1) + + # normalize coordinates to [0, 1] + bx_bw /= output.size(3) + by_bh /= output.size(2) + + # Shape: [batch, num_anchors * H * W, 1] + bx = bx_bw[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) + by = by_bh[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) + bw = bx_bw[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) + bh = by_bh[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) + + bx1 = bx - bw * 0.5 + by1 = by - bh * 0.5 + bx2 = bx1 + bw + by2 = by1 + bh + + # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4] + boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4) + # boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4) + # boxes = boxes.repeat(1, 1, num_classes, 1) + + # boxes: [batch, num_anchors * H * W, 1, 4] + # cls_confs: [batch, num_anchors * H * W, num_classes] + # det_confs: [batch, num_anchors * H * W] + + det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) + # print(det_confs) + # print(cls_confs) + confs = det_confs + # confs = cls_confs * det_confs + + # boxes: [batch, num_anchors * H * W, 1, 4] + # confs: [batch, num_anchors * H * W, num_classes] + + return boxes, confs + +class RegionLayer(nn.Module): + ''' Region layer (Yolov2 yolo layer) + model_out: while inference,is post-processing inside or outside the model + true:outside + ''' + def __init__(self, num_classes=0, anchors=[], num_anchors=1, stride=32, model_out=False): + super(RegionLayer, self).__init__() + self.num_classes = num_classes + self.anchors = anchors + self.num_anchors = num_anchors + self.anchor_step = len(anchors) // num_anchors + self.coord_scale = 1 + self.noobject_scale = 1 + self.object_scale = 5 + self.class_scale = 1 + self.thresh = 0.6 + self.stride = stride + self.seen = 0 + + self.model_out = model_out + + def forward(self, output, target=None): + if self.training: + return output + + return region_forward_dynamic(output, self.thresh, self.num_classes, self.anchors, self.num_anchors) + diff --git a/tool/utils.py b/tool/utils.py index a42e6264..94e38d06 100644 --- a/tool/utils.py +++ b/tool/utils.py @@ -137,8 +137,8 @@ def get_color(c, x, max_val): t_size = cv2.getTextSize(msg, 0, 0.7, thickness=bbox_thick // 2)[0] c1, c2 = (x1,y1), (x2, y2) c3 = (c1[0] + t_size[0], c1[1] - t_size[1] - 3) - cv2.rectangle(img, (x1,y1), (np.float32(c3[0]), np.float32(c3[1])), rgb, -1) - img = cv2.putText(img, msg, (c1[0], np.float32(c1[1] - 2)), cv2.FONT_HERSHEY_SIMPLEX,0.7, (0,0,0), bbox_thick//2,lineType=cv2.LINE_AA) + cv2.rectangle(img, (x1,y1), (int(c3[0]), int(c3[1])), rgb, -1) + img = cv2.putText(img, msg, (c1[0], int(c1[1] - 2)), cv2.FONT_HERSHEY_SIMPLEX,0.7, (0,0,0), bbox_thick//2,lineType=cv2.LINE_AA) img = cv2.rectangle(img, (x1, y1), (x2, y2), rgb, bbox_thick) if savename: From facf74dd80f482f262543c7d9954ed048d5e17eb Mon Sep 17 00:00:00 2001 From: Evgeniy Date: Sat, 6 May 2023 12:57:26 +0300 Subject: [PATCH 2/3] restored original thresholds --- demo_darknet2onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo_darknet2onnx.py b/demo_darknet2onnx.py index 09badbc7..3bdccf8f 100644 --- a/demo_darknet2onnx.py +++ b/demo_darknet2onnx.py @@ -46,7 +46,7 @@ def detect(session, image_src, namesfile): outputs = session.run(None, {input_name: img_in}) - boxes = post_processing(img_in, 0.25, 0.45, outputs) + boxes = post_processing(img_in, 0.4, 0.6, outputs) class_names = load_class_names(namesfile) plot_boxes_cv2(image_src, boxes[0], savename='predictions_onnx.jpg', class_names=class_names) From b7ba4b23de043007d745a8417553f3e1e8f81a66 Mon Sep 17 00:00:00 2001 From: Evgeniy Date: Sat, 6 May 2023 13:19:36 +0300 Subject: [PATCH 3/3] reuse yolo layer instead of duplicating the code --- tool/darknet2pytorch.py | 11 ++- tool/region_layer.py | 180 ---------------------------------------- tool/yolo_layer.py | 13 +-- 3 files changed, 13 insertions(+), 191 deletions(-) delete mode 100644 tool/region_layer.py diff --git a/tool/darknet2pytorch.py b/tool/darknet2pytorch.py index 6eb2a0c2..c0da249b 100644 --- a/tool/darknet2pytorch.py +++ b/tool/darknet2pytorch.py @@ -1,7 +1,6 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np -from tool.region_layer import RegionLayer from tool.yolo_layer import YoloLayer from tool.config import * from tool.torch_utils import * @@ -388,16 +387,16 @@ def create_network(self, blocks): out_strides.append(prev_stride) models.append(model) elif block['type'] == 'region': - region = RegionLayer() + region = YoloLayer() anchors = block['anchors'].split(',') region.anchors = [float(i) for i in anchors] region.num_classes = int(block['classes']) region.num_anchors = int(block['num']) region.anchor_step = len(region.anchors) // region.num_anchors - region.object_scale = float(block['object_scale']) - region.noobject_scale = float(block['noobject_scale']) - region.class_scale = float(block['class_scale']) - region.coord_scale = float(block['coord_scale']) + region.scale_x_y = 1.0 # thre is not such value in region config + region.anchor_mask = [int(i) for i in range(len(anchors) // 2)] # region has no anchor masks + region.stride = 1 # not implemented for region + region.multiply_confs = False # do not multiply detection and class confidence out_filters.append(prev_filters) out_strides.append(prev_stride) models.append(region) diff --git a/tool/region_layer.py b/tool/region_layer.py deleted file mode 100644 index 24f59752..00000000 --- a/tool/region_layer.py +++ /dev/null @@ -1,180 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F -from tool.torch_utils import * - -def region_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, only_objectness=1, - validation=False): - # Output would be invalid if it does not satisfy this assert - # assert (output.size(1) == (5 + num_classes) * num_anchors) - - # print(output.size()) - - # Slice the second dimension (channel) of output into: - # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ] - # And then into - # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ] - # batch = output.size(0) - # H = output.size(2) - # W = output.size(3) - - bxy_list = [] - bwh_list = [] - det_confs_list = [] - cls_confs_list = [] - - for i in range(num_anchors): - begin = i * (5 + num_classes) - end = (i + 1) * (5 + num_classes) - - bxy_list.append(output[:, begin : begin + 2]) - bwh_list.append(output[:, begin + 2 : begin + 4]) - det_confs_list.append(output[:, begin + 4 : begin + 5]) - cls_confs_list.append(output[:, begin + 5 : end]) - - # Shape: [batch, num_anchors * 2, H, W] - bxy = torch.cat(bxy_list, dim=1) - # Shape: [batch, num_anchors * 2, H, W] - bwh = torch.cat(bwh_list, dim=1) - - # Shape: [batch, num_anchors, H, W] - det_confs = torch.cat(det_confs_list, dim=1) - # Shape: [batch, num_anchors * H * W] - det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3)) - - # Shape: [batch, num_anchors * num_classes, H, W] - cls_confs = torch.cat(cls_confs_list, dim=1) - # Shape: [batch, num_anchors, num_classes, H * W] - cls_confs = cls_confs.view(output.size(0), num_anchors, num_classes, output.size(2) * output.size(3)) - # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes] - cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(output.size(0), num_anchors * output.size(2) * output.size(3), num_classes) - - # Apply sigmoid(), exp() and softmax() to slices - # - bxy = torch.sigmoid(bxy) # * scale_x_y - 0.5 * (scale_x_y - 1) - bwh = torch.exp(bwh) - det_confs = torch.sigmoid(det_confs) - cls_confs = torch.sigmoid(cls_confs) - - # Prepare C-x, C-y, P-w, P-h (None of them are torch related) - grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(3) - 1, output.size(3)), axis=0).repeat(output.size(2), 0), axis=0), axis=0) - grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(2) - 1, output.size(2)), axis=1).repeat(output.size(3), 1), axis=0), axis=0) - # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1) - # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W) - - anchor_w = [] - anchor_h = [] - for i in range(num_anchors): - anchor_w.append(anchors[i * 2]) - anchor_h.append(anchors[i * 2 + 1]) - - device = None - cuda_check = output.is_cuda - if cuda_check: - device = output.get_device() - - bx_list = [] - by_list = [] - bw_list = [] - bh_list = [] - - # Apply C-x, C-y, P-w, P-h - for i in range(num_anchors): - ii = i * 2 - # Shape: [batch, 1, H, W] - bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32) - # Shape: [batch, 1, H, W] - by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32) - # Shape: [batch, 1, H, W] - bw = bwh[:, ii : ii + 1] * anchor_w[i] - # Shape: [batch, 1, H, W] - bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i] - - bx_list.append(bx) - by_list.append(by) - bw_list.append(bw) - bh_list.append(bh) - - - ######################################## - # Figure out bboxes from slices # - ######################################## - - # Shape: [batch, num_anchors, H, W] - bx = torch.cat(bx_list, dim=1) - # Shape: [batch, num_anchors, H, W] - by = torch.cat(by_list, dim=1) - # Shape: [batch, num_anchors, H, W] - bw = torch.cat(bw_list, dim=1) - # Shape: [batch, num_anchors, H, W] - bh = torch.cat(bh_list, dim=1) - - # normalize coordinates to [0, 1] - # bx /= output.size(3) - # by /= output.size(2) - # Shape: [batch, 2 * num_anchors, H, W] - bx_bw = torch.cat((bx, bw), dim=1) - # Shape: [batch, 2 * num_anchors, H, W] - by_bh = torch.cat((by, bh), dim=1) - - # normalize coordinates to [0, 1] - bx_bw /= output.size(3) - by_bh /= output.size(2) - - # Shape: [batch, num_anchors * H * W, 1] - bx = bx_bw[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) - by = by_bh[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) - bw = bx_bw[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) - bh = by_bh[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) - - bx1 = bx - bw * 0.5 - by1 = by - bh * 0.5 - bx2 = bx1 + bw - by2 = by1 + bh - - # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4] - boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4) - # boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4) - # boxes = boxes.repeat(1, 1, num_classes, 1) - - # boxes: [batch, num_anchors * H * W, 1, 4] - # cls_confs: [batch, num_anchors * H * W, num_classes] - # det_confs: [batch, num_anchors * H * W] - - det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) - # print(det_confs) - # print(cls_confs) - confs = det_confs - # confs = cls_confs * det_confs - - # boxes: [batch, num_anchors * H * W, 1, 4] - # confs: [batch, num_anchors * H * W, num_classes] - - return boxes, confs - -class RegionLayer(nn.Module): - ''' Region layer (Yolov2 yolo layer) - model_out: while inference,is post-processing inside or outside the model - true:outside - ''' - def __init__(self, num_classes=0, anchors=[], num_anchors=1, stride=32, model_out=False): - super(RegionLayer, self).__init__() - self.num_classes = num_classes - self.anchors = anchors - self.num_anchors = num_anchors - self.anchor_step = len(anchors) // num_anchors - self.coord_scale = 1 - self.noobject_scale = 1 - self.object_scale = 5 - self.class_scale = 1 - self.thresh = 0.6 - self.stride = stride - self.seen = 0 - - self.model_out = model_out - - def forward(self, output, target=None): - if self.training: - return output - - return region_forward_dynamic(output, self.thresh, self.num_classes, self.anchors, self.num_anchors) - diff --git a/tool/yolo_layer.py b/tool/yolo_layer.py index c3c904a5..fffc4505 100644 --- a/tool/yolo_layer.py +++ b/tool/yolo_layer.py @@ -146,7 +146,7 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x def yolo_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1, - validation=False): + validation=False, multiply_confs=True): # Output would be invalid if it does not satisfy this assert # assert (output.size(1) == (5 + num_classes) * num_anchors) @@ -280,7 +280,10 @@ def yolo_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, # det_confs: [batch, num_anchors * H * W] det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3), 1) - confs = cls_confs * det_confs + if multiply_confs: + confs = cls_confs * det_confs + else: + confs = det_confs # boxes: [batch, num_anchors * H * W, 1, 4] # confs: [batch, num_anchors * H * W, num_classes] @@ -292,7 +295,7 @@ class YoloLayer(nn.Module): model_out: while inference,is post-processing inside or outside the model true:outside ''' - def __init__(self, anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, stride=32, model_out=False): + def __init__(self, anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, stride=32, model_out=False, multiply_confs=True): super(YoloLayer, self).__init__() self.anchor_mask = anchor_mask self.num_classes = num_classes @@ -307,6 +310,7 @@ def __init__(self, anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, str self.stride = stride self.seen = 0 self.scale_x_y = 1 + self.multiply_confs = multiply_confs self.model_out = model_out @@ -318,5 +322,4 @@ def forward(self, output, target=None): masked_anchors += self.anchors[m * self.anchor_step:(m + 1) * self.anchor_step] masked_anchors = [anchor / self.stride for anchor in masked_anchors] - return yolo_forward_dynamic(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask),scale_x_y=self.scale_x_y) - + return yolo_forward_dynamic(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask),scale_x_y=self.scale_x_y, multiply_confs=self.multiply_confs)