#!/usr/bin/env python3 """ IoT MQTT 并发压力测试 使用 asyncio + paho-mqtt 模拟大量 IoT 设备同时上报数据, 测试消息到达率、延迟、MQTT broker 负载。 运行方式: python mqtt_iot_stress.py [--broker HOST] [--port PORT] [--devices N] [--duration S] [--freq S] 示例: python mqtt_iot_stress.py --devices 1000 --duration 300 --freq 5 """ import asyncio import argparse import json import time import random import statistics import sys import threading from datetime import datetime from collections import defaultdict from concurrent.futures import ThreadPoolExecutor try: import paho.mqtt.client as mqtt except ImportError: print("请先安装 paho-mqtt: pip install paho-mqtt") sys.exit(1) # ==================== 配置 ==================== DEFAULT_BROKER = "localhost" DEFAULT_PORT = 1883 DEFAULT_DEVICES = 100 DEFAULT_DURATION = 120 # 秒 DEFAULT_FREQUENCY = 5 # 上报频率(秒) DEFAULT_TOPIC_PREFIX = "iot/sensor" DEFAULT_QOS = 1 # 梯度测试配置 STEP_CONFIGS = [ {"devices": 100, "duration": 60, "freq": 5}, {"devices": 500, "duration": 60, "freq": 5}, {"devices": 1000, "duration": 60, "freq": 5}, {"devices": 5000, "duration": 60, "freq": 5}, {"devices": 1000, "duration": 60, "freq": 1}, # 高频上报 {"devices": 1000, "duration": 60, "freq": 10}, # 低频上报 ] # ==================== 统计 ==================== class MQTTStats: """MQTT 统计收集器(线程安全)""" def __init__(self): self._lock = threading.Lock() self.connected = 0 self.disconnected = 0 self.messages_published = 0 self.messages_received = 0 self.publish_errors = 0 self.connect_errors = 0 self.latencies = [] self.publish_times = [] self.throughput_per_second = defaultdict(int) def record_connect(self): with self._lock: self.connected += 1 def record_disconnect(self): with self._lock: self.disconnected += 1 def record_connect_error(self): with self._lock: self.connect_errors += 1 def record_publish(self, latency_ms): with self._lock: self.messages_published += 1 self.publish_times.append(latency_ms) second = int(time.time()) self.throughput_per_second[second] += 1 def record_publish_error(self): with self._lock: self.publish_errors += 1 def record_receive(self, latency_ms): with self._lock: self.messages_received += 1 self.latencies.append(latency_ms) def summary(self): with self._lock: total_attempted = self.connected + self.connect_errors success_rate = (self.connected / total_attempted * 100) if total_attempted > 0 else 0 throughput_values = list(self.throughput_per_second.values()) throughput_stats = {} if throughput_values: throughput_stats = { "min_per_sec": min(throughput_values), "max_per_sec": max(throughput_values), "avg_per_sec": round(statistics.mean(throughput_values), 2), } publish_time_stats = {} if self.publish_times: sorted_times = sorted(self.publish_times) publish_time_stats = { "min_ms": round(sorted_times[0], 2), "avg_ms": round(statistics.mean(sorted_times), 2), "max_ms": round(sorted_times[-1], 2), "p95_ms": round(sorted_times[int(len(sorted_times) * 0.95)], 2), "p99_ms": round(sorted_times[int(len(sorted_times) * 0.99)], 2), } latency_stats = {} if self.latencies: sorted_lat = sorted(self.latencies) latency_stats = { "min_ms": round(sorted_lat[0], 2), "avg_ms": round(statistics.mean(sorted_lat), 2), "max_ms": round(sorted_lat[-1], 2), "p95_ms": round(sorted_lat[int(len(sorted_lat) * 0.95)], 2), "p99_ms": round(sorted_lat[int(len(sorted_lat) * 0.99)], 2), } return { "total_devices": total_attempted, "connected": self.connected, "connect_errors": self.connect_errors, "connection_success_rate": f"{success_rate:.1f}%", "disconnected": self.disconnected, "messages_published": self.messages_published, "messages_received": self.messages_received, "publish_errors": self.publish_errors, "publish_time": publish_time_stats, "message_latency": latency_stats, "throughput": throughput_stats, } # ==================== 设备模拟 ==================== def generate_sensor_payload(device_id): """生成传感器上报数据""" return json.dumps({ "deviceId": f"device_{device_id:06d}", "timestamp": datetime.now().isoformat(), "type": random.choice(["pressure", "flow", "quality", "level"]), "data": { "pressure": round(random.uniform(0.1, 0.8), 3), "flow": round(random.uniform(10, 500), 2), "temperature": round(random.uniform(5, 35), 1), "ph": round(random.uniform(6.5, 8.5), 2), "turbidity": round(random.uniform(0, 5), 2), }, "location": { "lng": round(random.uniform(116.0, 117.0), 6), "lat": round(random.uniform(39.5, 40.5), 6), }, "battery": round(random.uniform(20, 100), 1), "signal": random.randint(-90, -30), }) class SimulatedDevice: """模拟 IoT 设备""" def __init__(self, device_id, broker, port, stats, topic_prefix, qos): self.device_id = device_id self.broker = broker self.port = port self.stats = stats self.topic = f"{topic_prefix}/{device_id:06d}/data" self.qos = qos self.client = None self._running = False def on_connect(self, client, userdata, flags, rc, properties=None): if rc == 0: self.stats.record_connect() # 订阅确认主题(用于测试端到端延迟) client.subscribe(f"iot/ack/{self.device_id:06d}", qos=0) else: self.stats.record_connect_error() def on_disconnect(self, client, userdata, rc, properties=None): self.stats.record_disconnect() def on_message(self, client, userdata, msg): """收到 ACK 消息,计算端到端延迟""" try: payload = json.loads(msg.payload) if "sendTime" in payload: latency = (time.time() - payload["sendTime"]) * 1000 self.stats.record_receive(latency) except Exception: pass def on_publish(self, client, userdata, mid, rc=None, properties=None): pass def connect(self): """连接到 MQTT broker""" self.client = mqtt.Client( client_id=f"device_{self.device_id:06d}", callback_api_version=mqtt.CallbackAPIVersion.VERSION2, protocol=mqtt.MQTTv311, ) self.client.on_connect = self.on_connect self.client.on_disconnect = self.on_disconnect self.client.on_message = self.on_message self.client.on_publish = self.on_publish # 设置连接超时 self.client.connect_timeout = 10 try: self.client.connect(self.broker, self.port, keepalive=60) self.client.loop_start() return True except Exception: self.stats.record_connect_error() return False def publish(self): """发布一条传感器数据""" if not self.client or not self.client.is_connected(): return False payload = generate_sensor_payload(self.device_id) start_time = time.time() try: result = self.client.publish(self.topic, payload, qos=self.qos) if result.rc == mqtt.MQTT_ERR_SUCCESS: latency_ms = (time.time() - start_time) * 1000 self.stats.record_publish(latency_ms) return True else: self.stats.record_publish_error() return False except Exception: self.stats.record_publish_error() return False def disconnect(self): """断开连接""" self._running = False if self.client: try: self.client.loop_stop() self.client.disconnect() except Exception: pass # ==================== 测试运行器 ==================== def run_stress_test(broker, port, num_devices, duration, frequency, topic_prefix, qos): """运行 MQTT 压力测试""" stats = MQTTStats() print(f"\n{'=' * 60}") print(f"📡 IoT MQTT 压力测试") print(f"{'=' * 60}") print(f"Broker: {broker}:{port}") print(f"模拟设备数: {num_devices}") print(f"测试时长: {duration}s") print(f"上报频率: 每 {frequency}s") print(f"QoS: {qos}") print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"{'=' * 60}\n") # 创建并连接所有设备 devices = [] print(f"正在连接 {num_devices} 个设备...") batch_size = max(10, num_devices // 10) for i in range(num_devices): device = SimulatedDevice(i + 1, broker, port, stats, topic_prefix, qos) device.connect() devices.append(device) if (i + 1) % batch_size == 0: time.sleep(0.5) # 分批连接,避免瞬时压力 # 等待连接完成 time.sleep(3) print(f"✅ 设备连接完成: {stats.connected}/{num_devices}") # 开始上报数据 print(f"\n开始数据上报(每 {frequency}s)...") start_time = time.time() report_count = 0 while time.time() - start_time < duration: for device in devices: device.publish() report_count += 1 elapsed = time.time() - start_time if report_count % 10 == 0: print(f" [{elapsed:.0f}s] 已发送 {stats.messages_published} 条消息, " f"错误: {stats.publish_errors}") time.sleep(frequency) # 停止所有设备 print(f"\n停止所有设备...") for device in devices: device.disconnect() time.sleep(2) # 输出结果 summary = stats.summary() print(f"\n{'=' * 60}") print(f"📊 MQTT 压力测试结果") print(f"{'=' * 60}") print(f"模拟设备数: {num_devices}") print(f"成功连接: {summary['connected']}") print(f"连接失败: {summary['connect_errors']}") print(f"连接成功率: {summary['connection_success_rate']}") print(f"消息发布总数: {summary['messages_published']}") print(f"发布错误: {summary['publish_errors']}") if summary['publish_time']: print(f"\n消息发布时间:") print(f" 最小: {summary['publish_time']['min_ms']}ms") print(f" 平均: {summary['publish_time']['avg_ms']}ms") print(f" 最大: {summary['publish_time']['max_ms']}ms") print(f" P95: {summary['publish_time']['p95_ms']}ms") print(f" P99: {summary['publish_time']['p99_ms']}ms") if summary['throughput']: print(f"\n吞吐量:") print(f" 最小: {summary['throughput']['min_per_sec']} msg/s") print(f" 最大: {summary['throughput']['max_per_sec']} msg/s") print(f" 平均: {summary['throughput']['avg_per_sec']} msg/s") if summary['message_latency']: print(f"\n端到端延迟 (ACK):") print(f" 最小: {summary['message_latency']['min_ms']}ms") print(f" 平均: {summary['message_latency']['avg_ms']}ms") print(f" P95: {summary['message_latency']['p95_ms']}ms") print(f"{'=' * 60}\n") return summary def run_step_test(broker, port, topic_prefix, qos): """运行梯度负载测试""" print("\n" + "=" * 60) print("📈 MQTT 梯度负载测试") print("=" * 60) results = [] for config in STEP_CONFIGS: result = run_stress_test( broker, port, config["devices"], config["duration"], config["freq"], topic_prefix, qos, ) results.append({ "devices": config["devices"], "frequency": config["freq"], **result, }) time.sleep(10) # 阶段间冷却 return results # ==================== 入口 ==================== def main(): parser = argparse.ArgumentParser(description="IoT MQTT 压力测试") parser.add_argument("--broker", default=DEFAULT_BROKER, help=f"MQTT broker 地址 (默认: {DEFAULT_BROKER})") parser.add_argument("--port", type=int, default=DEFAULT_PORT, help=f"MQTT broker 端口 (默认: {DEFAULT_PORT})") parser.add_argument("--devices", type=int, default=DEFAULT_DEVICES, help=f"模拟设备数 (默认: {DEFAULT_DEVICES})") parser.add_argument("--duration", type=int, default=DEFAULT_DURATION, help=f"测试时长秒 (默认: {DEFAULT_DURATION})") parser.add_argument("--freq", type=float, default=DEFAULT_FREQUENCY, help=f"上报频率秒 (默认: {DEFAULT_FREQUENCY})") parser.add_argument("--topic", default=DEFAULT_TOPIC_PREFIX, help=f"主题前缀 (默认: {DEFAULT_TOPIC_PREFIX})") parser.add_argument("--qos", type=int, default=DEFAULT_QOS, choices=[0, 1, 2], help=f"QoS 级别 (默认: {DEFAULT_QOS})") parser.add_argument("--step", action="store_true", help="运行梯度负载测试") parser.add_argument("--output", default=None, help="结果输出 JSON 文件路径") args = parser.parse_args() if args.step: results = run_step_test(args.broker, args.port, args.topic, args.qos) else: results = run_stress_test( args.broker, args.port, args.devices, args.duration, args.freq, args.topic, args.qos, ) if args.output: with open(args.output, "w") as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"结果已保存到: {args.output}") return results if __name__ == "__main__": main()