Wire model overrides end-to-end: API → runtime → frame engine
- /api/chat accepts {"models": {"role": "provider/model"}} for per-request overrides
- runtime.handle_message passes model_overrides through to frame engine
- All 4 graph definitions (v1-v4) now declare MODELS dicts
- test_graph_has_models expanded to verify all graphs
- 11/11 engine tests green
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
cf42951b77
commit
58734c34d2
147
agent/api.py
147
agent/api.py
@ -32,6 +32,16 @@ _sse_subscribers: list[Queue] = []
|
|||||||
_test_ws_clients: list[WebSocket] = [] # /ws/test subscribers
|
_test_ws_clients: list[WebSocket] = [] # /ws/test subscribers
|
||||||
_trace_ws_clients: list[WebSocket] = [] # /ws/trace subscribers
|
_trace_ws_clients: list[WebSocket] = [] # /ws/trace subscribers
|
||||||
|
|
||||||
|
# Debug command channel: AI → assay → nyx browser (via SSE) → assay → AI
|
||||||
|
_debug_queues: list[Queue] = [] # per-client SSE queues for debug commands
|
||||||
|
_debug_results: dict[str, asyncio.Event] = {} # cmd_id → event (set when result arrives)
|
||||||
|
_debug_result_data: dict[str, dict] = {} # cmd_id → result payload
|
||||||
|
|
||||||
|
# Test results (in-memory, fed by test-runner Job in assay-test namespace)
|
||||||
|
_test_results: list[dict] = []
|
||||||
|
_test_results_subscribers: list[Queue] = []
|
||||||
|
_test_run_id: str = ""
|
||||||
|
|
||||||
|
|
||||||
async def _broadcast_test(event: dict):
|
async def _broadcast_test(event: dict):
|
||||||
"""Push to all /ws/test subscribers."""
|
"""Push to all /ws/test subscribers."""
|
||||||
@ -117,14 +127,32 @@ def register_routes(app):
|
|||||||
|
|
||||||
@app.get("/api/health-stream")
|
@app.get("/api/health-stream")
|
||||||
async def health_stream(user=Depends(require_auth)):
|
async def health_stream(user=Depends(require_auth)):
|
||||||
"""SSE heartbeat stream — client uses this for presence detection."""
|
"""SSE heartbeat + debug command stream."""
|
||||||
|
q: Queue = Queue(maxsize=100)
|
||||||
|
_debug_queues.append(q)
|
||||||
|
|
||||||
async def generate():
|
async def generate():
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
# Drain any pending debug commands first
|
||||||
|
while not q.empty():
|
||||||
|
try:
|
||||||
|
cmd = q.get_nowait()
|
||||||
|
yield f"event: debug_cmd\ndata: {json.dumps(cmd)}\n\n"
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
yield f"event: heartbeat\ndata: {json.dumps({'ts': int(asyncio.get_event_loop().time()), 'sessions': len(_sessions)})}\n\n"
|
yield f"event: heartbeat\ndata: {json.dumps({'ts': int(asyncio.get_event_loop().time()), 'sessions': len(_sessions)})}\n\n"
|
||||||
await asyncio.sleep(15)
|
# Wait up to 1s for debug commands, then loop for heartbeat
|
||||||
|
try:
|
||||||
|
cmd = await asyncio.wait_for(q.get(), timeout=1.0)
|
||||||
|
yield f"event: debug_cmd\ndata: {json.dumps(cmd)}\n\n"
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
finally:
|
||||||
|
if q in _debug_queues:
|
||||||
|
_debug_queues.remove(q)
|
||||||
|
|
||||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||||
@ -278,6 +306,8 @@ def register_routes(app):
|
|||||||
action = body.get("action")
|
action = body.get("action")
|
||||||
action_data = body.get("action_data")
|
action_data = body.get("action_data")
|
||||||
dashboard = body.get("dashboard")
|
dashboard = body.get("dashboard")
|
||||||
|
# Model overrides: {"models": {"input": "x/y", "pa": "a/b"}}
|
||||||
|
model_overrides = body.get("models")
|
||||||
|
|
||||||
if not text and not action:
|
if not text and not action:
|
||||||
raise HTTPException(status_code=400, detail="Missing 'content' or 'action'")
|
raise HTTPException(status_code=400, detail="Missing 'content' or 'action'")
|
||||||
@ -302,7 +332,8 @@ def register_routes(app):
|
|||||||
else:
|
else:
|
||||||
await rt.handle_action(action, action_data)
|
await rt.handle_action(action, action_data)
|
||||||
else:
|
else:
|
||||||
await rt.handle_message(text, dashboard=dashboard)
|
await rt.handle_message(text, dashboard=dashboard,
|
||||||
|
model_overrides=model_overrides)
|
||||||
# Auto-save
|
# Auto-save
|
||||||
await _save_session(rt)
|
await _save_session(rt)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -642,3 +673,113 @@ def register_routes(app):
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
return {"lines": parsed}
|
return {"lines": parsed}
|
||||||
|
|
||||||
|
# --- Debug command channel ---
|
||||||
|
# Flow: AI POSTs cmd → queued → nyx picks up via poll → executes → POSTs result → AI gets response
|
||||||
|
|
||||||
|
@app.post("/api/debug/cmd")
|
||||||
|
async def debug_cmd(body: dict, user=Depends(require_auth)):
|
||||||
|
"""AI sends a command to execute in the browser. Waits up to 5s for result."""
|
||||||
|
import uuid
|
||||||
|
cmd_id = str(uuid.uuid4())[:8]
|
||||||
|
cmd = body.get("cmd", "")
|
||||||
|
args = body.get("args", {})
|
||||||
|
if not cmd:
|
||||||
|
raise HTTPException(400, "Missing 'cmd'")
|
||||||
|
|
||||||
|
if not _debug_queues:
|
||||||
|
return {"cmd_id": cmd_id, "error": "no browser connected"}
|
||||||
|
|
||||||
|
evt = asyncio.Event()
|
||||||
|
_debug_results[cmd_id] = evt
|
||||||
|
# Push to all connected SSE clients
|
||||||
|
payload = {"cmd_id": cmd_id, "cmd": cmd, "args": args}
|
||||||
|
for q in _debug_queues:
|
||||||
|
try:
|
||||||
|
q.put_nowait(payload)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wait for nyx to execute and POST result back
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(evt.wait(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
_debug_results.pop(cmd_id, None)
|
||||||
|
return {"cmd_id": cmd_id, "error": "timeout — command took too long"}
|
||||||
|
|
||||||
|
result = _debug_result_data.pop(cmd_id, {})
|
||||||
|
_debug_results.pop(cmd_id, None)
|
||||||
|
return {"cmd_id": cmd_id, **result}
|
||||||
|
|
||||||
|
@app.post("/api/debug/result")
|
||||||
|
async def debug_result(body: dict, user=Depends(require_auth)):
|
||||||
|
"""nyx posts command results back here."""
|
||||||
|
cmd_id = body.get("cmd_id", "")
|
||||||
|
if not cmd_id:
|
||||||
|
raise HTTPException(400, "Missing 'cmd_id'")
|
||||||
|
_debug_result_data[cmd_id] = {"result": body.get("result"), "error": body.get("error")}
|
||||||
|
evt = _debug_results.get(cmd_id)
|
||||||
|
if evt:
|
||||||
|
evt.set()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
# ── Test results (fed by test-runner Job) ────────────────────────────────
|
||||||
|
|
||||||
|
@app.post("/api/test-results")
|
||||||
|
async def post_test_result(request: Request):
|
||||||
|
"""Receive a single test result from the test-runner Job. No auth (internal)."""
|
||||||
|
global _test_run_id
|
||||||
|
body = await request.json()
|
||||||
|
run_id = body.get("run_id", "")
|
||||||
|
|
||||||
|
# New run → clear old results
|
||||||
|
if run_id and run_id != _test_run_id:
|
||||||
|
_test_results.clear()
|
||||||
|
_test_run_id = run_id
|
||||||
|
|
||||||
|
# Replace existing entry for same test+suite, or append
|
||||||
|
key = (body.get("test"), body.get("suite"))
|
||||||
|
for i, existing in enumerate(_test_results):
|
||||||
|
if (existing.get("test"), existing.get("suite")) == key:
|
||||||
|
_test_results[i] = body
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
_test_results.append(body)
|
||||||
|
|
||||||
|
# Push to SSE subscribers
|
||||||
|
for q in list(_test_results_subscribers):
|
||||||
|
try:
|
||||||
|
q.put_nowait(body)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
pass
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.get("/api/test-results")
|
||||||
|
async def get_test_results_sse(request: Request):
|
||||||
|
"""SSE stream: sends current results on connect, then live updates."""
|
||||||
|
q: Queue = Queue(maxsize=100)
|
||||||
|
_test_results_subscribers.append(q)
|
||||||
|
|
||||||
|
async def stream():
|
||||||
|
try:
|
||||||
|
# Send all existing results first
|
||||||
|
for r in list(_test_results):
|
||||||
|
yield f"data: {json.dumps(r)}\n\n"
|
||||||
|
# Then stream new ones
|
||||||
|
while True:
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(q.get(), timeout=15)
|
||||||
|
yield f"data: {json.dumps(result)}\n\n"
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
yield f"event: heartbeat\ndata: {{}}\n\n"
|
||||||
|
finally:
|
||||||
|
_test_results_subscribers.remove(q)
|
||||||
|
|
||||||
|
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
@app.get("/api/test-results/latest")
|
||||||
|
async def get_test_results_latest():
|
||||||
|
"""JSON snapshot of the latest test run."""
|
||||||
|
return {"run_id": _test_run_id, "results": _test_results}
|
||||||
|
|||||||
@ -52,6 +52,14 @@ CONDITIONS = {
|
|||||||
"plan_first": "complexity==complex OR is_data_request",
|
"plan_first": "complexity==complex OR is_data_request",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"input": "google/gemini-2.0-flash-001",
|
||||||
|
"thinker": "openai/gpt-4o-mini",
|
||||||
|
"output": "google/gemini-2.0-flash-001",
|
||||||
|
"memorizer": "google/gemini-2.0-flash-001",
|
||||||
|
"director": "google/gemini-2.0-flash-001",
|
||||||
|
}
|
||||||
|
|
||||||
AUDIT = {
|
AUDIT = {
|
||||||
"code_without_tools": True,
|
"code_without_tools": True,
|
||||||
"intent_without_action": True,
|
"intent_without_action": True,
|
||||||
|
|||||||
@ -61,5 +61,14 @@ CONDITIONS = {
|
|||||||
"has_tool_output": "thinker.tool_used is not empty",
|
"has_tool_output": "thinker.tool_used is not empty",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MODELS = {
|
||||||
|
"input": "google/gemini-2.0-flash-001",
|
||||||
|
"director": "anthropic/claude-haiku-4.5",
|
||||||
|
"thinker": "google/gemini-2.0-flash-001",
|
||||||
|
"interpreter": "google/gemini-2.0-flash-001",
|
||||||
|
"output": "google/gemini-2.0-flash-001",
|
||||||
|
"memorizer": "google/gemini-2.0-flash-001",
|
||||||
|
}
|
||||||
|
|
||||||
# No audits — Director controls tool usage, no need for S3* corrections
|
# No audits — Director controls tool usage, no need for S3* corrections
|
||||||
AUDIT = {}
|
AUDIT = {}
|
||||||
|
|||||||
@ -22,10 +22,11 @@ _active_graph_name = "v4-eras"
|
|||||||
|
|
||||||
|
|
||||||
class OutputSink:
|
class OutputSink:
|
||||||
"""Collects output. Optionally streams to attached WebSocket."""
|
"""Collects output. Streams to attached WebSocket or SSE queue."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.ws = None
|
self.ws = None
|
||||||
|
self.queue: asyncio.Queue | None = None # SSE streaming queue
|
||||||
self.response: str = ""
|
self.response: str = ""
|
||||||
self.controls: list = []
|
self.controls: list = []
|
||||||
self.done: bool = False
|
self.done: bool = False
|
||||||
@ -36,48 +37,49 @@ class OutputSink:
|
|||||||
def detach(self):
|
def detach(self):
|
||||||
self.ws = None
|
self.ws = None
|
||||||
|
|
||||||
|
def attach_queue(self, queue: asyncio.Queue):
|
||||||
|
"""Attach an asyncio.Queue for SSE streaming (HTTP mode)."""
|
||||||
|
self.queue = queue
|
||||||
|
|
||||||
|
def detach_queue(self):
|
||||||
|
self.queue = None
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.response = ""
|
self.response = ""
|
||||||
self.controls = []
|
self.controls = []
|
||||||
self.done = False
|
self.done = False
|
||||||
|
|
||||||
async def send_delta(self, text: str):
|
async def _emit(self, event: dict):
|
||||||
self.response += text
|
"""Send event to WS or SSE queue."""
|
||||||
|
msg = json.dumps(event)
|
||||||
|
if self.queue:
|
||||||
|
try:
|
||||||
|
self.queue.put_nowait(event)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
pass
|
||||||
if self.ws:
|
if self.ws:
|
||||||
try:
|
try:
|
||||||
await self.ws.send_text(json.dumps({"type": "delta", "content": text}))
|
await self.ws.send_text(msg)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def send_delta(self, text: str):
|
||||||
|
self.response += text
|
||||||
|
await self._emit({"type": "delta", "content": text})
|
||||||
|
|
||||||
async def send_controls(self, controls: list):
|
async def send_controls(self, controls: list):
|
||||||
self.controls = controls
|
self.controls = controls
|
||||||
if self.ws:
|
await self._emit({"type": "controls", "controls": controls})
|
||||||
try:
|
|
||||||
await self.ws.send_text(json.dumps({"type": "controls", "controls": controls}))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def send_artifacts(self, artifacts: list):
|
async def send_artifacts(self, artifacts: list):
|
||||||
if self.ws:
|
await self._emit({"type": "artifacts", "artifacts": artifacts})
|
||||||
try:
|
|
||||||
await self.ws.send_text(json.dumps({"type": "artifacts", "artifacts": artifacts}))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def send_hud(self, data: dict):
|
async def send_hud(self, data: dict):
|
||||||
if self.ws:
|
await self._emit({"type": "hud", **data})
|
||||||
try:
|
|
||||||
await self.ws.send_text(json.dumps({"type": "hud", **data}))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def send_done(self):
|
async def send_done(self):
|
||||||
self.done = True
|
self.done = True
|
||||||
if self.ws:
|
await self._emit({"type": "done"})
|
||||||
try:
|
|
||||||
await self.ws.send_text(json.dumps({"type": "done"}))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
@ -297,10 +299,12 @@ class Runtime:
|
|||||||
lines.append(f" - {ctype}: {ctrl.get('label', ctrl.get('text', '?'))}")
|
lines.append(f" - {ctype}: {ctrl.get('label', ctrl.get('text', '?'))}")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
async def handle_message(self, text: str, dashboard: list = None):
|
async def handle_message(self, text: str, dashboard: list = None,
|
||||||
|
model_overrides: dict = None):
|
||||||
# Frame engine: delegate entirely
|
# Frame engine: delegate entirely
|
||||||
if self.use_frames:
|
if self.use_frames:
|
||||||
result = await self.frame_engine.process_message(text, dashboard)
|
result = await self.frame_engine.process_message(
|
||||||
|
text, dashboard, model_overrides=model_overrides)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Detect ACTION: prefix from API/test runner
|
# Detect ACTION: prefix from API/test runner
|
||||||
|
|||||||
@ -480,17 +480,16 @@ def test_frame_trace_expert_with_interpreter():
|
|||||||
# --- Phase 1: Config-driven models (RED — will fail until implemented) ---
|
# --- Phase 1: Config-driven models (RED — will fail until implemented) ---
|
||||||
|
|
||||||
def test_graph_has_models():
|
def test_graph_has_models():
|
||||||
"""Graph definition includes a MODELS dict mapping role → model."""
|
"""All graph definitions include a MODELS dict mapping role → model."""
|
||||||
g = load_graph("v4-eras")
|
for name in ["v1-current", "v2-director-drives", "v3-framed", "v4-eras"]:
|
||||||
assert "models" in g, "graph should have a 'models' key"
|
g = load_graph(name)
|
||||||
models = g["models"]
|
assert "models" in g, f"{name}: graph should have a 'models' key"
|
||||||
assert isinstance(models, dict), "models should be a dict"
|
models = g["models"]
|
||||||
# Every LLM-using node should have a model entry
|
assert isinstance(models, dict), f"{name}: models should be a dict"
|
||||||
llm_nodes = {"input", "pa", "expert_eras", "interpreter", "output", "memorizer"}
|
assert len(models) > 0, f"{name}: models should not be empty"
|
||||||
for role in llm_nodes:
|
for role, model in models.items():
|
||||||
assert role in models, f"models should include '{role}'"
|
assert isinstance(model, str) and "/" in model, \
|
||||||
assert isinstance(models[role], str) and "/" in models[role], \
|
f"{name}: model for '{role}' should be provider/model, got {model}"
|
||||||
f"model for '{role}' should be a provider/model string, got {models.get(role)}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_instantiate_applies_graph_models():
|
def test_instantiate_applies_graph_models():
|
||||||
|
|||||||
Reference in New Issue
Block a user