torchvision對象檢測介紹
Pytorch1.11版本以上支持Torchvision高版本支持以下對象檢測模型的遷移學習:
- Faster-RCNN - Mask-RCNN - FCOS - RetinaNet - SSD - KeyPointsRCNN其中基于COCO的預訓練模型mAP對應關系如下:
最近一段時間本人已經全部親測,都可以轉換為ONNX格式模型,都可以支持ONNXRUNTIME框架的Python版本與C++版本推理,本文以RetinaNet為例,演示了從模型下載到導出ONNX格式,然后基于ONNXRUNTIME推理的整個流程。
RetinaNet轉ONNX
把模型轉換為ONNX格式,Pytorch是原生支持的,只需要把通過torch.onnx.export接口,填上相關的參數,然后直接運行就可以生成ONNX模型文件。相關的轉換代碼如下:
model=tv.models.detection.retinanet_resnet50_fpn(pretrained=True) dummy_input=torch.randn(1,3,1333,800) model.eval() model(dummy_input) im=torch.zeros(1,3,1333,800).to("cpu") torch.onnx.export(model,im, "retinanet_resnet50_fpn.onnx", verbose=False, opset_version=11, training=torch.onnx.TrainingMode.EVAL, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={'input':{0:'batch',2:'height',3:'width'}} )運行時候控制臺會有一系列的警告輸出,但是絕對不影響模型轉換,影響不影響精度我還沒做個仔細的對比。 模型轉換之后,可以直接查看模型的輸入與輸出結構,圖示如下:
RetinaNet的ONNX格式推理
基于Python版本的ONNXRUNTIME完成推理演示,這個跟我之前寫過一篇文章Faster-RCNN的ONNX推理演示非常相似,大概是去年寫的,鏈接在這里: 代碼很簡單,只有三十幾行,Python就是方便使用,這里最需要注意的是輸入圖像的預處理必須是RGB格式,需要歸一化到0~1之間。對得到的三個輸出層分別解析,就可以獲取到坐標(boxes里面包含的實際坐標,無需轉換),推理部分的代碼如下:
importonnxruntimeasort importcv2ascv importnumpyasnp importtorchvision coco_names={'0':'background','1':'person','2':'bicycle','3':'car','4':'motorcycle','5':'airplane','6':'bus', '7':'train','8':'truck','9':'boat','10':'trafficlight','11':'firehydrant','13':'stopsign', '14':'parkingmeter','15':'bench','16':'bird','17':'cat','18':'dog','19':'horse','20':'sheep', '21':'cow','22':'elephant','23':'bear','24':'zebra','25':'giraffe','27':'backpack', '28':'umbrella','31':'handbag','32':'tie','33':'suitcase','34':'frisbee','35':'skis', '36':'snowboard','37':'sportsball','38':'kite','39':'baseballbat','40':'baseballglove', '41':'skateboard','42':'surfboard','43':'tennisracket','44':'bottle','46':'wineglass', '47':'cup','48':'fork','49':'knife','50':'spoon','51':'bowl','52':'banana','53':'apple', '54':'sandwich','55':'orange','56':'broccoli','57':'carrot','58':'hotdog','59':'pizza', '60':'donut','61':'cake','62':'chair','63':'couch','64':'pottedplant','65':'bed', '67':'diningtable','70':'toilet','72':'tv','73':'laptop','74':'mouse','75':'remote', '76':'keyboard','77':'cellphone','78':'microwave','79':'oven','80':'toaster','81':'sink', '82':'refrigerator','84':'book','85':'clock','86':'vase','87':'scissors','88':'teddybear', '89':'hairdrier','90':'toothbrush'} transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) sess_options=ort.SessionOptions() src=cv.imread("D:/images/mmc.png") cv.namedWindow("Retina-NetDetectionDemo",cv.WINDOW_AUTOSIZE) image=cv.cvtColor(src,cv.COLOR_BGR2RGB) blob=transform(image) c,h,w=blob.shape input_x=blob.view(1,c,h,w) defto_numpy(tensor): returntensor.detach().cpu().numpy()iftensor.requires_gradelsetensor.cpu().numpy() #computeONNXRuntimeoutputprediction ort_inputs={ort_session.get_inputs()[0].name:to_numpy(input_x)} ort_outs=ort_session.run(None,ort_inputs) #(N,4)dimensionalarraycontainingtheabsolutebounding-box boxes=ort_outs[0] scores=ort_outs[1] labels=ort_outs[2] print(boxes.shape,boxes.dtype,labels.shape,labels.dtype,scores.shape,scores.dtype) index=0 forx1,y1,x2,y2inboxes: ifscores[index]>0.65: cv.rectangle(src,(np.int32(x1),np.int32(y1)), (np.int32(x2),np.int32(y2)),(140,199,0),2,8,0) label_id=labels[index] label_txt=coco_names[str(label_id)] cv.putText(src,label_txt,(np.int32(x1),np.int32(y1)),cv.FONT_HERSHEY_SIMPLEX,0.75,(0,0,255),1) index+=1 cv.imshow("Retina-NetDetectionDemo",src) cv.imwrite("D:/mmc_result.png",src) cv.waitKey(0) cv.destroyAllWindows()
-
C++
+關注
關注
22文章
2108瀏覽量
73651 -
pytorch
+關注
關注
2文章
808瀏覽量
13226 -
訓練模型
+關注
關注
1文章
36瀏覽量
3821
原文標題:TorchVision對象檢測RetinaNet推理演示
文章出處:【微信號:CVSCHOOL,微信公眾號:OpenCV學堂】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論