Skip to content

Commit 598e3c6

Browse files
committed
update split rate 为1,不分割数据集
1 parent 2abdb8d commit 598e3c6

1 file changed

Lines changed: 19 additions & 12 deletions

File tree

dataset_tools/jade_create_paddle_text_detection_datasets.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -267,21 +267,28 @@ def SplitDataSets(image_path_list, ContaNumber_list,split_rate):
267267
repeat_list = []
268268
repeat_image_path_list = []
269269
train_image_file_list = []
270-
for i in range(len(image_path_list)):
271-
if ContaNumber_list[i] not in repeat_list:
272-
repeat_list.append(ContaNumber_list[i])
270+
if split_rate == 1:
271+
for i in range(len(image_path_list)):
273272
train_image_file_list.append(image_path_list[i])
274-
else:
275-
repeat_image_path_list.append(image_path_list[i])
276-
if split_rate > len(train_image_file_list) / len(image_path_list): ##应该从repeat里面分出一部分给train_image_file_list
277-
extra_count = ( int((split_rate - len(train_image_file_list) / len(image_path_list)) * len(image_path_list)))
278-
extra_image_path_list = random.sample(repeat_image_path_list,extra_count)
279-
train_image_file_list.extend(extra_image_path_list)
280-
test_image_files = [file for file in repeat_image_path_list if file not in extra_image_path_list]
273+
return train_image_file_list, train_image_file_list
281274
else:
282-
test_image_files = random.sample(image_path_list,int((1-split_rate)*len(image_path_list)))
275+
for i in range(len(image_path_list)):
276+
if ContaNumber_list[i] not in repeat_list:
277+
repeat_list.append(ContaNumber_list[i])
278+
train_image_file_list.append(image_path_list[i])
279+
else:
280+
repeat_image_path_list.append(image_path_list[i])
281+
if split_rate > len(train_image_file_list) / len(image_path_list): ##应该从repeat里面分出一部分给train_image_file_list
282+
extra_count = (int((split_rate - len(train_image_file_list) / len(image_path_list)) * len(image_path_list)))
283+
extra_image_path_list = random.sample(repeat_image_path_list, extra_count)
284+
train_image_file_list.extend(extra_image_path_list)
285+
test_image_files = [file for file in repeat_image_path_list if file not in extra_image_path_list]
286+
else:
287+
test_image_files = random.sample(image_path_list, int((1 - split_rate) * len(image_path_list)))
288+
289+
return train_image_file_list, test_image_files
290+
283291

284-
return train_image_file_list , test_image_files
285292

286293
def CreateTextDetDatasets(root_path, save_root_path, split_rate=0.9):
287294
##  

0 commit comments

Comments
 (0)