Skip to content

Commit 6182dfc

Browse files
committed
将prompt设置为基类的变量
1 parent 643e206 commit 6182dfc

File tree

2 files changed

+43
-45
lines changed

2 files changed

+43
-45
lines changed

webmainbench/metrics/base_content_splitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from abc import ABC, abstractmethodfrom typing import List, Dict, Anyimport osimport hashlibimport jsonfrom openai import OpenAIclass BaseContentSplitter(ABC): """抽象基类,用于从文本中提取特定类型的内容""" def __init__(self, config: Dict[str, Any] = None): """初始化提取器""" self.config = config or {} # 保留这行代码,用于控制是否使用LLM self.use_llm = self.config.get('use_llm', True) # 初始化OpenAI客户端(如果配置了LLM) if self.use_llm and self.config.get('llm_base_url') and self.config.get('llm_api_key'): self.client = OpenAI( base_url=self.config.get('llm_base_url', ""), api_key=self.config.get('llm_api_key', "") ) else: self.client = None self.cache_dir = self.config.get('cache_dir', os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.cache')) os.makedirs(self.cache_dir, exist_ok=True) @abstractmethod def extract(self, text: str, field_name: str = None) -> str: """提取特定类型的内容""" pass @abstractmethod def extract_basic(self, text: str) -> List[str]: """使用基本方法提取内容(通常是正则表达式)""" pass def should_use_llm(self, field_name: str) -> bool: """判断是否应该使用LLM进行增强提取""" if not self.use_llm: return False # 默认逻辑:对groundtruth内容不使用LLM,对其他内容使用 if field_name == "groundtruth_content": print(f"[DEBUG] 检测到groundtruth内容,不使用LLM") return False return True def enhance_with_llm(self, basic_results: List[str], cache_key: str = None) -> List[str]: """使用LLM增强基本提取结果""" if not basic_results: print(f"[DEBUG] 输入内容为空,跳过LLM增强") return [] # 生成缓存键 if cache_key is None: content_str = '\n'.join(basic_results) cache_key = hashlib.md5(content_str.encode('utf-8')).hexdigest() cache_file = os.path.join(self.cache_dir, f'{self.__class__.__name__.lower()}_cache_{cache_key}.json') # 检查缓存 if os.path.exists(cache_file): try: with open(cache_file, 'r', encoding='utf-8') as f: cached_result = json.load(f) print(f"[DEBUG] 从缓存加载LLM增强结果: {len(cached_result)} 个") return cached_result except Exception as e: print(f"[DEBUG] 缓存读取失败: {e}") # 实际的LLM增强逻辑 try: enhanced_results = self._llm_enhance(basic_results) # 保存缓存 try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(enhanced_results, f, ensure_ascii=False, indent=2) print(f"[DEBUG] LLM增强结果已缓存到: {cache_file}") except Exception as e: print(f"[DEBUG] 缓存保存失败: {e}") return enhanced_results except Exception as e: print(f"[DEBUG] LLM增强失败: {type(e).__name__}: {e}") return basic_results @abstractmethod def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强基本提取结果的具体实现""" pass
1+
from abc import ABC, abstractmethodfrom typing import List, Dict, Anyimport osimport hashlibimport jsonfrom openai import OpenAIclass BaseContentSplitter(ABC): """抽象基类,用于从文本中提取特定类型的内容""" # 默认的LLM提示词模板 DEFAULT_LLM_PROMPT = """请处理以下内容: {content} """ def __init__(self, config: Dict[str, Any] = None): """初始化提取器""" self.config = config or {} # 保留这行代码,用于控制是否使用LLM self.use_llm = self.config.get('use_llm', True) # 初始化OpenAI客户端(如果配置了LLM) if self.use_llm and self.config.get('llm_base_url') and self.config.get('llm_api_key'): self.client = OpenAI( base_url=self.config.get('llm_base_url', ""), api_key=self.config.get('llm_api_key', "") ) else: self.client = None self.cache_dir = self.config.get('cache_dir', os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.cache')) os.makedirs(self.cache_dir, exist_ok=True) @abstractmethod def extract(self, text: str, field_name: str = None) -> str: """提取特定类型的内容""" pass @abstractmethod def extract_basic(self, text: str) -> List[str]: """使用基本方法提取内容(通常是正则表达式)""" pass def should_use_llm(self, field_name: str) -> bool: """判断是否应该使用LLM进行增强提取""" if not self.use_llm: return False # 默认逻辑:对groundtruth内容不使用LLM,对其他内容使用 if field_name == "groundtruth_content": print(f"[DEBUG] 检测到groundtruth内容,不使用LLM") return False return True def enhance_with_llm(self, basic_results: List[str], cache_key: str = None) -> List[str]: """使用LLM增强基本提取结果""" if not basic_results: print(f"[DEBUG] 输入内容为空,跳过LLM增强") return [] # 生成缓存键 if cache_key is None: content_str = '\n'.join(basic_results) cache_key = hashlib.md5(content_str.encode('utf-8')).hexdigest() cache_file = os.path.join(self.cache_dir, f'{self.__class__.__name__.lower()}_cache_{cache_key}.json') # 检查缓存 if os.path.exists(cache_file): try: with open(cache_file, 'r', encoding='utf-8') as f: cached_result = json.load(f) print(f"[DEBUG] 从缓存加载LLM增强结果: {len(cached_result)} 个") return cached_result except Exception as e: print(f"[DEBUG] 缓存读取失败: {e}") # 实际的LLM增强逻辑 try: enhanced_results = self._llm_enhance(basic_results) # 保存缓存 try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(enhanced_results, f, ensure_ascii=False, indent=2) print(f"[DEBUG] LLM增强结果已缓存到: {cache_file}") except Exception as e: print(f"[DEBUG] 缓存保存失败: {e}") return enhanced_results except Exception as e: print(f"[DEBUG] LLM增强失败: {type(e).__name__}: {e}") return basic_results @abstractmethod def _llm_enhance(self, basic_results: List[str]) -> List[str]: """使用LLM增强基本提取结果的具体实现""" pass

webmainbench/metrics/formula_extractor.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,54 @@
66
class FormulaSplitter(BaseContentSplitter):
77
"""从文本中提取数学公式"""
88

9+
DEFAULT_LLM_PROMPT = '''任务:请从以下正则表达式提取的内容中,识别并保留真正的LaTeX数学公式,剔除货币形式的内容。
10+
11+
### 识别规则
12+
**真正的数学公式**(保留):
13+
- 包含数学符号:+ - × ÷ = < > ≤ ≥ ± ∞ ∑ ∫ ∂ √ ^ _ { } 等
14+
- 包含希腊字母:α β γ δ θ λ μ π σ ω 等
15+
- 包含LaTeX命令:\\frac \\sum \\int \\sqrt \\alpha \\beta \\sin \\cos 等
16+
- 包含数学表达式:变量、函数、方程等
17+
18+
**货币形式内容**(剔除):
19+
- 仅包含数字、逗号、小数点的价格:如 1,150.00
20+
- 纯粹的金额数值:如 25.99、1,200、5.50
21+
- 不包含任何数学运算符或数学符号的数字
22+
23+
### 处理要求
24+
1. **严格区分**:只保留真正的数学公式,剔除所有货币价格
25+
2. **格式标准化**:统一公式格式,确保LaTeX语法正确
26+
3. **保持原意**:不修改数学公式内容
27+
28+
### 输出格式
29+
- 每个有效的数学公式独占一行
30+
- 只输出公式内容,不包含$符号或其他包装
31+
- 如果输入不是有效的数学公式(如货币),则输出<空>
32+
- 按原顺序输出保留的公式
33+
34+
### 示例 1 (剔除后有有效公式)
35+
输入:1,150.00 → 剔除(货币)
36+
输入:x^2 + y^2 = r^2 → 保留(数学公式)
37+
输入:25.99 → 剔除(货币)
38+
输入:\\frac{a}{b} + c → 保留(数学公式)
39+
40+
### 示例 2 (剔除后无有效公式)
41+
输入:1,150.00 → 剔除(货币)
42+
输入:25.99 → 剔除(货币)
43+
44+
输出:<空>
45+
46+
注意,输出结果中不要添加任何解释!。
47+
[输入内容列表开始]'''
48+
949
def extract(self, text: str, field_name: str = None) -> str:
1050
"""提取数学公式"""
1151
regex_formulas = self.extract_basic(text)
12-
1352
if self.should_use_llm(field_name):
1453
print(f"[DEBUG] 使用LLM增强公式提取")
1554
formula_parts = self.enhance_with_llm(regex_formulas)
1655
else:
1756
formula_parts = regex_formulas
18-
1957
return '\n'.join(formula_parts)
2058

2159
def extract_basic(self, text: str) -> List[str]:
@@ -44,51 +82,11 @@ def _llm_enhance(self, basic_results: List[str]) -> List[str]:
4482

4583
formulas_text = '\n'.join(basic_results)
4684

47-
CORRECTION_PROMPT = '''任务:请从以下正则表达式提取的内容中,识别并保留真正的LaTeX数学公式,剔除货币形式的内容。
48-
49-
### 识别规则
50-
**真正的数学公式**(保留):
51-
- 包含数学符号:+ - × ÷ = < > ≤ ≥ ± ∞ ∑ ∫ ∂ √ ^ _ { } 等
52-
- 包含希腊字母:α β γ δ θ λ μ π σ ω 等
53-
- 包含LaTeX命令:\\frac \\sum \\int \\sqrt \\alpha \\beta \\sin \\cos 等
54-
- 包含数学表达式:变量、函数、方程等
55-
56-
**货币形式内容**(剔除):
57-
- 仅包含数字、逗号、小数点的价格:如 1,150.00
58-
- 纯粹的金额数值:如 25.99、1,200、5.50
59-
- 不包含任何数学运算符或数学符号的数字
60-
61-
### 处理要求
62-
1. **严格区分**:只保留真正的数学公式,剔除所有货币价格
63-
2. **格式标准化**:统一公式格式,确保LaTeX语法正确
64-
3. **保持原意**:不修改数学公式内容
65-
66-
### 输出格式
67-
- 每个有效的数学公式独占一行
68-
- 只输出公式内容,不包含$符号或其他包装
69-
- 如果输入不是有效的数学公式(如货币),则输出<空>
70-
- 按原顺序输出保留的公式
71-
72-
### 示例 1 (剔除后有有效公式)
73-
输入:1,150.00 → 剔除(货币)
74-
输入:x^2 + y^2 = r^2 → 保留(数学公式)
75-
输入:25.99 → 剔除(货币)
76-
输入:\\frac{a}{b} + c → 保留(数学公式)
77-
78-
### 示例 2 (剔除后无有效公式)
79-
输入:1,150.00 → 剔除(货币)
80-
输入:25.99 → 剔除(货币)
81-
82-
输出:<空>
83-
84-
注意,输出结果中不要添加任何解释!。
85-
[输入内容列表开始]'''
86-
8785
response = self.client.chat.completions.create(
8886
model=self.config.get('llm_model', "deepseek-chat"),
8987
temperature=0,
9088
messages=[
91-
{"role": "user", "content": CORRECTION_PROMPT + f"\n{formulas_text}\n" + '''[输入内容列表结束]
89+
{"role": "user", "content": self.DEFAULT_LLM_PROMPT + f"\n{formulas_text}\n" + '''[输入内容列表结束]
9290
---
9391
请按要求识别并输出真正的数学公式,剔除货币形式的内容。
9492
---'''}
@@ -102,4 +100,4 @@ def _llm_enhance(self, basic_results: List[str]) -> List[str]:
102100
elif not result_text:
103101
return []
104102
else:
105-
return [line.strip() for line in result_text.split('\n') if line.strip()]
103+
return [line.strip() for line in result_text.split('\n') if line.strip()]

0 commit comments

Comments
 (0)