Skip to content

Commit b34a46f

Browse files
authored
Merge pull request #41 from pekopoke/dev
fix code match
2 parents f5eed23 + 3514e04 commit b34a46f

2 files changed

Lines changed: 101 additions & 32 deletions

File tree

tests/test_code_extraction.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,6 @@ def test_code_block(self):
6868
'ijkl'
6969
""")
7070
self.assertEqual(result['code'], expected_code.strip())
71-
72-
# 验证清理后的文本
73-
expected_text = """
74-
I have the following string:
75-
How can I get the last four characters and store them in a string using Python?
76-
Like this:
77-
"""
78-
self.assertEqual(result['text'], text)
7971
self.assertEqual(result['formula'], '')
8072

8173
# def test_code_with_leading_trailing_spaces(self):
@@ -92,5 +84,28 @@ def test_code_block(self):
9284
# self.assertEqual(result['code'], '') # 不应该匹配多行行内代码
9385
# self.assertEqual(result['text'], text) # 原样保留
9486

87+
def test_indent_code_block(self):
88+
"""测试代码块"""
89+
text = """
90+
I have the following string: `"aaaabbbb"`
91+
How can I get the last four characters and store them in a string using Python?
92+
Like this:
93+
94+
print("hello world")
95+
print("hi")
96+
97+
"""
98+
99+
result = BaseMetric._extract_from_markdown(text)
100+
101+
# 验证提取的代码
102+
expected_code = ("""
103+
print("hello world")
104+
print("hi")
105+
""")
106+
self.assertEqual(result['code'], expected_code.strip())
107+
self.assertEqual(result['formula'], '')
108+
109+
95110
if __name__ == '__main__':
96111
unittest.main()

webmainbench/metrics/base.py

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -202,41 +202,95 @@ def _extract_from_markdown(text: str) -> Dict[str, str]:
202202
# 收集所有需要移除的内容片段
203203
extracted_segments = []
204204
code_parts = []
205-
# 同时匹配行内代码 `...` 和代码块 ```...```
206-
pattern = r'(```[\s\S]*?```)'
207-
for match in re.finditer(pattern, text):
205+
# # 同匹配行间代码块 ```...```
206+
# pattern = r'(```[\s\S]*?```)'
207+
# for match in re.finditer(pattern, text):
208+
# code_segment = match.group(0)
209+
# extracted_segments.append(code_segment)
210+
#
211+
# if code_segment.startswith('```'):
212+
# # 处理代码块(保留内部缩进)
213+
# lines = code_segment.split('\n')
214+
# # 移除首尾的```标记
215+
# content_lines = lines[1:-1]
216+
# # 保留原始缩进,只拼接内容
217+
# code_content = '\n'.join(content_lines)
218+
# else:
219+
# # 处理行内代码(只去除外层`和前后空格)
220+
# code_content = code_segment[1:-1].strip()
221+
#
222+
# if code_content: # 只添加非空内容
223+
# code_parts.append(code_content)
224+
225+
# 1. 首先处理三个反引号包裹的代码块(优先级最高)
226+
backtick_pattern = r'(```[\s\S]*?```)'
227+
for match in re.finditer(backtick_pattern, text):
208228
code_segment = match.group(0)
209-
extracted_segments.append(code_segment)
210229

211230
if code_segment.startswith('```'):
212-
# 处理代码块(保留内部缩进)
231+
# 处理代码块
213232
lines = code_segment.split('\n')
214233
# 移除首尾的```标记
215234
content_lines = lines[1:-1]
216-
# 保留原始缩进,只拼接内容
217235
code_content = '\n'.join(content_lines)
218236
else:
219-
# 处理行内代码(只去除外层`和前后空格)
237+
# 处理行内代码
220238
code_content = code_segment[1:-1].strip()
221239

222-
if code_content: # 只添加非空内容
240+
if code_content:
223241
code_parts.append(code_content)
224-
225-
# # 提取代码
226-
# code_parts = []
227-
# # 代码块 ```code```
228-
# for match in re.finditer(r'```[\s\S]*?```', text):
229-
# code_block = match.group(0)
230-
# extracted_segments.append(code_block)
231-
# code_parts.append(code_block.strip('`').strip())
232-
#
233-
# # 行内代码 `code`
234-
# for match in re.finditer(r'`([^`]+)`', text):
235-
# inline_code_full = match.group(0) # 包含反引号的完整匹配
236-
# inline_code_content = match.group(1) # 只是内容
237-
# extracted_segments.append(inline_code_full)
238-
# code_parts.append(inline_code_content)
239-
242+
243+
# 2. 处理缩进代码块 - 使用更精确的匹配
244+
# 匹配模式:前面有空行 + 连续的多行缩进内容 + 后面有空行
245+
# 关键:要求所有匹配的行都是缩进的
246+
indent_pattern = r'(?:\n\s*\n)((?:(?: {4,}|\t+)[^\n]*(?:\n|$)){2,})(?=\n\s*\n|$)'
247+
248+
for match in re.finditer(indent_pattern, text, re.MULTILINE):
249+
code_segment = match.group(1)
250+
251+
# 验证:确保所有行都是缩进的(避免混合缩进和非缩进行)
252+
lines = code_segment.split('\n')
253+
all_indented = all(
254+
line.startswith(' ') or line.startswith('\t') or not line.strip()
255+
for line in lines
256+
if line.strip() # 空行不算
257+
)
258+
259+
if not all_indented:
260+
continue # 跳过包含非缩进行的块
261+
262+
# 进一步验证代码特征
263+
non_empty_lines = [line.strip() for line in lines if line.strip()]
264+
if len(non_empty_lines) < 2: # 至少2行非空内容
265+
continue
266+
267+
# 检查是否有明显的非代码特征
268+
has_list_features = any(
269+
re.match(r'^[-•*]\s', line) or
270+
re.match(r'^\d+\.\s', line) or
271+
re.search(r'\$[\d,]', line) or
272+
re.search(r'\b(million|billion|thousand)\b', line, re.IGNORECASE)
273+
for line in non_empty_lines
274+
)
275+
276+
if has_list_features:
277+
continue # 跳过列表内容
278+
279+
# 清理代码段
280+
cleaned_lines = []
281+
for line in code_segment.split('\n'):
282+
if line.strip():
283+
if line.startswith(' '):
284+
cleaned_lines.append(line[4:])
285+
elif line.startswith('\t'):
286+
cleaned_lines.append(line[1:])
287+
else:
288+
cleaned_lines.append(line)
289+
290+
code_content = '\n'.join(cleaned_lines)
291+
if code_content.strip():
292+
code_parts.append(code_content)
293+
240294
# 提取公式
241295
formula_parts = []
242296
# 统一的公式提取模式

0 commit comments

Comments
 (0)