| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- #!/usr/bin/env python3
- """
- WebSocket 长连接并发压力测试
-
- 使用 asyncio + websockets 模拟大量并发 WebSocket 连接,
- 测试连接成功率、消息延迟、断连率及服务器资源消耗。
-
- 运行方式:
- python websocket_stress.py [--host HOST] [--port PORT] [--clients N] [--duration S]
-
- 示例:
- python websocket_stress.py --clients 1000 --duration 300
- """
-
- import asyncio
- import argparse
- import json
- import time
- import random
- import statistics
- import resource
- import sys
- from datetime import datetime
- from collections import defaultdict
-
- try:
- import websockets
- from websockets.asyncio.client import connect as ws_connect
- except ImportError:
- try:
- import websockets
- from websockets import connect as ws_connect
- except ImportError:
- print("请先安装 websockets: pip install websockets")
- sys.exit(1)
-
-
- # ==================== 配置 ====================
-
- DEFAULT_HOST = "localhost"
- DEFAULT_PORT = 8765
- DEFAULT_CLIENTS = 100
- DEFAULT_DURATION = 120 # 秒
- DEFAULT_MESSAGE_INTERVAL = 2 # 秒
- WS_PATH = "/ws"
-
- # 梯度测试配置
- STEP_CONFIGS = [
- {"clients": 100, "duration": 60},
- {"clients": 500, "duration": 60},
- {"clients": 1000, "duration": 60},
- {"clients": 5000, "duration": 60},
- ]
-
-
- # ==================== 统计 ====================
-
- class Stats:
- """线程安全的统计收集器"""
-
- def __init__(self):
- self.connected = 0
- self.disconnected = 0
- self.messages_sent = 0
- self.messages_received = 0
- self.errors = 0
- self.latencies = []
- self.connect_times = []
- self._lock = asyncio.Lock()
-
- async def record_connect(self, connect_time_ms):
- async with self._lock:
- self.connected += 1
- self.connect_times.append(connect_time_ms)
-
- async def record_disconnect(self):
- async with self._lock:
- self.disconnected += 1
-
- async def record_send(self):
- async with self._lock:
- self.messages_sent += 1
-
- async def record_receive(self, latency_ms):
- async with self._lock:
- self.messages_received += 1
- self.latencies.append(latency_ms)
-
- async def record_error(self):
- async with self._lock:
- self.errors += 1
-
- def summary(self):
- total_attempted = self.connected + self.errors
- success_rate = (self.connected / total_attempted * 100) if total_attempted > 0 else 0
-
- latency_stats = {}
- if self.latencies:
- latency_stats = {
- "min_ms": round(min(self.latencies), 2),
- "max_ms": round(max(self.latencies), 2),
- "avg_ms": round(statistics.mean(self.latencies), 2),
- "median_ms": round(statistics.median(self.latencies), 2),
- "p95_ms": round(sorted(self.latencies)[int(len(self.latencies) * 0.95)], 2)
- if len(self.latencies) > 1 else round(self.latencies[0], 2),
- "p99_ms": round(sorted(self.latencies)[int(len(self.latencies) * 0.99)], 2)
- if len(self.latencies) > 1 else round(self.latencies[0], 2),
- }
-
- connect_time_stats = {}
- if self.connect_times:
- connect_time_stats = {
- "min_ms": round(min(self.connect_times), 2),
- "avg_ms": round(statistics.mean(self.connect_times), 2),
- "max_ms": round(max(self.connect_times), 2),
- }
-
- return {
- "total_connections": total_attempted,
- "successful_connections": self.connected,
- "failed_connections": self.errors,
- "connection_success_rate": f"{success_rate:.1f}%",
- "disconnections": self.disconnected,
- "messages_sent": self.messages_sent,
- "messages_received": self.messages_received,
- "latency": latency_stats,
- "connect_time": connect_time_stats,
- }
-
-
- # ==================== 客户端模拟 ====================
-
- def generate_message():
- """生成模拟的传感器上报消息"""
- return json.dumps({
- "type": "sensor_data",
- "sensorId": random.randint(1, 5000),
- "timestamp": datetime.now().isoformat(),
- "data": {
- "pressure": round(random.uniform(0.1, 0.8), 3),
- "flow": round(random.uniform(10, 500), 2),
- "temperature": round(random.uniform(5, 35), 1),
- },
- })
-
-
- async def client_worker(client_id, uri, stats, duration, message_interval):
- """单个 WebSocket 客户端工作线程"""
- start_time = time.time()
- connect_start = time.time()
-
- try:
- async with ws_connect(uri) as ws:
- connect_time_ms = (time.time() - connect_start) * 1000
- await stats.record_connect(connect_time_ms)
-
- # 发送订阅消息
- subscribe_msg = json.dumps({
- "type": "subscribe",
- "channels": [f"sensor_{random.randint(1, 100)}"],
- })
- await ws.send(subscribe_msg)
-
- # 定期发送消息
- while time.time() - start_time < duration:
- try:
- msg = generate_message()
- send_time = time.time()
- await ws.send(msg)
- await stats.record_send()
-
- # 尝试接收消息(非阻塞)
- try:
- response = await asyncio.wait_for(ws.recv(), timeout=1.0)
- latency_ms = (time.time() - send_time) * 1000
- await stats.record_receive(latency_ms)
- except asyncio.TimeoutError:
- pass # 没有响应是正常的
- except Exception:
- pass
-
- await asyncio.sleep(message_interval + random.uniform(0, 1))
-
- except Exception:
- await stats.record_error()
- break
-
- except Exception as e:
- await stats.record_error()
- finally:
- await stats.record_disconnect()
-
-
- # ==================== 资源监控 ====================
-
- async def monitor_resources(duration, interval=5):
- """监控系统资源"""
- resource_stats = {
- "cpu_percent": [],
- "memory_mb": [],
- "timestamps": [],
- }
-
- start = time.time()
- while time.time() - start < duration:
- try:
- usage = resource.getrusage(resource.RUSAGE_SELF)
- mem_mb = usage.ru_maxrss / 1024 # Linux: KB → MB
- resource_stats["memory_mb"].append(round(mem_mb, 2))
- resource_stats["timestamps"].append(round(time.time() - start, 1))
- except Exception:
- pass
- await asyncio.sleep(interval)
-
- return resource_stats
-
-
- # ==================== 主测试函数 ====================
-
- async def run_stress_test(host, port, num_clients, duration, message_interval):
- """运行压力测试"""
- uri = f"ws://{host}:{port}{WS_PATH}"
- stats = Stats()
-
- print(f"\n{'=' * 60}")
- print(f"🔌 WebSocket 压力测试")
- print(f"{'=' * 60}")
- print(f"目标: {uri}")
- print(f"并发连接数: {num_clients}")
- print(f"测试时长: {duration}s")
- print(f"消息间隔: {message_interval}s")
- print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
- print(f"{'=' * 60}\n")
-
- # 启动所有客户端(分批启动避免瞬时压力)
- tasks = []
- batch_size = max(10, num_clients // 10)
- resource_task = asyncio.create_task(monitor_resources(duration + 30))
-
- for i in range(num_clients):
- task = asyncio.create_task(
- client_worker(i, uri, stats, duration, message_interval)
- )
- tasks.append(task)
-
- # 分批启动
- if (i + 1) % batch_size == 0:
- await asyncio.sleep(0.5)
-
- print(f"✅ 已启动 {num_clients} 个 WebSocket 客户端")
-
- # 等待测试完成
- await asyncio.sleep(duration + 5) # 额外 5 秒等待收尾
-
- # 取消剩余任务
- for task in tasks:
- task.cancel()
-
- # 等待资源监控结束
- try:
- resource_data = await asyncio.wait_for(resource_task, timeout=10)
- except asyncio.TimeoutError:
- resource_data = {"memory_mb": [], "timestamps": []}
-
- # 输出结果
- summary = stats.summary()
- print(f"\n{'=' * 60}")
- print(f"📊 WebSocket 压力测试结果")
- print(f"{'=' * 60}")
- print(f"并发连接数: {num_clients}")
- print(f"成功连接: {summary['successful_connections']}")
- print(f"失败连接: {summary['failed_connections']}")
- print(f"连接成功率: {summary['connection_success_rate']}")
- print(f"断连次数: {summary['disconnections']}")
- print(f"消息发送: {summary['messages_sent']}")
- print(f"消息接收: {summary['messages_received']}")
-
- if summary['latency']:
- print(f"\n消息延迟:")
- print(f" 最小: {summary['latency']['min_ms']}ms")
- print(f" 最大: {summary['latency']['max_ms']}ms")
- print(f" 平均: {summary['latency']['avg_ms']}ms")
- print(f" 中位数: {summary['latency']['median_ms']}ms")
- print(f" P95: {summary['latency']['p95_ms']}ms")
- print(f" P99: {summary['latency']['p99_ms']}ms")
-
- if summary['connect_time']:
- print(f"\n连接建立时间:")
- print(f" 最小: {summary['connect_time']['min_ms']}ms")
- print(f" 平均: {summary['connect_time']['avg_ms']}ms")
- print(f" 最大: {summary['connect_time']['max_ms']}ms")
-
- if resource_data.get("memory_mb"):
- print(f"\n资源消耗:")
- print(f" 峰值内存: {max(resource_data['memory_mb']):.2f} MB")
- print(f" 平均内存: {statistics.mean(resource_data['memory_mb']):.2f} MB")
-
- print(f"{'=' * 60}\n")
-
- return summary
-
-
- async def run_step_test(host, port, message_interval):
- """运行梯度负载测试"""
- print("\n" + "=" * 60)
- print("📈 WebSocket 梯度负载测试")
- print("=" * 60)
-
- results = []
- for config in STEP_CONFIGS:
- result = await run_stress_test(
- host, port, config["clients"], config["duration"], message_interval
- )
- results.append({
- "clients": config["clients"],
- **result,
- })
- await asyncio.sleep(10) # 阶段间冷却
-
- return results
-
-
- # ==================== 入口 ====================
-
- def main():
- parser = argparse.ArgumentParser(description="WebSocket 压力测试")
- parser.add_argument("--host", default=DEFAULT_HOST, help=f"目标主机 (默认: {DEFAULT_HOST})")
- parser.add_argument("--port", type=int, default=DEFAULT_PORT, help=f"目标端口 (默认: {DEFAULT_PORT})")
- parser.add_argument("--clients", type=int, default=DEFAULT_CLIENTS, help=f"并发连接数 (默认: {DEFAULT_CLIENTS})")
- parser.add_argument("--duration", type=int, default=DEFAULT_DURATION, help=f"测试时长秒 (默认: {DEFAULT_DURATION})")
- parser.add_argument("--interval", type=float, default=DEFAULT_MESSAGE_INTERVAL, help=f"消息间隔秒 (默认: {DEFAULT_MESSAGE_INTERVAL})")
- parser.add_argument("--step", action="store_true", help="运行梯度负载测试 (100/500/1000/5000)")
- parser.add_argument("--output", default=None, help="结果输出 JSON 文件路径")
-
- args = parser.parse_args()
-
- if args.step:
- results = asyncio.run(run_step_test(args.host, args.port, args.interval))
- else:
- results = asyncio.run(
- run_stress_test(args.host, args.port, args.clients, args.duration, args.interval)
- )
-
- # 保存结果
- if args.output:
- with open(args.output, "w") as f:
- json.dump(results, f, indent=2, ensure_ascii=False)
- print(f"结果已保存到: {args.output}")
-
- return results
-
-
- if __name__ == "__main__":
- main()
|