Skip to content

Commit ec34a56

Browse files
committed
预测公式先经过正则,再用LLM修正
1 parent dde20f0 commit ec34a56

4 files changed

Lines changed: 162 additions & 38 deletions

File tree

examples/multi_extractor_compare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def all_extractor_comparison():
88
print("\n=== 多抽取器对比演示 ===\n")
99

1010
# 创建数据集
11-
dataset_path = Path("../data/test_math.jsonl")
11+
dataset_path = Path("../data/WebMainBench_llm-webkit_v1_WebMainBench_7887_within_formula.jsonl")
1212
dataset = DataLoader.load_jsonl(dataset_path)
1313

1414
# 创建webkit抽取器

webmainbench/data/saver.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,13 @@ def save_dataset_with_extraction(results: Union["EvaluationResult", Dict[str, An
303303

304304
# 解析预测值(predicted)
305305
predicted_content = extraction_result.get('extracted_content', '')
306-
# TODO: 这里可以根据需要选择不同的解析方法
307306
predicted_parts = BaseMetric._extract_from_markdown(predicted_content, field_name="llm_webkit_md") # 关键:解析预测内容
308307
for part_type in ['code', 'formula', 'table', 'text']:
309308
sample_dict[f'{current_extractor_name}_predicted_{part_type}'] = predicted_parts.get(part_type, '')
310309

311310
# 解析真实值(groundtruth)- 只需要解析一次
312311
if extractor_names: # 只有当存在extractor时才解析
313312
groundtruth_content = sample_dict.get('groundtruth_content', '')
314-
# TODO: 这里可以根据需要选择不同的解析方法
315313
groundtruth_parts = BaseMetric._extract_from_markdown(groundtruth_content, field_name="groundtruth_content") # 关键:解析真实内容
316314
for part_type in ['code', 'formula', 'table', 'text']:
317315
# 使用第一个extractor的名字作为前缀,或者使用通用前缀

webmainbench/metrics/base.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -293,43 +293,55 @@ def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]:
293293
if code_content.strip():
294294
code_parts.append(code_content)
295295

296-
# 提取公式 - 根据字段类型决定使用API还是正则
296+
# 提取公式 - 新的两步处理逻辑
297297
formula_parts = []
298298

299-
# 如果是groundtruth_content,使用正则提取公式
300-
if field_name == "llm_webkit_md":
301-
print(f"[DEBUG] 检测到groundtruth内容,使用正则提取公式")
302-
# 统一的公式提取模式
303-
latex_patterns = [
304-
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$
305-
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\]
306-
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$
307-
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\)
308-
]
309-
310-
for pattern in latex_patterns:
311-
for match in re.finditer(pattern, text, re.DOTALL):
312-
formula_full = match.group(0)
313-
formula_content = match.group(1)
314-
extracted_segments.append(formula_full)
315-
if formula_content.strip():
316-
formula_parts.append(formula_content.strip())
299+
# 第一步:先用正则提取公式
300+
regex_formulas = []
301+
latex_patterns = [
302+
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$
303+
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\]
304+
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$
305+
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\)
306+
]
307+
308+
for pattern in latex_patterns:
309+
for match in re.finditer(pattern, text, re.DOTALL):
310+
formula_full = match.group(0)
311+
formula_content = match.group(1)
312+
extracted_segments.append(formula_full)
313+
if formula_content.strip():
314+
regex_formulas.append(formula_content.strip())
315+
316+
# 第二步:根据字段类型决定是否需要API修正
317+
if field_name == "groundtruth_content":
318+
print(f"[DEBUG] 检测到groundtruth内容,仅使用正则提取公式")
319+
formula_parts = regex_formulas
317320
else:
318-
# 其他内容使用API提取公式
319-
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.cache')
320-
os.makedirs(cache_dir, exist_ok=True)
321-
322-
# 使用文本哈希作为缓存文件名
323-
text_hash = hashlib.md5(text.encode('utf-8')).hexdigest()
324-
cache_file = os.path.join(cache_dir, f'formula_cache_{text_hash}.json')
325-
326-
# 使用LLM API提取公式
327-
try:
328-
from .formula_extractor import extract_formulas_with_llm
329-
formula_parts = extract_formulas_with_llm(text, cache_file)
330-
print(f"[DEBUG] 公式提取成功,提取到 {len(formula_parts)} 个公式")
331-
except Exception as e:
332-
print(f"[DEBUG] 公式提取失败: {type(e).__name__}: {e}")
321+
print(f"[DEBUG] 检测到llm_webkit_md内容,使用正则+API修正模式")
322+
# 对于llm_webkit_md,将正则结果传递给API进行修正
323+
if regex_formulas:
324+
# 将正则提取的公式作为输入传递给API
325+
regex_formulas_text = '\n'.join(regex_formulas)
326+
print(f"[DEBUG] 正则提取到 {len(regex_formulas)} 个公式,准备API修正")
327+
328+
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.cache')
329+
os.makedirs(cache_dir, exist_ok=True)
330+
331+
# 使用正则结果的哈希作为缓存文件名
332+
text_hash = hashlib.md5(regex_formulas_text.encode('utf-8')).hexdigest()
333+
cache_file = os.path.join(cache_dir, f'formula_correction_cache_{text_hash}.json')
334+
335+
try:
336+
from .formula_extractor import correct_formulas_with_llm
337+
corrected_formulas = correct_formulas_with_llm(regex_formulas, cache_file)
338+
formula_parts = corrected_formulas
339+
print(f"[DEBUG] API修正成功,最终得到 {len(formula_parts)} 个公式")
340+
except Exception as e:
341+
print(f"[DEBUG] API修正失败: {type(e).__name__}: {e},使用正则结果")
342+
formula_parts = regex_formulas
343+
else:
344+
print(f"[DEBUG] 正则未提取到公式,跳过API修正")
333345
formula_parts = []
334346

335347
# 提取表格
Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,115 @@
1-
# webmainbench/metrics/formula_extractor.pyimport jsonimport osfrom openai import OpenAIdef extract_formulas_with_llm(text, cache_file=None): """使用LLM API提取文本中的数学公式""" # 预检查:如果$符号数量<2,直接返回空列表 dollar_count = text.count('$') if dollar_count < 2: print(f"[DEBUG] 输入文本$符号数量为{dollar_count},小于2,跳过API调用") return [] # 检查缓存 if cache_file and os.path.exists(cache_file): try: with open(cache_file, 'r', encoding='utf-8') as f: cached_result = json.load(f) print(f"[DEBUG] 从缓存加载公式: {len(cached_result)} 个") return cached_result except Exception as e: print(f"[DEBUG] 缓存读取失败: {e}") # API配置 client = OpenAI( base_url="http://35.220.164.252:3888/v1/", api_key="sk-PZgDr7sZdt77805Cg8s5ZB9QnGMGke61ovYnHYcHKIYVGHNA" ) PROMPT = '''任务:请作为信息抽取专家,精确提取所提供 Markdown 文本中的所有数学公式,并按要求输出。### 公式格式说明Markdown 中数学公式通常包括以下两类:- **行内公式(Inline)**:由单个美元符号 `$...$` 包裹,例如:`$E = mc^2$`- **行间公式(Block)**: - 双美元符号包裹:`$$...$$`,例如:`$$\sum_{i=1}^n i = \frac{n(n+1)}{2}$$` - 数学代码块:以 ```` ```math ```` 开头和结尾的代码块### 提取要求1. **精准提取**:仅提取正确标记的公式部分,排除普通文本、代码(除非是数学代码块)、注释或无关内容。2. **保持原貌**:提取的公式必须与原文完全一致,不得修改、简化或转译。3. **LaTeX 公式**:若识别到 LaTeX 格式的公式(包括 LaTeX 环境或命令),也应原样提取。4. **区分货币与公式**:避免将美元货币金额(如 `$3.99`)误提取为公式,需结合上下文判断是否为数学表达式。### 输出格式- 提取所有识别到的公式,按出现顺序逐行输出原始字符串。- 每个公式以独立行形式呈现,不附加任何额外信息。- 若无公式,则不返回任何内容。[输入文本开始]''' try: print(f"[DEBUG] 开始调用 OpenAI API...") response = client.chat.completions.create( model="deepseek-chat", temperature=0, messages=[ {"role": "user", "content": PROMPT + f"{text}" + '''[输入文本结束]---请根据上述要求,仅输出提取后的公式内容或空字符串。---请注意:- 绝对不要对公式内容做任何修改或解释。- 确保不遗漏任何符合要求的公式,也不添加非公式文本。Optimized for: clarity, precision, context-awareness, and strict formatting compliance.'''} ] ) result_text = response.choices[0].message.content.strip() print(f"[DEBUG] API 返回原始结果: {repr(result_text)}") # 解析返回的公式 if not result_text: formulas = [] else: formulas = [line.strip() for line in result_text.split('\n') if line.strip()] print(f"[DEBUG] 解析后的公式列表: {formulas}") # 保存缓存 if cache_file: try: os.makedirs(os.path.dirname(cache_file), exist_ok=True) with open(cache_file, 'w', encoding='utf-8') as f: json.dump(formulas, f, ensure_ascii=False, indent=2) print(f"[DEBUG] 结果已缓存到: {cache_file}") except Exception as e: print(f"[DEBUG] 缓存保存失败: {e}") return formulas except Exception as e: print(f"[DEBUG] API 调用异常: {type(e).__name__}: {e}") raise e
1+
# webmainbench/metrics/formula_extractor.py
2+
import json
3+
import os
4+
from openai import OpenAI
5+
6+
def correct_formulas_with_llm(regex_formulas, cache_file=None):
7+
"""使用LLM API修正正则提取的公式"""
8+
9+
if not regex_formulas:
10+
print(f"[DEBUG] 输入公式列表为空,跳过API修正")
11+
return []
12+
13+
# 检查缓存
14+
if cache_file and os.path.exists(cache_file):
15+
try:
16+
with open(cache_file, 'r', encoding='utf-8') as f:
17+
cached_result = json.load(f)
18+
print(f"[DEBUG] 从缓存加载修正结果: {len(cached_result)} 个")
19+
return cached_result
20+
except Exception as e:
21+
print(f"[DEBUG] 缓存读取失败: {e}")
22+
23+
# API配置
24+
client = OpenAI(
25+
base_url="",
26+
api_key=""
27+
)
28+
29+
# 将正则提取的公式转换为文本
30+
formulas_text = '\n'.join(regex_formulas)
31+
32+
CORRECTION_PROMPT = '''任务:请从以下正则表达式提取的内容中,识别并保留真正的LaTeX数学公式,剔除货币形式的内容。
33+
34+
### 识别规则
35+
**真正的数学公式**(保留):
36+
- 包含数学符号:+ - × ÷ = < > ≤ ≥ ± ∞ ∑ ∫ ∂ √ ^ _ { } 等
37+
- 包含希腊字母:α β γ δ θ λ μ π σ ω 等
38+
- 包含LaTeX命令:\\frac \\sum \\int \\sqrt \\alpha \\beta \\sin \\cos 等
39+
- 包含数学表达式:变量、函数、方程等
40+
41+
**货币形式内容**(剔除):
42+
- 仅包含数字、逗号、小数点的价格:如 1,150.00
43+
- 纯粹的金额数值:如 25.99、1,200、5.50
44+
- 不包含任何数学运算符或数学符号的数字
45+
46+
### 处理要求
47+
1. **严格区分**:只保留真正的数学公式,剔除所有货币价格
48+
2. **格式标准化**:统一公式格式,确保LaTeX语法正确
49+
3. **保持原意**:不修改数学公式内容
50+
51+
### 输出格式
52+
- 每个有效的数学公式独占一行
53+
- 只输出公式内容,不包含$符号或其他包装
54+
- 如果输入不是有效的数学公式(如货币),则输出<空>
55+
- 按原顺序输出保留的公式
56+
57+
### 示例 1 (剔除后有有效公式)
58+
输入:1,150.00 → 剔除(货币)
59+
输入:x^2 + y^2 = r^2 → 保留(数学公式)
60+
输入:25.99 → 剔除(货币)
61+
输入:\\frac{a}{b} + c → 保留(数学公式)
62+
63+
### 示例 2 (剔除后无有效公式)
64+
输入:1,150.00 → 剔除(货币)
65+
输入:25.99 → 剔除(货币)
66+
67+
输出:<空>
68+
69+
注意,输出结果中不要添加任何解释!。
70+
[输入内容列表开始]'''
71+
72+
try:
73+
print(f"[DEBUG] 开始调用 OpenAI API 进行公式修正...")
74+
response = client.chat.completions.create(
75+
model="deepseek-chat",
76+
temperature=0,
77+
messages=[
78+
{"role": "user", "content": CORRECTION_PROMPT + f"\n{formulas_text}\n" + '''[输入内容列表结束]
79+
---
80+
请按要求识别并输出真正的数学公式,剔除货币形式的内容。
81+
---'''}
82+
]
83+
)
84+
85+
result_text = response.choices[0].message.content.strip()
86+
print(f"[DEBUG] API 返回修正结果: {repr(result_text)}")
87+
88+
# 检测返回内容是否包含"空"字 - 如果包含则整个结果为空
89+
if '空' in result_text:
90+
print(f"[DEBUG] 检测到API返回包含'空'字,将整个结果设置为空列表")
91+
corrected_formulas = []
92+
elif not result_text:
93+
corrected_formulas = []
94+
else:
95+
# 正常解析返回的公式列表
96+
corrected_formulas = [line.strip() for line in result_text.split('\n') if line.strip()]
97+
98+
print(f"[DEBUG] 修正后的公式列表: {corrected_formulas}")
99+
100+
# 保存缓存
101+
if cache_file:
102+
try:
103+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
104+
with open(cache_file, 'w', encoding='utf-8') as f:
105+
json.dump(corrected_formulas, f, ensure_ascii=False, indent=2)
106+
print(f"[DEBUG] 修正结果已缓存到: {cache_file}")
107+
except Exception as e:
108+
print(f"[DEBUG] 缓存保存失败: {e}")
109+
110+
return corrected_formulas
111+
112+
except Exception as e:
113+
print(f"[DEBUG] API 修正异常: {type(e).__name__}: {e}")
114+
print(f"[DEBUG] 回退到原始正则结果")
115+
return regex_formulas

0 commit comments

Comments
 (0)