Skip to content

Commit 129c079

Browse files
authored
Merge pull request #14 from e06084/main
feat: update llm-webkit extract
2 parents 07b095c + 782c8ff commit 129c079

5 files changed

Lines changed: 57 additions & 122 deletions
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
extractor,dataset,total_samples,success_rate,overall,code_edit,formula_edit,table_TEDS,table_edit,text_edit
2-
llm-webkit,sample_dataset,4,0.75,0.8667,1.0,1.0,1.0,1.0,0.3333
2+
llm-webkit,sample_dataset,4,0.5,0.9,1.0,1.0,1.0,1.0,0.5

results/sample_dataset_llm-webkit_evaluation_results.json

Lines changed: 19 additions & 107 deletions
Large diffs are not rendered by default.

results/sample_dataset_with_llm-webkit_extraction.jsonl

Lines changed: 4 additions & 4 deletions
Large diffs are not rendered by default.

webmainbench/data/saver.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def save_evaluation_results(results: Union["EvaluationResult", Dict[str, Any]],
9898
else:
9999
results_dict = results
100100

101+
# 移除extracted_content和extracted_content_list字段以减少文件大小
102+
results_dict = DataSaver._remove_content_fields(results_dict)
103+
101104
if format.lower() == "json":
102105
with open(file_path, 'w', encoding='utf-8') as f:
103106
json.dump(results_dict, f, indent=2, ensure_ascii=False)
@@ -265,6 +268,30 @@ def _save_jsonl_list(data_list: List[Dict[str, Any]], file_path: Union[str, Path
265268
json.dump(item, f, ensure_ascii=False)
266269
f.write('\n')
267270

271+
@staticmethod
272+
def _remove_content_fields(data: Dict[str, Any]) -> Dict[str, Any]:
273+
"""移除extracted_content和extracted_content_list字段以减少保存文件大小"""
274+
import copy
275+
276+
cleaned_data = copy.deepcopy(data)
277+
278+
def remove_fields(obj):
279+
if isinstance(obj, dict):
280+
# 移除extracted_content和extracted_content_list字段
281+
obj.pop('extracted_content', None)
282+
obj.pop('extracted_content_list', None)
283+
# 递归处理嵌套字典和列表
284+
for value in obj.values():
285+
if isinstance(value, (dict, list)):
286+
remove_fields(value)
287+
elif isinstance(obj, list):
288+
for item in obj:
289+
if isinstance(item, (dict, list)):
290+
remove_fields(item)
291+
292+
remove_fields(cleaned_data)
293+
return cleaned_data
294+
268295
@staticmethod
269296
def append_intermediate_results(results: List[Dict[str, Any]],
270297
file_path: Union[str, Path]) -> None:

webmainbench/extractors/llm_webkit_extractor.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -370,20 +370,16 @@ def _load_vllm_model(self):
370370
trust_remote_code=True
371371
)
372372

373-
# vLLM配置
373+
# vLLM配置 - 参考ray_test_qa.py的简化配置
374374
model_kwargs = {
375375
"model": self.inference_config.model_path,
376376
"trust_remote_code": True,
377377
"dtype": self.inference_config.dtype,
378378
"tensor_parallel_size": self.inference_config.tensor_parallel_size,
379-
"max_model_len": self.inference_config.max_tokens,
380-
"max_num_batched_tokens": max(self.inference_config.max_tokens, 8192),
381-
"gpu_memory_utilization": self.inference_config.gpu_memory_utilization,
382-
"enforce_eager": self.inference_config.enforce_eager,
383-
"disable_custom_all_reduce": True,
384-
"load_format": "auto",
385379
}
386380

381+
print(f"🔧 vLLM配置: {model_kwargs}")
382+
387383
self.model = LLM(**model_kwargs)
388384

389385
# 初始化token状态管理器
@@ -397,8 +393,8 @@ def _load_vllm_model(self):
397393
print("✅ vLLM模型加载成功!")
398394

399395
except Exception as e:
400-
print(f"⚠️ vLLM加载失败,回退到transformers: {e}")
401-
self._load_transformers_model()
396+
print(f"vLLM加载失败: {e}")
397+
raise RuntimeError(f"vLLM模型加载失败: {e}")
402398

403399
def _create_prompt(self, simplified_html: str) -> str:
404400
"""创建分类提示."""
@@ -463,7 +459,7 @@ def _generate_with_transformers(self, prompt: str) -> str:
463459

464460
except Exception as e:
465461
print(f"⚠️ transformers生成失败: {e}")
466-
return "{}"
462+
raise RuntimeError(f"transformers生成失败: {e}")
467463

468464
def _extract_json_from_text(self, text: str) -> str:
469465
"""从生成的文本中提取JSON"""

0 commit comments

Comments
 (0)