@@ -39,37 +39,45 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric
3939 MetricResult with TEDS score
4040 """
4141 try :
42- # Convert inputs to HTML format
42+ # 新增:检查 table_edit 的计算结果
43+ table_edit_result = kwargs .get ('table_edit_result' )
44+ if table_edit_result is None :
45+ return MetricResult .create_error_result (
46+ self .name , "Missing table_edit result in kwargs"
47+ )
48+ if not table_edit_result .success :
49+ # 若 table_edit 失败,TEDS 直接返回失败
50+ return MetricResult .create_error_result (
51+ self .name ,
52+ f"Skipped due to table_edit failure: { table_edit_result .details .get ('error' , 'unknown reason' )} "
53+ )
54+
55+ # 原有逻辑:转换为HTML并解析树结构
4356 pred_html = self ._normalize_to_html (predicted )
4457 gt_html = self ._normalize_to_html (groundtruth )
45-
46- # Parse HTML to tree structures
58+
4759 pred_tree = self ._parse_html_table (pred_html )
4860 gt_tree = self ._parse_html_table (gt_html )
49-
61+
62+ # 后续逻辑保持不变...
5063 if pred_tree is None and gt_tree is None :
51- # Both are empty/invalid tables
5264 return MetricResult (
5365 metric_name = self .name ,
5466 score = 1.0 ,
5567 details = {"note" : "Both tables are empty or invalid" }
5668 )
57-
69+
5870 if pred_tree is None or gt_tree is None :
59- # One is empty/invalid
6071 return MetricResult (
6172 metric_name = self .name ,
6273 score = 0.0 ,
6374 details = {"note" : "One table is empty or invalid" }
6475 )
65-
66- # Calculate tree edit distance
76+
6777 edit_distance = self ._tree_edit_distance (pred_tree , gt_tree )
68-
69- # Calculate TEDS score
7078 max_nodes = max (self ._count_nodes (pred_tree ), self ._count_nodes (gt_tree ))
7179 teds_score = 1.0 - (edit_distance / max_nodes ) if max_nodes > 0 else 1.0
72-
80+
7381 details = {
7482 "edit_distance" : edit_distance ,
7583 "predicted_nodes" : self ._count_nodes (pred_tree ),
@@ -78,17 +86,68 @@ def _calculate_score(self, predicted: Any, groundtruth: Any, **kwargs) -> Metric
7886 "structure_only" : self .structure_only ,
7987 "algorithm" : "TEDS"
8088 }
81-
89+
8290 return MetricResult (
8391 metric_name = self .name ,
84- score = max (0.0 , min (1.0 , teds_score )), # Clamp to [0, 1]
92+ score = max (0.0 , min (1.0 , teds_score )),
8593 details = details
8694 )
87-
95+
8896 except Exception as e :
8997 return MetricResult .create_error_result (
9098 self .name , f"TEDS calculation failed: { str (e )} "
9199 )
100+ # try:
101+ # # Convert inputs to HTML format
102+ # pred_html = self._normalize_to_html(predicted)
103+ # gt_html = self._normalize_to_html(groundtruth)
104+ #
105+ # # Parse HTML to tree structures
106+ # pred_tree = self._parse_html_table(pred_html)
107+ # gt_tree = self._parse_html_table(gt_html)
108+ #
109+ # if pred_tree is None and gt_tree is None:
110+ # # Both are empty/invalid tables
111+ # return MetricResult(
112+ # metric_name=self.name,
113+ # score=1.0,
114+ # details={"note": "Both tables are empty or invalid"}
115+ # )
116+ #
117+ # if pred_tree is None or gt_tree is None:
118+ # # One is empty/invalid
119+ # return MetricResult(
120+ # metric_name=self.name,
121+ # score=0.0,
122+ # details={"note": "One table is empty or invalid"}
123+ # )
124+ #
125+ # # Calculate tree edit distance
126+ # edit_distance = self._tree_edit_distance(pred_tree, gt_tree)
127+ #
128+ # # Calculate TEDS score
129+ # max_nodes = max(self._count_nodes(pred_tree), self._count_nodes(gt_tree))
130+ # teds_score = 1.0 - (edit_distance / max_nodes) if max_nodes > 0 else 1.0
131+ #
132+ # details = {
133+ # "edit_distance": edit_distance,
134+ # "predicted_nodes": self._count_nodes(pred_tree),
135+ # "groundtruth_nodes": self._count_nodes(gt_tree),
136+ # "max_nodes": max_nodes,
137+ # "structure_only": self.structure_only,
138+ # "algorithm": "TEDS"
139+ # }
140+ #
141+ # return MetricResult(
142+ # metric_name=self.name,
143+ # score=max(0.0, min(1.0, teds_score)), # Clamp to [0, 1]
144+ # details=details
145+ # )
146+ #
147+ # except Exception as e:
148+ # return MetricResult.create_error_result(
149+ # self.name, f"TEDS calculation failed: {str(e)}"
150+ # )
92151
93152 def _normalize_to_html (self , table_data : Any ) -> str :
94153 """Convert various table formats to HTML."""
0 commit comments