- run_tests.py: ThreadPoolExecutor runs N tests concurrently within a suite - Each testcase has its own session_id so parallel is safe - Engine tests: fixed asyncio.new_event_loop() for thread safety - Usage: python tests/run_tests.py testcases --parallel=3 - Wall time reduction: ~3x for testcases (15min → 5min with parallel=3) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
352 lines
12 KiB
Python
352 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Test orchestrator — runs test suites and posts results to dev assay.
|
||
|
||
Usage:
|
||
python tests/run_tests.py # all suites
|
||
python tests/run_tests.py api # one suite
|
||
python tests/run_tests.py matrix/eras_query[haiku] # single test
|
||
python tests/run_tests.py matrix --repeat=3 # each test 3x, report avg/p50/p95
|
||
python tests/run_tests.py testcases --parallel=3 # 3 testcases concurrently
|
||
python tests/run_tests.py api/health roundtrip/full_chat # multiple tests
|
||
|
||
Test names: suite/name (without the suite prefix in the test registry).
|
||
engine tests: graph_load, node_instantiation, edge_types_complete,
|
||
condition_reflex, condition_tool_output,
|
||
frame_trace_reflex, frame_trace_expert, frame_trace_expert_with_interpreter
|
||
api tests: health, eras_umsatz_api, eras_umsatz_artifact
|
||
matrix tests: eras_query[variant], eras_artifact[variant], social_reflex[variant]
|
||
variants: gemini-flash, haiku, gpt-4o-mini
|
||
testcases: fast, reflex_path, expert_eras, domain_context, ... (from testcases/*.md)
|
||
roundtrip tests: nyx_loads, inject_artifact, inject_message, full_chat, full_eras
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
import urllib.request
|
||
import uuid
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from datetime import datetime, timezone
|
||
from dataclasses import dataclass, field, asdict
|
||
|
||
RESULTS_ENDPOINT = os.environ.get('RESULTS_ENDPOINT', '')
|
||
RUN_ID = os.environ.get('RUN_ID', str(uuid.uuid4())[:8])
|
||
|
||
|
||
def _now_iso() -> str:
|
||
return datetime.now(timezone.utc).isoformat()
|
||
|
||
|
||
@dataclass
|
||
class TestResult:
|
||
run_id: str
|
||
test: str
|
||
suite: str
|
||
status: str # 'pass', 'fail', 'running', 'error'
|
||
duration_ms: float = 0
|
||
error: str = ''
|
||
ts: str = ''
|
||
stats: dict = field(default_factory=dict) # {runs, min_ms, avg_ms, p50_ms, p95_ms, max_ms, pass_rate}
|
||
|
||
|
||
def post_result(result: TestResult):
|
||
"""Post a single test result to the dev assay endpoint."""
|
||
print(json.dumps(asdict(result)), flush=True)
|
||
if not RESULTS_ENDPOINT:
|
||
return
|
||
try:
|
||
payload = json.dumps(asdict(result)).encode()
|
||
req = urllib.request.Request(
|
||
RESULTS_ENDPOINT,
|
||
data=payload,
|
||
headers={'Content-Type': 'application/json'},
|
||
)
|
||
urllib.request.urlopen(req, timeout=5)
|
||
except Exception as e:
|
||
print(f' [warn] failed to post result: {e}', file=sys.stderr)
|
||
|
||
|
||
def run_test(name: str, suite: str, fn) -> TestResult:
|
||
"""Run a single test function and return the result."""
|
||
result = TestResult(run_id=RUN_ID, test=name, suite=suite, status='running', ts=_now_iso())
|
||
post_result(result)
|
||
|
||
start = time.time()
|
||
try:
|
||
fn()
|
||
result.status = 'pass'
|
||
except AssertionError as e:
|
||
result.status = 'fail'
|
||
result.error = str(e)
|
||
except Exception as e:
|
||
result.status = 'error'
|
||
result.error = f'{type(e).__name__}: {e}'
|
||
result.duration_ms = round((time.time() - start) * 1000)
|
||
result.ts = _now_iso()
|
||
|
||
post_result(result)
|
||
return result
|
||
|
||
|
||
def get_api_tests() -> dict:
|
||
"""Load API tests from e2e_harness.py."""
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
import e2e_harness
|
||
e2e_harness.ASSAY_BASE = os.environ.get('ASSAY_API', 'http://assay-runtime-test:8000').rstrip('/api')
|
||
# Skip browser-dependent tests
|
||
return {k: v for k, v in e2e_harness.TESTS.items() if 'takeover' not in k and 'panes' not in k}
|
||
|
||
|
||
def get_roundtrip_tests() -> dict:
|
||
"""Load Playwright roundtrip tests."""
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
from test_roundtrip import TESTS
|
||
return TESTS
|
||
|
||
|
||
def get_engine_tests() -> dict:
|
||
"""Load engine-level tests (no LLM, no network)."""
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
from test_engine import TESTS
|
||
return TESTS
|
||
|
||
|
||
def get_matrix_tests() -> dict:
|
||
"""Load model matrix tests (real LLM calls, test×variant combos)."""
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
from test_matrix import get_matrix_tests
|
||
return get_matrix_tests()
|
||
|
||
|
||
def get_testcase_tests() -> dict:
|
||
"""Load markdown testcases from testcases/ (integration tests, real LLM)."""
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
from test_testcases import get_testcase_tests
|
||
return get_testcase_tests()
|
||
|
||
|
||
SUITES = {
|
||
'engine': get_engine_tests,
|
||
'api': get_api_tests,
|
||
'matrix': get_matrix_tests,
|
||
'testcases': get_testcase_tests,
|
||
'roundtrip': get_roundtrip_tests,
|
||
}
|
||
|
||
|
||
def _compute_stats(durations: list[float], passed: int, total: int) -> dict:
|
||
"""Compute timing stats from a list of durations."""
|
||
if not durations:
|
||
return {}
|
||
durations.sort()
|
||
n = len(durations)
|
||
return {
|
||
'runs': total,
|
||
'passed': passed,
|
||
'pass_rate': round(100 * passed / total) if total else 0,
|
||
'min_ms': round(durations[0]),
|
||
'avg_ms': round(sum(durations) / n),
|
||
'p50_ms': round(durations[n // 2]),
|
||
'p95_ms': round(durations[min(int(n * 0.95), n - 1)]),
|
||
'max_ms': round(durations[-1]),
|
||
}
|
||
|
||
|
||
def run_test_repeated(name: str, suite: str, fn, repeat: int) -> TestResult:
|
||
"""Run a test N times, aggregate timing stats into one result."""
|
||
# Post running status
|
||
result = TestResult(run_id=RUN_ID, test=name, suite=suite, status='running', ts=_now_iso())
|
||
post_result(result)
|
||
|
||
durations = []
|
||
passed_count = 0
|
||
last_error = ''
|
||
|
||
for i in range(repeat):
|
||
start = time.time()
|
||
try:
|
||
fn()
|
||
elapsed = round((time.time() - start) * 1000)
|
||
durations.append(elapsed)
|
||
passed_count += 1
|
||
except (AssertionError, Exception) as e:
|
||
elapsed = round((time.time() - start) * 1000)
|
||
durations.append(elapsed)
|
||
last_error = str(e)[:200]
|
||
|
||
stats = _compute_stats(durations, passed_count, repeat)
|
||
result.stats = stats
|
||
result.duration_ms = stats.get('avg_ms', 0)
|
||
result.status = 'pass' if passed_count == repeat else ('fail' if passed_count > 0 else 'error')
|
||
result.error = f'{stats["pass_rate"]}% pass, avg={stats["avg_ms"]}ms p50={stats["p50_ms"]}ms p95={stats["p95_ms"]}ms'
|
||
if last_error and passed_count < repeat:
|
||
result.error += f' | last err: {last_error}'
|
||
result.ts = _now_iso()
|
||
post_result(result)
|
||
return result
|
||
|
||
|
||
def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int, int]:
|
||
"""Parse CLI args into (suite_filter, test_filter, repeat, parallel).
|
||
|
||
Supports: --repeat=N, --parallel=N
|
||
|
||
Returns:
|
||
suite_filter: set of suite names, or None for all suites
|
||
test_filter: set of 'suite/test' names (empty = run all in suite)
|
||
repeat: number of times to run each test (default 1)
|
||
parallel: max concurrent tests (default 1 = sequential)
|
||
"""
|
||
repeat = 1
|
||
parallel = 1
|
||
filtered_args = []
|
||
skip_next = False
|
||
for i, arg in enumerate(args):
|
||
if skip_next:
|
||
skip_next = False
|
||
continue
|
||
if arg.startswith('--repeat='):
|
||
repeat = int(arg.split('=', 1)[1])
|
||
elif arg == '--repeat' and i + 1 < len(args):
|
||
repeat = int(args[i + 1])
|
||
skip_next = True
|
||
elif arg.startswith('--parallel='):
|
||
parallel = int(arg.split('=', 1)[1])
|
||
elif arg == '--parallel' and i + 1 < len(args):
|
||
parallel = int(args[i + 1])
|
||
skip_next = True
|
||
else:
|
||
filtered_args.append(arg)
|
||
|
||
if not filtered_args:
|
||
return None, set(), repeat, parallel
|
||
|
||
suites = set()
|
||
tests = set()
|
||
for arg in filtered_args:
|
||
if '/' in arg:
|
||
tests.add(arg)
|
||
suites.add(arg.split('/')[0])
|
||
else:
|
||
suites.add(arg)
|
||
return suites, tests, repeat, parallel
|
||
|
||
|
||
def _run_one(name: str, suite_name: str, fn, repeat: int) -> TestResult:
|
||
"""Run a single test (with optional repeat). Thread-safe."""
|
||
if repeat > 1:
|
||
return run_test_repeated(name, suite_name, fn, repeat)
|
||
return run_test(name, suite_name, fn)
|
||
|
||
|
||
def _print_result(suite_name: str, name: str, r: TestResult, repeat: int):
|
||
"""Print a test result line."""
|
||
status = 'PASS' if r.status == 'pass' else 'FAIL'
|
||
if repeat > 1:
|
||
stats = r.stats
|
||
print(f' [{status}] {suite_name}/{name} ×{repeat} '
|
||
f'(avg={stats.get("avg_ms", 0)}ms p50={stats.get("p50_ms", 0)}ms '
|
||
f'p95={stats.get("p95_ms", 0)}ms pass={stats.get("pass_rate", 0)}%)', flush=True)
|
||
else:
|
||
print(f' [{status}] {suite_name}/{name} ({r.duration_ms:.0f}ms)', flush=True)
|
||
if r.error and repeat == 1:
|
||
print(f' {r.error[:200]}', flush=True)
|
||
|
||
|
||
def run_suite(suite_name: str, tests: dict, test_filter: set[str],
|
||
repeat: int = 1, parallel: int = 1) -> list[TestResult]:
|
||
"""Run tests from a suite, optionally filtered, repeated, and parallelized."""
|
||
# Build filtered test list
|
||
filtered = []
|
||
for name, fn in tests.items():
|
||
full_name = f'{suite_name}/{name}'
|
||
short_name = name.replace(f'{suite_name}_', '')
|
||
if test_filter and full_name not in test_filter and f'{suite_name}/{short_name}' not in test_filter:
|
||
continue
|
||
filtered.append((name, fn))
|
||
|
||
if not filtered:
|
||
return []
|
||
|
||
# Sequential execution
|
||
if parallel <= 1 or len(filtered) <= 1:
|
||
results = []
|
||
for name, fn in filtered:
|
||
r = _run_one(name, suite_name, fn, repeat)
|
||
_print_result(suite_name, name, r, repeat)
|
||
results.append(r)
|
||
return results
|
||
|
||
# Parallel execution
|
||
results = []
|
||
with ThreadPoolExecutor(max_workers=parallel) as pool:
|
||
futures = {}
|
||
for name, fn in filtered:
|
||
f = pool.submit(_run_one, name, suite_name, fn, repeat)
|
||
futures[f] = name
|
||
|
||
for future in as_completed(futures):
|
||
name = futures[future]
|
||
try:
|
||
r = future.result()
|
||
except Exception as e:
|
||
r = TestResult(run_id=RUN_ID, test=name, suite=suite_name,
|
||
status='error', error=f'ThreadError: {e}', ts=_now_iso())
|
||
_print_result(suite_name, name, r, repeat)
|
||
results.append(r)
|
||
|
||
return results
|
||
|
||
|
||
def main():
|
||
suite_filter, test_filter, repeat, parallel = parse_args(sys.argv[1:])
|
||
|
||
print(f'=== Test Run {RUN_ID} ===', flush=True)
|
||
if suite_filter:
|
||
print(f'Filter: suites={suite_filter}, tests={test_filter or "all"}', flush=True)
|
||
if repeat > 1:
|
||
print(f'Repeat: {repeat}x per test', flush=True)
|
||
if parallel > 1:
|
||
print(f'Parallel: {parallel} concurrent tests', flush=True)
|
||
print(f'ASSAY_API: {os.environ.get("ASSAY_API", "not set")}', flush=True)
|
||
print(f'NYX_URL: {os.environ.get("NYX_URL", "not set")}', flush=True)
|
||
print(flush=True)
|
||
|
||
all_results = []
|
||
|
||
for suite_name, loader in SUITES.items():
|
||
if suite_filter and suite_name not in suite_filter:
|
||
continue
|
||
label = suite_name
|
||
if repeat > 1:
|
||
label += f' ×{repeat}'
|
||
if parallel > 1:
|
||
label += f' ∥{parallel}'
|
||
print(f'--- {label} ---', flush=True)
|
||
tests = loader()
|
||
all_results.extend(run_suite(suite_name, tests, test_filter, repeat, parallel))
|
||
print(flush=True)
|
||
|
||
# Summary
|
||
passed = sum(1 for r in all_results if r.status == 'pass')
|
||
failed = sum(1 for r in all_results if r.status in ('fail', 'error'))
|
||
total_ms = sum(r.duration_ms for r in all_results)
|
||
print(f'=== {passed} passed, {failed} failed, {len(all_results)} total ({total_ms:.0f}ms) ===', flush=True)
|
||
|
||
if RESULTS_ENDPOINT:
|
||
summary = TestResult(
|
||
run_id=RUN_ID, test='__summary__', suite='summary',
|
||
status='pass' if failed == 0 else 'fail',
|
||
duration_ms=total_ms,
|
||
error=f'{passed} passed, {failed} failed',
|
||
)
|
||
post_result(summary)
|
||
|
||
sys.exit(1 if failed else 0)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|