-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Expand file tree
/
Copy pathtest_interactive_interrupt.py
More file actions
203 lines (173 loc) Β· 7.5 KB
/
test_interactive_interrupt.py
File metadata and controls
203 lines (173 loc) Β· 7.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#!/usr/bin/env python3
"""Interactive interrupt test that mimics the exact CLI flow.
Starts an agent in a thread with a mock delegate_task that takes a while,
then simulates the user typing a message via _interrupt_queue.
Logs every step to stderr (which isn't affected by redirect_stdout)
so we can see exactly where the interrupt gets lost.
"""
import contextlib
import io
import json
import logging
import queue
import sys
import threading
import time
import os
# Force stderr logging so redirect_stdout doesn't swallow it
logging.basicConfig(level=logging.DEBUG, stream=sys.stderr,
format="%(asctime)s [%(threadName)s] %(message)s")
log = logging.getLogger("interrupt_test")
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from unittest.mock import MagicMock, patch
from run_agent import AIAgent, IterationBudget
from tools.interrupt import set_interrupt, is_interrupted
def make_slow_response(delay=2.0):
"""API response that takes a while."""
def create(**kwargs):
log.info(f" π Mock API call starting (will take {delay}s)...")
time.sleep(delay)
log.info(f" π Mock API call completed")
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = "Done with the task"
resp.choices[0].message.tool_calls = None
resp.choices[0].message.refusal = None
resp.choices[0].finish_reason = "stop"
resp.usage.prompt_tokens = 100
resp.usage.completion_tokens = 10
resp.usage.total_tokens = 110
resp.usage.prompt_tokens_details = None
return resp
return create
def main() -> int:
set_interrupt(False)
# βββ Create parent agent βββ
parent = AIAgent.__new__(AIAgent)
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
parent.model = "test/model"
parent.base_url = "http://localhost:1"
parent.api_key = "test"
parent.provider = "test"
parent.api_mode = "chat_completions"
parent.platform = "cli"
parent.enabled_toolsets = ["terminal", "file"]
parent.providers_allowed = None
parent.providers_ignored = None
parent.providers_order = None
parent.provider_sort = None
parent.max_tokens = None
parent.reasoning_config = None
parent.prefill_messages = None
parent._session_db = None
parent._delegate_depth = 0
parent._delegate_spinner = None
parent.tool_progress_callback = None
parent.iteration_budget = IterationBudget(max_total=100)
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
# Monkey-patch parent.interrupt to log
_original_interrupt = AIAgent.interrupt
def logged_interrupt(self, message=None):
log.info(f"π΄ parent.interrupt() called with: {message!r}")
log.info(f" _active_children count: {len(self._active_children)}")
_original_interrupt(self, message)
log.info(f" After interrupt: _interrupt_requested={self._interrupt_requested}")
for i, child in enumerate(self._active_children):
log.info(f" Child {i}._interrupt_requested={child._interrupt_requested}")
parent.interrupt = lambda msg=None: logged_interrupt(parent, msg)
# βββ Simulate the exact CLI flow βββ
interrupt_queue = queue.Queue()
child_running = threading.Event()
agent_result = [None]
def agent_thread_func():
"""Simulates the agent_thread in cli.py's chat() method."""
log.info("π’ agent_thread starting")
with patch("run_agent.OpenAI") as MockOpenAI:
mock_client = MagicMock()
mock_client.chat.completions.create = make_slow_response(delay=3.0)
mock_client.close = MagicMock()
MockOpenAI.return_value = mock_client
from tools.delegate_tool import _run_single_child
# Signal that child is about to start
original_init = AIAgent.__init__
def patched_init(self_agent, *a, **kw):
log.info("π‘ Child AIAgent.__init__ called")
original_init(self_agent, *a, **kw)
child_running.set()
log.info(
f"π‘ Child started, parent._active_children = {len(parent._active_children)}"
)
with patch.object(AIAgent, "__init__", patched_init):
result = _run_single_child(
task_index=0,
goal="Do a slow thing",
context=None,
toolsets=["terminal"],
model="test/model",
max_iterations=3,
parent_agent=parent,
task_count=1,
override_provider="test",
override_base_url="http://localhost:1",
override_api_key="test",
override_api_mode="chat_completions",
)
agent_result[0] = result
log.info(f"π’ agent_thread finished. Result status: {result.get('status')}")
# βββ Start agent thread (like chat() does) βββ
agent_thread = threading.Thread(target=agent_thread_func, name="agent_thread", daemon=True)
agent_thread.start()
# βββ Wait for child to start βββ
if not child_running.wait(timeout=10):
print("FAIL: Child never started", file=sys.stderr)
set_interrupt(False)
return 1
# Give child time to enter its main loop and start API call
time.sleep(1.0)
# βββ Simulate user typing a message (like handle_enter does) βββ
log.info("π Simulating user typing 'Hey stop that'")
interrupt_queue.put("Hey stop that")
# βββ Simulate chat() polling loop (like the real chat() method) βββ
log.info("π‘ Starting interrupt queue polling (like chat())")
interrupt_msg = None
poll_count = 0
while agent_thread.is_alive():
try:
interrupt_msg = interrupt_queue.get(timeout=0.1)
if interrupt_msg:
log.info(f"π¨ Got interrupt message from queue: {interrupt_msg!r}")
log.info(" Calling parent.interrupt()...")
parent.interrupt(interrupt_msg)
log.info(" parent.interrupt() returned. Breaking poll loop.")
break
except queue.Empty:
poll_count += 1
if poll_count % 20 == 0: # Log every 2s
log.info(f" Still polling ({poll_count} iterations)...")
# βββ Wait for agent to finish βββ
log.info("β³ Waiting for agent_thread to join...")
t0 = time.monotonic()
agent_thread.join(timeout=10)
elapsed = time.monotonic() - t0
log.info(f"β
agent_thread joined after {elapsed:.2f}s")
# βββ Check results βββ
result = agent_result[0]
if result:
log.info(f"Result status: {result['status']}")
log.info(f"Result duration: {result['duration_seconds']}s")
if result["status"] == "interrupted" and elapsed < 2.0:
print("β
PASS: Interrupt worked correctly!", file=sys.stderr)
set_interrupt(False)
return 0
print(f"β FAIL: status={result['status']}, elapsed={elapsed:.2f}s", file=sys.stderr)
set_interrupt(False)
return 1
print("β FAIL: No result returned", file=sys.stderr)
set_interrupt(False)
return 1
if __name__ == "__main__":
sys.exit(main())