Skip to content

Commit dde20f0

Browse files
committed
1. gt使用正则匹配
2. predicate调用api
1 parent 0a6ce17 commit dde20f0

4 files changed

Lines changed: 56 additions & 41 deletions

File tree

examples/multi_extractor_compare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def all_extractor_comparison():
88
print("\n=== 多抽取器对比演示 ===\n")
99

1010
# 创建数据集
11-
dataset_path = Path("data/sample_dataset.jsonl")
11+
dataset_path = Path("../data/test_math.jsonl")
1212
dataset = DataLoader.load_jsonl(dataset_path)
1313

1414
# 创建webkit抽取器

webmainbench/data/saver.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,16 @@ def save_dataset_with_extraction(results: Union["EvaluationResult", Dict[str, An
303303

304304
# 解析预测值(predicted)
305305
predicted_content = extraction_result.get('extracted_content', '')
306-
predicted_parts = BaseMetric._extract_from_markdown(predicted_content) # 关键:解析预测内容
306+
# TODO: 这里可以根据需要选择不同的解析方法
307+
predicted_parts = BaseMetric._extract_from_markdown(predicted_content, field_name="llm_webkit_md") # 关键:解析预测内容
307308
for part_type in ['code', 'formula', 'table', 'text']:
308309
sample_dict[f'{current_extractor_name}_predicted_{part_type}'] = predicted_parts.get(part_type, '')
309310

310311
# 解析真实值(groundtruth)- 只需要解析一次
311312
if extractor_names: # 只有当存在extractor时才解析
312313
groundtruth_content = sample_dict.get('groundtruth_content', '')
313-
groundtruth_parts = BaseMetric._extract_from_markdown(groundtruth_content) # 关键:解析真实内容
314+
# TODO: 这里可以根据需要选择不同的解析方法
315+
groundtruth_parts = BaseMetric._extract_from_markdown(groundtruth_content, field_name="groundtruth_content") # 关键:解析真实内容
314316
for part_type in ['code', 'formula', 'table', 'text']:
315317
# 使用第一个extractor的名字作为前缀,或者使用通用前缀
316318
prefix = extractor_names[0] if len(extractor_names) == 1 else 'groundtruth'

webmainbench/metrics/base.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import traceback
99
import re
1010
from bs4 import BeautifulSoup
11+
import os
12+
import hashlib
1113

1214
@dataclass
1315
class MetricResult:
@@ -121,16 +123,16 @@ def batch_calculate(self, predicted_list: List[Any],
121123
result = self.calculate(pred, gt, **kwargs)
122124
results.append(result)
123125
return results
124-
126+
125127
@staticmethod
126-
def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[str, str]:
128+
def split_content(text: str, content_list: List[Dict[str, Any]] = None, field_name: str = None) -> Dict[str, str]:
127129
"""
128130
统一的内容分割方法,将文本分为代码、公式、表格和剩余文本4个部分。
129-
131+
130132
Args:
131133
text: 原始markdown文本
132134
content_list: 结构化内容列表(来自llm-webkit等)
133-
135+
field_name: 当前处理的字段名称,传递给_extract_from_markdown
134136
Returns:
135137
Dict with keys: 'code', 'formula', 'table', 'text'
136138
"""
@@ -139,9 +141,9 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None) -> Dict[
139141
extracted_content = BaseMetric._extract_from_content_list(content_list)
140142
if any(extracted_content.values()):
141143
return extracted_content
142-
143-
# 从markdown文本中提取
144-
return BaseMetric._extract_from_markdown(text or "")
144+
145+
# 从markdown文本中提取,传递字段名称
146+
return BaseMetric._extract_from_markdown(text or "", field_name=field_name)
145147

146148
@staticmethod
147149
def _extract_from_content_list(content_list: List[Dict[str, Any]]) -> Dict[str, str]:
@@ -193,12 +195,12 @@ def _recursive_extract(items):
193195
'text': '\n'.join(extracted['text'])
194196
}
195197

196-
@staticmethod
197-
def _extract_from_markdown(text: str) -> Dict[str, str]:
198+
@staticmethod
199+
def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]:
198200
"""从markdown文本中提取各种类型的内容"""
199201
if not text:
200202
return {'code': '', 'formula': '', 'table': '', 'text': ''}
201-
203+
202204
# 收集所有需要移除的内容片段
203205
extracted_segments = []
204206
code_parts = []
@@ -291,34 +293,44 @@ def _extract_from_markdown(text: str) -> Dict[str, str]:
291293
if code_content.strip():
292294
code_parts.append(code_content)
293295

294-
# 提取公式
296+
# 提取公式 - 根据字段类型决定使用API还是正则
295297
formula_parts = []
296-
# 统一的公式提取模式
297-
latex_patterns = [
298-
# r'(?<!\\)\$\$([^$]+)\$\$(?!\\)', # Display math (not escaped)
299-
# r'(?<!\\)\$([^$\n]+)\$(?![\\\$])', # Inline math (not escaped)
300-
# r'\\begin\{equation\*?\}(.*?)\\end\{equation\*?\}', # Equation environment
301-
# r'\\begin\{align\*?\}(.*?)\\end\{align\*?\}', # Align environment
302-
# r'\\begin\{gather\*?\}(.*?)\\end\{gather\*?\}', # Gather environment
303-
# r'\\begin\{eqnarray\*?\}(.*?)\\end\{eqnarray\*?\}', # Eqnarray environment
304-
# r'\\begin\{multline\*?\}(.*?)\\end\{multline\*?\}', # Multline environment
305-
# r'\\begin\{split\}(.*?)\\end\{split\}', # Split environment
306-
# r'(?<!\\)\$\$([^$]+)\$\$(?!\\)',
307-
# r'(?<!\\)\$([^$\n\w][^$\n]*[^$\n\w])\$(?![\\\$])',
308-
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$,确保 $ 没有被转义
309-
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\],确保 \ 没有被转义
310-
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$,确保 $ 没有被转义
311-
# r'(?<!\\)\$(.*?)(?<!\\)\$(?!\d)', # 第二个$后面不能是数字
312-
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\),确保 \ 没有被转义
313-
]
314-
315-
for pattern in latex_patterns:
316-
for match in re.finditer(pattern, text, re.DOTALL):
317-
formula_full = match.group(0) # 完整匹配(包含$符号)
318-
formula_content = match.group(1) # 只是公式内容
319-
extracted_segments.append(formula_full)
320-
if formula_content.strip():
321-
formula_parts.append(formula_content.strip())
298+
299+
# 如果是groundtruth_content,使用正则提取公式
300+
if field_name == "llm_webkit_md":
301+
print(f"[DEBUG] 检测到groundtruth内容,使用正则提取公式")
302+
# 统一的公式提取模式
303+
latex_patterns = [
304+
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$
305+
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\]
306+
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$
307+
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\)
308+
]
309+
310+
for pattern in latex_patterns:
311+
for match in re.finditer(pattern, text, re.DOTALL):
312+
formula_full = match.group(0)
313+
formula_content = match.group(1)
314+
extracted_segments.append(formula_full)
315+
if formula_content.strip():
316+
formula_parts.append(formula_content.strip())
317+
else:
318+
# 其他内容使用API提取公式
319+
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.cache')
320+
os.makedirs(cache_dir, exist_ok=True)
321+
322+
# 使用文本哈希作为缓存文件名
323+
text_hash = hashlib.md5(text.encode('utf-8')).hexdigest()
324+
cache_file = os.path.join(cache_dir, f'formula_cache_{text_hash}.json')
325+
326+
# 使用LLM API提取公式
327+
try:
328+
from .formula_extractor import extract_formulas_with_llm
329+
formula_parts = extract_formulas_with_llm(text, cache_file)
330+
print(f"[DEBUG] 公式提取成功,提取到 {len(formula_parts)} 个公式")
331+
except Exception as e:
332+
print(f"[DEBUG] 公式提取失败: {type(e).__name__}: {e}")
333+
formula_parts = []
322334

323335
# 提取表格
324336
table_parts = []
@@ -468,4 +480,4 @@ def __str__(self) -> str:
468480
return f"{self.__class__.__name__}(name='{self.name}')"
469481

470482
def __repr__(self) -> str:
471-
return self.__str__()
483+
return self.__str__()

webmainbench/metrics/formula_extractor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# webmainbench/metrics/formula_extractor.pyimport jsonimport osfrom openai import OpenAIdef extract_formulas_with_llm(text, cache_file=None): """使用LLM API提取文本中的数学公式""" # 预检查:如果$符号数量<2,直接返回空列表 dollar_count = text.count('$') if dollar_count < 2: print(f"[DEBUG] 输入文本$符号数量为{dollar_count},小于2,跳过API调用") return [] # 检查缓存 if cache_file and os.path.exists(cache_file): try: with open(cache_file, 'r', encoding='utf-8') as f: cached_result = json.load(f) print(f"[DEBUG] 从缓存加载公式: {len(cached_result)} 个") return cached_result except Exception as e: print(f"[DEBUG] 缓存读取失败: {e}") # API配置 client = OpenAI( base_url="http://35.220.164.252:3888/v1/", api_key="sk-PZgDr7sZdt77805Cg8s5ZB9QnGMGke61ovYnHYcHKIYVGHNA" ) PROMPT = '''任务:请作为信息抽取专家,精确提取所提供 Markdown 文本中的所有数学公式,并按要求输出。### 公式格式说明Markdown 中数学公式通常包括以下两类:- **行内公式(Inline)**:由单个美元符号 `$...$` 包裹,例如:`$E = mc^2$`- **行间公式(Block)**: - 双美元符号包裹:`$$...$$`,例如:`$$\sum_{i=1}^n i = \frac{n(n+1)}{2}$$` - 数学代码块:以 ```` ```math ```` 开头和结尾的代码块### 提取要求1. **精准提取**:仅提取正确标记的公式部分,排除普通文本、代码(除非是数学代码块)、注释或无关内容。2. **保持原貌**:提取的公式必须与原文完全一致,不得修改、简化或转译。3. **LaTeX 公式**:若识别到 LaTeX 格式的公式(包括 LaTeX 环境或命令),也应原样提取。4. **区分货币与公式**:避免将美元货币金额(如 `$3.99`)误提取为公式,需结合上下文判断是否为数学表达式。### 输出格式- 提取所有识别到的公式,按出现顺序逐行输出原始字符串。- 每个公式以独立行形式呈现,不附加任何额外信息。- 若无公式,则不返回任何内容。[输入文本开始]''' try: print(f"[DEBUG] 开始调用 OpenAI API...") response = client.chat.completions.create( model="deepseek-chat", temperature=0, messages=[ {"role": "user", "content": PROMPT + f"{text}" + '''[输入文本结束]---请根据上述要求,仅输出提取后的公式内容或空字符串。---请注意:- 绝对不要对公式内容做任何修改或解释。- 确保不遗漏任何符合要求的公式,也不添加非公式文本。Optimized for: clarity, precision, context-awareness, and strict formatting compliance.'''} ] ) result_text = response.choices[0].message.content.strip() print(f"[DEBUG] API 返回原始结果: {repr(result_text)}") # 解析返回的公式 if not result_text: formulas = [] else: formulas = [line.strip() for line in result_text.split('\n') if line.strip()] print(f"[DEBUG] 解析后的公式列表: {formulas}") # 保存缓存 if cache_file: try: os.makedirs(os.path.dirname(cache_file), exist_ok=True) with open(cache_file, 'w', encoding='utf-8') as f: json.dump(formulas, f, ensure_ascii=False, indent=2) print(f"[DEBUG] 结果已缓存到: {cache_file}") except Exception as e: print(f"[DEBUG] 缓存保存失败: {e}") return formulas except Exception as e: print(f"[DEBUG] API 调用异常: {type(e).__name__}: {e}") raise e

0 commit comments

Comments
 (0)