From 1e64b0a58c159f2339fc8998a876aebbe379d8ac Mon Sep 17 00:00:00 2001 From: Nico Date: Wed, 1 Apr 2026 17:45:29 +0200 Subject: [PATCH] Remove old /ws chat endpoint (replaced by Streamable HTTP) Kept /ws/test and /ws/trace debug WebSockets for dev tooling. Co-Authored-By: Claude Opus 4.6 --- agent/api.py | 227 ++++++++++++++++++++++++--------------------------- 1 file changed, 107 insertions(+), 120 deletions(-) diff --git a/agent/api.py b/agent/api.py index e318edb..31b7bca 100644 --- a/agent/api.py +++ b/agent/api.py @@ -7,7 +7,7 @@ import logging from asyncio import Queue from pathlib import Path -from fastapi import Depends, HTTPException, Query, WebSocket, WebSocketDisconnect +from fastapi import Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect from starlette.responses import StreamingResponse import httpx @@ -110,10 +110,25 @@ def register_routes(app): async def shutdown(): await db_sessions.close_pool() - @app.get("/health") + @app.get("/api/health") + @app.get("/health") # K8s probes async def health(): return {"status": "ok"} + @app.get("/api/health-stream") + async def health_stream(user=Depends(require_auth)): + """SSE heartbeat stream — client uses this for presence detection.""" + async def generate(): + try: + while True: + yield f"event: heartbeat\ndata: {json.dumps({'ts': int(asyncio.get_event_loop().time()), 'sessions': len(_sessions)})}\n\n" + await asyncio.sleep(15) + except asyncio.CancelledError: + pass + + return StreamingResponse(generate(), media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) + # --- Session CRUD --- @app.post("/api/sessions") @@ -202,124 +217,6 @@ def register_routes(app): log.info("[api] created persistent runtime (legacy)") return _active_runtime - @app.websocket("/ws") - async def ws_endpoint(ws: WebSocket, token: str | None = Query(None), - access_token: str | None = Query(None), - session: str | None = Query(None)): - user_claims = {"sub": "anonymous"} - if AUTH_ENABLED and token: - try: - user_claims = await _validate_token(token) - if not user_claims.get("name") and access_token: - async with httpx.AsyncClient() as client: - resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo", - headers={"Authorization": f"Bearer {access_token}"}) - if resp.status_code == 200: - info = resp.json() - log.info(f"[auth] userinfo enrichment: {info}") - user_claims["name"] = info.get("name") - user_claims["preferred_username"] = info.get("preferred_username") - user_claims["email"] = info.get("email") - except HTTPException: - await ws.close(code=4001, reason="Invalid token") - return - origin = ws.headers.get("origin", ws.headers.get("host", "")) - await ws.accept() - - # Get or create session, attach WS - runtime = await _get_or_create_session( - session_id=session, user_claims=user_claims, origin=origin) - runtime.update_identity(user_claims, origin) - runtime.attach_ws(ws) - - # Tell client which session they're on + ready signal - try: - await ws.send_text(json.dumps({ - "type": "session_info", "session_id": runtime.session_id, - "graph": runtime.graph.get("name", "unknown"), - "history_len": len(runtime.history)})) - await ws.send_text(json.dumps({ - "type": "ready", "session_id": runtime.session_id})) - except Exception: - pass - - try: - while True: - data = await ws.receive_text() - msg = json.loads(data) - msg_type = msg.get("type", "") - rt = _sessions.get(runtime.session_id, runtime) - - # Ping — keep-alive, no processing - if msg_type == "ping": - continue - - # Auth message from nyx — already authed via query param, ignore - if msg_type in ("auth", "connect"): - continue - - try: - if msg_type == "action": - action = msg.get("action", "unknown") - data_payload = msg.get("data") - if hasattr(rt, 'use_frames') and rt.use_frames: - action_text = f"ACTION:{action}" - if data_payload: - action_text += f"|data:{json.dumps(data_payload)}" - await rt.handle_message(action_text) - else: - await rt.handle_action(action, data_payload) - - elif msg_type == "cancel_process": - rt.process_manager.cancel(msg.get("pid", 0)) - - elif msg_type == "new": - # New session requested - rt.detach_ws() - new_rt = await _get_or_create_session( - user_claims=user_claims, origin=origin) - new_rt.attach_ws(ws) - runtime = new_rt - rt = new_rt - await ws.send_text(json.dumps({ - "type": "session_info", "session_id": rt.session_id, - "graph": rt.graph.get("name", "unknown"), - "history_len": 0})) - await ws.send_text(json.dumps({"type": "cleared"})) - continue - - elif msg_type == "stop": - # Cancel running pipeline - if _pipeline_task and not _pipeline_task.done(): - _pipeline_task.cancel() - continue - - elif msg_type == "message": - # nyx format: {type: 'message', content: '...'} - text = msg.get("content", "").strip() - if text: - await rt.handle_message(text, dashboard=msg.get("dashboard")) - - else: - # Legacy/assay format: {text: '...', dashboard: [...]} - text = msg.get("text", "").strip() - if text: - await rt.handle_message(text, dashboard=msg.get("dashboard")) - - # Auto-save after each processed message - asyncio.create_task(_save_session(rt)) - - except Exception as e: - import traceback - log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}") - try: - await ws.send_text(json.dumps({"type": "hud", "node": "runtime", "event": "error", "detail": str(e)[:200]})) - except Exception: - pass - except WebSocketDisconnect: - runtime.detach_ws() - log.info(f"[api] WS disconnected — session {runtime.session_id} stays alive") - async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool: """Validate token for debug WS. Returns True if auth OK.""" if not AUTH_ENABLED: @@ -370,6 +267,96 @@ def register_routes(app): _trace_ws_clients.remove(ws) log.info(f"[api] /ws/trace disconnected ({len(_trace_ws_clients)} clients)") + # --- Streamable HTTP chat endpoint --- + + @app.post("/api/chat") + async def api_chat(request: Request, user=Depends(require_auth)): + """Send a message and receive streaming SSE response.""" + body = await request.json() + session_id = body.get("session_id") + text = body.get("content", "").strip() + action = body.get("action") + action_data = body.get("action_data") + dashboard = body.get("dashboard") + + if not text and not action: + raise HTTPException(status_code=400, detail="Missing 'content' or 'action'") + + origin = request.headers.get("origin", request.headers.get("host", "")) + rt = await _get_or_create_session( + session_id=session_id, user_claims=user, origin=origin) + rt.update_identity(user, origin) + + # Attach SSE queue to sink for this request + q: Queue = Queue(maxsize=500) + rt.sink.attach_queue(q) + + async def run_and_close(): + try: + if action: + if hasattr(rt, 'use_frames') and rt.use_frames: + action_text = f"ACTION:{action}" + if action_data: + action_text += f"|data:{json.dumps(action_data)}" + await rt.handle_message(action_text) + else: + await rt.handle_action(action, action_data) + else: + await rt.handle_message(text, dashboard=dashboard) + # Auto-save + await _save_session(rt) + except Exception as e: + import traceback + log.error(f"[chat] handler error: {e}\n{traceback.format_exc()}") + try: + q.put_nowait({"type": "error", "detail": str(e)[:200]}) + except asyncio.QueueFull: + pass + finally: + # Signal end-of-stream + q.put_nowait(None) + + # Run pipeline in background task + task = asyncio.create_task(run_and_close()) + + async def generate(): + try: + while True: + event = await q.get() + if event is None: + break + event_type = event.get("type", "message") + yield f"event: {event_type}\ndata: {json.dumps(event)}\n\n" + except asyncio.CancelledError: + task.cancel() + finally: + rt.sink.detach_queue() + + return StreamingResponse(generate(), media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) + + @app.get("/api/session") + async def api_session(token: str = Query(None), session: str = Query(None), + user=Depends(require_auth)): + """Get or create session — replaces WS connect handshake.""" + rt = await _get_or_create_session( + session_id=session, user_claims=user, origin="") + return { + "session_id": rt.session_id, + "graph": rt.graph.get("name", "unknown"), + "history_len": len(rt.history), + "status": "ready", + "memorizer": rt.memorizer.state, + } + + @app.post("/api/stop") + async def api_stop(user=Depends(require_auth)): + """Cancel running pipeline.""" + if _pipeline_task and not _pipeline_task.done(): + _pipeline_task.cancel() + return {"stopped": True} + return {"stopped": False, "detail": "No pipeline running"} + @app.get("/api/events") async def sse_events(user=Depends(require_auth)): q: Queue = Queue(maxsize=100)