Skip to content

Commit daea117

Browse files
committed
修改命名,防止与extractor混淆
1 parent 96fd8a3 commit daea117

8 files changed

Lines changed: 24 additions & 24 deletions

File tree

examples/multi_extractor_compare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def all_extractor_comparison():
88
print("\n=== 多抽取器对比演示 ===\n")
99

1010
# 创建数据集
11-
dataset_path = Path("../data/sample_dataset.jsonl")
11+
dataset_path = Path("../data/test_math.jsonl")
1212
dataset = DataLoader.load_jsonl(dataset_path)
1313

1414
# 创建webkit抽取器

webmainbench/metrics/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from .teds_metrics import TEDSMetric, StructureTEDSMetric
1212
from .calculator import MetricCalculator
1313
from .mainhtml_calculator import MainHTMLMetricCalculator
14-
from .base_extractor import ContentExtractor
15-
from .formula_extractor import FormulaExtractor
16-
from .code_extractor import CodeExtractor
17-
from .table_extractor import TableExtractor
14+
from .base_content_splitter import BaseContentSplitter
15+
from .formula_extractor import FormulaSplitter
16+
from .code_extractor import CodeSplitter
17+
from .table_extractor import TableSplitter
1818

1919
__all__ = [
2020
"BaseMetric",
@@ -31,8 +31,8 @@
3131
"TextEditMetric",
3232
"MetricCalculator",
3333
"MainHTMLMetricCalculator",
34-
'ContentExtractor',
35-
'FormulaExtractor',
36-
'CodeExtractor',
37-
'TableExtractor',
34+
'BaseContentSplitter',
35+
'FormulaSplitter',
36+
'CodeSplitter',
37+
'TableSplitter',
3838
]

webmainbench/metrics/base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,20 +198,20 @@ def _extract_from_markdown(text: str, field_name: str = None) -> Dict[str, str]:
198198

199199
# 创建提取器配置
200200
config = {
201-
'llm_base_url': '',
202-
'llm_api_key': '',
203-
'llm_model': '',
204-
'use_llm': False # 使用时改为True
201+
'llm_base_url': 'http://35.220.164.252:3888/v1/',
202+
'llm_api_key': 'sk-PZgDr7sZdt77805Cg8s5ZB9QnGMGke61ovYnHYcHKIYVGHNA',
203+
'llm_model': 'deepseek-chat',
204+
'use_llm': True # 使用时改为True
205205
}
206206

207207
# 直接创建具体的提取器实例
208-
from .code_extractor import CodeExtractor
209-
from .formula_extractor import FormulaExtractor
210-
from .table_extractor import TableExtractor
208+
from .code_extractor import CodeSplitter
209+
from .formula_extractor import FormulaSplitter
210+
from .table_extractor import TableSplitter
211211

212-
code_extractor = CodeExtractor(config)
213-
formula_extractor = FormulaExtractor(config)
214-
table_extractor = TableExtractor(config)
212+
code_extractor = CodeSplitter(config)
213+
formula_extractor = FormulaSplitter(config)
214+
table_extractor = TableSplitter(config)
215215

216216
# 提取各类内容
217217
code_content = code_extractor.extract(text, field_name)

webmainbench/metrics/base_content_splitter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from abc import ABC, abstractmethodfrom typing import List, Dict, Anyimport osimport hashlibimport jsonclass BaseContentSplitter(ABC): """抽象基类,用于从文本中提取特定类型的内容""" def __init__(self, config: Dict[str, Any] = None): """初始化提取器""" self.config = config or {} # 保留这行代码,用于控制是否使用LLM self.use_llm = self.config.get('use_llm', True) self.cache_dir = self.config.get('cache_dir', os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.cache')) os.makedirs(self.cache_dir, exist_ok=True) @abstractmethod def extract(self, text: str, field_name: str = None) -> str: """提取特定类型的内容""" pass @abstractmethod def extract_basic(self, text: str) -> List[str]: """使用基本方法提取内容(通常是正则表达式)""" pass def should_use_llm(self, field_name: str) -> bool: """判断是否应该使用LLM进行增强提取""" if not self.use_llm: return False # 默认逻辑:对groundtruth内容不使用LLM,对其他内容使用 if field_name == "groundtruth_content": print(f"[DEBUG] 检测到groundtruth内容,不使用LLM") return False return True def enhance_with_llm(self, basic_results: List[str], cache_key: str = None) -> List[str]: """使用LLM增强基本提取结果""" if not basic_results: print(f"[DEBUG] 输入内容为空,跳过LLM增强") return [] # 生成缓存键 if cache_key is None: content_str = '\n'.join(basic_results) cache_key = hashlib.md5(content_str.encode('utf-8')).hexdigest() cache_file = os.path.join(self.cache_dir, f'{self.__class__.__name__.lower()}_cache_{cache_key}.json') # 检查缓存 if os.path.exists(cache_file): try: with open(cache_file, 'r', encoding='utf-8') as f: cached_result = json.load(f) print(f"[DEBUG] 从缓存加载LLM增强结果: {len(cached_result)} 个") return cached_result except Exception as e: print(f"[DEBUG] 缓存读取失败: {e}") # 实际的LLM增强逻辑 try: enhanced_results = self._llm_enhance(basic_results) # 保存缓存 try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(enhanced_results, f, ensure_ascii=False, indent=2) print(f"[DEBUG] LLM增强结果已缓存到: {cache_file}") except Exception as e: print(f"[DEBUG] 缓存保存失败: {e}") return enhanced_results except Exception as e: print(f"[DEBUG] LLM增强失败: {type(e).__name__}: {e}") return basic_results @abstractmethod def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强基本提取结果的具体实现""" pass

webmainbench/metrics/base_extractor.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

webmainbench/metrics/code_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
# webmainbench/metrics/extractors/code_extractor.pyimport refrom typing import List, Dict, Anyfrom .base_extractor import ContentExtractorclass CodeExtractor(ContentExtractor): """从文本中提取代码块""" def extract(self, text: str, field_name: str = None) -> str: """提取代码块""" code_blocks = self.extract_basic(text) if self.should_use_llm(field_name): code_parts = self.enhance_with_llm(code_blocks) else: code_parts = code_blocks return '\n'.join(code_parts) def extract_basic(self, text: str) -> List[str]: """使用正则表达式提取代码块""" code_parts = [] # 处理三个反引号包裹的代码块 backtick_pattern = r'(```[\s\S]*?```)' for match in re.finditer(backtick_pattern, text): code_segment = match.group(0) if code_segment.startswith('```'): lines = code_segment.split('\n') content_lines = lines[1:-1] code_content = '\n'.join(content_lines) if code_content: code_parts.append(code_content) # 处理缩进代码块 - 定义缺失的模式 indent_pattern = r'(?:\n\s*\n)((?:(?: {4,}|\t+)[^\n]*(?:\n|$)){2,})(?=\n\s*\n|$)' for match in re.finditer(indent_pattern, text, re.MULTILINE): code_segment = match.group(1) # 验证:确保所有行都是缩进的 lines = code_segment.split('\n') all_indented = all( line.startswith(' ') or line.startswith('\t') or not line.strip() for line in lines if line.strip() ) if not all_indented: continue # 进一步验证代码特征 non_empty_lines = [line.strip() for line in lines if line.strip()] if len(non_empty_lines) < 2: continue # 检查是否有明显的非代码特征 has_list_features = any( re.match(r'^[-•*]\s', line) or re.match(r'^\d+\.\s', line) or re.search(r'\$[\d,]', line) or re.search(r'\b(million|billion|thousand)\b', line, re.IGNORECASE) for line in non_empty_lines ) if has_list_features: continue # 清理代码段 cleaned_lines = [] for line in code_segment.split('\n'): if line.strip(): if line.startswith(' '): cleaned_lines.append(line[4:]) elif line.startswith('\t'): cleaned_lines.append(line[1:]) else: cleaned_lines.append(line) code_content = '\n'.join(cleaned_lines) if code_content.strip(): code_parts.append(code_content) return code_parts def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强代码提取结果(未实现)""" print(f"[DEBUG] 代码LLM增强功能尚未实现,返回原始结果") return basic_results
1+
# webmainbench/metrics/extractors/code_extractor.pyimport refrom typing import List, Dict, Anyfrom .base_content_splitter import BaseContentSplitterclass CodeSplitter(BaseContentSplitter): """从文本中提取代码块""" def extract(self, text: str, field_name: str = None) -> str: """提取代码块""" code_blocks = self.extract_basic(text) if self.should_use_llm(field_name): code_parts = self.enhance_with_llm(code_blocks) else: code_parts = code_blocks return '\n'.join(code_parts) def extract_basic(self, text: str) -> List[str]: """使用正则表达式提取代码块""" code_parts = [] # 处理三个反引号包裹的代码块 backtick_pattern = r'(```[\s\S]*?```)' for match in re.finditer(backtick_pattern, text): code_segment = match.group(0) if code_segment.startswith('```'): lines = code_segment.split('\n') content_lines = lines[1:-1] code_content = '\n'.join(content_lines) if code_content: code_parts.append(code_content) # 处理缩进代码块 - 定义缺失的模式 indent_pattern = r'(?:\n\s*\n)((?:(?: {4,}|\t+)[^\n]*(?:\n|$)){2,})(?=\n\s*\n|$)' for match in re.finditer(indent_pattern, text, re.MULTILINE): code_segment = match.group(1) # 验证:确保所有行都是缩进的 lines = code_segment.split('\n') all_indented = all( line.startswith(' ') or line.startswith('\t') or not line.strip() for line in lines if line.strip() ) if not all_indented: continue # 进一步验证代码特征 non_empty_lines = [line.strip() for line in lines if line.strip()] if len(non_empty_lines) < 2: continue # 检查是否有明显的非代码特征 has_list_features = any( re.match(r'^[-•*]\s', line) or re.match(r'^\d+\.\s', line) or re.search(r'\$[\d,]', line) or re.search(r'\b(million|billion|thousand)\b', line, re.IGNORECASE) for line in non_empty_lines ) if has_list_features: continue # 清理代码段 cleaned_lines = [] for line in code_segment.split('\n'): if line.strip(): if line.startswith(' '): cleaned_lines.append(line[4:]) elif line.startswith('\t'): cleaned_lines.append(line[1:]) else: cleaned_lines.append(line) code_content = '\n'.join(cleaned_lines) if code_content.strip(): code_parts.append(code_content) return code_parts def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强代码提取结果(未实现)""" print(f"[DEBUG] 代码LLM增强功能尚未实现,返回原始结果") return basic_results

webmainbench/metrics/formula_extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from typing import List
33
from openai import OpenAI
44

5-
from .base_extractor import ContentExtractor
5+
from .base_content_splitter import BaseContentSplitter
66

7-
class FormulaExtractor(ContentExtractor):
7+
class FormulaSplitter(BaseContentSplitter):
88
"""从文本中提取数学公式"""
99

1010
def extract(self, text: str, field_name: str = None) -> str:

0 commit comments

Comments
 (0)