Skip to content

Commit 96fd8a3

Browse files
committed
feat: refactor _extract_from_markdown with LLM-enhanced table/formula/code extraction
1 parent e4194fc commit 96fd8a3

7 files changed

Lines changed: 123 additions & 315 deletions

File tree

examples/multi_extractor_compare.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ def all_extractor_comparison():
88
print("\n=== 多抽取器对比演示 ===\n")
99

1010
# 创建数据集
11-
dataset_path = Path("../data/WebMainBench_llm-webkit_v1_WebMainBench_7887_within_formula.jsonl")
11+
dataset_path = Path("../data/sample_dataset.jsonl")
1212
dataset = DataLoader.load_jsonl(dataset_path)
13-
13+
1414
# 创建webkit抽取器
1515
config = {
1616
"use_preprocessed_html": True, # 🔑 关键配置:启用预处理HTML模式
1717
"preprocessed_html_field": "llm_webkit_html" # 指定预处理HTML字段名
1818
}
19+
1920
webkit_extractor = ExtractorFactory.create("llm-webkit", config=config)
2021
# 创建magic-extractor抽取器
2122
magic_extractor = ExtractorFactory.create("magic-html")

webmainbench/metrics/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from .teds_metrics import TEDSMetric, StructureTEDSMetric
1212
from .calculator import MetricCalculator
1313
from .mainhtml_calculator import MainHTMLMetricCalculator
14+
from .base_extractor import ContentExtractor
15+
from .formula_extractor import FormulaExtractor
16+
from .code_extractor import CodeExtractor
17+
from .table_extractor import TableExtractor
1418

1519
__all__ = [
1620
"BaseMetric",
@@ -27,4 +31,8 @@
2731
"TextEditMetric",
2832
"MetricCalculator",
2933
"MainHTMLMetricCalculator",
34+
'ContentExtractor',
35+
'FormulaExtractor',
36+
'CodeExtractor',
37+
'TableExtractor',
3038
]

webmainbench/metrics/base.py

Lines changed: 26 additions & 222 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,7 @@
44

55
from abc import ABC, abstractmethod
66
from dataclasses import dataclass
7-
from typing import Dict, Any, List, Optional, Union
8-
import traceback
9-
import re
10-
from bs4 import BeautifulSoup
11-
import os
12-
import hashlib
7+
from typing import Dict, Any, List, Optional
138

149
@dataclass
1510
class MetricResult:
@@ -144,7 +139,7 @@ def split_content(text: str, content_list: List[Dict[str, Any]] = None, field_na
144139

145140
# 从markdown文本中提取,传递字段名称
146141
return BaseMetric._extract_from_markdown(text or "", field_name=field_name)
147-
142+
148143
@staticmethod
149144
def _extract_from_content_list(content_list: List[Dict[str, Any]]) -> Dict[str, str]:
150145
"""从content_list中递归提取各种类型的内容"""
@@ -194,233 +189,42 @@ def _recursive_extract(items):
194189
'table': '\n'.join(extracted['table']),
195190
'text': '\n'.join(extracted['text'])
196191
}
197-
192+
198193
@staticmethod
199194
def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]:
200195
"""从markdown文本中提取各种类型的内容"""
201196
if not text:
202197
return {'code': '', 'formula': '', 'table': '', 'text': ''}
203198

204-
# 收集所有需要移除的内容片段
205-
extracted_segments = []
206-
code_parts = []
207-
# # 同匹配行间代码块 ```...```
208-
# pattern = r'(```[\s\S]*?```)'
209-
# for match in re.finditer(pattern, text):
210-
# code_segment = match.group(0)
211-
# extracted_segments.append(code_segment)
212-
#
213-
# if code_segment.startswith('```'):
214-
# # 处理代码块(保留内部缩进)
215-
# lines = code_segment.split('\n')
216-
# # 移除首尾的```标记
217-
# content_lines = lines[1:-1]
218-
# # 保留原始缩进,只拼接内容
219-
# code_content = '\n'.join(content_lines)
220-
# else:
221-
# # 处理行内代码(只去除外层`和前后空格)
222-
# code_content = code_segment[1:-1].strip()
223-
#
224-
# if code_content: # 只添加非空内容
225-
# code_parts.append(code_content)
226-
227-
# 1. 首先处理三个反引号包裹的代码块(优先级最高)
228-
backtick_pattern = r'(```[\s\S]*?```)'
229-
for match in re.finditer(backtick_pattern, text):
230-
code_segment = match.group(0)
231-
232-
if code_segment.startswith('```'):
233-
# 处理代码块
234-
lines = code_segment.split('\n')
235-
# 移除首尾的```标记
236-
content_lines = lines[1:-1]
237-
code_content = '\n'.join(content_lines)
238-
else:
239-
# 处理行内代码
240-
code_content = code_segment[1:-1].strip()
241-
242-
if code_content:
243-
code_parts.append(code_content)
244-
245-
# 2. 处理缩进代码块 - 使用更精确的匹配
246-
# 匹配模式:前面有空行 + 连续的多行缩进内容 + 后面有空行
247-
# 关键:要求所有匹配的行都是缩进的
248-
indent_pattern = r'(?:\n\s*\n)((?:(?: {4,}|\t+)[^\n]*(?:\n|$)){2,})(?=\n\s*\n|$)'
249-
250-
for match in re.finditer(indent_pattern, text, re.MULTILINE):
251-
code_segment = match.group(1)
252-
253-
# 验证:确保所有行都是缩进的(避免混合缩进和非缩进行)
254-
lines = code_segment.split('\n')
255-
all_indented = all(
256-
line.startswith(' ') or line.startswith('\t') or not line.strip()
257-
for line in lines
258-
if line.strip() # 空行不算
259-
)
260-
261-
if not all_indented:
262-
continue # 跳过包含非缩进行的块
263-
264-
# 进一步验证代码特征
265-
non_empty_lines = [line.strip() for line in lines if line.strip()]
266-
if len(non_empty_lines) < 2: # 至少2行非空内容
267-
continue
268-
269-
# 检查是否有明显的非代码特征
270-
has_list_features = any(
271-
re.match(r'^[-•*]\s', line) or
272-
re.match(r'^\d+\.\s', line) or
273-
re.search(r'\$[\d,]', line) or
274-
re.search(r'\b(million|billion|thousand)\b', line, re.IGNORECASE)
275-
for line in non_empty_lines
276-
)
277-
278-
if has_list_features:
279-
continue # 跳过列表内容
280-
281-
# 清理代码段
282-
cleaned_lines = []
283-
for line in code_segment.split('\n'):
284-
if line.strip():
285-
if line.startswith(' '):
286-
cleaned_lines.append(line[4:])
287-
elif line.startswith('\t'):
288-
cleaned_lines.append(line[1:])
289-
else:
290-
cleaned_lines.append(line)
291-
292-
code_content = '\n'.join(cleaned_lines)
293-
if code_content.strip():
294-
code_parts.append(code_content)
295-
296-
# 提取公式 - 新的两步处理逻辑
297-
formula_parts = []
298-
299-
# 第一步:先用正则提取公式
300-
regex_formulas = []
301-
latex_patterns = [
302-
r'(?<!\\)\$\$(.*?)(?<!\\)\$\$', # 行间 $$...$$
303-
r'(?<!\\)\\\[(.*?)(?<!\\)\\\]', # 行间 \[...\]
304-
r'(?<!\\)\$(.*?)(?<!\\)\$', # 行内 $...$
305-
r'(?<!\\)\\\((.*?)(?<!\\)\\\)', # 行内 \(...\)
306-
]
307-
308-
for pattern in latex_patterns:
309-
for match in re.finditer(pattern, text, re.DOTALL):
310-
formula_full = match.group(0)
311-
formula_content = match.group(1)
312-
extracted_segments.append(formula_full)
313-
if formula_content.strip():
314-
regex_formulas.append(formula_content.strip())
315-
316-
# 第二步:根据字段类型决定是否需要API修正
317-
if field_name == "groundtruth_content":
318-
print(f"[DEBUG] 检测到groundtruth内容,仅使用正则提取公式")
319-
formula_parts = regex_formulas
320-
else:
321-
print(f"[DEBUG] 检测到llm_webkit_md内容,使用正则+API修正模式")
322-
# 对于llm_webkit_md,将正则结果传递给API进行修正
323-
if regex_formulas:
324-
# 将正则提取的公式作为输入传递给API
325-
regex_formulas_text = '\n'.join(regex_formulas)
326-
print(f"[DEBUG] 正则提取到 {len(regex_formulas)} 个公式,准备API修正")
327-
328-
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.cache')
329-
os.makedirs(cache_dir, exist_ok=True)
330-
331-
# 使用正则结果的哈希作为缓存文件名
332-
text_hash = hashlib.md5(regex_formulas_text.encode('utf-8')).hexdigest()
333-
cache_file = os.path.join(cache_dir, f'formula_correction_cache_{text_hash}.json')
334-
335-
try:
336-
from .formula_extractor import correct_formulas_with_llm
337-
corrected_formulas = correct_formulas_with_llm(regex_formulas, cache_file)
338-
formula_parts = corrected_formulas
339-
print(f"[DEBUG] API修正成功,最终得到 {len(formula_parts)} 个公式")
340-
except Exception as e:
341-
print(f"[DEBUG] API修正失败: {type(e).__name__}: {e},使用正则结果")
342-
formula_parts = regex_formulas
343-
else:
344-
print(f"[DEBUG] 正则未提取到公式,跳过API修正")
345-
formula_parts = []
346-
347-
# 提取表格
348-
table_parts = []
349-
350-
# ===== 1. 提取 HTML 表格 =====
351-
# 用 BeautifulSoup 替代正则,防止嵌套或匹配不全
352-
soup = BeautifulSoup(text, "html.parser")
353-
for table in soup.find_all("table"):
354-
# 判断当前表格的父级是否是表格内的标签(<td>、<tr>、<tbody>等)
355-
parent_is_table_related = table.find_parent(["td", "tr", "tbody", "table"]) is not None
356-
if not parent_is_table_related: # 父级不是表格相关标签 → 是外层表格
357-
html_table = str(table)
358-
extracted_segments.append(html_table)
359-
table_parts.append(html_table)
360-
361-
# ===== 2. 提取 Markdown 表格 =====
362-
lines = text.split('\n')
363-
table_lines = []
364-
in_markdown_table = False
365-
found_separator = False # 是否已找到分隔行
366-
367-
def is_md_table_line(line):
368-
"""判断是否可能是 Markdown 表格行"""
369-
if line.count("|") < 1: # 至少三个竖线
370-
return False
371-
return True
372-
373-
def is_md_separator_line(line):
374-
"""判断是否为 Markdown 分隔行"""
375-
parts = [p.strip() for p in line.split("|")]
376-
# 检查是否所有部分都是分隔符格式
377-
for p in parts:
378-
if p and not re.match(r"^:?\-{3,}:?$", p):
379-
return False
380-
return True
199+
# 创建提取器配置
200+
config = {
201+
'llm_base_url': '',
202+
'llm_api_key': '',
203+
'llm_model': '',
204+
'use_llm': False # 使用时改为True
205+
}
381206

382-
def save_table():
383-
"""保存当前表格并清空缓存"""
384-
nonlocal table_lines
385-
# 只有当表格行数大于等于2,且第二行是分隔行时才保存
386-
if len(table_lines) >= 2 and is_md_separator_line(table_lines[1]):
387-
md_table = '\n'.join(table_lines)
388-
extracted_segments.append(md_table)
389-
table_parts.append(md_table)
207+
# 直接创建具体的提取器实例
208+
from .code_extractor import CodeExtractor
209+
from .formula_extractor import FormulaExtractor
210+
from .table_extractor import TableExtractor
390211

391-
for line in lines:
392-
if is_md_table_line(line):
393-
table_lines.append(line)
394-
in_markdown_table = True
395-
if is_md_separator_line(line):
396-
found_separator = True
397-
else:
398-
if in_markdown_table:
399-
save_table()
400-
table_lines = []
401-
in_markdown_table = False
402-
found_separator = False
212+
code_extractor = CodeExtractor(config)
213+
formula_extractor = FormulaExtractor(config)
214+
table_extractor = TableExtractor(config)
403215

404-
# 处理文档末尾的 Markdown 表格
405-
if in_markdown_table:
406-
save_table()
216+
# 提取各类内容
217+
code_content = code_extractor.extract(text, field_name)
218+
formula_content = formula_extractor.extract(text, field_name)
219+
table_content = table_extractor.extract(text, field_name)
407220

408-
# 提取剩余文本(移除所有已提取的内容片段)
409-
clean_text = text
410-
for segment in extracted_segments:
411-
clean_text = clean_text.replace(segment, '', 1)
412-
413-
# 清理多余的空行
414-
clean_text = re.sub(r'\n\s*\n', '\n\n', clean_text)
415-
clean_text = clean_text.strip()
416-
417221
return {
418-
'code': '\n'.join(code_parts),
419-
'formula': '\n'.join(formula_parts),
420-
'table': '\n'.join(table_parts),
421-
'text': text # 原始全部文本
222+
'code': code_content,
223+
'formula': formula_content,
224+
'table': table_content,
225+
'text': text # 保留原始全部文本
422226
}
423-
227+
424228
def aggregate_results(self, results: List[MetricResult]) -> MetricResult:
425229
"""
426230
Aggregate multiple metric results.

webmainbench/metrics/base_extractor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from abc import ABC, abstractmethodfrom typing import List, Dict, Anyimport osimport hashlibimport jsonclass ContentExtractor(ABC): """抽象基类,用于从文本中提取特定类型的内容""" def __init__(self, config: Dict[str, Any] = None): """初始化提取器""" self.config = config or {} # 保留这行代码,用于控制是否使用LLM self.use_llm = self.config.get('use_llm', True) self.cache_dir = self.config.get('cache_dir', os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.cache')) os.makedirs(self.cache_dir, exist_ok=True) @abstractmethod def extract(self, text: str, field_name: str = None) -> str: """提取特定类型的内容""" pass @abstractmethod def extract_basic(self, text: str) -> List[str]: """使用基本方法提取内容(通常是正则表达式)""" pass def should_use_llm(self, field_name: str) -> bool: """判断是否应该使用LLM进行增强提取""" if not self.use_llm: return False # 默认逻辑:对groundtruth内容不使用LLM,对其他内容使用 if field_name == "groundtruth_content": print(f"[DEBUG] 检测到groundtruth内容,不使用LLM") return False return True def enhance_with_llm(self, basic_results: List[str], cache_key: str = None) -> List[str]: """使用LLM增强基本提取结果""" if not basic_results: print(f"[DEBUG] 输入内容为空,跳过LLM增强") return [] # 生成缓存键 if cache_key is None: content_str = '\n'.join(basic_results) cache_key = hashlib.md5(content_str.encode('utf-8')).hexdigest() cache_file = os.path.join(self.cache_dir, f'{self.__class__.__name__.lower()}_cache_{cache_key}.json') # 检查缓存 if os.path.exists(cache_file): try: with open(cache_file, 'r', encoding='utf-8') as f: cached_result = json.load(f) print(f"[DEBUG] 从缓存加载LLM增强结果: {len(cached_result)} 个") return cached_result except Exception as e: print(f"[DEBUG] 缓存读取失败: {e}") # 实际的LLM增强逻辑 try: enhanced_results = self._llm_enhance(basic_results) # 保存缓存 try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(enhanced_results, f, ensure_ascii=False, indent=2) print(f"[DEBUG] LLM增强结果已缓存到: {cache_file}") except Exception as e: print(f"[DEBUG] 缓存保存失败: {e}") return enhanced_results except Exception as e: print(f"[DEBUG] LLM增强失败: {type(e).__name__}: {e}") return basic_results @abstractmethod def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强基本提取结果的具体实现""" pass

webmainbench/metrics/code_extractor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# webmainbench/metrics/extractors/code_extractor.pyimport refrom typing import List, Dict, Anyfrom .base_extractor import ContentExtractorclass CodeExtractor(ContentExtractor): """从文本中提取代码块""" def extract(self, text: str, field_name: str = None) -> str: """提取代码块""" code_blocks = self.extract_basic(text) if self.should_use_llm(field_name): code_parts = self.enhance_with_llm(code_blocks) else: code_parts = code_blocks return '\n'.join(code_parts) def extract_basic(self, text: str) -> List[str]: """使用正则表达式提取代码块""" code_parts = [] # 处理三个反引号包裹的代码块 backtick_pattern = r'(```[\s\S]*?```)' for match in re.finditer(backtick_pattern, text): code_segment = match.group(0) if code_segment.startswith('```'): lines = code_segment.split('\n') content_lines = lines[1:-1] code_content = '\n'.join(content_lines) if code_content: code_parts.append(code_content) # 处理缩进代码块 - 定义缺失的模式 indent_pattern = r'(?:\n\s*\n)((?:(?: {4,}|\t+)[^\n]*(?:\n|$)){2,})(?=\n\s*\n|$)' for match in re.finditer(indent_pattern, text, re.MULTILINE): code_segment = match.group(1) # 验证:确保所有行都是缩进的 lines = code_segment.split('\n') all_indented = all( line.startswith(' ') or line.startswith('\t') or not line.strip() for line in lines if line.strip() ) if not all_indented: continue # 进一步验证代码特征 non_empty_lines = [line.strip() for line in lines if line.strip()] if len(non_empty_lines) < 2: continue # 检查是否有明显的非代码特征 has_list_features = any( re.match(r'^[-•*]\s', line) or re.match(r'^\d+\.\s', line) or re.search(r'\$[\d,]', line) or re.search(r'\b(million|billion|thousand)\b', line, re.IGNORECASE) for line in non_empty_lines ) if has_list_features: continue # 清理代码段 cleaned_lines = [] for line in code_segment.split('\n'): if line.strip(): if line.startswith(' '): cleaned_lines.append(line[4:]) elif line.startswith('\t'): cleaned_lines.append(line[1:]) else: cleaned_lines.append(line) code_content = '\n'.join(cleaned_lines) if code_content.strip(): code_parts.append(code_content) return code_parts def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强代码提取结果(未实现)""" print(f"[DEBUG] 代码LLM增强功能尚未实现,返回原始结果") return basic_results

0 commit comments

Comments
 (0)