Skip to content

Commit 6e14873

Browse files
committed
* 使用脚本制作数据集
1 parent 091214d commit 6e14873

File tree

3 files changed

+102
-8
lines changed

3 files changed

+102
-8
lines changed

dataset_tools/jade_create_object_dection_datasets.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414
import xml.etree.ElementTree as ET
1515
from dataset_tools.jade_voc_datasets import GetXmlClassesNames
1616

17-
def CreateSavePath(save_path):
18-
if os.path.exists(save_path):
19-
return save_path
20-
else:
21-
os.makedirs(save_path)
22-
return save_path
2317

2418
def ProcessXml(xml_path):
2519
# Read the XML annotation file.
@@ -108,6 +102,85 @@ def CreateYearsDatasets(dir,year=None,save_path=None,rate=0.95):
108102
CreateLabelList(save_path)
109103

110104

105+
# VOC 数据集转换为Darknet数据集
106+
def CreateYearsDarknetVocDatasets(dir, year=None, save_path=None, rate=0.95,VOC_CLASSES=None):
107+
years = os.listdir(dir)
108+
if year is None:
109+
progressBar1 = ProgressBar(len(years))
110+
else:
111+
progressBar1 = ProgressBar(1)
112+
if os.path.exists(save_path):
113+
pass
114+
else:
115+
os.makedirs(save_path)
116+
if year is None:
117+
for year in years:
118+
if os.path.isdir(os.path.join(dir, year)):
119+
if os.path.exists(os.path.join(dir, year, DIRECTORY_IMAGES)) and os.path.exists(
120+
os.path.join(dir, year, DIRECTORY_ANNOTATIONS)):
121+
CreateDarknetVocDatasets(os.path.join(dir, year), save_path, rate,VOC_CLASSES)
122+
progressBar1.update()
123+
else:
124+
if os.path.isdir(os.path.join(dir, year)):
125+
if os.path.exists(os.path.join(dir, year, DIRECTORY_IMAGES)) and os.path.exists(os.path.join(dir, year, DIRECTORY_ANNOTATIONS)):
126+
CreateDarknetVocDatasets(os.path.join(dir, year), save_path, rate,VOC_CLASSES)
127+
progressBar1.update()
128+
129+
with open(os.path.join(save_path,"classes.txt"),"wb") as f:
130+
for class_name in VOC_CLASSES:
131+
f.write((class_name+"\n").encode("utf-8"))
132+
133+
134+
def convert_voc_to_yolo(xml_dir, output_dir, classes):
135+
tree = ET.parse(xml_dir)
136+
root = tree.getroot()
137+
img_w = int(root.find('size/width').text)
138+
img_h = int(root.find('size/height').text)
139+
140+
with open(os.path.join(output_dir), 'w') as f:
141+
for obj in root.findall('object'):
142+
cls_name = obj.find('name').text
143+
cls_id = classes.index(cls_name)
144+
bbox = obj.find('bndbox')
145+
x_center = (int(bbox.find('xmin').text) + int(bbox.find('xmax').text)) / 2 / img_w
146+
y_center = (int(bbox.find('ymin').text) + int(bbox.find('ymax').text)) / 2 / img_h
147+
width = (int(bbox.find('xmax').text) - int(bbox.find('xmin').text)) / img_w
148+
height = (int(bbox.find('ymax').text) - int(bbox.find('ymin').text)) / img_h
149+
f.write(f"{cls_id} {x_center} {y_center} {width} {height}\n")
150+
151+
def CreateDarknetVocDataset(dir,save_path,image_files,dataset_type,remove_label="None",VOC_CLASSES=None):
152+
save_image_path = CreateSavePath(os.path.join(save_path,"images",dataset_type,))
153+
save_label_path = CreateSavePath(os.path.join(save_path,"labels",dataset_type))
154+
for image_file in image_files:
155+
with open(os.path.join(dir, DIRECTORY_IMAGES, image_file), "rb") as f2:
156+
if len(f2.read()) == 0:
157+
pass
158+
else:
159+
class_name_list = GetXmlClassesNames(os.path.join(dir, DIRECTORY_ANNOTATIONS, image_file[:-4] + ".xml"))
160+
if len(class_name_list) > 0 and remove_label not in class_name_list:
161+
shutil.copy(os.path.join(dir, DIRECTORY_IMAGES, image_file), save_image_path)
162+
convert_voc_to_yolo(os.path.join(dir,DIRECTORY_ANNOTATIONS,image_file[:-4] + ".xml"),os.path.join(save_label_path,image_file[:-4] + ".txt"),VOC_CLASSES)
163+
else:
164+
print(os.path.join(dir, DIRECTORY_ANNOTATIONS, image_file[:-4] + ".xml"))
165+
166+
167+
168+
169+
def CreateDarknetVocDatasets(dir,save_path,rate,VOC_CLASSES):
170+
"""
171+
:param dir:
172+
"""
173+
image_files = os.listdir(os.path.join(dir, DIRECTORY_IMAGES))
174+
train_image_files = random.sample(image_files, int(len(image_files) *rate))
175+
test_image_files = [file for file in image_files if file not in train_image_files]
176+
CreateDarknetVocDataset(dir,save_path,train_image_files,"train",VOC_CLASSES=VOC_CLASSES)
177+
CreateDarknetVocDataset(dir,save_path,test_image_files,"test",VOC_CLASSES=VOC_CLASSES)
178+
179+
180+
181+
182+
183+
111184
##制作VOC数据集
112185
def CreateVOCDataset(dir, datasetname,save_path=None,rate=0.95):
113186
"""

main.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,28 @@
99
from dataset_tools.jade_create_paddle_text_detection_datasets import *
1010
from dataset_tools.jade_create_paddle_ocr_datasets import *
1111
from dataset_tools.jade_create_object_dection_datasets import CreateYearsDatasets
12+
13+
def test_create_paddle_years_datasets(args):
14+
CreateYearsDatasets(args.input_dataset_dir,None,save_path=args.save_dataset_dir,rate=0.95)
15+
1216
if __name__ == '__main__':
17+
import argparse
18+
parser = argparse.ArgumentParser(description="制作数据集脚本")
19+
parser.add_argument("--dataset_type", default='paddle', help="制作数据集的类型")
20+
parser.add_argument("--input_dataset_dir", default='test', help="数据集的地址")
21+
parser.add_argument("--save_dataset_dir", default='test/output_seals_01', help="保存数据集的地址")
22+
parser.add_argument("--voc_labels", nargs='+',default="", help="类别")
23+
args = parser.parse_args()
24+
print(list(args.voc_labels))
25+
if args.dataset_type == "paddle_detection":
26+
CreateYearsDatasets(args.input_dataset_dir, None, save_path=args.save_dataset_dir, rate=0.95)
27+
elif args.dataset_type == "yolo_detection":
28+
CreateDarknetVocDatasets(args.input_dataset_dir, save_path=args.save_dataset_dir, rate=0.95, VOC_CLASSES=args.voc_labels)
29+
1330
#removeNolabelDatasets(r"F:\数据集\关键点检测数据集\定制版箱号关键点数据集\2022-03-09")
1431
#create_text_detection_datasets(r"F:\数据集\关键点检测数据集\定制版箱号关键点数据集",r"E:\Data\字符检测识别数据集\定制版箱号关键点数据集",0.95)
1532
#CreatePaddleOCRDatasets(root_path="E:\Data\字符检测识别数据集\镇江大港厂内车牌关键点检测数据集", save_path="E:\Data\OCR\镇江大港厂内车牌识别数据集",dataset_type="镇江厂内车牌数据集")
1633
#removeNolabelVocDatasets(r"E:\Data\VOC数据集\集装箱残损检测数据集")
1734
#CreateYearsDatasets(r"E:\Data\VOC数据集\集装箱残损检测数据集")
18-
create_text_detection_datasets(r"F:\数据集\关键点检测数据集\箱号关键点数据集",r'E:\Data\字符检测识别数据集\箱号关键点数据集')
1935
#CreatePaddleOCRDatasets(r'F:\数据集\VOC数据集\箱门检测数据集\ContainVOC', save_path="E:\Data\OCR\箱号识别数据集",dataset_type="箱号数据集")
2036
#CreateYearsDatasets("F:\数据集\VOC数据集\验残集装箱检测数据集",0.95)

test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ def testCreateYearsDatasets():
1515
"""
1616
CreateYearsDatasets(r"F:\数据集\VOC数据集\验残集装箱检测数据集", save_path=r"E:\Data\VOC数据集\验残集装箱检测数据集")
1717

18+
def testCreateYearsDarknetVocDatasets():
19+
# VOC_CLASSES = ["container"]
20+
VOC_CLASSES = ["FRONTEND","DOOREND","UPEND","slide","bromine_tank"]
21+
CreateYearsDarknetVocDatasets(r"F:\数据集\VOC数据集\箱门检测数据集\ContainVOC", save_path=r"E:\Data\VOC数据集\箱门检测数据集\ContainerVOCDarknet",VOC_CLASSES=VOC_CLASSES)
22+
1823

1924
if __name__ == '__main__':
20-
testCreateYearsDatasets()
25+
testCreateYearsDarknetVocDatasets()

0 commit comments

Comments
 (0)