@@ -189,18 +189,23 @@ def get_match_best_text(self, results):
189189 return match_best_text , "" , max_key_val
190190
191191
192- class CreatePaddleOCRDatasets (object ):
193- def __init__ (self , root_path , save_path , dataset_type = None ):
192+ class create_paddle_ocr_datasets (object ):
193+ def __init__ (self , root_path , save_path , dataset_type = None , year = "" ):
194194 self .root_path = root_path
195+ self .year = year
196+ if year :
197+ self .root_path = os .path .join (root_path ,year )
198+ else :
199+ if os .path .exists (save_path ):
200+ try :
201+ shutil .rmtree (save_path )
202+ except :
203+ print ("删除文件夹失败,文件夹为:{}" .format (save_path ))
195204 self .save_path = save_path
196205 self .conta_check_model = ContaNumber ()
197206 self .dataset_type = dataset_type ## 数据集类型,如车牌数据集,箱号数据集
198207 label_list = self .get_label_text_path ()
199- if os .path .exists (save_path ):
200- try :
201- shutil .rmtree (save_path )
202- except :
203- print ("删除文件夹失败,文件夹为:{}" .format (save_path ))
208+
204209 for label_path in label_list :
205210 self .createOCRDatasets (label_path )
206211
@@ -304,13 +309,15 @@ def get_label_text_path(self):
304309 label_path_list .append (os .path .join (self .root_path , filename ))
305310 return label_path_list
306311
307- def createOCRDatasets (self , label_txt_path ):
312+
313+ def write_datasets (self ,label_txt_path ,year = "" ):
308314 save_h_path = CreateSavePath (os .path .join (self .save_path , "OCRH" ))
309315 save_v_path = CreateSavePath (os .path .join (self .save_path , "OCRV" ))
316+ privous_dir = GetPreviousDir (label_txt_path )
317+
310318 istrain = False
311319 if "train" in label_txt_path :
312320 istrain = True
313- privous_dir = GetPreviousDir (label_txt_path )
314321 all_image_width = 0
315322 all_image_height = 0
316323 all_image_count = 0
@@ -319,9 +326,14 @@ def createOCRDatasets(self, label_txt_path):
319326 index = 0
320327 processBar = ProgressBar (len (content_list ))
321328 for content_byte in content_list :
322- content = str (content_byte ,"utf-8" ).strip ()
323- save_h_detail_path = CreateSavePath (os .path .join (save_h_path , content .split ("/" )[0 ]))
324- save_v_detail_path = CreateSavePath (os .path .join (save_v_path , content .split ("/" )[0 ]))
329+ content = str (content_byte , "utf-8" ).strip ()
330+ if self .year :
331+ save_h_detail_path = CreateSavePath (os .path .join (save_h_path ,self .year ))
332+ save_v_detail_path = CreateSavePath (os .path .join (save_v_path ,self .year ))
333+ else :
334+ save_h_detail_path = CreateSavePath (os .path .join (save_h_path , content .split ("/" )[0 ]))
335+ save_v_detail_path = CreateSavePath (os .path .join (save_v_path , content .split ("/" )[0 ]))
336+
325337 save_h_detail_train_path = CreateSavePath (os .path .join (save_h_detail_path , "train" ))
326338 save_h_detail_test_path = CreateSavePath (os .path .join (save_h_detail_path , "test" ))
327339
@@ -351,7 +363,7 @@ def createOCRDatasets(self, label_txt_path):
351363 txt = self .verification_rules (txt_orignal )
352364 if txt :
353365 if istrain is False :
354- if h < w : ## 水平
366+ if h < w : ## 水平
355367 cv2 .imencode ('.jpg' , txt_img * 255 )[1 ].tofile (
356368 os .path .join (save_h_detail_test_path , image_name ))
357369 all_image_width = all_image_width + txt_img .shape [1 ]
@@ -399,7 +411,14 @@ def createOCRDatasets(self, label_txt_path):
399411 index = index + 1
400412 processBar .update ()
401413
402- print ("平均高度为:{},平均宽度为:{}" .format (all_image_height / all_image_count ,all_image_width / all_image_count ))
414+ print (
415+ "平均高度为:{},平均宽度为:{}" .format (all_image_height / all_image_count , all_image_width / all_image_count ))
416+
417+ def createOCRDatasets (self , label_txt_path ):
418+ self .write_datasets (label_txt_path )
419+
420+
421+
403422
404423 def createDatasets (self , root_path ):
405424 if os .path .exists (os .path .join (root_path , "rec_gt_train.txt" )) is True :
0 commit comments