Remove old /ws chat endpoint (replaced by Streamable HTTP)

Kept /ws/test and /ws/trace debug WebSockets for dev tooling.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Nico 2026-04-01 17:45:29 +02:00
parent e94f1c437c
commit 1e64b0a58c

View File

@ -7,7 +7,7 @@ import logging
from asyncio import Queue
from pathlib import Path
from fastapi import Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi import Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect
from starlette.responses import StreamingResponse
import httpx
@ -110,10 +110,25 @@ def register_routes(app):
async def shutdown():
await db_sessions.close_pool()
@app.get("/health")
@app.get("/api/health")
@app.get("/health") # K8s probes
async def health():
return {"status": "ok"}
@app.get("/api/health-stream")
async def health_stream(user=Depends(require_auth)):
"""SSE heartbeat stream — client uses this for presence detection."""
async def generate():
try:
while True:
yield f"event: heartbeat\ndata: {json.dumps({'ts': int(asyncio.get_event_loop().time()), 'sessions': len(_sessions)})}\n\n"
await asyncio.sleep(15)
except asyncio.CancelledError:
pass
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
# --- Session CRUD ---
@app.post("/api/sessions")
@ -202,124 +217,6 @@ def register_routes(app):
log.info("[api] created persistent runtime (legacy)")
return _active_runtime
@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
access_token: str | None = Query(None),
session: str | None = Query(None)):
user_claims = {"sub": "anonymous"}
if AUTH_ENABLED and token:
try:
user_claims = await _validate_token(token)
if not user_claims.get("name") and access_token:
async with httpx.AsyncClient() as client:
resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
headers={"Authorization": f"Bearer {access_token}"})
if resp.status_code == 200:
info = resp.json()
log.info(f"[auth] userinfo enrichment: {info}")
user_claims["name"] = info.get("name")
user_claims["preferred_username"] = info.get("preferred_username")
user_claims["email"] = info.get("email")
except HTTPException:
await ws.close(code=4001, reason="Invalid token")
return
origin = ws.headers.get("origin", ws.headers.get("host", ""))
await ws.accept()
# Get or create session, attach WS
runtime = await _get_or_create_session(
session_id=session, user_claims=user_claims, origin=origin)
runtime.update_identity(user_claims, origin)
runtime.attach_ws(ws)
# Tell client which session they're on + ready signal
try:
await ws.send_text(json.dumps({
"type": "session_info", "session_id": runtime.session_id,
"graph": runtime.graph.get("name", "unknown"),
"history_len": len(runtime.history)}))
await ws.send_text(json.dumps({
"type": "ready", "session_id": runtime.session_id}))
except Exception:
pass
try:
while True:
data = await ws.receive_text()
msg = json.loads(data)
msg_type = msg.get("type", "")
rt = _sessions.get(runtime.session_id, runtime)
# Ping — keep-alive, no processing
if msg_type == "ping":
continue
# Auth message from nyx — already authed via query param, ignore
if msg_type in ("auth", "connect"):
continue
try:
if msg_type == "action":
action = msg.get("action", "unknown")
data_payload = msg.get("data")
if hasattr(rt, 'use_frames') and rt.use_frames:
action_text = f"ACTION:{action}"
if data_payload:
action_text += f"|data:{json.dumps(data_payload)}"
await rt.handle_message(action_text)
else:
await rt.handle_action(action, data_payload)
elif msg_type == "cancel_process":
rt.process_manager.cancel(msg.get("pid", 0))
elif msg_type == "new":
# New session requested
rt.detach_ws()
new_rt = await _get_or_create_session(
user_claims=user_claims, origin=origin)
new_rt.attach_ws(ws)
runtime = new_rt
rt = new_rt
await ws.send_text(json.dumps({
"type": "session_info", "session_id": rt.session_id,
"graph": rt.graph.get("name", "unknown"),
"history_len": 0}))
await ws.send_text(json.dumps({"type": "cleared"}))
continue
elif msg_type == "stop":
# Cancel running pipeline
if _pipeline_task and not _pipeline_task.done():
_pipeline_task.cancel()
continue
elif msg_type == "message":
# nyx format: {type: 'message', content: '...'}
text = msg.get("content", "").strip()
if text:
await rt.handle_message(text, dashboard=msg.get("dashboard"))
else:
# Legacy/assay format: {text: '...', dashboard: [...]}
text = msg.get("text", "").strip()
if text:
await rt.handle_message(text, dashboard=msg.get("dashboard"))
# Auto-save after each processed message
asyncio.create_task(_save_session(rt))
except Exception as e:
import traceback
log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}")
try:
await ws.send_text(json.dumps({"type": "hud", "node": "runtime", "event": "error", "detail": str(e)[:200]}))
except Exception:
pass
except WebSocketDisconnect:
runtime.detach_ws()
log.info(f"[api] WS disconnected — session {runtime.session_id} stays alive")
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
"""Validate token for debug WS. Returns True if auth OK."""
if not AUTH_ENABLED:
@ -370,6 +267,96 @@ def register_routes(app):
_trace_ws_clients.remove(ws)
log.info(f"[api] /ws/trace disconnected ({len(_trace_ws_clients)} clients)")
# --- Streamable HTTP chat endpoint ---
@app.post("/api/chat")
async def api_chat(request: Request, user=Depends(require_auth)):
"""Send a message and receive streaming SSE response."""
body = await request.json()
session_id = body.get("session_id")
text = body.get("content", "").strip()
action = body.get("action")
action_data = body.get("action_data")
dashboard = body.get("dashboard")
if not text and not action:
raise HTTPException(status_code=400, detail="Missing 'content' or 'action'")
origin = request.headers.get("origin", request.headers.get("host", ""))
rt = await _get_or_create_session(
session_id=session_id, user_claims=user, origin=origin)
rt.update_identity(user, origin)
# Attach SSE queue to sink for this request
q: Queue = Queue(maxsize=500)
rt.sink.attach_queue(q)
async def run_and_close():
try:
if action:
if hasattr(rt, 'use_frames') and rt.use_frames:
action_text = f"ACTION:{action}"
if action_data:
action_text += f"|data:{json.dumps(action_data)}"
await rt.handle_message(action_text)
else:
await rt.handle_action(action, action_data)
else:
await rt.handle_message(text, dashboard=dashboard)
# Auto-save
await _save_session(rt)
except Exception as e:
import traceback
log.error(f"[chat] handler error: {e}\n{traceback.format_exc()}")
try:
q.put_nowait({"type": "error", "detail": str(e)[:200]})
except asyncio.QueueFull:
pass
finally:
# Signal end-of-stream
q.put_nowait(None)
# Run pipeline in background task
task = asyncio.create_task(run_and_close())
async def generate():
try:
while True:
event = await q.get()
if event is None:
break
event_type = event.get("type", "message")
yield f"event: {event_type}\ndata: {json.dumps(event)}\n\n"
except asyncio.CancelledError:
task.cancel()
finally:
rt.sink.detach_queue()
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@app.get("/api/session")
async def api_session(token: str = Query(None), session: str = Query(None),
user=Depends(require_auth)):
"""Get or create session — replaces WS connect handshake."""
rt = await _get_or_create_session(
session_id=session, user_claims=user, origin="")
return {
"session_id": rt.session_id,
"graph": rt.graph.get("name", "unknown"),
"history_len": len(rt.history),
"status": "ready",
"memorizer": rt.memorizer.state,
}
@app.post("/api/stop")
async def api_stop(user=Depends(require_auth)):
"""Cancel running pipeline."""
if _pipeline_task and not _pipeline_task.done():
_pipeline_task.cancel()
return {"stopped": True}
return {"stopped": False, "detail": "No pipeline running"}
@app.get("/api/events")
async def sse_events(user=Depends(require_auth)):
q: Queue = Queue(maxsize=100)