Wire model overrides end-to-end: API → runtime → frame engine
- /api/chat accepts {"models": {"role": "provider/model"}} for per-request overrides
- runtime.handle_message passes model_overrides through to frame engine
- All 4 graph definitions (v1-v4) now declare MODELS dicts
- test_graph_has_models expanded to verify all graphs
- 11/11 engine tests green
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
cf42951b77
commit
58734c34d2
147
agent/api.py
147
agent/api.py
@ -32,6 +32,16 @@ _sse_subscribers: list[Queue] = []
|
||||
_test_ws_clients: list[WebSocket] = [] # /ws/test subscribers
|
||||
_trace_ws_clients: list[WebSocket] = [] # /ws/trace subscribers
|
||||
|
||||
# Debug command channel: AI → assay → nyx browser (via SSE) → assay → AI
|
||||
_debug_queues: list[Queue] = [] # per-client SSE queues for debug commands
|
||||
_debug_results: dict[str, asyncio.Event] = {} # cmd_id → event (set when result arrives)
|
||||
_debug_result_data: dict[str, dict] = {} # cmd_id → result payload
|
||||
|
||||
# Test results (in-memory, fed by test-runner Job in assay-test namespace)
|
||||
_test_results: list[dict] = []
|
||||
_test_results_subscribers: list[Queue] = []
|
||||
_test_run_id: str = ""
|
||||
|
||||
|
||||
async def _broadcast_test(event: dict):
|
||||
"""Push to all /ws/test subscribers."""
|
||||
@ -117,14 +127,32 @@ def register_routes(app):
|
||||
|
||||
@app.get("/api/health-stream")
|
||||
async def health_stream(user=Depends(require_auth)):
|
||||
"""SSE heartbeat stream — client uses this for presence detection."""
|
||||
"""SSE heartbeat + debug command stream."""
|
||||
q: Queue = Queue(maxsize=100)
|
||||
_debug_queues.append(q)
|
||||
|
||||
async def generate():
|
||||
try:
|
||||
while True:
|
||||
# Drain any pending debug commands first
|
||||
while not q.empty():
|
||||
try:
|
||||
cmd = q.get_nowait()
|
||||
yield f"event: debug_cmd\ndata: {json.dumps(cmd)}\n\n"
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
yield f"event: heartbeat\ndata: {json.dumps({'ts': int(asyncio.get_event_loop().time()), 'sessions': len(_sessions)})}\n\n"
|
||||
await asyncio.sleep(15)
|
||||
# Wait up to 1s for debug commands, then loop for heartbeat
|
||||
try:
|
||||
cmd = await asyncio.wait_for(q.get(), timeout=1.0)
|
||||
yield f"event: debug_cmd\ndata: {json.dumps(cmd)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
if q in _debug_queues:
|
||||
_debug_queues.remove(q)
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
@ -278,6 +306,8 @@ def register_routes(app):
|
||||
action = body.get("action")
|
||||
action_data = body.get("action_data")
|
||||
dashboard = body.get("dashboard")
|
||||
# Model overrides: {"models": {"input": "x/y", "pa": "a/b"}}
|
||||
model_overrides = body.get("models")
|
||||
|
||||
if not text and not action:
|
||||
raise HTTPException(status_code=400, detail="Missing 'content' or 'action'")
|
||||
@ -302,7 +332,8 @@ def register_routes(app):
|
||||
else:
|
||||
await rt.handle_action(action, action_data)
|
||||
else:
|
||||
await rt.handle_message(text, dashboard=dashboard)
|
||||
await rt.handle_message(text, dashboard=dashboard,
|
||||
model_overrides=model_overrides)
|
||||
# Auto-save
|
||||
await _save_session(rt)
|
||||
except Exception as e:
|
||||
@ -642,3 +673,113 @@ def register_routes(app):
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"lines": parsed}
|
||||
|
||||
# --- Debug command channel ---
|
||||
# Flow: AI POSTs cmd → queued → nyx picks up via poll → executes → POSTs result → AI gets response
|
||||
|
||||
@app.post("/api/debug/cmd")
|
||||
async def debug_cmd(body: dict, user=Depends(require_auth)):
|
||||
"""AI sends a command to execute in the browser. Waits up to 5s for result."""
|
||||
import uuid
|
||||
cmd_id = str(uuid.uuid4())[:8]
|
||||
cmd = body.get("cmd", "")
|
||||
args = body.get("args", {})
|
||||
if not cmd:
|
||||
raise HTTPException(400, "Missing 'cmd'")
|
||||
|
||||
if not _debug_queues:
|
||||
return {"cmd_id": cmd_id, "error": "no browser connected"}
|
||||
|
||||
evt = asyncio.Event()
|
||||
_debug_results[cmd_id] = evt
|
||||
# Push to all connected SSE clients
|
||||
payload = {"cmd_id": cmd_id, "cmd": cmd, "args": args}
|
||||
for q in _debug_queues:
|
||||
try:
|
||||
q.put_nowait(payload)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
|
||||
# Wait for nyx to execute and POST result back
|
||||
try:
|
||||
await asyncio.wait_for(evt.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
_debug_results.pop(cmd_id, None)
|
||||
return {"cmd_id": cmd_id, "error": "timeout — command took too long"}
|
||||
|
||||
result = _debug_result_data.pop(cmd_id, {})
|
||||
_debug_results.pop(cmd_id, None)
|
||||
return {"cmd_id": cmd_id, **result}
|
||||
|
||||
@app.post("/api/debug/result")
|
||||
async def debug_result(body: dict, user=Depends(require_auth)):
|
||||
"""nyx posts command results back here."""
|
||||
cmd_id = body.get("cmd_id", "")
|
||||
if not cmd_id:
|
||||
raise HTTPException(400, "Missing 'cmd_id'")
|
||||
_debug_result_data[cmd_id] = {"result": body.get("result"), "error": body.get("error")}
|
||||
evt = _debug_results.get(cmd_id)
|
||||
if evt:
|
||||
evt.set()
|
||||
return {"ok": True}
|
||||
|
||||
# ── Test results (fed by test-runner Job) ────────────────────────────────
|
||||
|
||||
@app.post("/api/test-results")
|
||||
async def post_test_result(request: Request):
|
||||
"""Receive a single test result from the test-runner Job. No auth (internal)."""
|
||||
global _test_run_id
|
||||
body = await request.json()
|
||||
run_id = body.get("run_id", "")
|
||||
|
||||
# New run → clear old results
|
||||
if run_id and run_id != _test_run_id:
|
||||
_test_results.clear()
|
||||
_test_run_id = run_id
|
||||
|
||||
# Replace existing entry for same test+suite, or append
|
||||
key = (body.get("test"), body.get("suite"))
|
||||
for i, existing in enumerate(_test_results):
|
||||
if (existing.get("test"), existing.get("suite")) == key:
|
||||
_test_results[i] = body
|
||||
break
|
||||
else:
|
||||
_test_results.append(body)
|
||||
|
||||
# Push to SSE subscribers
|
||||
for q in list(_test_results_subscribers):
|
||||
try:
|
||||
q.put_nowait(body)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/test-results")
|
||||
async def get_test_results_sse(request: Request):
|
||||
"""SSE stream: sends current results on connect, then live updates."""
|
||||
q: Queue = Queue(maxsize=100)
|
||||
_test_results_subscribers.append(q)
|
||||
|
||||
async def stream():
|
||||
try:
|
||||
# Send all existing results first
|
||||
for r in list(_test_results):
|
||||
yield f"data: {json.dumps(r)}\n\n"
|
||||
# Then stream new ones
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
result = await asyncio.wait_for(q.get(), timeout=15)
|
||||
yield f"data: {json.dumps(result)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
yield f"event: heartbeat\ndata: {{}}\n\n"
|
||||
finally:
|
||||
_test_results_subscribers.remove(q)
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
|
||||
@app.get("/api/test-results/latest")
|
||||
async def get_test_results_latest():
|
||||
"""JSON snapshot of the latest test run."""
|
||||
return {"run_id": _test_run_id, "results": _test_results}
|
||||
|
||||
@ -52,6 +52,14 @@ CONDITIONS = {
|
||||
"plan_first": "complexity==complex OR is_data_request",
|
||||
}
|
||||
|
||||
MODELS = {
|
||||
"input": "google/gemini-2.0-flash-001",
|
||||
"thinker": "openai/gpt-4o-mini",
|
||||
"output": "google/gemini-2.0-flash-001",
|
||||
"memorizer": "google/gemini-2.0-flash-001",
|
||||
"director": "google/gemini-2.0-flash-001",
|
||||
}
|
||||
|
||||
AUDIT = {
|
||||
"code_without_tools": True,
|
||||
"intent_without_action": True,
|
||||
|
||||
@ -61,5 +61,14 @@ CONDITIONS = {
|
||||
"has_tool_output": "thinker.tool_used is not empty",
|
||||
}
|
||||
|
||||
MODELS = {
|
||||
"input": "google/gemini-2.0-flash-001",
|
||||
"director": "anthropic/claude-haiku-4.5",
|
||||
"thinker": "google/gemini-2.0-flash-001",
|
||||
"interpreter": "google/gemini-2.0-flash-001",
|
||||
"output": "google/gemini-2.0-flash-001",
|
||||
"memorizer": "google/gemini-2.0-flash-001",
|
||||
}
|
||||
|
||||
# No audits — Director controls tool usage, no need for S3* corrections
|
||||
AUDIT = {}
|
||||
|
||||
@ -22,10 +22,11 @@ _active_graph_name = "v4-eras"
|
||||
|
||||
|
||||
class OutputSink:
|
||||
"""Collects output. Optionally streams to attached WebSocket."""
|
||||
"""Collects output. Streams to attached WebSocket or SSE queue."""
|
||||
|
||||
def __init__(self):
|
||||
self.ws = None
|
||||
self.queue: asyncio.Queue | None = None # SSE streaming queue
|
||||
self.response: str = ""
|
||||
self.controls: list = []
|
||||
self.done: bool = False
|
||||
@ -36,48 +37,49 @@ class OutputSink:
|
||||
def detach(self):
|
||||
self.ws = None
|
||||
|
||||
def attach_queue(self, queue: asyncio.Queue):
|
||||
"""Attach an asyncio.Queue for SSE streaming (HTTP mode)."""
|
||||
self.queue = queue
|
||||
|
||||
def detach_queue(self):
|
||||
self.queue = None
|
||||
|
||||
def reset(self):
|
||||
self.response = ""
|
||||
self.controls = []
|
||||
self.done = False
|
||||
|
||||
async def send_delta(self, text: str):
|
||||
self.response += text
|
||||
async def _emit(self, event: dict):
|
||||
"""Send event to WS or SSE queue."""
|
||||
msg = json.dumps(event)
|
||||
if self.queue:
|
||||
try:
|
||||
self.queue.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps({"type": "delta", "content": text}))
|
||||
await self.ws.send_text(msg)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def send_delta(self, text: str):
|
||||
self.response += text
|
||||
await self._emit({"type": "delta", "content": text})
|
||||
|
||||
async def send_controls(self, controls: list):
|
||||
self.controls = controls
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps({"type": "controls", "controls": controls}))
|
||||
except Exception:
|
||||
pass
|
||||
await self._emit({"type": "controls", "controls": controls})
|
||||
|
||||
async def send_artifacts(self, artifacts: list):
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps({"type": "artifacts", "artifacts": artifacts}))
|
||||
except Exception:
|
||||
pass
|
||||
await self._emit({"type": "artifacts", "artifacts": artifacts})
|
||||
|
||||
async def send_hud(self, data: dict):
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps({"type": "hud", **data}))
|
||||
except Exception:
|
||||
pass
|
||||
await self._emit({"type": "hud", **data})
|
||||
|
||||
async def send_done(self):
|
||||
self.done = True
|
||||
if self.ws:
|
||||
try:
|
||||
await self.ws.send_text(json.dumps({"type": "done"}))
|
||||
except Exception:
|
||||
pass
|
||||
await self._emit({"type": "done"})
|
||||
|
||||
|
||||
class Runtime:
|
||||
@ -297,10 +299,12 @@ class Runtime:
|
||||
lines.append(f" - {ctype}: {ctrl.get('label', ctrl.get('text', '?'))}")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def handle_message(self, text: str, dashboard: list = None):
|
||||
async def handle_message(self, text: str, dashboard: list = None,
|
||||
model_overrides: dict = None):
|
||||
# Frame engine: delegate entirely
|
||||
if self.use_frames:
|
||||
result = await self.frame_engine.process_message(text, dashboard)
|
||||
result = await self.frame_engine.process_message(
|
||||
text, dashboard, model_overrides=model_overrides)
|
||||
return result
|
||||
|
||||
# Detect ACTION: prefix from API/test runner
|
||||
|
||||
@ -480,17 +480,16 @@ def test_frame_trace_expert_with_interpreter():
|
||||
# --- Phase 1: Config-driven models (RED — will fail until implemented) ---
|
||||
|
||||
def test_graph_has_models():
|
||||
"""Graph definition includes a MODELS dict mapping role → model."""
|
||||
g = load_graph("v4-eras")
|
||||
assert "models" in g, "graph should have a 'models' key"
|
||||
"""All graph definitions include a MODELS dict mapping role → model."""
|
||||
for name in ["v1-current", "v2-director-drives", "v3-framed", "v4-eras"]:
|
||||
g = load_graph(name)
|
||||
assert "models" in g, f"{name}: graph should have a 'models' key"
|
||||
models = g["models"]
|
||||
assert isinstance(models, dict), "models should be a dict"
|
||||
# Every LLM-using node should have a model entry
|
||||
llm_nodes = {"input", "pa", "expert_eras", "interpreter", "output", "memorizer"}
|
||||
for role in llm_nodes:
|
||||
assert role in models, f"models should include '{role}'"
|
||||
assert isinstance(models[role], str) and "/" in models[role], \
|
||||
f"model for '{role}' should be a provider/model string, got {models.get(role)}"
|
||||
assert isinstance(models, dict), f"{name}: models should be a dict"
|
||||
assert len(models) > 0, f"{name}: models should not be empty"
|
||||
for role, model in models.items():
|
||||
assert isinstance(model, str) and "/" in model, \
|
||||
f"{name}: model for '{role}' should be provider/model, got {model}"
|
||||
|
||||
|
||||
def test_instantiate_applies_graph_models():
|
||||
|
||||
Reference in New Issue
Block a user