#!/usr/bin/env python3 """ WebSocket 长连接并发压力测试 使用 asyncio + websockets 模拟大量并发 WebSocket 连接, 测试连接成功率、消息延迟、断连率及服务器资源消耗。 运行方式: python websocket_stress.py [--host HOST] [--port PORT] [--clients N] [--duration S] 示例: python websocket_stress.py --clients 1000 --duration 300 """ import asyncio import argparse import json import time import random import statistics import resource import sys from datetime import datetime from collections import defaultdict try: import websockets from websockets.asyncio.client import connect as ws_connect except ImportError: try: import websockets from websockets import connect as ws_connect except ImportError: print("请先安装 websockets: pip install websockets") sys.exit(1) # ==================== 配置 ==================== DEFAULT_HOST = "localhost" DEFAULT_PORT = 8765 DEFAULT_CLIENTS = 100 DEFAULT_DURATION = 120 # 秒 DEFAULT_MESSAGE_INTERVAL = 2 # 秒 WS_PATH = "/ws" # 梯度测试配置 STEP_CONFIGS = [ {"clients": 100, "duration": 60}, {"clients": 500, "duration": 60}, {"clients": 1000, "duration": 60}, {"clients": 5000, "duration": 60}, ] # ==================== 统计 ==================== class Stats: """线程安全的统计收集器""" def __init__(self): self.connected = 0 self.disconnected = 0 self.messages_sent = 0 self.messages_received = 0 self.errors = 0 self.latencies = [] self.connect_times = [] self._lock = asyncio.Lock() async def record_connect(self, connect_time_ms): async with self._lock: self.connected += 1 self.connect_times.append(connect_time_ms) async def record_disconnect(self): async with self._lock: self.disconnected += 1 async def record_send(self): async with self._lock: self.messages_sent += 1 async def record_receive(self, latency_ms): async with self._lock: self.messages_received += 1 self.latencies.append(latency_ms) async def record_error(self): async with self._lock: self.errors += 1 def summary(self): total_attempted = self.connected + self.errors success_rate = (self.connected / total_attempted * 100) if total_attempted > 0 else 0 latency_stats = {} if self.latencies: latency_stats = { "min_ms": round(min(self.latencies), 2), "max_ms": round(max(self.latencies), 2), "avg_ms": round(statistics.mean(self.latencies), 2), "median_ms": round(statistics.median(self.latencies), 2), "p95_ms": round(sorted(self.latencies)[int(len(self.latencies) * 0.95)], 2) if len(self.latencies) > 1 else round(self.latencies[0], 2), "p99_ms": round(sorted(self.latencies)[int(len(self.latencies) * 0.99)], 2) if len(self.latencies) > 1 else round(self.latencies[0], 2), } connect_time_stats = {} if self.connect_times: connect_time_stats = { "min_ms": round(min(self.connect_times), 2), "avg_ms": round(statistics.mean(self.connect_times), 2), "max_ms": round(max(self.connect_times), 2), } return { "total_connections": total_attempted, "successful_connections": self.connected, "failed_connections": self.errors, "connection_success_rate": f"{success_rate:.1f}%", "disconnections": self.disconnected, "messages_sent": self.messages_sent, "messages_received": self.messages_received, "latency": latency_stats, "connect_time": connect_time_stats, } # ==================== 客户端模拟 ==================== def generate_message(): """生成模拟的传感器上报消息""" return json.dumps({ "type": "sensor_data", "sensorId": random.randint(1, 5000), "timestamp": datetime.now().isoformat(), "data": { "pressure": round(random.uniform(0.1, 0.8), 3), "flow": round(random.uniform(10, 500), 2), "temperature": round(random.uniform(5, 35), 1), }, }) async def client_worker(client_id, uri, stats, duration, message_interval): """单个 WebSocket 客户端工作线程""" start_time = time.time() connect_start = time.time() try: async with ws_connect(uri) as ws: connect_time_ms = (time.time() - connect_start) * 1000 await stats.record_connect(connect_time_ms) # 发送订阅消息 subscribe_msg = json.dumps({ "type": "subscribe", "channels": [f"sensor_{random.randint(1, 100)}"], }) await ws.send(subscribe_msg) # 定期发送消息 while time.time() - start_time < duration: try: msg = generate_message() send_time = time.time() await ws.send(msg) await stats.record_send() # 尝试接收消息(非阻塞) try: response = await asyncio.wait_for(ws.recv(), timeout=1.0) latency_ms = (time.time() - send_time) * 1000 await stats.record_receive(latency_ms) except asyncio.TimeoutError: pass # 没有响应是正常的 except Exception: pass await asyncio.sleep(message_interval + random.uniform(0, 1)) except Exception: await stats.record_error() break except Exception as e: await stats.record_error() finally: await stats.record_disconnect() # ==================== 资源监控 ==================== async def monitor_resources(duration, interval=5): """监控系统资源""" resource_stats = { "cpu_percent": [], "memory_mb": [], "timestamps": [], } start = time.time() while time.time() - start < duration: try: usage = resource.getrusage(resource.RUSAGE_SELF) mem_mb = usage.ru_maxrss / 1024 # Linux: KB → MB resource_stats["memory_mb"].append(round(mem_mb, 2)) resource_stats["timestamps"].append(round(time.time() - start, 1)) except Exception: pass await asyncio.sleep(interval) return resource_stats # ==================== 主测试函数 ==================== async def run_stress_test(host, port, num_clients, duration, message_interval): """运行压力测试""" uri = f"ws://{host}:{port}{WS_PATH}" stats = Stats() print(f"\n{'=' * 60}") print(f"🔌 WebSocket 压力测试") print(f"{'=' * 60}") print(f"目标: {uri}") print(f"并发连接数: {num_clients}") print(f"测试时长: {duration}s") print(f"消息间隔: {message_interval}s") print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"{'=' * 60}\n") # 启动所有客户端(分批启动避免瞬时压力) tasks = [] batch_size = max(10, num_clients // 10) resource_task = asyncio.create_task(monitor_resources(duration + 30)) for i in range(num_clients): task = asyncio.create_task( client_worker(i, uri, stats, duration, message_interval) ) tasks.append(task) # 分批启动 if (i + 1) % batch_size == 0: await asyncio.sleep(0.5) print(f"✅ 已启动 {num_clients} 个 WebSocket 客户端") # 等待测试完成 await asyncio.sleep(duration + 5) # 额外 5 秒等待收尾 # 取消剩余任务 for task in tasks: task.cancel() # 等待资源监控结束 try: resource_data = await asyncio.wait_for(resource_task, timeout=10) except asyncio.TimeoutError: resource_data = {"memory_mb": [], "timestamps": []} # 输出结果 summary = stats.summary() print(f"\n{'=' * 60}") print(f"📊 WebSocket 压力测试结果") print(f"{'=' * 60}") print(f"并发连接数: {num_clients}") print(f"成功连接: {summary['successful_connections']}") print(f"失败连接: {summary['failed_connections']}") print(f"连接成功率: {summary['connection_success_rate']}") print(f"断连次数: {summary['disconnections']}") print(f"消息发送: {summary['messages_sent']}") print(f"消息接收: {summary['messages_received']}") if summary['latency']: print(f"\n消息延迟:") print(f" 最小: {summary['latency']['min_ms']}ms") print(f" 最大: {summary['latency']['max_ms']}ms") print(f" 平均: {summary['latency']['avg_ms']}ms") print(f" 中位数: {summary['latency']['median_ms']}ms") print(f" P95: {summary['latency']['p95_ms']}ms") print(f" P99: {summary['latency']['p99_ms']}ms") if summary['connect_time']: print(f"\n连接建立时间:") print(f" 最小: {summary['connect_time']['min_ms']}ms") print(f" 平均: {summary['connect_time']['avg_ms']}ms") print(f" 最大: {summary['connect_time']['max_ms']}ms") if resource_data.get("memory_mb"): print(f"\n资源消耗:") print(f" 峰值内存: {max(resource_data['memory_mb']):.2f} MB") print(f" 平均内存: {statistics.mean(resource_data['memory_mb']):.2f} MB") print(f"{'=' * 60}\n") return summary async def run_step_test(host, port, message_interval): """运行梯度负载测试""" print("\n" + "=" * 60) print("📈 WebSocket 梯度负载测试") print("=" * 60) results = [] for config in STEP_CONFIGS: result = await run_stress_test( host, port, config["clients"], config["duration"], message_interval ) results.append({ "clients": config["clients"], **result, }) await asyncio.sleep(10) # 阶段间冷却 return results # ==================== 入口 ==================== def main(): parser = argparse.ArgumentParser(description="WebSocket 压力测试") parser.add_argument("--host", default=DEFAULT_HOST, help=f"目标主机 (默认: {DEFAULT_HOST})") parser.add_argument("--port", type=int, default=DEFAULT_PORT, help=f"目标端口 (默认: {DEFAULT_PORT})") parser.add_argument("--clients", type=int, default=DEFAULT_CLIENTS, help=f"并发连接数 (默认: {DEFAULT_CLIENTS})") parser.add_argument("--duration", type=int, default=DEFAULT_DURATION, help=f"测试时长秒 (默认: {DEFAULT_DURATION})") parser.add_argument("--interval", type=float, default=DEFAULT_MESSAGE_INTERVAL, help=f"消息间隔秒 (默认: {DEFAULT_MESSAGE_INTERVAL})") parser.add_argument("--step", action="store_true", help="运行梯度负载测试 (100/500/1000/5000)") parser.add_argument("--output", default=None, help="结果输出 JSON 文件路径") args = parser.parse_args() if args.step: results = asyncio.run(run_step_test(args.host, args.port, args.interval)) else: results = asyncio.run( run_stress_test(args.host, args.port, args.clients, args.duration, args.interval) ) # 保存结果 if args.output: with open(args.output, "w") as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"结果已保存到: {args.output}") return results if __name__ == "__main__": main()