[目标检测][PyTorch]边界框操作

阅读源码时发现torchvision提供了许多边界框操作,包括

  1. NMS
  2. 移除小边界框
  3. 修剪边界框坐标,保证边界框坐标位于图像内
  4. 计算边界框面积
  5. IoU计算

在线文档:torchvision.ops

源码地址:torchvision/ops/boxes.py

当前torchvision版本:0.6.0a0+82fd1c8

NMS

之前实现了NumPy版本的NMS - [目标检测]Non-Maximum Suppression。现在可以使用PyTorch提供的实现

NMS的目的在于去除冗余的检测框,PyTorch提供了两个函数实现

1
2
def nms(boxes, scores, iou_threshold):
def batched_nms(boxes, scores, idxs, iou_threshold):

函数nms适用于单类别的NMS操作;而函数batched_nms适用于多类别的NMS操作,其能够实现各类别移除各自冗余的检测框。实现步骤如下:

  1. 从检测框中寻找最大的坐标
  2. 利用最大坐标值对各类别检测框添加一个偏移,从而保证各个类别的检测框相互分离
  3. 执行nms函数

移除小边界框

移除边长度小于阈值的检测框

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def remove_small_boxes(boxes, min_size):
# type: (Tensor, float)
"""
Remove boxes which contains at least one side smaller than min_size.

Arguments:
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
min_size (float): minimum size

Returns:
keep (Tensor[K]): indices of the boxes that have both sides
larger than min_size
"""
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size)
keep = keep.nonzero().squeeze(1)
return keep

修剪边界框坐标,保证边界框坐标位于图像内

这一操作适用于图像处理阶段,对图像执行随机裁剪时,需要同时裁剪中心位于裁剪图像内的边界框坐标

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def clip_boxes_to_image(boxes, size):
# type: (Tensor, Tuple[int, int])
"""
Clip boxes so that they lie inside an image of size `size`.

Arguments:
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
size (Tuple[height, width]): size of the image

Returns:
clipped_boxes (Tensor[N, 4])
"""
dim = boxes.dim()
boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2]
height, width = size

if torchvision._is_tracing():
boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
else:
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)

clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
return clipped_boxes.reshape(boxes.shape)

计算边界框面积

适用在IoU计算

1
2
3
4
5
6
7
8
9
10
11
12
13
def box_area(boxes):
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2) coordinates.

Arguments:
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2) format

Returns:
area (Tensor[N]): area for each box
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])

IoU计算

之前实现了NumpyPyTorchIoU计算:

现在PyTorch内置了IoU的计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def box_iou(boxes1, boxes2):
"""
Return intersection-over-union (Jaccard index) of boxes.

Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

Arguments:
boxes1 (Tensor[N, 4])
boxes2 (Tensor[M, 4])

Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
area1 = box_area(boxes1)
area2 = box_area(boxes2)

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]

wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]

iou = inter / (area1[:, None] + area2 - inter)
return iou

测试如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
import torch
from torchvision.ops import boxes

if __name__ == '__main__':
pred_boxes = np.array([[30, 280, 200, 300], [20, 300, 100, 360]])
target_boxes = np.array([[10, 10, 50, 50], [40, 270, 100, 380], [450, 300, 500, 500]])

# ious大小为[N2, N1]
# N1表示pred_boxes的个数
# N2表示target_boxes的个数
ious = boxes.box_iou(torch.from_numpy(pred_boxes).float(), torch.from_numpy(target_boxes).float())

print('ious: ', ious)
########################## 输出
ious: tensor([[0.0000, 0.1364, 0.0000],
[0.0000, 0.4615, 0.0000]])