From a699b53e413c230f26c4a696735297c6af480e19 Mon Sep 17 00:00:00 2001 From: Stardream Date: Sat, 21 Mar 2026 10:57:25 +1000 Subject: [PATCH] feat(stress_test): replace hardcoded config with interactive prompts All parameters (URL, concurrency, total requests, timeout, rate limit, log file, stop keywords) are now collected at startup via stdin. Target URL is required with no default; other params retain sensible defaults. Remove hardcoded target URL from source. Co-Authored-By: Claude Sonnet 4.6 --- py/stress_test.py | 270 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 py/stress_test.py diff --git a/py/stress_test.py b/py/stress_test.py new file mode 100644 index 0000000..03de228 --- /dev/null +++ b/py/stress_test.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +""" +API 压力测试脚本 +方法: GET (与浏览器一致) +""" + +import asyncio +import aiohttp +import time +import json +import sys +from datetime import datetime +from collections import Counter + +HEADERS = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36 Edg/145.0.0.0", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", + "Accept-Language": "en,zh-CN;q=0.9,zh;q=0.8", + "Cache-Control": "no-cache", + "Pragma": "no-cache", + "Sec-Fetch-Dest": "document", + "Sec-Fetch-Mode": "navigate", + "Sec-Fetch-Site": "none", +} + + +def prompt_config(): + print("=" * 60) + print(" API 压力测试 — 参数配置(直接回车使用默认值)") + print("=" * 60) + + def ask(prompt, default, cast=str): + raw = input(f" {prompt} [{default}]: ").strip() + return cast(raw) if raw else default + + while True: + target_url = input(" 目标 URL: ").strip() + if target_url: + break + print(" 目标 URL 不能为空,请重新输入。") + concurrency = ask("并发数", 10, int) + total_reqs = ask("总请求数 (0=无限/Ctrl+C 停止)", 0, int) + timeout_sec = ask("超时秒数", 5, int) + rate_per_sec = ask("速率限制 (次/秒)", 1, int) + log_file = ask("日志文件名", "stress_test_log.jsonl") + + kw_raw = input(" 停止关键词 (逗号分隔) [恭喜,机器人]: ").strip() + stop_keywords = [k.strip() for k in kw_raw.split(",") if k.strip()] if kw_raw else ["恭喜", "机器人"] + + print() + return target_url, concurrency, total_reqs, timeout_sec, rate_per_sec, log_file, stop_keywords + +stats = { + "sent": 0, + "success": 0, + "failed": 0, + "total_ms": 0.0, + "status_counter": Counter(), + "responses": [], +} + +lock = asyncio.Lock() +stop_event = asyncio.Event() # 触发后所有 worker 停止 +triggered_info = {} # 记录触发停止的那条响应 + + +class RateLimiter: + def __init__(self, rate: int): + self.rate = rate + self.tokens = float(rate) + self.last = time.perf_counter() + self._lock = asyncio.Lock() + + async def acquire(self): + async with self._lock: + now = time.perf_counter() + self.tokens += (now - self.last) * self.rate + self.last = now + if self.tokens > self.rate: + self.tokens = float(self.rate) + if self.tokens < 1.0: + wait = (1.0 - self.tokens) / self.rate + await asyncio.sleep(wait) + self.tokens = 0.0 + else: + self.tokens -= 1.0 + + +async def single_request(session: aiohttp.ClientSession, req_id: int, log_fh, + target_url: str, stop_keywords: list): + t0 = time.perf_counter() + result = { + "id": req_id, + "timestamp": datetime.utcnow().isoformat(), + "status": None, + "elapsed_ms": None, + "body": None, + "error": None, + } + try: + async with session.get(target_url, headers=HEADERS) as resp: + body = await resp.text(errors="replace") + elapsed = (time.perf_counter() - t0) * 1000 + result.update(status=resp.status, elapsed_ms=round(elapsed, 2), body=body) + + async with lock: + stats["sent"] += 1 + stats["success"] += 1 + stats["total_ms"] += elapsed + stats["status_counter"][resp.status] += 1 + if body not in stats["responses"]: + stats["responses"].append(body) + print(f"\n ★ 新响应体 [{len(stats['responses'])}]: {body[:200]}") + + # 关键词检测 + for kw in stop_keywords: + if kw in body: + print(f"\n 🎯 检测到关键词「{kw}」,正在停止...") + print(f" 完整响应: {body}") + triggered_info["keyword"] = kw + triggered_info["body"] = body + triggered_info["id"] = req_id + triggered_info["timestamp"] = result["timestamp"] + stop_event.set() + break + + except Exception as e: + elapsed = (time.perf_counter() - t0) * 1000 + result.update(elapsed_ms=round(elapsed, 2), error=str(e)) + async with lock: + stats["sent"] += 1 + stats["failed"] += 1 + + log_fh.write(json.dumps(result, ensure_ascii=False) + "\n") + log_fh.flush() + return result + + +async def worker(queue: asyncio.Queue, session: aiohttp.ClientSession, + log_fh, limiter: RateLimiter, target_url: str, stop_keywords: list): + while not stop_event.is_set(): + try: + req_id = queue.get_nowait() + except asyncio.QueueEmpty: + await asyncio.sleep(0.05) + continue + if req_id is None: + queue.task_done() + break + await limiter.acquire() + if stop_event.is_set(): + queue.task_done() + break + await single_request(session, req_id, log_fh, target_url, stop_keywords) + queue.task_done() + + +def print_progress(): + s = stats + if s["sent"] == 0: + return + avg = s["total_ms"] / max(s["success"], 1) + status_str = " | ".join(f"HTTP {k}: {v}" for k, v in sorted(s["status_counter"].items())) + print( + f"\r[{datetime.now().strftime('%H:%M:%S')}] " + f"发送:{s['sent']} 成功:{s['success']} 失败:{s['failed']} " + f"均耗时:{avg:.0f}ms {status_str} ", + end="", flush=True, + ) + + +async def progress_loop(): + while not stop_event.is_set(): + await asyncio.sleep(1) + print_progress() + + +async def run(target_url, concurrency, total_reqs, timeout_sec, rate_per_sec, log_file, stop_keywords): + connector = aiohttp.TCPConnector(limit=concurrency, ssl=False) + timeout = aiohttp.ClientTimeout(total=timeout_sec) + limiter = RateLimiter(rate_per_sec) + + print("=" * 60) + print(f" 目标 URL : {target_url}") + print(f" 请求方法 : GET") + print(f" 并发数 : {concurrency}") + print(f" 速率限制 : {rate_per_sec} 次/秒") + print(f" 总请求数 : {'∞ (Ctrl+C 停止)' if total_reqs == 0 else total_reqs}") + print(f" 停止关键词 : {stop_keywords}") + print(f" 超时 : {timeout_sec}s") + print(f" 日志文件 : {log_file}") + print("=" * 60 + "\n") + + queue = asyncio.Queue(maxsize=concurrency * 2) + prog = None + + with open(log_file, "w", encoding="utf-8") as log_fh: + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + workers = [ + asyncio.create_task(worker(queue, session, log_fh, limiter, target_url, stop_keywords)) + for _ in range(concurrency) + ] + prog = asyncio.create_task(progress_loop()) + + # 持续往队列塞任务,直到 stop_event 或 Ctrl+C + try: + req_id = 0 + infinite = (total_reqs == 0) + limit = total_reqs if not infinite else float("inf") + + while req_id < limit and not stop_event.is_set(): + try: + queue.put_nowait(req_id) + req_id += 1 + except asyncio.QueueFull: + await asyncio.sleep(0.05) + + if not infinite and not stop_event.is_set(): + for _ in range(concurrency): + await queue.put(None) + await queue.join() + else: + # 等待 stop_event(关键词触发或 Ctrl+C) + await stop_event.wait() + + except (KeyboardInterrupt, asyncio.CancelledError): + stop_event.set() + finally: + if prog: + prog.cancel() + # 清空队列 + while not queue.empty(): + try: + queue.get_nowait() + queue.task_done() + except Exception: + break + await asyncio.gather(*workers, return_exceptions=True) + + print_progress() + s = stats + print("\n\n" + "=" * 60) + print(" 压测结束 — 汇总") + print("=" * 60) + print(f" 总发送 : {s['sent']}") + print(f" 成功 : {s['success']}") + print(f" 失败 : {s['failed']}") + if s["success"]: + print(f" 平均耗时 : {s['total_ms']/s['success']:.1f} ms") + print(f" 状态码 : {dict(s['status_counter'])}") + print(f"\n ── 全部不同响应体(共 {len(s['responses'])} 种)──") + for i, body in enumerate(s["responses"], 1): + print(f" [{i}] {body[:400]}") + print(f"\n 详细日志 → {log_file}") + print("=" * 60) + + if triggered_info: + print(f"\n {'='*60}") + print(f" 🎯 触发停止的响应(第 {triggered_info['id']} 条 / {triggered_info['timestamp']})") + print(f" 关键词 : 「{triggered_info['keyword']}」") + print(f" 内容 : {triggered_info['body']}") + print(f" {'='*60}") + + +if __name__ == "__main__": + try: + cfg = prompt_config() + asyncio.run(run(*cfg)) + except KeyboardInterrupt: + sys.exit(0)