"""Model matrix tests — run the same API test with different LLM model configs. Each variant defines model overrides for specific node roles. The test runner generates one test per (base_test × variant) combination. Results are posted with the variant name in brackets, e.g. "eras_umsatz_api[haiku]". Usage via run_tests.py: python tests/run_tests.py matrix # all variants × all tests python tests/run_tests.py matrix/eras_query[haiku] # single combo """ import json import os import sys import urllib.request sys.path.insert(0, os.path.dirname(__file__)) _api_url = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000/api') ASSAY_BASE = _api_url.removesuffix('/api') if _api_url.endswith('/api') else _api_url SERVICE_TOKEN = '7Oorb9S3OpwFyWgm4zi_Tq7GeamefbjjTgooPVPWAwPDOf6B4TvgvQlLbhmT4DjsqBS_D1g' # --- Model variants --- # Each variant overrides specific node models. Omitted roles keep graph defaults. VARIANTS = { 'gemini-flash': { # Graph default — this is the baseline }, 'haiku': { 'pa': 'anthropic/claude-haiku-4.5', 'expert_eras': 'anthropic/claude-haiku-4.5', 'interpreter': 'anthropic/claude-haiku-4.5', }, 'gpt-4o-mini': { 'pa': 'openai/gpt-4o-mini', 'expert_eras': 'openai/gpt-4o-mini', 'interpreter': 'openai/gpt-4o-mini', }, } # --- API helper with model overrides --- def api_chat(text: str, models: dict = None, timeout: int = 120) -> list[tuple[str, str]]: """Send message via /api/chat with optional model overrides. Returns SSE events.""" body = {'content': text} if models: body['models'] = models payload = json.dumps(body).encode() req = urllib.request.Request( f'{ASSAY_BASE}/api/chat', data=payload, method='POST', headers={ 'Authorization': f'Bearer {SERVICE_TOKEN}', 'Content-Type': 'application/json', }, ) resp = urllib.request.urlopen(req, timeout=timeout) output = resp.read().decode('utf-8') events = [] for block in output.split('\n\n'): event_type, data = '', '' for line in block.strip().split('\n'): if line.startswith('event: '): event_type = line[7:] elif line.startswith('data: '): data = line[6:] if event_type and data: events.append((event_type, data)) return events # --- Base tests (same logic, different models) --- def _test_eras_query(models: dict): """ERAS query produces correct SQL path (artikelposition, not geraeteverbraeuche).""" events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models) tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1] and 'query_db' in e[1]] assert tool_calls, 'no query_db tool call found' query_data = json.loads(tool_calls[0][1]) query = query_data.get('args', {}).get('query', '') assert 'artikelposition' in query.lower(), f'query missing artikelposition: {query}' assert 'geraeteverbraeuche' not in query.lower(), f'query uses wrong table: {query}' # Check it completed done_events = [e for e in events if e[0] == 'done'] assert done_events, 'no done event' def _test_eras_artifact(models: dict): """ERAS query produces artifact with data rows.""" events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models) artifact_events = [e for e in events if e[0] == 'artifacts'] assert artifact_events, 'no artifact event' artifacts = json.loads(artifact_events[0][1]).get('artifacts', []) assert len(artifacts) >= 1, f'expected >=1 artifact, got {len(artifacts)}' has_data = any( art.get('data', {}).get('fields') or art.get('data', {}).get('rows') for art in artifacts ) assert has_data, 'no artifact contains data' def _test_social_reflex(models: dict): """Social greeting takes reflex path (fast, no expert).""" events = api_chat('Hallo!', models=models) # Should get a response (delta events) deltas = [e for e in events if e[0] == 'delta'] assert deltas, 'no delta events' # Should complete done = [e for e in events if e[0] == 'done'] assert done, 'no done event' # Should NOT call any tools tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1]] assert not tool_calls, f'reflex path should not use tools, got {len(tool_calls)} calls' # --- Test registry: base tests that get multiplied by variants --- BASE_TESTS = { 'eras_query': _test_eras_query, 'eras_artifact': _test_eras_artifact, 'social_reflex': _test_social_reflex, } def get_matrix_tests() -> dict: """Generate test×variant matrix. Returns {name: callable} dict for run_tests.py.""" tests = {} for variant_name, models in VARIANTS.items(): for test_name, test_fn in BASE_TESTS.items(): combo_name = f'{test_name}[{variant_name}]' # Capture current values in closure tests[combo_name] = (lambda fn, m: lambda: fn(m))(test_fn, models) return tests