Remove old /ws chat endpoint (replaced by Streamable HTTP)
Kept /ws/test and /ws/trace debug WebSockets for dev tooling. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e94f1c437c
commit
1e64b0a58c
227
agent/api.py
227
agent/api.py
@ -7,7 +7,7 @@ import logging
|
|||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
|
from fastapi import Depends, HTTPException, Query, Request, WebSocket, WebSocketDisconnect
|
||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -110,10 +110,25 @@ def register_routes(app):
|
|||||||
async def shutdown():
|
async def shutdown():
|
||||||
await db_sessions.close_pool()
|
await db_sessions.close_pool()
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/api/health")
|
||||||
|
@app.get("/health") # K8s probes
|
||||||
async def health():
|
async def health():
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
@app.get("/api/health-stream")
|
||||||
|
async def health_stream(user=Depends(require_auth)):
|
||||||
|
"""SSE heartbeat stream — client uses this for presence detection."""
|
||||||
|
async def generate():
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
yield f"event: heartbeat\ndata: {json.dumps({'ts': int(asyncio.get_event_loop().time()), 'sessions': len(_sessions)})}\n\n"
|
||||||
|
await asyncio.sleep(15)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||||
|
|
||||||
# --- Session CRUD ---
|
# --- Session CRUD ---
|
||||||
|
|
||||||
@app.post("/api/sessions")
|
@app.post("/api/sessions")
|
||||||
@ -202,124 +217,6 @@ def register_routes(app):
|
|||||||
log.info("[api] created persistent runtime (legacy)")
|
log.info("[api] created persistent runtime (legacy)")
|
||||||
return _active_runtime
|
return _active_runtime
|
||||||
|
|
||||||
@app.websocket("/ws")
|
|
||||||
async def ws_endpoint(ws: WebSocket, token: str | None = Query(None),
|
|
||||||
access_token: str | None = Query(None),
|
|
||||||
session: str | None = Query(None)):
|
|
||||||
user_claims = {"sub": "anonymous"}
|
|
||||||
if AUTH_ENABLED and token:
|
|
||||||
try:
|
|
||||||
user_claims = await _validate_token(token)
|
|
||||||
if not user_claims.get("name") and access_token:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.get(f"{ZITADEL_ISSUER}/oidc/v1/userinfo",
|
|
||||||
headers={"Authorization": f"Bearer {access_token}"})
|
|
||||||
if resp.status_code == 200:
|
|
||||||
info = resp.json()
|
|
||||||
log.info(f"[auth] userinfo enrichment: {info}")
|
|
||||||
user_claims["name"] = info.get("name")
|
|
||||||
user_claims["preferred_username"] = info.get("preferred_username")
|
|
||||||
user_claims["email"] = info.get("email")
|
|
||||||
except HTTPException:
|
|
||||||
await ws.close(code=4001, reason="Invalid token")
|
|
||||||
return
|
|
||||||
origin = ws.headers.get("origin", ws.headers.get("host", ""))
|
|
||||||
await ws.accept()
|
|
||||||
|
|
||||||
# Get or create session, attach WS
|
|
||||||
runtime = await _get_or_create_session(
|
|
||||||
session_id=session, user_claims=user_claims, origin=origin)
|
|
||||||
runtime.update_identity(user_claims, origin)
|
|
||||||
runtime.attach_ws(ws)
|
|
||||||
|
|
||||||
# Tell client which session they're on + ready signal
|
|
||||||
try:
|
|
||||||
await ws.send_text(json.dumps({
|
|
||||||
"type": "session_info", "session_id": runtime.session_id,
|
|
||||||
"graph": runtime.graph.get("name", "unknown"),
|
|
||||||
"history_len": len(runtime.history)}))
|
|
||||||
await ws.send_text(json.dumps({
|
|
||||||
"type": "ready", "session_id": runtime.session_id}))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
data = await ws.receive_text()
|
|
||||||
msg = json.loads(data)
|
|
||||||
msg_type = msg.get("type", "")
|
|
||||||
rt = _sessions.get(runtime.session_id, runtime)
|
|
||||||
|
|
||||||
# Ping — keep-alive, no processing
|
|
||||||
if msg_type == "ping":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Auth message from nyx — already authed via query param, ignore
|
|
||||||
if msg_type in ("auth", "connect"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
if msg_type == "action":
|
|
||||||
action = msg.get("action", "unknown")
|
|
||||||
data_payload = msg.get("data")
|
|
||||||
if hasattr(rt, 'use_frames') and rt.use_frames:
|
|
||||||
action_text = f"ACTION:{action}"
|
|
||||||
if data_payload:
|
|
||||||
action_text += f"|data:{json.dumps(data_payload)}"
|
|
||||||
await rt.handle_message(action_text)
|
|
||||||
else:
|
|
||||||
await rt.handle_action(action, data_payload)
|
|
||||||
|
|
||||||
elif msg_type == "cancel_process":
|
|
||||||
rt.process_manager.cancel(msg.get("pid", 0))
|
|
||||||
|
|
||||||
elif msg_type == "new":
|
|
||||||
# New session requested
|
|
||||||
rt.detach_ws()
|
|
||||||
new_rt = await _get_or_create_session(
|
|
||||||
user_claims=user_claims, origin=origin)
|
|
||||||
new_rt.attach_ws(ws)
|
|
||||||
runtime = new_rt
|
|
||||||
rt = new_rt
|
|
||||||
await ws.send_text(json.dumps({
|
|
||||||
"type": "session_info", "session_id": rt.session_id,
|
|
||||||
"graph": rt.graph.get("name", "unknown"),
|
|
||||||
"history_len": 0}))
|
|
||||||
await ws.send_text(json.dumps({"type": "cleared"}))
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif msg_type == "stop":
|
|
||||||
# Cancel running pipeline
|
|
||||||
if _pipeline_task and not _pipeline_task.done():
|
|
||||||
_pipeline_task.cancel()
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif msg_type == "message":
|
|
||||||
# nyx format: {type: 'message', content: '...'}
|
|
||||||
text = msg.get("content", "").strip()
|
|
||||||
if text:
|
|
||||||
await rt.handle_message(text, dashboard=msg.get("dashboard"))
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Legacy/assay format: {text: '...', dashboard: [...]}
|
|
||||||
text = msg.get("text", "").strip()
|
|
||||||
if text:
|
|
||||||
await rt.handle_message(text, dashboard=msg.get("dashboard"))
|
|
||||||
|
|
||||||
# Auto-save after each processed message
|
|
||||||
asyncio.create_task(_save_session(rt))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
log.error(f"[ws] handler error: {e}\n{traceback.format_exc()}")
|
|
||||||
try:
|
|
||||||
await ws.send_text(json.dumps({"type": "hud", "node": "runtime", "event": "error", "detail": str(e)[:200]}))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
runtime.detach_ws()
|
|
||||||
log.info(f"[api] WS disconnected — session {runtime.session_id} stays alive")
|
|
||||||
|
|
||||||
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
async def _auth_debug_ws(ws: WebSocket, token: str | None) -> bool:
|
||||||
"""Validate token for debug WS. Returns True if auth OK."""
|
"""Validate token for debug WS. Returns True if auth OK."""
|
||||||
if not AUTH_ENABLED:
|
if not AUTH_ENABLED:
|
||||||
@ -370,6 +267,96 @@ def register_routes(app):
|
|||||||
_trace_ws_clients.remove(ws)
|
_trace_ws_clients.remove(ws)
|
||||||
log.info(f"[api] /ws/trace disconnected ({len(_trace_ws_clients)} clients)")
|
log.info(f"[api] /ws/trace disconnected ({len(_trace_ws_clients)} clients)")
|
||||||
|
|
||||||
|
# --- Streamable HTTP chat endpoint ---
|
||||||
|
|
||||||
|
@app.post("/api/chat")
|
||||||
|
async def api_chat(request: Request, user=Depends(require_auth)):
|
||||||
|
"""Send a message and receive streaming SSE response."""
|
||||||
|
body = await request.json()
|
||||||
|
session_id = body.get("session_id")
|
||||||
|
text = body.get("content", "").strip()
|
||||||
|
action = body.get("action")
|
||||||
|
action_data = body.get("action_data")
|
||||||
|
dashboard = body.get("dashboard")
|
||||||
|
|
||||||
|
if not text and not action:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing 'content' or 'action'")
|
||||||
|
|
||||||
|
origin = request.headers.get("origin", request.headers.get("host", ""))
|
||||||
|
rt = await _get_or_create_session(
|
||||||
|
session_id=session_id, user_claims=user, origin=origin)
|
||||||
|
rt.update_identity(user, origin)
|
||||||
|
|
||||||
|
# Attach SSE queue to sink for this request
|
||||||
|
q: Queue = Queue(maxsize=500)
|
||||||
|
rt.sink.attach_queue(q)
|
||||||
|
|
||||||
|
async def run_and_close():
|
||||||
|
try:
|
||||||
|
if action:
|
||||||
|
if hasattr(rt, 'use_frames') and rt.use_frames:
|
||||||
|
action_text = f"ACTION:{action}"
|
||||||
|
if action_data:
|
||||||
|
action_text += f"|data:{json.dumps(action_data)}"
|
||||||
|
await rt.handle_message(action_text)
|
||||||
|
else:
|
||||||
|
await rt.handle_action(action, action_data)
|
||||||
|
else:
|
||||||
|
await rt.handle_message(text, dashboard=dashboard)
|
||||||
|
# Auto-save
|
||||||
|
await _save_session(rt)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
log.error(f"[chat] handler error: {e}\n{traceback.format_exc()}")
|
||||||
|
try:
|
||||||
|
q.put_nowait({"type": "error", "detail": str(e)[:200]})
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
# Signal end-of-stream
|
||||||
|
q.put_nowait(None)
|
||||||
|
|
||||||
|
# Run pipeline in background task
|
||||||
|
task = asyncio.create_task(run_and_close())
|
||||||
|
|
||||||
|
async def generate():
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
event = await q.get()
|
||||||
|
if event is None:
|
||||||
|
break
|
||||||
|
event_type = event.get("type", "message")
|
||||||
|
yield f"event: {event_type}\ndata: {json.dumps(event)}\n\n"
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
task.cancel()
|
||||||
|
finally:
|
||||||
|
rt.sink.detach_queue()
|
||||||
|
|
||||||
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||||
|
|
||||||
|
@app.get("/api/session")
|
||||||
|
async def api_session(token: str = Query(None), session: str = Query(None),
|
||||||
|
user=Depends(require_auth)):
|
||||||
|
"""Get or create session — replaces WS connect handshake."""
|
||||||
|
rt = await _get_or_create_session(
|
||||||
|
session_id=session, user_claims=user, origin="")
|
||||||
|
return {
|
||||||
|
"session_id": rt.session_id,
|
||||||
|
"graph": rt.graph.get("name", "unknown"),
|
||||||
|
"history_len": len(rt.history),
|
||||||
|
"status": "ready",
|
||||||
|
"memorizer": rt.memorizer.state,
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.post("/api/stop")
|
||||||
|
async def api_stop(user=Depends(require_auth)):
|
||||||
|
"""Cancel running pipeline."""
|
||||||
|
if _pipeline_task and not _pipeline_task.done():
|
||||||
|
_pipeline_task.cancel()
|
||||||
|
return {"stopped": True}
|
||||||
|
return {"stopped": False, "detail": "No pipeline running"}
|
||||||
|
|
||||||
@app.get("/api/events")
|
@app.get("/api/events")
|
||||||
async def sse_events(user=Depends(require_auth)):
|
async def sse_events(user=Depends(require_auth)):
|
||||||
q: Queue = Queue(maxsize=100)
|
q: Queue = Queue(maxsize=100)
|
||||||
|
|||||||
Reference in New Issue
Block a user