Add --parallel=N for concurrent test execution

- run_tests.py: ThreadPoolExecutor runs N tests concurrently within a suite
- Each testcase has its own session_id so parallel is safe
- Engine tests: fixed asyncio.new_event_loop() for thread safety
- Usage: python tests/run_tests.py testcases --parallel=3
- Wall time reduction: ~3x for testcases (15min → 5min with parallel=3)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Nico 2026-04-03 20:01:06 +02:00
parent c21ff08211
commit d8e832d2d4
2 changed files with 84 additions and 32 deletions

View File

@ -7,6 +7,7 @@ Usage:
python tests/run_tests.py api # one suite
python tests/run_tests.py matrix/eras_query[haiku] # single test
python tests/run_tests.py matrix --repeat=3 # each test 3x, report avg/p50/p95
python tests/run_tests.py testcases --parallel=3 # 3 testcases concurrently
python tests/run_tests.py api/health roundtrip/full_chat # multiple tests
Test names: suite/name (without the suite prefix in the test registry).
@ -26,6 +27,7 @@ import sys
import time
import urllib.request
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timezone
from dataclasses import dataclass, field, asdict
@ -186,17 +188,19 @@ def run_test_repeated(name: str, suite: str, fn, repeat: int) -> TestResult:
return result
def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int]:
"""Parse CLI args into (suite_filter, test_filter, repeat).
def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int, int]:
"""Parse CLI args into (suite_filter, test_filter, repeat, parallel).
Supports: --repeat=N or --repeat N
Supports: --repeat=N, --parallel=N
Returns:
suite_filter: set of suite names, or None for all suites
test_filter: set of 'suite/test' names (empty = run all in suite)
repeat: number of times to run each test (default 1)
parallel: max concurrent tests (default 1 = sequential)
"""
repeat = 1
parallel = 1
filtered_args = []
skip_next = False
for i, arg in enumerate(args):
@ -208,11 +212,16 @@ def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int]:
elif arg == '--repeat' and i + 1 < len(args):
repeat = int(args[i + 1])
skip_next = True
elif arg.startswith('--parallel='):
parallel = int(arg.split('=', 1)[1])
elif arg == '--parallel' and i + 1 < len(args):
parallel = int(args[i + 1])
skip_next = True
else:
filtered_args.append(arg)
if not filtered_args:
return None, set(), repeat
return None, set(), repeat, parallel
suites = set()
tests = set()
@ -222,47 +231,85 @@ def parse_args(args: list[str]) -> tuple[set[str] | None, set[str], int]:
suites.add(arg.split('/')[0])
else:
suites.add(arg)
return suites, tests, repeat
return suites, tests, repeat, parallel
def _run_one(name: str, suite_name: str, fn, repeat: int) -> TestResult:
"""Run a single test (with optional repeat). Thread-safe."""
if repeat > 1:
return run_test_repeated(name, suite_name, fn, repeat)
return run_test(name, suite_name, fn)
def _print_result(suite_name: str, name: str, r: TestResult, repeat: int):
"""Print a test result line."""
status = 'PASS' if r.status == 'pass' else 'FAIL'
if repeat > 1:
stats = r.stats
print(f' [{status}] {suite_name}/{name} ×{repeat} '
f'(avg={stats.get("avg_ms", 0)}ms p50={stats.get("p50_ms", 0)}ms '
f'p95={stats.get("p95_ms", 0)}ms pass={stats.get("pass_rate", 0)}%)', flush=True)
else:
print(f' [{status}] {suite_name}/{name} ({r.duration_ms:.0f}ms)', flush=True)
if r.error and repeat == 1:
print(f' {r.error[:200]}', flush=True)
def run_suite(suite_name: str, tests: dict, test_filter: set[str],
repeat: int = 1) -> list[TestResult]:
"""Run tests from a suite, optionally filtered and repeated."""
results = []
repeat: int = 1, parallel: int = 1) -> list[TestResult]:
"""Run tests from a suite, optionally filtered, repeated, and parallelized."""
# Build filtered test list
filtered = []
for name, fn in tests.items():
# Apply test filter if specified
full_name = f'{suite_name}/{name}'
# Strip suite prefix for matching (roundtrip/full_eras matches roundtrip_full_eras)
short_name = name.replace(f'{suite_name}_', '')
if test_filter and full_name not in test_filter and f'{suite_name}/{short_name}' not in test_filter:
continue
filtered.append((name, fn))
if repeat > 1:
r = run_test_repeated(name, suite_name, fn, repeat)
status = 'PASS' if r.status == 'pass' else 'FAIL'
stats = r.stats
print(f' [{status}] {suite_name}/{name} ×{repeat} '
f'(avg={stats.get("avg_ms", 0)}ms p50={stats.get("p50_ms", 0)}ms '
f'p95={stats.get("p95_ms", 0)}ms pass={stats.get("pass_rate", 0)}%)', flush=True)
else:
r = run_test(name, suite_name, fn)
status = 'PASS' if r.status == 'pass' else 'FAIL'
print(f' [{status}] {suite_name}/{name} ({r.duration_ms:.0f}ms)', flush=True)
if not filtered:
return []
# Sequential execution
if parallel <= 1 or len(filtered) <= 1:
results = []
for name, fn in filtered:
r = _run_one(name, suite_name, fn, repeat)
_print_result(suite_name, name, r, repeat)
results.append(r)
return results
# Parallel execution
results = []
with ThreadPoolExecutor(max_workers=parallel) as pool:
futures = {}
for name, fn in filtered:
f = pool.submit(_run_one, name, suite_name, fn, repeat)
futures[f] = name
for future in as_completed(futures):
name = futures[future]
try:
r = future.result()
except Exception as e:
r = TestResult(run_id=RUN_ID, test=name, suite=suite_name,
status='error', error=f'ThreadError: {e}', ts=_now_iso())
_print_result(suite_name, name, r, repeat)
results.append(r)
results.append(r)
if r.error and repeat == 1:
print(f' {r.error[:200]}', flush=True)
return results
def main():
suite_filter, test_filter, repeat = parse_args(sys.argv[1:])
suite_filter, test_filter, repeat, parallel = parse_args(sys.argv[1:])
print(f'=== Test Run {RUN_ID} ===', flush=True)
if suite_filter:
print(f'Filter: suites={suite_filter}, tests={test_filter or "all"}', flush=True)
if repeat > 1:
print(f'Repeat: {repeat}x per test', flush=True)
if parallel > 1:
print(f'Parallel: {parallel} concurrent tests', flush=True)
print(f'ASSAY_API: {os.environ.get("ASSAY_API", "not set")}', flush=True)
print(f'NYX_URL: {os.environ.get("NYX_URL", "not set")}', flush=True)
print(flush=True)
@ -272,9 +319,14 @@ def main():
for suite_name, loader in SUITES.items():
if suite_filter and suite_name not in suite_filter:
continue
print(f'--- {suite_name}{" ×" + str(repeat) if repeat > 1 else ""} ---', flush=True)
label = suite_name
if repeat > 1:
label += f' ×{repeat}'
if parallel > 1:
label += f'{parallel}'
print(f'--- {label} ---', flush=True)
tests = loader()
all_results.extend(run_suite(suite_name, tests, test_filter, repeat))
all_results.extend(run_suite(suite_name, tests, test_filter, repeat, parallel))
print(flush=True)
# Summary

View File

@ -407,7 +407,7 @@ def test_frame_trace_reflex():
}
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
result = asyncio.get_event_loop().run_until_complete(
result = asyncio.new_event_loop().run_until_complete(
engine.process_message("hello")
)
@ -434,7 +434,7 @@ def test_frame_trace_expert():
}
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
result = asyncio.get_event_loop().run_until_complete(
result = asyncio.new_event_loop().run_until_complete(
engine.process_message("show top customers")
)
@ -465,7 +465,7 @@ def test_frame_trace_expert_with_interpreter():
}
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
result = asyncio.get_event_loop().run_until_complete(
result = asyncio.new_event_loop().run_until_complete(
engine.process_message("show customer revenue")
)
@ -519,7 +519,7 @@ def test_model_override_per_request():
engine, sink, hud = make_frame_engine(nodes, "v4-eras")
# process_message should accept model_overrides param
result = asyncio.get_event_loop().run_until_complete(
result = asyncio.new_event_loop().run_until_complete(
engine.process_message("hello", model_overrides={"input": "test/fast-model"})
)
# Should complete without error (overrides applied internally)
@ -588,7 +588,7 @@ def test_contextvar_hud_isolation():
async def run_both():
await asyncio.gather(task_a(), task_b())
asyncio.get_event_loop().run_until_complete(run_both())
asyncio.new_event_loop().run_until_complete(run_both())
assert len(results_a) == 1 and results_a[0]["from"] == "a", \
f"task_a HUD leaked: {results_a}"