- /api/chat accepts {"models": {"role": "provider/model"}} for per-request overrides
- runtime.handle_message passes model_overrides through to frame engine
- All 4 graph definitions (v1-v4) now declare MODELS dicts
- test_graph_has_models expanded to verify all graphs
- 11/11 engine tests green
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
786 lines
30 KiB
Python
786 lines
30 KiB
Python
"""API endpoints, SSE, polling."""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from asyncio import Queue
|
|
from pathlib import Path
|
|
|
|
from fastapi import Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect
|
|
from starlette.responses import StreamingResponse
|
|
|
|
import httpx
|
|
|
|
from .auth import AUTH_ENABLED, ZITADEL_ISSUER, _validate_token, require_auth
|
|
from .runtime import Runtime, TRACE_FILE
|
|
from . import db_sessions
|
|
|
|
log = logging.getLogger("runtime")
|
|
|
|
# Session map: session_id -> Runtime (in-memory active sessions)
|
|
_sessions: dict[str, Runtime] = {}
|
|
MAX_ACTIVE_SESSIONS = 50
|
|
|
|
# Legacy: for backward compat with single-session MCP/test endpoints
|
|
_active_runtime: Runtime | None = None
|
|
|
|
# SSE subscribers
|
|
_sse_subscribers: list[Queue] = []
|
|
|
|
# Dedicated WS channels (debug sockets)
|
|
_test_ws_clients: list[WebSocket] = [] # /ws/test subscribers
|
|
_trace_ws_clients: list[WebSocket] = [] # /ws/trace subscribers
|
|
|
|
# Debug command channel: AI → assay → nyx browser (via SSE) → assay → AI
|
|
_debug_queues: list[Queue] = [] # per-client SSE queues for debug commands
|
|
_debug_results: dict[str, asyncio.Event] = {} # cmd_id → event (set when result arrives)
|
|
_debug_result_data: dict[str, dict] = {} # cmd_id → result payload
|
|
|
|
# Test results (in-memory, fed by test-runner Job in assay-test namespace)
|
|
_test_results: list[dict] = []
|
|
_test_results_subscribers: list[Queue] = []
|
|
_test_run_id: str = ""
|
|
|
|
|
|
async def _broadcast_test(event: dict):
|
|
"""Push to all /ws/test subscribers."""
|
|
msg = json.dumps(event)
|
|
dead = []
|
|
log.info(f"[ws/test] broadcasting to {len(_test_ws_clients)} clients")
|
|
for ws in _test_ws_clients:
|
|
try:
|
|
await ws.send_text(msg)
|
|
except Exception as e:
|
|
log.error(f"[ws/test] send failed: {e}")
|
|
dead.append(ws)
|
|
for ws in dead:
|
|
_test_ws_clients.remove(ws)
|
|
|
|
|
|
async def _broadcast_trace(event: dict):
|
|
"""Push to all /ws/trace subscribers."""
|
|
msg = json.dumps(event)
|
|
dead = []
|
|
for ws in _trace_ws_clients:
|
|
try:
|
|
await ws.send_text(msg)
|
|
except Exception:
|
|
dead.append(ws)
|
|
for ws in dead:
|
|
_trace_ws_clients.remove(ws)
|
|
|
|
# Async message pipeline state
|
|
_pipeline_task: asyncio.Task | None = None
|
|
_pipeline_result: dict = {"status": "idle"}
|
|
_pipeline_id: int = 0
|
|
|
|
|
|
def _broadcast_sse(event: dict):
|
|
"""Push an event to all SSE subscribers + /ws/trace + update pipeline progress."""
|
|
for q in _sse_subscribers:
|
|
try:
|
|
q.put_nowait(event)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
# Push to /ws/trace subscribers (fire-and-forget)
|
|
if _trace_ws_clients:
|
|
try:
|
|
asyncio.get_event_loop().create_task(_broadcast_trace(event))
|
|
except RuntimeError:
|
|
pass # no event loop (startup)
|
|
# Update pipeline progress from HUD events
|
|
if _pipeline_result.get("status") == "running":
|
|
node = event.get("node", "")
|
|
evt = event.get("event", "")
|
|
if node and evt in ("thinking", "perceived", "decided", "streaming", "tool_call", "interpreted", "updated"):
|
|
_pipeline_result["stage"] = node
|
|
_pipeline_result["event"] = evt
|
|
|
|
|
|
def _state_hash(rt: Runtime = None) -> str:
|
|
r = rt or _active_runtime
|
|
if not r:
|
|
return "no_session"
|
|
raw = json.dumps({
|
|
"mem": r.memorizer.state,
|
|
"hlen": len(r.history),
|
|
}, sort_keys=True)
|
|
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
|
|
|
|
|
def register_routes(app):
|
|
"""Register all API routes on the FastAPI app."""
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
await db_sessions.init_pool()
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown():
|
|
await db_sessions.close_pool()
|
|
|
|
@app.get("/api/health")
|
|
@app.get("/health") # K8s probes
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/api/health-stream")
|
|
async def health_stream(user=Depends(require_auth)):
|
|
"""SSE heartbeat + debug command stream."""
|
|
q: Queue = Queue(maxsize=100)
|
|
_debug_queues.append(q)
|
|
|
|
async def generate():
|
|
try:
|
|
while True:
|
|
# Drain any pending debug commands first
|
|
while not q.empty():
|
|
try:
|
|
cmd = q.get_nowait()
|
|
yield f"event: debug_cmd\ndata: {json.dumps(cmd)}\n\n"
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
yield f"event: heartbeat\ndata: {json.dumps({'ts': int(asyncio.get_event_loop().time()), 'sessions': len(_sessions)})}\n\n"
|
|
# Wait up to 1s for debug commands, then loop for heartbeat
|
|
try:
|
|
cmd = await asyncio.wait_for(q.get(), timeout=1.0)
|
|
yield f"event: debug_cmd\ndata: {json.dumps(cmd)}\n\n"
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
if q in _debug_queues:
|
|
_debug_queues.remove(q)
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
|
|
# --- Session CRUD ---
|
|
|
|
@app.post("/api/sessions")
|
|
async def create_session(body: dict = None, user=Depends(require_auth)):
|
|
"""Create a new session."""
|
|
user_id = user.get("sub", "anonymous")
|
|
graph = (body or {}).get("graph", "v4-eras")
|
|
session_id = await db_sessions.create_session(user_id, graph)
|
|
return {"session_id": session_id}
|
|
|
|
@app.get("/api/sessions")
|
|
async def list_sessions(user=Depends(require_auth)):
|
|
"""List sessions for current user."""
|
|
user_id = user.get("sub", "anonymous")
|
|
sessions = await db_sessions.list_sessions(user_id)
|
|
return {"sessions": sessions}
|
|
|
|
@app.delete("/api/sessions/{session_id}")
|
|
async def delete_session(session_id: str, user=Depends(require_auth)):
|
|
"""Delete a session."""
|
|
if session_id in _sessions:
|
|
rt = _sessions.pop(session_id)
|
|
rt.sensor.stop()
|
|
await db_sessions.delete_session(session_id)
|
|
return {"status": "deleted"}
|
|
|
|
@app.get("/auth/config")
|
|
async def auth_config():
|
|
from .auth import ZITADEL_ISSUER, ZITADEL_CLIENT_ID, ZITADEL_PROJECT_ID, AUTH_ENABLED
|
|
return {
|
|
"enabled": AUTH_ENABLED,
|
|
"issuer": ZITADEL_ISSUER,
|
|
"clientId": ZITADEL_CLIENT_ID,
|
|
"projectId": ZITADEL_PROJECT_ID,
|
|
}
|
|
|
|
async def _get_or_create_session(session_id: str = None, user_claims=None, origin="") -> Runtime:
|
|
"""Get existing session or create new one."""
|
|
global _active_runtime
|
|
|
|
# Reuse in-memory session
|
|
if session_id and session_id in _sessions:
|
|
rt = _sessions[session_id]
|
|
_active_runtime = rt
|
|
return rt
|
|
|
|
# Try loading from DB
|
|
if session_id:
|
|
saved = await db_sessions.load_session(session_id)
|
|
if saved:
|
|
rt = Runtime(user_claims=user_claims, origin=origin,
|
|
broadcast=_broadcast_sse,
|
|
graph_name=saved["graph_name"],
|
|
session_id=session_id)
|
|
rt.restore_state(saved)
|
|
_sessions[session_id] = rt
|
|
_active_runtime = rt
|
|
log.info(f"[api] restored session {session_id} from DB")
|
|
return rt
|
|
|
|
# Create new session
|
|
user_id = (user_claims or {}).get("sub", "anonymous")
|
|
new_id = await db_sessions.create_session(user_id)
|
|
rt = Runtime(user_claims=user_claims, origin=origin,
|
|
broadcast=_broadcast_sse, session_id=new_id)
|
|
_sessions[new_id] = rt
|
|
_active_runtime = rt
|
|
log.info(f"[api] created new session {new_id}")
|
|
return rt
|
|
|
|
async def _save_session(rt: Runtime):
|
|
"""Persist session state to DB (upsert)."""
|
|
state = rt.to_state()
|
|
await db_sessions.save_session(
|
|
rt.session_id, state["history"],
|
|
state["memorizer_state"], state["ui_state"],
|
|
user_id=rt.identity, graph_name=rt.graph.get("name", "v4-eras"))
|
|
|
|
def _ensure_runtime(user_claims=None, origin=""):
|
|
"""Legacy: get or create singleton runtime (backward compat for MCP/tests)."""
|
|
global _active_runtime
|
|
if _active_runtime is None:
|
|
_active_runtime = Runtime(user_claims=user_claims, origin=origin,
|
|
broadcast=_broadcast_sse)
|
|
_sessions[_active_runtime.session_id] = _active_runtime
|
|
log.info("[api] created persistent runtime (legacy)")
|
|
return _active_runtime
|
|
|
|
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
|
"""Validate token for debug WS. Returns True if auth OK."""
|
|
if not AUTH_ENABLED:
|
|
return True
|
|
if not token:
|
|
await ws.close(code=4001, reason="Missing token")
|
|
return False
|
|
try:
|
|
await _validate_token(token)
|
|
return True
|
|
except HTTPException:
|
|
await ws.close(code=4001, reason="Invalid token")
|
|
return False
|
|
|
|
@app.websocket("/ws/test")
|
|
async def ws_test(ws: WebSocket, token: str | None = Query(None)):
|
|
"""Dedicated WS for test runner progress. Debug only, auth required."""
|
|
await ws.accept()
|
|
if not await _auth_debug_ws(ws, token):
|
|
return
|
|
_test_ws_clients.append(ws)
|
|
log.info(f"[api] /ws/test connected ({len(_test_ws_clients)} clients)")
|
|
try:
|
|
while True:
|
|
await ws.receive_text()
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
if ws in _test_ws_clients:
|
|
_test_ws_clients.remove(ws)
|
|
log.info(f"[api] /ws/test disconnected ({len(_test_ws_clients)} clients)")
|
|
|
|
@app.websocket("/ws/trace")
|
|
async def ws_trace(ws: WebSocket, token: str | None = Query(None)):
|
|
"""Dedicated WS for HUD/frame trace events. Debug only, auth required."""
|
|
await ws.accept()
|
|
if not await _auth_debug_ws(ws, token):
|
|
return
|
|
_trace_ws_clients.append(ws)
|
|
log.info(f"[api] /ws/trace connected ({len(_trace_ws_clients)} clients)")
|
|
try:
|
|
while True:
|
|
await ws.receive_text()
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
if ws in _trace_ws_clients:
|
|
_trace_ws_clients.remove(ws)
|
|
log.info(f"[api] /ws/trace disconnected ({len(_trace_ws_clients)} clients)")
|
|
|
|
# --- Streamable HTTP chat endpoint ---
|
|
|
|
@app.post("/api/chat")
|
|
async def api_chat(request: Request, user=Depends(require_auth)):
|
|
"""Send a message and receive streaming SSE response."""
|
|
body = await request.json()
|
|
session_id = body.get("session_id")
|
|
text = body.get("content", "").strip()
|
|
action = body.get("action")
|
|
action_data = body.get("action_data")
|
|
dashboard = body.get("dashboard")
|
|
# Model overrides: {"models": {"input": "x/y", "pa": "a/b"}}
|
|
model_overrides = body.get("models")
|
|
|
|
if not text and not action:
|
|
raise HTTPException(status_code=400, detail="Missing 'content' or 'action'")
|
|
|
|
origin = request.headers.get("origin", request.headers.get("host", ""))
|
|
rt = await _get_or_create_session(
|
|
session_id=session_id, user_claims=user, origin=origin)
|
|
rt.update_identity(user, origin)
|
|
|
|
# Attach SSE queue to sink for this request
|
|
q: Queue = Queue(maxsize=500)
|
|
rt.sink.attach_queue(q)
|
|
|
|
async def run_and_close():
|
|
try:
|
|
if action:
|
|
if hasattr(rt, 'use_frames') and rt.use_frames:
|
|
action_text = f"ACTION:{action}"
|
|
if action_data:
|
|
action_text += f"|data:{json.dumps(action_data)}"
|
|
await rt.handle_message(action_text)
|
|
else:
|
|
await rt.handle_action(action, action_data)
|
|
else:
|
|
await rt.handle_message(text, dashboard=dashboard,
|
|
model_overrides=model_overrides)
|
|
# Auto-save
|
|
await _save_session(rt)
|
|
except Exception as e:
|
|
import traceback
|
|
log.error(f"[chat] handler error: {e}\n{traceback.format_exc()}")
|
|
try:
|
|
q.put_nowait({"type": "error", "detail": str(e)[:200]})
|
|
except asyncio.QueueFull:
|
|
pass
|
|
finally:
|
|
# Signal end-of-stream
|
|
q.put_nowait(None)
|
|
|
|
# Run pipeline in background task
|
|
task = asyncio.create_task(run_and_close())
|
|
|
|
async def generate():
|
|
try:
|
|
while True:
|
|
event = await q.get()
|
|
if event is None:
|
|
break
|
|
event_type = event.get("type", "message")
|
|
yield f"event: {event_type}\ndata: {json.dumps(event)}\n\n"
|
|
except asyncio.CancelledError:
|
|
task.cancel()
|
|
finally:
|
|
rt.sink.detach_queue()
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
|
|
@app.get("/api/session")
|
|
async def api_session(token: str = Query(None), session: str = Query(None),
|
|
user=Depends(require_auth)):
|
|
"""Get or create session — replaces WS connect handshake."""
|
|
rt = await _get_or_create_session(
|
|
session_id=session, user_claims=user, origin="")
|
|
return {
|
|
"session_id": rt.session_id,
|
|
"graph": rt.graph.get("name", "unknown"),
|
|
"history_len": len(rt.history),
|
|
"status": "ready",
|
|
"memorizer": rt.memorizer.state,
|
|
}
|
|
|
|
@app.post("/api/stop")
|
|
async def api_stop(user=Depends(require_auth)):
|
|
"""Cancel running pipeline."""
|
|
if _pipeline_task and not _pipeline_task.done():
|
|
_pipeline_task.cancel()
|
|
return {"stopped": True}
|
|
return {"stopped": False, "detail": "No pipeline running"}
|
|
|
|
@app.get("/api/events")
|
|
async def sse_events(user=Depends(require_auth)):
|
|
q: Queue = Queue(maxsize=100)
|
|
_sse_subscribers.append(q)
|
|
|
|
async def generate():
|
|
try:
|
|
while True:
|
|
event = await q.get()
|
|
yield f"data: {json.dumps(event)}\n\n"
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
_sse_subscribers.remove(q)
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
|
|
|
@app.get("/api/poll")
|
|
async def poll(since: str = "", user=Depends(require_auth)):
|
|
h = _state_hash()
|
|
if since and since == h:
|
|
return {"changed": False, "hash": h}
|
|
return {
|
|
"changed": True,
|
|
"hash": h,
|
|
"state": _active_runtime.memorizer.state if _active_runtime else {},
|
|
"history_len": len(_active_runtime.history) if _active_runtime else 0,
|
|
"last_messages": _active_runtime.history[-3:] if _active_runtime else [],
|
|
}
|
|
|
|
@app.post("/api/send/check")
|
|
async def api_send_check(user=Depends(require_auth)):
|
|
"""Validate runtime is ready to accept a message. Fast, no LLM calls."""
|
|
global _pipeline_task
|
|
runtime = _ensure_runtime()
|
|
if _pipeline_task and not _pipeline_task.done():
|
|
return {"ready": False, "reason": "busy", "detail": "Pipeline already running"}
|
|
return {
|
|
"ready": True,
|
|
"graph": runtime.graph.get("name", "unknown"),
|
|
"identity": runtime.identity,
|
|
"history_len": len(runtime.history),
|
|
"ws_connected": runtime.sink.ws is not None,
|
|
}
|
|
|
|
@app.post("/api/send")
|
|
async def api_send(body: dict, user=Depends(require_auth)):
|
|
"""Queue a message for async processing. Returns immediately with a message ID."""
|
|
global _pipeline_task, _pipeline_result, _pipeline_id
|
|
runtime = _ensure_runtime()
|
|
if _pipeline_task and not _pipeline_task.done():
|
|
raise HTTPException(status_code=409, detail="Pipeline already running")
|
|
text = body.get("text", "").strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="Missing 'text' field")
|
|
|
|
_pipeline_id += 1
|
|
msg_id = f"msg_{_pipeline_id}"
|
|
dashboard = body.get("dashboard")
|
|
|
|
_pipeline_result = {"status": "running", "id": msg_id, "stage": "queued", "text": text}
|
|
|
|
async def _run_pipeline():
|
|
global _pipeline_result
|
|
try:
|
|
_pipeline_result["stage"] = "input"
|
|
result = await runtime.handle_message(text, dashboard=dashboard)
|
|
# Frame engine returns a dict with response; imperative pipeline uses history
|
|
if isinstance(result, dict) and "response" in result:
|
|
response = result["response"]
|
|
log.info(f"[api] frame engine response[{len(response)}]: {response[:80]}")
|
|
else:
|
|
response = runtime.history[-1]["content"] if runtime.history else ""
|
|
log.info(f"[api] history response[{len(response)}]: {response[:80]}")
|
|
_pipeline_result = {
|
|
"status": "done",
|
|
"id": msg_id,
|
|
"stage": "done",
|
|
"response": response,
|
|
"memorizer": runtime.memorizer.state,
|
|
}
|
|
# Persist session after message
|
|
await _save_session(runtime)
|
|
except Exception as e:
|
|
import traceback
|
|
log.error(f"[api] pipeline error: {e}\n{traceback.format_exc()}")
|
|
_pipeline_result = {
|
|
"status": "error",
|
|
"id": msg_id,
|
|
"stage": "error",
|
|
"detail": str(e),
|
|
}
|
|
|
|
_pipeline_task = asyncio.create_task(_run_pipeline())
|
|
return {"status": "queued", "id": msg_id}
|
|
|
|
@app.get("/api/result")
|
|
async def api_result(user=Depends(require_auth)):
|
|
"""Poll for the current pipeline result."""
|
|
return _pipeline_result
|
|
|
|
@app.get("/api/frames")
|
|
async def api_frames(user=Depends(require_auth), last: int = 5):
|
|
"""Get frame traces from the frame engine. Returns last N message traces."""
|
|
runtime = _ensure_runtime()
|
|
if hasattr(runtime, 'frame_engine'):
|
|
engine = runtime.frame_engine
|
|
traces = engine.trace_history[-last:]
|
|
return {
|
|
"graph": engine.graph.get("name", "unknown"),
|
|
"engine": "frames",
|
|
"traces": traces,
|
|
"last_trace": engine.last_trace.to_dict() if engine.last_trace.message else None,
|
|
}
|
|
return {"engine": "imperative", "traces": [], "detail": "Frame engine not active"}
|
|
|
|
@app.post("/api/clear")
|
|
async def api_clear(user=Depends(require_auth)):
|
|
global _pipeline_result
|
|
runtime = _ensure_runtime()
|
|
runtime.history.clear()
|
|
runtime.ui_node.state.clear()
|
|
runtime.ui_node.bindings.clear()
|
|
runtime.ui_node.thinker_controls.clear()
|
|
runtime.ui_node.machines.clear()
|
|
runtime.memorizer.state = {
|
|
"user_name": runtime.identity,
|
|
"user_mood": "neutral",
|
|
"topic": None,
|
|
"topic_history": [],
|
|
"situation": runtime.memorizer.state.get("situation", ""),
|
|
"language": "en",
|
|
"style_hint": "casual, technical",
|
|
"facts": [],
|
|
"user_expectation": "conversational",
|
|
}
|
|
_pipeline_result = {"status": "idle", "id": "", "stage": "cleared"}
|
|
# Notify frontend via WS
|
|
if runtime.sink.ws:
|
|
try:
|
|
await runtime.sink.ws.send_text(json.dumps({"type": "cleared"}))
|
|
except Exception:
|
|
pass
|
|
return {"status": "cleared"}
|
|
|
|
@app.get("/api/state")
|
|
async def get_state(user=Depends(require_auth)):
|
|
runtime = _ensure_runtime()
|
|
return {
|
|
"status": "active",
|
|
"memorizer": runtime.memorizer.state,
|
|
"history_len": len(runtime.history),
|
|
"ws_connected": runtime.sink.ws is not None,
|
|
}
|
|
|
|
@app.get("/api/history")
|
|
async def get_history(last: int = 10, user=Depends(require_auth)):
|
|
runtime = _ensure_runtime()
|
|
return {
|
|
"status": "active",
|
|
"messages": runtime.history[-last:],
|
|
}
|
|
|
|
@app.get("/api/graph/active")
|
|
async def get_active_graph():
|
|
from .engine import load_graph, get_graph_for_cytoscape
|
|
from .runtime import _active_graph_name
|
|
graph = load_graph(_active_graph_name)
|
|
# Include model info from instantiated nodes if runtime exists
|
|
node_details = {}
|
|
if _active_runtime:
|
|
for role, impl_name in graph["nodes"].items():
|
|
# Find the node instance by role
|
|
node_inst = getattr(_active_runtime, 'frame_engine', None)
|
|
if node_inst and hasattr(node_inst, 'nodes'):
|
|
inst = node_inst.nodes.get(role)
|
|
if inst:
|
|
node_details[role] = {
|
|
"impl": impl_name,
|
|
"model": getattr(inst, 'model', None) or '',
|
|
"max_tokens": getattr(inst, 'max_context_tokens', 0),
|
|
}
|
|
return {
|
|
"name": graph["name"],
|
|
"description": graph["description"],
|
|
"nodes": graph["nodes"],
|
|
"edges": graph["edges"],
|
|
"node_details": node_details,
|
|
"cytoscape": get_graph_for_cytoscape(graph),
|
|
}
|
|
|
|
@app.get("/api/graph/list")
|
|
async def get_graph_list():
|
|
from .engine import list_graphs
|
|
return {"graphs": list_graphs()}
|
|
|
|
@app.post("/api/graph/switch")
|
|
async def switch_graph(body: dict, user=Depends(require_auth)):
|
|
global _active_runtime
|
|
from .engine import load_graph
|
|
import agent.runtime as rt
|
|
name = body.get("name", "")
|
|
graph = load_graph(name) # validates it exists
|
|
rt._active_graph_name = name
|
|
|
|
# Preserve WS connection across graph switch
|
|
old_ws = None
|
|
old_claims = {}
|
|
old_origin = ""
|
|
if _active_runtime:
|
|
old_ws = _active_runtime.sink.ws
|
|
old_claims = {"name": _active_runtime.identity}
|
|
old_origin = _active_runtime.channel
|
|
_active_runtime.sensor.stop()
|
|
_active_runtime = None
|
|
|
|
# Create new runtime with new graph
|
|
new_runtime = _ensure_runtime(user_claims=old_claims, origin=old_origin)
|
|
|
|
# Re-attach WS if it was connected
|
|
if old_ws:
|
|
new_runtime.attach_ws(old_ws)
|
|
log.info(f"[api] re-attached WS after graph switch to '{name}'")
|
|
|
|
return {"status": "ok", "name": graph["name"],
|
|
"note": "Graph switched. WS re-attached."}
|
|
|
|
# --- Test status (real-time) ---
|
|
_test_status = {"running": False, "current": "", "results": [], "last_green": None, "last_red": None, "total_expected": 0}
|
|
|
|
@app.post("/api/test/status")
|
|
async def post_test_status(body: dict, user=Depends(require_auth)):
|
|
"""Receive test status updates from the test runner."""
|
|
event = body.get("event", "")
|
|
if event == "suite_start":
|
|
_test_status["running"] = True
|
|
_test_status["current"] = body.get("suite", "")
|
|
if body.get("count"):
|
|
# First suite_start with count resets everything
|
|
_test_status["results"] = []
|
|
_test_status["total_expected"] = body["count"]
|
|
_test_status["last_green"] = None
|
|
_test_status["last_red"] = None
|
|
elif event == "step_result":
|
|
result = body.get("result", {})
|
|
_test_status["results"].append(result)
|
|
_test_status["current"] = f"{result.get('step', '')} — {result.get('check', '')}"
|
|
if result.get("status") == "FAIL":
|
|
_test_status["last_red"] = result
|
|
elif result.get("status") == "PASS":
|
|
_test_status["last_green"] = result
|
|
elif event == "suite_end":
|
|
_test_status["running"] = False
|
|
_test_status["current"] = ""
|
|
# Broadcast to /ws/test subscribers — must await to ensure delivery before response
|
|
await _broadcast_test({"type": "test_status", **_test_status})
|
|
# Also SSE for backward compat
|
|
_broadcast_sse({"type": "test_status", **_test_status})
|
|
return {"ok": True}
|
|
|
|
@app.get("/api/test/status")
|
|
async def get_test_status(user=Depends(require_auth)):
|
|
return _test_status
|
|
|
|
@app.get("/api/tests")
|
|
async def get_tests():
|
|
"""Latest test results from runtime_test.py."""
|
|
results_path = Path(__file__).parent.parent / "testcases" / "results.json"
|
|
if not results_path.exists():
|
|
return {}
|
|
return json.loads(results_path.read_text(encoding="utf-8"))
|
|
|
|
@app.get("/api/trace")
|
|
async def get_trace(last: int = 30, user=Depends(require_auth)):
|
|
if not TRACE_FILE.exists():
|
|
return {"lines": []}
|
|
lines = TRACE_FILE.read_text(encoding="utf-8").strip().split("\n")
|
|
parsed = []
|
|
for line in lines[-last:]:
|
|
try:
|
|
parsed.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return {"lines": parsed}
|
|
|
|
# --- Debug command channel ---
|
|
# Flow: AI POSTs cmd → queued → nyx picks up via poll → executes → POSTs result → AI gets response
|
|
|
|
@app.post("/api/debug/cmd")
|
|
async def debug_cmd(body: dict, user=Depends(require_auth)):
|
|
"""AI sends a command to execute in the browser. Waits up to 5s for result."""
|
|
import uuid
|
|
cmd_id = str(uuid.uuid4())[:8]
|
|
cmd = body.get("cmd", "")
|
|
args = body.get("args", {})
|
|
if not cmd:
|
|
raise HTTPException(400, "Missing 'cmd'")
|
|
|
|
if not _debug_queues:
|
|
return {"cmd_id": cmd_id, "error": "no browser connected"}
|
|
|
|
evt = asyncio.Event()
|
|
_debug_results[cmd_id] = evt
|
|
# Push to all connected SSE clients
|
|
payload = {"cmd_id": cmd_id, "cmd": cmd, "args": args}
|
|
for q in _debug_queues:
|
|
try:
|
|
q.put_nowait(payload)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
|
|
# Wait for nyx to execute and POST result back
|
|
try:
|
|
await asyncio.wait_for(evt.wait(), timeout=5.0)
|
|
except asyncio.TimeoutError:
|
|
_debug_results.pop(cmd_id, None)
|
|
return {"cmd_id": cmd_id, "error": "timeout — command took too long"}
|
|
|
|
result = _debug_result_data.pop(cmd_id, {})
|
|
_debug_results.pop(cmd_id, None)
|
|
return {"cmd_id": cmd_id, **result}
|
|
|
|
@app.post("/api/debug/result")
|
|
async def debug_result(body: dict, user=Depends(require_auth)):
|
|
"""nyx posts command results back here."""
|
|
cmd_id = body.get("cmd_id", "")
|
|
if not cmd_id:
|
|
raise HTTPException(400, "Missing 'cmd_id'")
|
|
_debug_result_data[cmd_id] = {"result": body.get("result"), "error": body.get("error")}
|
|
evt = _debug_results.get(cmd_id)
|
|
if evt:
|
|
evt.set()
|
|
return {"ok": True}
|
|
|
|
# ── Test results (fed by test-runner Job) ────────────────────────────────
|
|
|
|
@app.post("/api/test-results")
|
|
async def post_test_result(request: Request):
|
|
"""Receive a single test result from the test-runner Job. No auth (internal)."""
|
|
global _test_run_id
|
|
body = await request.json()
|
|
run_id = body.get("run_id", "")
|
|
|
|
# New run → clear old results
|
|
if run_id and run_id != _test_run_id:
|
|
_test_results.clear()
|
|
_test_run_id = run_id
|
|
|
|
# Replace existing entry for same test+suite, or append
|
|
key = (body.get("test"), body.get("suite"))
|
|
for i, existing in enumerate(_test_results):
|
|
if (existing.get("test"), existing.get("suite")) == key:
|
|
_test_results[i] = body
|
|
break
|
|
else:
|
|
_test_results.append(body)
|
|
|
|
# Push to SSE subscribers
|
|
for q in list(_test_results_subscribers):
|
|
try:
|
|
q.put_nowait(body)
|
|
except asyncio.QueueFull:
|
|
pass
|
|
return {"ok": True}
|
|
|
|
@app.get("/api/test-results")
|
|
async def get_test_results_sse(request: Request):
|
|
"""SSE stream: sends current results on connect, then live updates."""
|
|
q: Queue = Queue(maxsize=100)
|
|
_test_results_subscribers.append(q)
|
|
|
|
async def stream():
|
|
try:
|
|
# Send all existing results first
|
|
for r in list(_test_results):
|
|
yield f"data: {json.dumps(r)}\n\n"
|
|
# Then stream new ones
|
|
while True:
|
|
if await request.is_disconnected():
|
|
break
|
|
try:
|
|
result = await asyncio.wait_for(q.get(), timeout=15)
|
|
yield f"data: {json.dumps(result)}\n\n"
|
|
except asyncio.TimeoutError:
|
|
yield f"event: heartbeat\ndata: {{}}\n\n"
|
|
finally:
|
|
_test_results_subscribers.remove(q)
|
|
|
|
return StreamingResponse(stream(), media_type="text/event-stream")
|
|
|
|
@app.get("/api/test-results/latest")
|
|
async def get_test_results_latest():
|
|
"""JSON snapshot of the latest test run."""
|
|
return {"run_id": _test_run_id, "results": _test_results}
|