88import traceback
99import re
1010from bs4 import BeautifulSoup
11+ import os
12+ import hashlib
1113
1214@dataclass
1315class MetricResult :
@@ -121,16 +123,16 @@ def batch_calculate(self, predicted_list: List[Any],
121123 result = self .calculate (pred , gt , ** kwargs )
122124 results .append (result )
123125 return results
124-
126+
125127 @staticmethod
126- def split_content (text : str , content_list : List [Dict [str , Any ]] = None ) -> Dict [str , str ]:
128+ def split_content (text : str , content_list : List [Dict [str , Any ]] = None , field_name : str = None ) -> Dict [str , str ]:
127129 """
128130 统一的内容分割方法,将文本分为代码、公式、表格和剩余文本4个部分。
129-
131+
130132 Args:
131133 text: 原始markdown文本
132134 content_list: 结构化内容列表(来自llm-webkit等)
133-
135+ field_name: 当前处理的字段名称,传递给_extract_from_markdown
134136 Returns:
135137 Dict with keys: 'code', 'formula', 'table', 'text'
136138 """
@@ -139,9 +141,9 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[
139141 extracted_content = BaseMetric ._extract_from_content_list (content_list )
140142 if any (extracted_content .values ()):
141143 return extracted_content
142-
143- # 从markdown文本中提取
144- return BaseMetric ._extract_from_markdown (text or "" )
144+
145+ # 从markdown文本中提取,传递字段名称
146+ return BaseMetric ._extract_from_markdown (text or "" , field_name = field_name )
145147
146148 @staticmethod
147149 def _extract_from_content_list (content_list : List [Dict [str , Any ]]) -> Dict [str , str ]:
@@ -193,12 +195,12 @@ def _recursive_extract(items):
193195 'text' : '\n ' .join (extracted ['text' ])
194196 }
195197
196- @staticmethod
197- def _extract_from_markdown (text : str ) -> Dict [str , str ]:
198+ @staticmethod
199+ def _extract_from_markdown (text : str , field_name : str = None ) -> Dict [str , str ]:
198200 """从markdown文本中提取各种类型的内容"""
199201 if not text :
200202 return {'code' : '' , 'formula' : '' , 'table' : '' , 'text' : '' }
201-
203+
202204 # 收集所有需要移除的内容片段
203205 extracted_segments = []
204206 code_parts = []
@@ -291,34 +293,44 @@ def _extract_from_markdown(text: str) -> Dict[str, str]:
291293 if code_content .strip ():
292294 code_parts .append (code_content )
293295
294- # 提取公式
296+ # 提取公式 - 根据字段类型决定使用API还是正则
295297 formula_parts = []
296- # 统一的公式提取模式
297- latex_patterns = [
298- # r'(?<!\\)\$\$([^$]+)\$\$(?!\\)', # Display math (not escaped)
299- # r'(?<!\\)\$([^$\n]+)\$(?![\\\$])', # Inline math (not escaped)
300- # r'\\begin\{equation\*?\}(.*?)\\end\{equation\*?\}', # Equation environment
301- # r'\\begin\{align\*?\}(.*?)\\end\{align\*?\}', # Align environment
302- # r'\\begin\{gather\*?\}(.*?)\\end\{gather\*?\}', # Gather environment
303- # r'\\begin\{eqnarray\*?\}(.*?)\\end\{eqnarray\*?\}', # Eqnarray environment
304- # r'\\begin\{multline\*?\}(.*?)\\end\{multline\*?\}', # Multline environment
305- # r'\\begin\{split\}(.*?)\\end\{split\}', # Split environment
306- # r'(?<!\\)\$\$([^$]+)\$\$(?!\\)',
307- # r'(?<!\\)\$([^$\n\w][^$\n]*[^$\n\w])\$(?![\\\$])',
308- r'(?<!\\)\$\$(.*?)(?<!\\)\$\$' , # 行间 $$...$$,确保 $ 没有被转义
309- r'(?<!\\)\\\[(.*?)(?<!\\)\\\]' , # 行间 \[...\],确保 \ 没有被转义
310- r'(?<!\\)\$(.*?)(?<!\\)\$' , # 行内 $...$,确保 $ 没有被转义
311- # r'(?<!\\)\$(.*?)(?<!\\)\$(?!\d)', # 第二个$后面不能是数字
312- r'(?<!\\)\\\((.*?)(?<!\\)\\\)' , # 行内 \(...\),确保 \ 没有被转义
313- ]
314-
315- for pattern in latex_patterns :
316- for match in re .finditer (pattern , text , re .DOTALL ):
317- formula_full = match .group (0 ) # 完整匹配(包含$符号)
318- formula_content = match .group (1 ) # 只是公式内容
319- extracted_segments .append (formula_full )
320- if formula_content .strip ():
321- formula_parts .append (formula_content .strip ())
298+
299+ # 如果是groundtruth_content,使用正则提取公式
300+ if field_name == "llm_webkit_md" :
301+ print (f"[DEBUG] 检测到groundtruth内容,使用正则提取公式" )
302+ # 统一的公式提取模式
303+ latex_patterns = [
304+ r'(?<!\\)\$\$(.*?)(?<!\\)\$\$' , # 行间 $$...$$
305+ r'(?<!\\)\\\[(.*?)(?<!\\)\\\]' , # 行间 \[...\]
306+ r'(?<!\\)\$(.*?)(?<!\\)\$' , # 行内 $...$
307+ r'(?<!\\)\\\((.*?)(?<!\\)\\\)' , # 行内 \(...\)
308+ ]
309+
310+ for pattern in latex_patterns :
311+ for match in re .finditer (pattern , text , re .DOTALL ):
312+ formula_full = match .group (0 )
313+ formula_content = match .group (1 )
314+ extracted_segments .append (formula_full )
315+ if formula_content .strip ():
316+ formula_parts .append (formula_content .strip ())
317+ else :
318+ # 其他内容使用API提取公式
319+ cache_dir = os .path .join (os .path .dirname (os .path .abspath (__file__ )), '.cache' )
320+ os .makedirs (cache_dir , exist_ok = True )
321+
322+ # 使用文本哈希作为缓存文件名
323+ text_hash = hashlib .md5 (text .encode ('utf-8' )).hexdigest ()
324+ cache_file = os .path .join (cache_dir , f'formula_cache_{ text_hash } .json' )
325+
326+ # 使用LLM API提取公式
327+ try :
328+ from .formula_extractor import extract_formulas_with_llm
329+ formula_parts = extract_formulas_with_llm (text , cache_file )
330+ print (f"[DEBUG] 公式提取成功,提取到 { len (formula_parts )} 个公式" )
331+ except Exception as e :
332+ print (f"[DEBUG] 公式提取失败: { type (e ).__name__ } : { e } " )
333+ formula_parts = []
322334
323335 # 提取表格
324336 table_parts = []
@@ -468,4 +480,4 @@ def __str__(self) -> str:
468480 return f"{ self .__class__ .__name__ } (name='{ self .name } ')"
469481
470482 def __repr__ (self ) -> str :
471- return self .__str__ ()
483+ return self .__str__ ()
0 commit comments