Skip to content

Commit 5e39ad3

Browse files
committed
优化了6个计算指标,忽略预测长度为0、gt长度为0的数据,只计算有效数据
1 parent 931e6dd commit 5e39ad3

6 files changed

Lines changed: 242 additions & 78 deletions

File tree

examples/basic_usage.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -796,13 +796,15 @@ def demo_multi_extraction():
796796
from webmainbench import DataLoader, DataSaver, Evaluator, ExtractorFactory
797797
from pathlib import Path
798798
import time
799+
800+
799801
# 设置日志
800802
setup_logging(level="INFO")
801803

802804
# 配置文件路径
803805
data_dir = Path("../data")
804-
dataset_path = data_dir / "sample_dataset.jsonl"
805-
# dataset_path = "/home/lulindong/Pycharm_projects/cc/test.jsonl"
806+
# dataset_path = data_dir / "sample_dataset.jsonl"
807+
dataset_path = "/home/lulindong/Pycharm_projects/cc/WebMainBench_llm-webkit_v1_WebMainBench_dataset_merge_with_llm_webkit.jsonl"
806808

807809
print(f"📂 数据集文件: {dataset_path}")
808810

@@ -815,7 +817,6 @@ def demo_multi_extraction():
815817
"list_bullets": True,
816818
"preserve_formatting": True
817819
}},
818-
819820
{"name": "trafilatura", "config": {}},
820821
{"name": "magic-html", "config": {}},
821822
]
@@ -902,7 +903,7 @@ def demo_multi_extraction():
902903
all_results.append(result)
903904

904905
# 保存带有当前抽取器内容的数据集
905-
enriched_dataset_path = results_dir / f"{dataset.name}_with_{extractor.name}_extraction.jsonl"
906+
enriched_dataset_path = results_dir / f"{dataset.name}_{extractor.name}_extraction_infer.jsonl"
906907
DataSaver.save_dataset_with_extraction(
907908
results=result,
908909
dataset=dataset,

webmainbench/evaluator/evaluator.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -361,32 +361,72 @@ def _evaluate_sample(self, sample: DataSample, extractor: BaseExtractor) -> Dict
361361

362362
def _aggregate_metrics(self, sample_results: List[Dict[str, Any]]) -> Dict[str, float]:
363363
"""Aggregate metrics across all samples."""
364-
# Collect metric results by metric name
365-
metric_groups = {}
366-
367-
for sample_result in sample_results:
368-
if not sample_result.get('extraction_success', True):
369-
continue
370-
371-
metrics = sample_result.get('metrics', {})
372-
for metric_name, metric_data in metrics.items():
373-
if metric_data.get('success', False):
374-
if metric_name not in metric_groups:
375-
metric_groups[metric_name] = []
376-
metric_groups[metric_name].append(metric_data['score'])
377-
378-
# Calculate aggregated scores
379-
aggregated_metrics = {}
380-
for metric_name, scores in metric_groups.items():
381-
if scores:
382-
aggregated_metrics[metric_name] = sum(scores) / len(scores)
364+
# # Collect metric results by metric name
365+
# metric_groups = {}
366+
#
367+
# for sample_result in sample_results:
368+
# if not sample_result.get('extraction_success', True):
369+
# continue
370+
#
371+
# metrics = sample_result.get('metrics', {})
372+
# for metric_name, metric_data in metrics.items():
373+
# if metric_data.get('success', False):
374+
# if metric_name not in metric_groups:
375+
# metric_groups[metric_name] = []
376+
# metric_groups[metric_name].append(metric_data['score'])
377+
#
378+
# # Calculate aggregated scores
379+
# aggregated_metrics = {}
380+
# for metric_name, scores in metric_groups.items():
381+
# if scores:
382+
# aggregated_metrics[metric_name] = sum(scores) / len(scores)
383+
# else:
384+
# aggregated_metrics[metric_name] = 0.0
385+
#
386+
# # overall score is already calculated by MetricCalculator
387+
# # No need to override it here
388+
#
389+
# return aggregated_metrics
390+
"""
391+
聚合所有样本的指标,计算全局平均值(每个指标单独聚合)
392+
"""
393+
if not sample_results:
394+
return {}
395+
396+
# 初始化每个指标的总分和样本数
397+
metric_totals = {
398+
"text_edit": 0.0,
399+
"code_edit": 0.0,
400+
"table_edit": 0.0,
401+
"table_TEDS": 0.0,
402+
"formula_edit": 0.0,
403+
"overall": 0.0 # 全局overall单独计算
404+
}
405+
metric_counts = {k: 0 for k in metric_totals.keys()} # 记录每个指标有效样本数
406+
407+
# 累加所有样本的指标分数
408+
for sample in sample_results:
409+
metrics = sample.get("metrics", {})
410+
for metric_name in metric_totals.keys():
411+
if metric_name in metrics and metrics[metric_name].get("success", False):
412+
metric_totals[metric_name] += metrics[metric_name]["score"]
413+
metric_counts[metric_name] += 1
414+
415+
# 计算每个指标的平均值(全局overall为5个单项指标的平均值)
416+
overall_metrics = {}
417+
for metric_name in metric_totals.keys():
418+
if metric_counts[metric_name] > 0:
419+
overall_metrics[metric_name] = metric_totals[metric_name] / metric_counts[metric_name]
383420
else:
384-
aggregated_metrics[metric_name] = 0.0
385-
386-
# overall score is already calculated by MetricCalculator
387-
# No need to override it here
388-
389-
return aggregated_metrics
421+
overall_metrics[metric_name] = 0.0 # 无有效样本时默认为0
422+
423+
# 特别处理全局overall:固定为5个单项指标的平均值(无论单项是否有有效样本)
424+
# 排除样本级overall,仅用5个核心指标计算全局overall
425+
core_metrics = ["text_edit", "code_edit", "table_edit", "table_TEDS", "formula_edit"]
426+
core_scores = [overall_metrics[metric] for metric in core_metrics]
427+
overall_metrics["overall"] = sum(core_scores) / len(core_metrics)
428+
429+
return overall_metrics
390430

391431
def _calculate_category_metrics(self, sample_results: List[Dict[str, Any]],
392432
samples: List[DataSample]) -> Optional[Dict[str, Dict[str, float]]]:

webmainbench/extractors/trafilatura_extractor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class TrafilaturaInferenceConfig:
2020
# 可根据需要添加更多trafilatura支持的参数
2121
include_images: bool = False
2222
include_links: bool = False
23+
# 新增:支持的输出格式(txt/markdown/json/xml等)
24+
output_format: str = "markdown" # 默认保持纯文本
2325

2426

2527
@extractor("trafilatura")
@@ -65,7 +67,8 @@ def _extract_content(self, html: str, url: str = None) -> ExtractionResult:
6567
include_comments=self.inference_config.include_comments,
6668
include_tables=self.inference_config.include_tables,
6769
include_images=self.inference_config.include_images,
68-
include_links=self.inference_config.include_links
70+
include_links=self.inference_config.include_links,
71+
output_format=self.inference_config.output_format # 传入输出格式
6972
)
7073

7174
# 创建 content_list(简单分割段落)

webmainbench/metrics/calculator.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,36 +70,76 @@ def calculate_all(self, predicted_content: str,
7070
Returns:
7171
Dictionary mapping metric names to MetricResult instances
7272
"""
73-
results = {}
74-
75-
for metric_name, metric in self.metrics.items():
76-
try:
77-
if metric_name in ["edit_distance", "bleu", "rouge"]:
78-
# Text-based metrics
79-
result = metric.calculate(predicted_content, groundtruth_content, **kwargs)
80-
elif metric_name in ["code_edit", "formula_edit",
81-
"table_edit", "table_TEDS", "text_edit"]:
82-
# 新的内容类型指标,需要传递 content_list
83-
result = metric.calculate(
84-
predicted_content,
85-
groundtruth_content,
86-
predicted_content_list=predicted_content_list,
87-
groundtruth_content_list=groundtruth_content_list,
88-
**kwargs
89-
)
90-
else:
91-
# Generic calculation
92-
result = metric.calculate(predicted_content, groundtruth_content, **kwargs)
93-
94-
results[metric_name] = result
95-
96-
except Exception as e:
97-
# Create error result for failed metrics
98-
results[metric_name] = MetricResult.create_error_result(
99-
metric_name, f"Metric calculation failed: {str(e)}"
73+
# results = {}
74+
#
75+
# for metric_name, metric in self.metrics.items():
76+
# try:
77+
# if metric_name in ["edit_distance", "bleu", "rouge"]:
78+
# # Text-based metrics
79+
# result = metric.calculate(predicted_content, groundtruth_content, **kwargs)
80+
# elif metric_name in ["code_edit", "formula_edit",
81+
# "table_edit", "table_TEDS", "text_edit"]:
82+
# # 新的内容类型指标,需要传递 content_list
83+
# result = metric.calculate(
84+
# predicted_content,
85+
# groundtruth_content,
86+
# predicted_content_list=predicted_content_list,
87+
# groundtruth_content_list=groundtruth_content_list,
88+
# **kwargs
89+
# )
90+
# else:
91+
# # Generic calculation
92+
# result = metric.calculate(predicted_content, groundtruth_content, **kwargs)
93+
#
94+
# results[metric_name] = result
95+
#
96+
# except Exception as e:
97+
# # Create error result for failed metrics
98+
# results[metric_name] = MetricResult.create_error_result(
99+
# metric_name, f"Metric calculation failed: {str(e)}"
100+
# )
101+
102+
results: Dict[str, MetricResult] = {}
103+
104+
# 1. 先计算非表格指标(无依赖关系)
105+
for metric_name in list(self.metrics.keys()):
106+
if metric_name in ["table_edit", "table_TEDS"]:
107+
continue # 表格相关指标单独处理
108+
109+
metric = self.metrics[metric_name]
110+
result = metric.calculate(
111+
predicted=predicted_content,
112+
groundtruth=groundtruth_content,
113+
predicted_content_list=predicted_content_list,
114+
groundtruth_content_list=groundtruth_content_list, **kwargs
115+
)
116+
results[metric_name] = result
117+
118+
# 2. 处理表格相关指标(有依赖关系)
119+
# 2.1 计算 table_edit
120+
if "table_edit" in self.metrics:
121+
table_edit_result = self.metrics["table_edit"].calculate(
122+
predicted=predicted_content,
123+
groundtruth=groundtruth_content,
124+
predicted_content_list=predicted_content_list,
125+
groundtruth_content_list=groundtruth_content_list,
126+
**kwargs
127+
)
128+
results["table_edit"] = table_edit_result
129+
130+
# 2.2 计算 table_TEDS(依赖 table_edit 的结果)
131+
if "table_TEDS" in self.metrics:
132+
teds_result = self.metrics["table_TEDS"].calculate(
133+
predicted=predicted_content,
134+
groundtruth=groundtruth_content,
135+
predicted_content_list=predicted_content_list,
136+
groundtruth_content_list=groundtruth_content_list,
137+
table_edit_result=table_edit_result, # 传递依赖结果
138+
**kwargs
100139
)
140+
results["table_TEDS"] = teds_result
101141

102-
# Add overall score as average of all metrics
142+
# 3. 计算综合得分(所有成功指标的平均值)
103143
successful_scores = []
104144
failed_metrics = []
105145

webmainbench/metrics/teds_metrics.py

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,37 +39,45 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric
3939
MetricResult with TEDS score
4040
"""
4141
try:
42-
# Convert inputs to HTML format
42+
# 新增:检查 table_edit 的计算结果
43+
table_edit_result = kwargs.get('table_edit_result')
44+
if table_edit_result is None:
45+
return MetricResult.create_error_result(
46+
self.name, "Missing table_edit result in kwargs"
47+
)
48+
if not table_edit_result.success:
49+
# 若 table_edit 失败,TEDS 直接返回失败
50+
return MetricResult.create_error_result(
51+
self.name,
52+
f"Skipped due to table_edit failure: {table_edit_result.details.get('error', 'unknown reason')}"
53+
)
54+
55+
# 原有逻辑:转换为HTML并解析树结构
4356
pred_html = self._normalize_to_html(predicted)
4457
gt_html = self._normalize_to_html(groundtruth)
45-
46-
# Parse HTML to tree structures
58+
4759
pred_tree = self._parse_html_table(pred_html)
4860
gt_tree = self._parse_html_table(gt_html)
49-
61+
62+
# 后续逻辑保持不变...
5063
if pred_tree is None and gt_tree is None:
51-
# Both are empty/invalid tables
5264
return MetricResult(
5365
metric_name=self.name,
5466
score=1.0,
5567
details={"note": "Both tables are empty or invalid"}
5668
)
57-
69+
5870
if pred_tree is None or gt_tree is None:
59-
# One is empty/invalid
6071
return MetricResult(
6172
metric_name=self.name,
6273
score=0.0,
6374
details={"note": "One table is empty or invalid"}
6475
)
65-
66-
# Calculate tree edit distance
76+
6777
edit_distance = self._tree_edit_distance(pred_tree, gt_tree)
68-
69-
# Calculate TEDS score
7078
max_nodes = max(self._count_nodes(pred_tree), self._count_nodes(gt_tree))
7179
teds_score = 1.0 - (edit_distance / max_nodes) if max_nodes > 0 else 1.0
72-
80+
7381
details = {
7482
"edit_distance": edit_distance,
7583
"predicted_nodes": self._count_nodes(pred_tree),
@@ -78,17 +86,68 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric
7886
"structure_only": self.structure_only,
7987
"algorithm": "TEDS"
8088
}
81-
89+
8290
return MetricResult(
8391
metric_name=self.name,
84-
score=max(0.0, min(1.0, teds_score)), # Clamp to [0, 1]
92+
score=max(0.0, min(1.0, teds_score)),
8593
details=details
8694
)
87-
95+
8896
except Exception as e:
8997
return MetricResult.create_error_result(
9098
self.name, f"TEDS calculation failed: {str(e)}"
9199
)
100+
# try:
101+
# # Convert inputs to HTML format
102+
# pred_html = self._normalize_to_html(predicted)
103+
# gt_html = self._normalize_to_html(groundtruth)
104+
#
105+
# # Parse HTML to tree structures
106+
# pred_tree = self._parse_html_table(pred_html)
107+
# gt_tree = self._parse_html_table(gt_html)
108+
#
109+
# if pred_tree is None and gt_tree is None:
110+
# # Both are empty/invalid tables
111+
# return MetricResult(
112+
# metric_name=self.name,
113+
# score=1.0,
114+
# details={"note": "Both tables are empty or invalid"}
115+
# )
116+
#
117+
# if pred_tree is None or gt_tree is None:
118+
# # One is empty/invalid
119+
# return MetricResult(
120+
# metric_name=self.name,
121+
# score=0.0,
122+
# details={"note": "One table is empty or invalid"}
123+
# )
124+
#
125+
# # Calculate tree edit distance
126+
# edit_distance = self._tree_edit_distance(pred_tree, gt_tree)
127+
#
128+
# # Calculate TEDS score
129+
# max_nodes = max(self._count_nodes(pred_tree), self._count_nodes(gt_tree))
130+
# teds_score = 1.0 - (edit_distance / max_nodes) if max_nodes > 0 else 1.0
131+
#
132+
# details = {
133+
# "edit_distance": edit_distance,
134+
# "predicted_nodes": self._count_nodes(pred_tree),
135+
# "groundtruth_nodes": self._count_nodes(gt_tree),
136+
# "max_nodes": max_nodes,
137+
# "structure_only": self.structure_only,
138+
# "algorithm": "TEDS"
139+
# }
140+
#
141+
# return MetricResult(
142+
# metric_name=self.name,
143+
# score=max(0.0, min(1.0, teds_score)), # Clamp to [0, 1]
144+
# details=details
145+
# )
146+
#
147+
# except Exception as e:
148+
# return MetricResult.create_error_result(
149+
# self.name, f"TEDS calculation failed: {str(e)}"
150+
# )
92151

93152
def _normalize_to_html(self, table_data: Any) -> str:
94153
"""Convert various table formats to HTML."""

0 commit comments

Comments
 (0)