Skip to content

Commit 9b5d698

Browse files
committed
update 添加VOC map 计算方法
1 parent 41398b3 commit 9b5d698

1 file changed

Lines changed: 298 additions & 0 deletions

File tree

dataset_tools/cal_voc_map.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# @File : cal_map.py
4+
# @Author : jade
5+
# @Date : 2022/9/9 10:08
6+
# @Email : jadehh@1ive.com
7+
# @Software : Samples
8+
# @Desc :
9+
import os
10+
from detector import Detector
11+
from dataset_tools.jade_voc_datasets import GetXmlClassesNames,ProcessXml
12+
from opencv_tools.jade_visualize import *
13+
from opencv_tools.jade_opencv_process import *
14+
from jade import *
15+
import copy
16+
import matplotlib.pyplot as plt
17+
min_overlap = 0.5
18+
font_size = 24
19+
def cal_iou(bbgt,bb,gt_label,pd_label,class_name):
20+
ovmax = min_overlap
21+
ov = 0
22+
wrong_class_status = False
23+
bi = [max(bb[0], bbgt[0]), max(bb[1], bbgt[1]), min(bb[2], bbgt[2]), min(bb[3], bbgt[3])]
24+
iw = bi[2] - bi[0] + 1
25+
ih = bi[3] - bi[1] + 1
26+
if iw > 0 and ih > 0:
27+
# compute overlap (IoU) = area of intersection / area of union
28+
ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
29+
+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
30+
ov = iw * ih / ua
31+
if ov > ovmax:
32+
if pd_label == gt_label == class_name:
33+
pass
34+
else:
35+
ov = 0
36+
wrong_class_status = True
37+
else:
38+
ov = 0
39+
40+
return ov,wrong_class_status
41+
42+
def voc_ap(rec, prec):
43+
"""
44+
--- Official matlab code VOC2012---
45+
mrec=[0 ; rec ; 1];
46+
mpre=[0 ; prec ; 0];
47+
for i=numel(mpre)-1:-1:1
48+
mpre(i)=max(mpre(i),mpre(i+1));
49+
end
50+
i=find(mrec(2:end)~=mrec(1:end-1))+1;
51+
ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
52+
"""
53+
rec.insert(0, 0.0) # insert 0.0 at begining of list
54+
rec.append(1.0) # insert 1.0 at end of list
55+
mrec = rec[:]
56+
prec.insert(0, 0.0) # insert 0.0 at begining of list
57+
prec.append(0.0) # insert 0.0 at end of list
58+
mpre = prec[:]
59+
"""
60+
This part makes the precision monotonically decreasing
61+
(goes from the end to the beginning)
62+
matlab: for i=numel(mpre)-1:-1:1
63+
mpre(i)=max(mpre(i),mpre(i+1));
64+
"""
65+
# matlab indexes start in 1 but python in 0, so I have to do:
66+
# range(start=(len(mpre) - 2), end=0, step=-1)
67+
# also the python function range excludes the end, resulting in:
68+
# range(start=(len(mpre) - 2), end=-1, step=-1)
69+
for i in range(len(mpre)-2, -1, -1):
70+
mpre[i] = max(mpre[i], mpre[i+1])
71+
"""
72+
This part creates a list of indexes where the recall changes
73+
matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
74+
"""
75+
i_list = []
76+
for i in range(1, len(mrec)):
77+
if mrec[i] != mrec[i-1]:
78+
i_list.append(i) # if it was matlab would be i + 1
79+
"""
80+
The Average Precision (AP) is the area under the curve
81+
(numerical integration)
82+
matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
83+
"""
84+
ap = 0.0
85+
for i in i_list:
86+
ap += ((mrec[i]-mrec[i-1])*mpre[i])
87+
return ap, mrec, mpre
88+
89+
def match_boxes(gt_preds,pd_preds,class_name):
90+
gt_boxes = gt_preds["boxes"]
91+
gt_labels = gt_preds["labels"]
92+
93+
pd_boxes = pd_preds["boxes"].tolist()
94+
pd_labels = pd_preds["labels"]
95+
ov_iou_list = []
96+
match_class_status_list = []
97+
for (i,gt_box) in enumerate(gt_boxes):
98+
for (j,pd_box) in enumerate(pd_boxes):
99+
ov_iou,match_class_status = cal_iou(gt_box,pd_box,gt_labels[i],pd_labels[j],class_name)
100+
if ov_iou > 0:
101+
ov_iou_list.append(ov_iou)
102+
match_class_status_list.append(match_class_status)
103+
gt_boxes.remove(gt_box)
104+
pd_boxes.remove(pd_box)
105+
return ov_iou_list,match_class_status_list
106+
from opencv_tools import ReadChinesePath
107+
def cal_map(root_path,detector,is_test=True,show_animation=True):
108+
sum_AP = 0.0
109+
n_classes = 0
110+
label_txt_path = os.path.join(root_path,"label_list.txt")
111+
if is_test:
112+
data_txt_path = os.path.join(root_path,"test.txt")
113+
else:
114+
data_txt_path = os.path.join(root_path,"train.txt")
115+
class_name_list = []
116+
gt_counter_per_class = {}
117+
no_label_samples_list = []
118+
with open("output.txt", "w") as output_file:
119+
with open(label_txt_path, "rb") as f:
120+
for class_name_byte in f.readlines():
121+
nd = 0
122+
n_classes = n_classes + 1
123+
with open(data_txt_path, "rb") as f2:
124+
for data_byte in f2.readlines():
125+
class_name = str(class_name_byte, encoding="utf-8").strip()
126+
data_list = str(data_byte, encoding="utf-8").strip().split(" ")
127+
image_path = os.path.join(root_path, data_list[0])
128+
anno_path = os.path.join(root_path, data_list[1])
129+
anno_class_names = GetXmlClassesNames(anno_path)
130+
if anno_class_names:
131+
if class_name in anno_class_names:
132+
nd = nd + 1
133+
gt_counter_per_class[str(class_name_byte, encoding="utf-8").strip()] = nd
134+
tp = [0] * nd # creates an array of zeros of size nd
135+
fp = [0] * nd
136+
print("正在计算{},检测准确率".format(str(class_name_byte, encoding="utf-8").strip()))
137+
index = 0
138+
with open(data_txt_path, "rb") as f2:
139+
for data_byte in f2.readlines():
140+
class_name = str(class_name_byte, encoding="utf-8").strip()
141+
data_list = str(data_byte, encoding="utf-8").strip().split(" ")
142+
image_path = os.path.join(root_path, data_list[0])
143+
anno_path = os.path.join(root_path, data_list[1])
144+
anno_class_names = GetXmlClassesNames(anno_path)
145+
img = ReadChinesePath(image_path)
146+
gt_preds = {}
147+
gt_preds_cp = {}
148+
over_lay_list = []
149+
if anno_class_names:
150+
if class_name in anno_class_names:
151+
results = detector.predict(img, class_type=class_name)
152+
imagename, shape, bboxes, labels_text, labels, difficult, truncated = ProcessXml(
153+
anno_path, is_rate=False)
154+
gt_preds["boxes"] = bboxes
155+
gt_preds["labels"] = labels_text
156+
gt_preds_cp["boxes"] = copy.copy(bboxes)
157+
gt_preds_cp["labels"] = labels_text
158+
over_lay_list, match_class_status_list = match_boxes(gt_preds, results, class_name)
159+
"""
160+
Draw image to show animation
161+
"""
162+
if show_animation:
163+
bottom_border = 60
164+
height, widht = img.shape[:2]
165+
# colors (OpenCV works with BGR)
166+
white = (255, 255, 255)
167+
light_blue = (255, 200, 100)
168+
green = (0, 255, 0)
169+
light_red = (30, 30, 255)
170+
# 1st line
171+
margin = 10
172+
v_pos = int(height - margin - (bottom_border / 2.0))
173+
text = "Image: " + GetLastDir(image_path) + " "
174+
img = Add_Chinese_Label(img, text, (margin, v_pos - 48), color=white,
175+
font_size=font_size)
176+
text = "Class Name : " + class_name + " "
177+
img = Add_Chinese_Label(img, text, (margin, v_pos - 24), light_blue,
178+
font_size=font_size)
179+
if len(over_lay_list) > 0:
180+
color = light_red
181+
for ovmax in over_lay_list:
182+
width = font_size * len(text)
183+
text = "IoU: {0:.2f}% ".format(ovmax * 100) + "> {0:.2f}% ".format(
184+
min_overlap * 100)
185+
img = Add_Chinese_Label(img, text, (margin + width, v_pos - 24), color,
186+
font_size=font_size)
187+
for score in results["scores"]:
188+
text = "confidence: {0:.2f}% ".format(float(score) * 100)
189+
img = Add_Chinese_Label(img, text, (margin, v_pos), white, font_size=font_size)
190+
color = light_red
191+
text = "Result: Match"
192+
if len(over_lay_list) != len(gt_preds_cp["boxes"]):
193+
fp[index] = 1
194+
else:
195+
tp[index] = 1
196+
for match in match_class_status_list:
197+
if match:
198+
text = "Result: Not Match"
199+
tp[index] = 0
200+
img = visualize(img, results)
201+
for (label, box) in zip(gt_preds_cp["labels"], gt_preds_cp["boxes"]):
202+
if label == class_name:
203+
bb = [int(i) for i in box]
204+
img = cv2.rectangle(img, (bb[0], bb[1]), (bb[2], bb[3]), color, 2)
205+
206+
line_width = len(text) * font_size
207+
img = Add_Chinese_Label(img, text, (margin + line_width, v_pos), color)
208+
# show image
209+
cv2.namedWindow("Animation", 0)
210+
cv2.imshow("Animation", img)
211+
cv2.waitKey(20) # show for 20 ms
212+
index = index + 1
213+
else:
214+
##不是此类别
215+
pass
216+
# print("不是此类别")
217+
else:
218+
if anno_path not in no_label_samples_list:
219+
no_label_samples_list.append(anno_path)
220+
##没有目标区域
221+
results = detector.predict(img)
222+
print("没有目标区域,目标检测结果为:{},标准的结果为:{},anno path = {}".format(results, anno_class_names,
223+
anno_path))
224+
img = visualize(img, results)
225+
cv2.namedWindow("ERROR Detection", 0)
226+
cv2.imshow("ERROR Detection", img)
227+
cv2.waitKey(1)
228+
print(tp, fp)
229+
# compute precision/recall
230+
cumsum = 0
231+
for idx, val in enumerate(fp):
232+
fp[idx] += cumsum
233+
cumsum += val
234+
cumsum = 0
235+
for idx, val in enumerate(tp):
236+
tp[idx] += cumsum
237+
cumsum += val
238+
# print(tp)
239+
rec = tp[:]
240+
for idx, val in enumerate(tp):
241+
rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name]
242+
# print(rec)
243+
prec = tp[:]
244+
for idx, val in enumerate(tp):
245+
prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx])
246+
# print(prec)
247+
248+
ap, mrec, mprec = voc_ap(rec[:], prec[:])
249+
250+
sum_AP += ap
251+
text = "{0:.2f}%".format(
252+
ap * 100) + " = " + class_name + " AP " # class_name + " AP = {0:.2f}%".format(ap*100)
253+
"""
254+
Write to output.txt
255+
"""
256+
rounded_prec = ['%.2f' % elem for elem in prec]
257+
rounded_rec = ['%.2f' % elem for elem in rec]
258+
output_file.write(
259+
text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
260+
print(text)
261+
262+
plt.plot(rec, prec, '-o')
263+
# add a new penultimate point to the list (mrec[-2], 0.0)
264+
# since the last line segment (and respective area) do not affect the AP value
265+
area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
266+
area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
267+
plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
268+
# set window title
269+
fig = plt.gcf() # gcf - get current figure
270+
fig.canvas.set_window_title('AP ' + class_name)
271+
# set plot title
272+
plt.title('class: ' + text)
273+
# plt.suptitle('This is a somewhat long figure title', fontsize=16)
274+
# set axis titles
275+
plt.xlabel('Recall')
276+
plt.ylabel('Precision')
277+
# optional - set axes
278+
axes = plt.gca() # gca - get current axes
279+
axes.set_xlim([0.0, 1.0])
280+
axes.set_ylim([0.0, 1.05]) # .05 to give some extra space
281+
# Alternative option -> wait for button to be pressed
282+
# while not plt.waitforbuttonpress(): pass # wait for key display
283+
# Alternative option -> normal display
284+
# plt.show()
285+
# save the plot
286+
CreateSavePath("classes")
287+
fig.savefig("classes/" + class_name + ".png")
288+
plt.cla() # clear axes for next plot
289+
290+
output_file.write("\n# mAP of all classes\n")
291+
mAP = sum_AP / n_classes
292+
text = "mAP = {0:.2f}%".format(mAP * 100)
293+
print(text)
294+
295+
296+
if __name__ == '__main__':
297+
detector = Detector(r"H:\PycharmProjects\Github\ContainerOCR\models\ContaDetModels\2021-12-03")
298+
cal_map(r"E:\Data\VOC数据集\箱门检测数据集\ContainVOC",detector,is_test=False)

0 commit comments

Comments
 (0)