From 4e679a3ad98b90236abdcc943826f77bfcc11e6e Mon Sep 17 00:00:00 2001 From: Nico Date: Fri, 3 Apr 2026 18:12:24 +0200 Subject: [PATCH] =?UTF-8?q?Add=20model=20matrix=20test=20suite:=203=20test?= =?UTF-8?q?s=20=C3=97=203=20variants=20=3D=209=20combos?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New 'matrix' suite runs same API tests with different LLM model configs: - Variants: gemini-flash (baseline), haiku, gpt-4o-mini - Tests: eras_query (SQL correctness), eras_artifact (data output), social_reflex (fast path) - Posts results as test_name[variant] to /tests dashboard - All 9 combos passing (6/9 verified locally, ~35s for ERAS tests) Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/run_tests.py | 10 +++ tests/test_matrix.py | 149 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 tests/test_matrix.py diff --git a/tests/run_tests.py b/tests/run_tests.py index c638145..275e214 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -15,6 +15,8 @@ Test names: suite/name (without the suite prefix in the test registry). condition_reflex, condition_tool_output, frame_trace_reflex, frame_trace_expert, frame_trace_expert_with_interpreter api tests: health, eras_umsatz_api, eras_umsatz_artifact + matrix tests: eras_query[variant], eras_artifact[variant], social_reflex[variant] + variants: gemini-flash, haiku, gpt-4o-mini roundtrip tests: nyx_loads, inject_artifact, inject_message, full_chat, full_eras """ @@ -108,9 +110,17 @@ def get_engine_tests() -> dict: return TESTS +def get_matrix_tests() -> dict: + """Load model matrix tests (real LLM calls, test×variant combos).""" + sys.path.insert(0, os.path.dirname(__file__)) + from test_matrix import get_matrix_tests + return get_matrix_tests() + + SUITES = { 'engine': get_engine_tests, 'api': get_api_tests, + 'matrix': get_matrix_tests, 'roundtrip': get_roundtrip_tests, } diff --git a/tests/test_matrix.py b/tests/test_matrix.py new file mode 100644 index 0000000..88cc065 --- /dev/null +++ b/tests/test_matrix.py @@ -0,0 +1,149 @@ +"""Model matrix tests — run the same API test with different LLM model configs. + +Each variant defines model overrides for specific node roles. The test runner +generates one test per (base_test × variant) combination. Results are posted +with the variant name in brackets, e.g. "eras_umsatz_api[haiku]". + +Usage via run_tests.py: + python tests/run_tests.py matrix # all variants × all tests + python tests/run_tests.py matrix/eras_query[haiku] # single combo +""" + +import json +import os +import sys +import urllib.request + +sys.path.insert(0, os.path.dirname(__file__)) + +_api_url = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000/api') +ASSAY_BASE = _api_url.removesuffix('/api') if _api_url.endswith('/api') else _api_url +SERVICE_TOKEN = '7Oorb9S3OpwFyWgm4zi_Tq7GeamefbjjTgooPVPWAwPDOf6B4TvgvQlLbhmT4DjsqBS_D1g' + + +# --- Model variants --- +# Each variant overrides specific node models. Omitted roles keep graph defaults. + +VARIANTS = { + 'gemini-flash': { + # Graph default — this is the baseline + }, + 'haiku': { + 'pa': 'anthropic/claude-haiku-4.5', + 'expert_eras': 'anthropic/claude-haiku-4.5', + 'interpreter': 'anthropic/claude-haiku-4.5', + }, + 'gpt-4o-mini': { + 'pa': 'openai/gpt-4o-mini', + 'expert_eras': 'openai/gpt-4o-mini', + 'interpreter': 'openai/gpt-4o-mini', + }, +} + + +# --- API helper with model overrides --- + +def api_chat(text: str, models: dict = None, timeout: int = 120) -> list[tuple[str, str]]: + """Send message via /api/chat with optional model overrides. Returns SSE events.""" + body = {'content': text} + if models: + body['models'] = models + + payload = json.dumps(body).encode() + req = urllib.request.Request( + f'{ASSAY_BASE}/api/chat', + data=payload, + method='POST', + headers={ + 'Authorization': f'Bearer {SERVICE_TOKEN}', + 'Content-Type': 'application/json', + }, + ) + resp = urllib.request.urlopen(req, timeout=timeout) + output = resp.read().decode('utf-8') + + events = [] + for block in output.split('\n\n'): + event_type, data = '', '' + for line in block.strip().split('\n'): + if line.startswith('event: '): + event_type = line[7:] + elif line.startswith('data: '): + data = line[6:] + if event_type and data: + events.append((event_type, data)) + return events + + +# --- Base tests (same logic, different models) --- + +def _test_eras_query(models: dict): + """ERAS query produces correct SQL path (artikelposition, not geraeteverbraeuche).""" + events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models) + + tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1] and 'query_db' in e[1]] + assert tool_calls, 'no query_db tool call found' + + query_data = json.loads(tool_calls[0][1]) + query = query_data.get('args', {}).get('query', '') + + assert 'artikelposition' in query.lower(), f'query missing artikelposition: {query}' + assert 'geraeteverbraeuche' not in query.lower(), f'query uses wrong table: {query}' + + # Check it completed + done_events = [e for e in events if e[0] == 'done'] + assert done_events, 'no done event' + + +def _test_eras_artifact(models: dict): + """ERAS query produces artifact with data rows.""" + events = api_chat('Zeig mir die 5 groessten Kunden nach Umsatz', models=models) + + artifact_events = [e for e in events if e[0] == 'artifacts'] + assert artifact_events, 'no artifact event' + + artifacts = json.loads(artifact_events[0][1]).get('artifacts', []) + assert len(artifacts) >= 1, f'expected >=1 artifact, got {len(artifacts)}' + + has_data = any( + art.get('data', {}).get('fields') or art.get('data', {}).get('rows') + for art in artifacts + ) + assert has_data, 'no artifact contains data' + + +def _test_social_reflex(models: dict): + """Social greeting takes reflex path (fast, no expert).""" + events = api_chat('Hallo!', models=models) + + # Should get a response (delta events) + deltas = [e for e in events if e[0] == 'delta'] + assert deltas, 'no delta events' + + # Should complete + done = [e for e in events if e[0] == 'done'] + assert done, 'no done event' + + # Should NOT call any tools + tool_calls = [e for e in events if e[0] == 'hud' and 'tool_call' in e[1]] + assert not tool_calls, f'reflex path should not use tools, got {len(tool_calls)} calls' + + +# --- Test registry: base tests that get multiplied by variants --- + +BASE_TESTS = { + 'eras_query': _test_eras_query, + 'eras_artifact': _test_eras_artifact, + 'social_reflex': _test_social_reflex, +} + + +def get_matrix_tests() -> dict: + """Generate test×variant matrix. Returns {name: callable} dict for run_tests.py.""" + tests = {} + for variant_name, models in VARIANTS.items(): + for test_name, test_fn in BASE_TESTS.items(): + combo_name = f'{test_name}[{variant_name}]' + # Capture current values in closure + tests[combo_name] = (lambda fn, m: lambda: fn(m))(test_fn, models) + return tests