1+ #!/usr/bin/env python
2+ # -*- coding: utf-8 -*-
3+ # @File : cal_map.py
4+ # @Author : jade
5+ # @Date : 2022/9/9 10:08
6+ # @Email : jadehh@1ive.com
7+ # @Software : Samples
8+ # @Desc :
9+ import os
10+ from detector import Detector
11+ from dataset_tools .jade_voc_datasets import GetXmlClassesNames ,ProcessXml
12+ from opencv_tools .jade_visualize import *
13+ from opencv_tools .jade_opencv_process import *
14+ from jade import *
15+ import copy
16+ import matplotlib .pyplot as plt
17+ min_overlap = 0.5
18+ font_size = 24
19+ def cal_iou (bbgt ,bb ,gt_label ,pd_label ,class_name ):
20+ ovmax = min_overlap
21+ ov = 0
22+ wrong_class_status = False
23+ bi = [max (bb [0 ], bbgt [0 ]), max (bb [1 ], bbgt [1 ]), min (bb [2 ], bbgt [2 ]), min (bb [3 ], bbgt [3 ])]
24+ iw = bi [2 ] - bi [0 ] + 1
25+ ih = bi [3 ] - bi [1 ] + 1
26+ if iw > 0 and ih > 0 :
27+ # compute overlap (IoU) = area of intersection / area of union
28+ ua = (bb [2 ] - bb [0 ] + 1 ) * (bb [3 ] - bb [1 ] + 1 ) + (bbgt [2 ] - bbgt [0 ]
29+ + 1 ) * (bbgt [3 ] - bbgt [1 ] + 1 ) - iw * ih
30+ ov = iw * ih / ua
31+ if ov > ovmax :
32+ if pd_label == gt_label == class_name :
33+ pass
34+ else :
35+ ov = 0
36+ wrong_class_status = True
37+ else :
38+ ov = 0
39+
40+ return ov ,wrong_class_status
41+
42+ def voc_ap (rec , prec ):
43+ """
44+ --- Official matlab code VOC2012---
45+ mrec=[0 ; rec ; 1];
46+ mpre=[0 ; prec ; 0];
47+ for i=numel(mpre)-1:-1:1
48+ mpre(i)=max(mpre(i),mpre(i+1));
49+ end
50+ i=find(mrec(2:end)~=mrec(1:end-1))+1;
51+ ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
52+ """
53+ rec .insert (0 , 0.0 ) # insert 0.0 at begining of list
54+ rec .append (1.0 ) # insert 1.0 at end of list
55+ mrec = rec [:]
56+ prec .insert (0 , 0.0 ) # insert 0.0 at begining of list
57+ prec .append (0.0 ) # insert 0.0 at end of list
58+ mpre = prec [:]
59+ """
60+ This part makes the precision monotonically decreasing
61+ (goes from the end to the beginning)
62+ matlab: for i=numel(mpre)-1:-1:1
63+ mpre(i)=max(mpre(i),mpre(i+1));
64+ """
65+ # matlab indexes start in 1 but python in 0, so I have to do:
66+ # range(start=(len(mpre) - 2), end=0, step=-1)
67+ # also the python function range excludes the end, resulting in:
68+ # range(start=(len(mpre) - 2), end=-1, step=-1)
69+ for i in range (len (mpre )- 2 , - 1 , - 1 ):
70+ mpre [i ] = max (mpre [i ], mpre [i + 1 ])
71+ """
72+ This part creates a list of indexes where the recall changes
73+ matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
74+ """
75+ i_list = []
76+ for i in range (1 , len (mrec )):
77+ if mrec [i ] != mrec [i - 1 ]:
78+ i_list .append (i ) # if it was matlab would be i + 1
79+ """
80+ The Average Precision (AP) is the area under the curve
81+ (numerical integration)
82+ matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
83+ """
84+ ap = 0.0
85+ for i in i_list :
86+ ap += ((mrec [i ]- mrec [i - 1 ])* mpre [i ])
87+ return ap , mrec , mpre
88+
89+ def match_boxes (gt_preds ,pd_preds ,class_name ):
90+ gt_boxes = gt_preds ["boxes" ]
91+ gt_labels = gt_preds ["labels" ]
92+
93+ pd_boxes = pd_preds ["boxes" ].tolist ()
94+ pd_labels = pd_preds ["labels" ]
95+ ov_iou_list = []
96+ match_class_status_list = []
97+ for (i ,gt_box ) in enumerate (gt_boxes ):
98+ for (j ,pd_box ) in enumerate (pd_boxes ):
99+ ov_iou ,match_class_status = cal_iou (gt_box ,pd_box ,gt_labels [i ],pd_labels [j ],class_name )
100+ if ov_iou > 0 :
101+ ov_iou_list .append (ov_iou )
102+ match_class_status_list .append (match_class_status )
103+ gt_boxes .remove (gt_box )
104+ pd_boxes .remove (pd_box )
105+ return ov_iou_list ,match_class_status_list
106+ from opencv_tools import ReadChinesePath
107+ def cal_map (root_path ,detector ,is_test = True ,show_animation = True ):
108+ sum_AP = 0.0
109+ n_classes = 0
110+ label_txt_path = os .path .join (root_path ,"label_list.txt" )
111+ if is_test :
112+ data_txt_path = os .path .join (root_path ,"test.txt" )
113+ else :
114+ data_txt_path = os .path .join (root_path ,"train.txt" )
115+ class_name_list = []
116+ gt_counter_per_class = {}
117+ no_label_samples_list = []
118+ with open ("output.txt" , "w" ) as output_file :
119+ with open (label_txt_path , "rb" ) as f :
120+ for class_name_byte in f .readlines ():
121+ nd = 0
122+ n_classes = n_classes + 1
123+ with open (data_txt_path , "rb" ) as f2 :
124+ for data_byte in f2 .readlines ():
125+ class_name = str (class_name_byte , encoding = "utf-8" ).strip ()
126+ data_list = str (data_byte , encoding = "utf-8" ).strip ().split (" " )
127+ image_path = os .path .join (root_path , data_list [0 ])
128+ anno_path = os .path .join (root_path , data_list [1 ])
129+ anno_class_names = GetXmlClassesNames (anno_path )
130+ if anno_class_names :
131+ if class_name in anno_class_names :
132+ nd = nd + 1
133+ gt_counter_per_class [str (class_name_byte , encoding = "utf-8" ).strip ()] = nd
134+ tp = [0 ] * nd # creates an array of zeros of size nd
135+ fp = [0 ] * nd
136+ print ("正在计算{},检测准确率" .format (str (class_name_byte , encoding = "utf-8" ).strip ()))
137+ index = 0
138+ with open (data_txt_path , "rb" ) as f2 :
139+ for data_byte in f2 .readlines ():
140+ class_name = str (class_name_byte , encoding = "utf-8" ).strip ()
141+ data_list = str (data_byte , encoding = "utf-8" ).strip ().split (" " )
142+ image_path = os .path .join (root_path , data_list [0 ])
143+ anno_path = os .path .join (root_path , data_list [1 ])
144+ anno_class_names = GetXmlClassesNames (anno_path )
145+ img = ReadChinesePath (image_path )
146+ gt_preds = {}
147+ gt_preds_cp = {}
148+ over_lay_list = []
149+ if anno_class_names :
150+ if class_name in anno_class_names :
151+ results = detector .predict (img , class_type = class_name )
152+ imagename , shape , bboxes , labels_text , labels , difficult , truncated = ProcessXml (
153+ anno_path , is_rate = False )
154+ gt_preds ["boxes" ] = bboxes
155+ gt_preds ["labels" ] = labels_text
156+ gt_preds_cp ["boxes" ] = copy .copy (bboxes )
157+ gt_preds_cp ["labels" ] = labels_text
158+ over_lay_list , match_class_status_list = match_boxes (gt_preds , results , class_name )
159+ """
160+ Draw image to show animation
161+ """
162+ if show_animation :
163+ bottom_border = 60
164+ height , widht = img .shape [:2 ]
165+ # colors (OpenCV works with BGR)
166+ white = (255 , 255 , 255 )
167+ light_blue = (255 , 200 , 100 )
168+ green = (0 , 255 , 0 )
169+ light_red = (30 , 30 , 255 )
170+ # 1st line
171+ margin = 10
172+ v_pos = int (height - margin - (bottom_border / 2.0 ))
173+ text = "Image: " + GetLastDir (image_path ) + " "
174+ img = Add_Chinese_Label (img , text , (margin , v_pos - 48 ), color = white ,
175+ font_size = font_size )
176+ text = "Class Name : " + class_name + " "
177+ img = Add_Chinese_Label (img , text , (margin , v_pos - 24 ), light_blue ,
178+ font_size = font_size )
179+ if len (over_lay_list ) > 0 :
180+ color = light_red
181+ for ovmax in over_lay_list :
182+ width = font_size * len (text )
183+ text = "IoU: {0:.2f}% " .format (ovmax * 100 ) + "> {0:.2f}% " .format (
184+ min_overlap * 100 )
185+ img = Add_Chinese_Label (img , text , (margin + width , v_pos - 24 ), color ,
186+ font_size = font_size )
187+ for score in results ["scores" ]:
188+ text = "confidence: {0:.2f}% " .format (float (score ) * 100 )
189+ img = Add_Chinese_Label (img , text , (margin , v_pos ), white , font_size = font_size )
190+ color = light_red
191+ text = "Result: Match"
192+ if len (over_lay_list ) != len (gt_preds_cp ["boxes" ]):
193+ fp [index ] = 1
194+ else :
195+ tp [index ] = 1
196+ for match in match_class_status_list :
197+ if match :
198+ text = "Result: Not Match"
199+ tp [index ] = 0
200+ img = visualize (img , results )
201+ for (label , box ) in zip (gt_preds_cp ["labels" ], gt_preds_cp ["boxes" ]):
202+ if label == class_name :
203+ bb = [int (i ) for i in box ]
204+ img = cv2 .rectangle (img , (bb [0 ], bb [1 ]), (bb [2 ], bb [3 ]), color , 2 )
205+
206+ line_width = len (text ) * font_size
207+ img = Add_Chinese_Label (img , text , (margin + line_width , v_pos ), color )
208+ # show image
209+ cv2 .namedWindow ("Animation" , 0 )
210+ cv2 .imshow ("Animation" , img )
211+ cv2 .waitKey (20 ) # show for 20 ms
212+ index = index + 1
213+ else :
214+ ##不是此类别
215+ pass
216+ # print("不是此类别")
217+ else :
218+ if anno_path not in no_label_samples_list :
219+ no_label_samples_list .append (anno_path )
220+ ##没有目标区域
221+ results = detector .predict (img )
222+ print ("没有目标区域,目标检测结果为:{},标准的结果为:{},anno path = {}" .format (results , anno_class_names ,
223+ anno_path ))
224+ img = visualize (img , results )
225+ cv2 .namedWindow ("ERROR Detection" , 0 )
226+ cv2 .imshow ("ERROR Detection" , img )
227+ cv2 .waitKey (1 )
228+ print (tp , fp )
229+ # compute precision/recall
230+ cumsum = 0
231+ for idx , val in enumerate (fp ):
232+ fp [idx ] += cumsum
233+ cumsum += val
234+ cumsum = 0
235+ for idx , val in enumerate (tp ):
236+ tp [idx ] += cumsum
237+ cumsum += val
238+ # print(tp)
239+ rec = tp [:]
240+ for idx , val in enumerate (tp ):
241+ rec [idx ] = float (tp [idx ]) / gt_counter_per_class [class_name ]
242+ # print(rec)
243+ prec = tp [:]
244+ for idx , val in enumerate (tp ):
245+ prec [idx ] = float (tp [idx ]) / (fp [idx ] + tp [idx ])
246+ # print(prec)
247+
248+ ap , mrec , mprec = voc_ap (rec [:], prec [:])
249+
250+ sum_AP += ap
251+ text = "{0:.2f}%" .format (
252+ ap * 100 ) + " = " + class_name + " AP " # class_name + " AP = {0:.2f}%".format(ap*100)
253+ """
254+ Write to output.txt
255+ """
256+ rounded_prec = ['%.2f' % elem for elem in prec ]
257+ rounded_rec = ['%.2f' % elem for elem in rec ]
258+ output_file .write (
259+ text + "\n Precision: " + str (rounded_prec ) + "\n Recall :" + str (rounded_rec ) + "\n \n " )
260+ print (text )
261+
262+ plt .plot (rec , prec , '-o' )
263+ # add a new penultimate point to the list (mrec[-2], 0.0)
264+ # since the last line segment (and respective area) do not affect the AP value
265+ area_under_curve_x = mrec [:- 1 ] + [mrec [- 2 ]] + [mrec [- 1 ]]
266+ area_under_curve_y = mprec [:- 1 ] + [0.0 ] + [mprec [- 1 ]]
267+ plt .fill_between (area_under_curve_x , 0 , area_under_curve_y , alpha = 0.2 , edgecolor = 'r' )
268+ # set window title
269+ fig = plt .gcf () # gcf - get current figure
270+ fig .canvas .set_window_title ('AP ' + class_name )
271+ # set plot title
272+ plt .title ('class: ' + text )
273+ # plt.suptitle('This is a somewhat long figure title', fontsize=16)
274+ # set axis titles
275+ plt .xlabel ('Recall' )
276+ plt .ylabel ('Precision' )
277+ # optional - set axes
278+ axes = plt .gca () # gca - get current axes
279+ axes .set_xlim ([0.0 , 1.0 ])
280+ axes .set_ylim ([0.0 , 1.05 ]) # .05 to give some extra space
281+ # Alternative option -> wait for button to be pressed
282+ # while not plt.waitforbuttonpress(): pass # wait for key display
283+ # Alternative option -> normal display
284+ # plt.show()
285+ # save the plot
286+ CreateSavePath ("classes" )
287+ fig .savefig ("classes/" + class_name + ".png" )
288+ plt .cla () # clear axes for next plot
289+
290+ output_file .write ("\n # mAP of all classes\n " )
291+ mAP = sum_AP / n_classes
292+ text = "mAP = {0:.2f}%" .format (mAP * 100 )
293+ print (text )
294+
295+
296+ if __name__ == '__main__' :
297+ detector = Detector (r"H:\PycharmProjects\Github\ContainerOCR\models\ContaDetModels\2021-12-03" )
298+ cal_map (r"E:\Data\VOC数据集\箱门检测数据集\ContainVOC" ,detector ,is_test = False )
0 commit comments