极市导读
本文主要围绕以YOLO为基线的关键点检测器,介绍了该框架的演变,并提供了在onnxruntime的推理框架下实现YOLOv7-Pose的具体代码和相应解释。 >>极市七夕粉丝福利活动:搞科研的日子是364天,但七夕只有一天!
if self.kpt_label:
#Direct kpt prediction
pkpt_x = ps[:,
6::
3] *
2. -
0.5
pkpt_y = ps[:,
7::
3] *
2. -
0.5
pkpt_score = ps[:,
8::
3]
#mask
kpt_mask = (tkpt[i][:,
0::
2] !=
0)
lkptv += self.BCEcls(pkpt_score, kpt_mask.float())
#l2 distance based loss
#lkpt += (((pkpt-tkpt[i])*kpt_mask)**2).mean() #Try to make this loss based on distance instead of ordinary difference
#oks based loss
d = (pkpt_x-tkpt[i][:,
0::
2])**
2 + (pkpt_y-tkpt[i][:,
1::
2])**
2
s = torch.prod(tbox[i][:,
-2:], dim=
1, keepdim=
True)
kpt_loss_factor = (torch.sum(kpt_mask !=
0) + torch.sum(kpt_mask ==
0))/torch.sum(kpt_mask !=
0)
lkpt += kpt_loss_factor*((
1 - torch.exp(-d/(s*(
4*sigmas**
2)+
1e-9)))*kpt_mask).mean()
% weigths = torch.load(
'weights/yolov7-w6-pose.pt')
% image = cv2.imread(
'sample/pose.jpeg')
!python pose.py
# 原代码:
for k, m
in model.named_modules():
m._non_persistent_buffers_set = set()
# pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv):
# assign export-friendly activations
if isinstance(m.act, nn.Hardswish):
m.act = Hardswish()
elif isinstance(m.act, nn.SiLU):
m.act = SiLU()
model.model[
-1].export =
not opt.grid
# set Detect() layer grid export
# 修改代码:
for k, m
in model.named_modules():
m._non_persistent_buffers_set = set()
# pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv):
# assign export-friendly activations
if isinstance(m.act, nn.Hardswish):
m.act = Hardswish()
elif isinstance(m.act, nn.SiLU):
m.act = SiLU()
elif isinstance(m, models.yolo.IKeypoint):
m.forward = m.forward_keypoint
# assign forward (optional)
# 此处切换检测头
model.model[
-1].export =
not opt.grid
# set Detect() layer grid export
python export.py --weights
'weights/yolov7-w6-pose.pt' --img-size 960 --simplify True
import onnxruntime
import matplotlib.pyplot
as plt
import torch
import cv2
from torchvision
import transforms
import numpy
as np
from utils.datasets
import letterbox
from utils.general
import non_max_suppression_kpt
from utils.plots
import output_to_keypoint, plot_skeleton_kpts
device = torch.device(
"cpu")
image = cv2.imread(
'sample/pose.jpeg')
image = letterbox(image,
960, stride=
64, auto=
True)[
0]
image_ = image.copy()
image = transforms.ToTensor()(image)
image = torch.tensor(np.array([image.numpy()]))
print(image.shape)
sess = onnxruntime.InferenceSession(
'weights/yolov7-w6-pose.onnx')
out = sess.run([
'output'], {
'images': image.numpy()})[
0]
out = torch.from_numpy(out)
output = non_max_suppression_kpt(out,
0.25,
0.65, nc=
1, nkpt=
17, kpt_label=
True)
output = output_to_keypoint(output)
nimg = image[
0].permute(
1,
2,
0) *
255
nimg = nimg.cpu().numpy().astype(np.uint8)
nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
for idx
in range(output.shape[
0]):
plot_skeleton_kpts(nimg, output[idx,
7:].T,
3)
# matplotlib inline
plt.figure(figsize=(
8,
8))
plt.axis(
'off')
plt.imshow(nimg)
plt.show()
plt.savefig(
"tmp")
公众号后台回复“ECCV2022”获取论文分类资源下载~
“
点击阅读原文进入CV社区
收获更多技术干货