|
1 | 1 | from astrbot.api import sp, star |
2 | 2 | from astrbot.api.event import AstrMessageEvent, MessageEventResult |
| 3 | +from astrbot.core import logger |
3 | 4 | from astrbot.core.agent.runners.deerflow.constants import ( |
| 5 | + DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY, |
4 | 6 | DEERFLOW_PROVIDER_TYPE, |
5 | 7 | DEERFLOW_THREAD_ID_KEY, |
6 | 8 | ) |
| 9 | +from astrbot.core.agent.runners.deerflow.deerflow_api_client import DeerFlowAPIClient |
7 | 10 | from astrbot.core.utils.active_event_registry import active_event_registry |
8 | 11 |
|
9 | 12 | from .utils.rst_scene import RstScene |
|
17 | 20 | THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys()) |
18 | 21 |
|
19 | 22 |
|
| 23 | +async def _cleanup_deerflow_thread_if_present( |
| 24 | + context: star.Context, |
| 25 | + umo: str, |
| 26 | +) -> None: |
| 27 | + try: |
| 28 | + thread_id = await sp.get_async( |
| 29 | + scope="umo", |
| 30 | + scope_id=umo, |
| 31 | + key=DEERFLOW_THREAD_ID_KEY, |
| 32 | + default="", |
| 33 | + ) |
| 34 | + if not thread_id: |
| 35 | + return |
| 36 | + |
| 37 | + cfg = context.get_config(umo=umo) |
| 38 | + provider_id = cfg["provider_settings"].get( |
| 39 | + DEERFLOW_AGENT_RUNNER_PROVIDER_ID_KEY, |
| 40 | + "", |
| 41 | + ) |
| 42 | + if not provider_id: |
| 43 | + return |
| 44 | + |
| 45 | + merged_provider_config = context.provider_manager.get_provider_config_by_id( |
| 46 | + provider_id, |
| 47 | + merged=True, |
| 48 | + ) |
| 49 | + if not merged_provider_config: |
| 50 | + logger.warning( |
| 51 | + "Failed to resolve DeerFlow provider config for remote thread cleanup: provider_id=%s", |
| 52 | + provider_id, |
| 53 | + ) |
| 54 | + return |
| 55 | + |
| 56 | + client = DeerFlowAPIClient( |
| 57 | + api_base=merged_provider_config.get( |
| 58 | + "deerflow_api_base", |
| 59 | + "http://127.0.0.1:2026", |
| 60 | + ), |
| 61 | + api_key=merged_provider_config.get("deerflow_api_key", ""), |
| 62 | + auth_header=merged_provider_config.get("deerflow_auth_header", ""), |
| 63 | + proxy=merged_provider_config.get("proxy", ""), |
| 64 | + ) |
| 65 | + try: |
| 66 | + await client.delete_thread(thread_id) |
| 67 | + finally: |
| 68 | + try: |
| 69 | + await client.close() |
| 70 | + except Exception as e: |
| 71 | + logger.warning( |
| 72 | + "Failed to close DeerFlow API client after thread cleanup: %s", |
| 73 | + e, |
| 74 | + ) |
| 75 | + except Exception as e: |
| 76 | + logger.warning( |
| 77 | + "Failed to clean up DeerFlow thread for session %s: %s", |
| 78 | + umo, |
| 79 | + e, |
| 80 | + ) |
| 81 | + |
| 82 | + |
| 83 | +async def _clear_third_party_agent_runner_state( |
| 84 | + context: star.Context, |
| 85 | + umo: str, |
| 86 | + agent_runner_type: str, |
| 87 | +) -> None: |
| 88 | + session_key = THIRD_PARTY_AGENT_RUNNER_KEY.get(agent_runner_type) |
| 89 | + if not session_key: |
| 90 | + return |
| 91 | + |
| 92 | + if agent_runner_type == DEERFLOW_PROVIDER_TYPE: |
| 93 | + await _cleanup_deerflow_thread_if_present(context, umo) |
| 94 | + |
| 95 | + await sp.remove_async( |
| 96 | + scope="umo", |
| 97 | + scope_id=umo, |
| 98 | + key=session_key, |
| 99 | + ) |
| 100 | + |
| 101 | + |
20 | 102 | class ConversationCommands: |
21 | 103 | def __init__(self, context: star.Context) -> None: |
22 | 104 | self.context = context |
@@ -65,10 +147,10 @@ async def reset(self, message: AstrMessageEvent) -> None: |
65 | 147 | agent_runner_type = cfg["provider_settings"]["agent_runner_type"] |
66 | 148 | if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: |
67 | 149 | active_event_registry.stop_all(umo, exclude=message) |
68 | | - await sp.remove_async( |
69 | | - scope="umo", |
70 | | - scope_id=umo, |
71 | | - key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type], |
| 150 | + await _clear_third_party_agent_runner_state( |
| 151 | + self.context, |
| 152 | + umo, |
| 153 | + agent_runner_type, |
72 | 154 | ) |
73 | 155 | message.set_result( |
74 | 156 | MessageEventResult().message("✅ Conversation reset successfully.") |
@@ -139,10 +221,10 @@ async def new_conv(self, message: AstrMessageEvent) -> None: |
139 | 221 | agent_runner_type = cfg["provider_settings"]["agent_runner_type"] |
140 | 222 | if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: |
141 | 223 | active_event_registry.stop_all(message.unified_msg_origin, exclude=message) |
142 | | - await sp.remove_async( |
143 | | - scope="umo", |
144 | | - scope_id=message.unified_msg_origin, |
145 | | - key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type], |
| 224 | + await _clear_third_party_agent_runner_state( |
| 225 | + self.context, |
| 226 | + message.unified_msg_origin, |
| 227 | + agent_runner_type, |
146 | 228 | ) |
147 | 229 | message.set_result( |
148 | 230 | MessageEventResult().message("✅ New conversation created.") |
|
0 commit comments