| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- #!/usr/bin/env python3
- """
- REST API 性能压力测试 - Locust
-
- 使用 Locust 框架对水务管理系统的 REST API 进行并发压力测试。
- 覆盖主要 API 端点:登录、数据查询、数据上报、报表查询等。
-
- 运行方式:
- locust -f locustfile.py --host=http://localhost:8080
- # 无头模式:
- locust -f locustfile.py --host=http://localhost:8080 --headless -u 100 -r 10 --run-time 5m
-
- 支持 10/50/100/500/1000 并发用户梯度测试。
- """
-
- import random
- import time
- import json
- from datetime import datetime, timedelta
-
- try:
- from locust import HttpUser, task, between, events, LoadTestShape
- except ImportError:
- print("请先安装 locust: pip install locust")
- raise SystemExit(1)
-
-
- # ==================== 配置 ====================
-
- # 默认 API Token(实际使用时通过环境变量覆盖)
- API_TOKEN = "test-token"
-
- # 模拟的管网节点 ID 范围
- NODE_ID_RANGE = (1, 10000)
-
- # 模拟的传感器 ID 范围
- SENSOR_ID_RANGE = (1, 5000)
-
-
- # ==================== 用户行为 ====================
-
- class WaterManagementUser(HttpUser):
- """模拟水务管理系统的用户行为"""
-
- wait_time = between(1, 5) # 每次请求间隔 1-5 秒
-
- def on_start(self):
- """用户初始化:模拟登录获取 token"""
- self.token = API_TOKEN
- self.headers = {
- "Authorization": f"Bearer {self.token}",
- "Content-Type": "application/json",
- }
- # 尝试真实登录
- try:
- with self.client.post(
- "/api/auth/login",
- json={"username": "admin", "password": "admin123"},
- catch_response=True,
- timeout=10,
- ) as resp:
- if resp.status_code == 200:
- data = resp.json()
- if data.get("data", {}).get("token"):
- self.token = data["data"]["token"]
- self.headers["Authorization"] = f"Bearer {self.token}"
- resp.success()
- else:
- resp.failure("No token in response")
- else:
- resp.failure(f"Login failed: {resp.status_code}")
- except Exception:
- pass # 使用默认 token
-
- # ---------- 数据查询类任务 ----------
-
- @task(10)
- def query_monitoring_data(self):
- """查询实时监测数据(高频操作)"""
- node_id = random.randint(*NODE_ID_RANGE)
- start = (datetime.now() - timedelta(hours=24)).isoformat()
- end = datetime.now().isoformat()
-
- self.client.get(
- f"/api/monitoring/data?nodeId={node_id}&startTime={start}&endTime={end}",
- headers=self.headers,
- name="/api/monitoring/data",
- timeout=30,
- )
-
- @task(5)
- def query_pressure_data(self):
- """查询管网压力数据"""
- area_id = random.randint(1, 100)
- self.client.get(
- f"/api/monitoring/pressure?areaId={area_id}",
- headers=self.headers,
- name="/api/monitoring/pressure",
- timeout=30,
- )
-
- @task(5)
- def query_flow_data(self):
- """查询流量数据"""
- sensor_id = random.randint(*SENSOR_ID_RANGE)
- self.client.get(
- f"/api/monitoring/flow?sensorId={sensor_id}",
- headers=self.headers,
- name="/api/monitoring/flow",
- timeout=30,
- )
-
- @task(3)
- def query_water_quality(self):
- """查询水质数据"""
- station_id = random.randint(1, 50)
- self.client.get(
- f"/api/monitoring/water-quality?stationId={station_id}",
- headers=self.headers,
- name="/api/monitoring/water-quality",
- timeout=30,
- )
-
- # ---------- 数据上报类任务 ----------
-
- @task(8)
- def report_sensor_data(self):
- """上报传感器数据(高频操作)"""
- payload = {
- "sensorId": random.randint(*SENSOR_ID_RANGE),
- "timestamp": datetime.now().isoformat(),
- "pressure": round(random.uniform(0.1, 0.8), 3),
- "flow": round(random.uniform(10, 500), 2),
- "temperature": round(random.uniform(5, 35), 1),
- "quality": {
- "turbidity": round(random.uniform(0, 5), 2),
- "ph": round(random.uniform(6.5, 8.5), 2),
- "chlorine": round(random.uniform(0.1, 0.8), 3),
- },
- }
-
- with self.client.post(
- "/api/data/report",
- json=payload,
- headers=self.headers,
- catch_response=True,
- name="/api/data/report",
- timeout=15,
- ) as resp:
- if resp.status_code in (200, 201, 401, 404):
- resp.success()
-
- @task(3)
- def report_alarm(self):
- """上报告警事件"""
- alarm_types = ["PRESSURE_HIGH", "PRESSURE_LOW", "FLOW_ABNORMAL",
- "QUALITY_WARNING", "LEAK_DETECTED"]
- payload = {
- "sensorId": random.randint(*SENSOR_ID_RANGE),
- "alarmType": random.choice(alarm_types),
- "severity": random.choice(["INFO", "WARNING", "CRITICAL"]),
- "timestamp": datetime.now().isoformat(),
- "description": f"自动告警 - {random.choice(alarm_types)}",
- "value": round(random.uniform(0, 100), 2),
- "threshold": round(random.uniform(50, 100), 2),
- }
-
- with self.client.post(
- "/api/alarm/report",
- json=payload,
- headers=self.headers,
- catch_response=True,
- name="/api/alarm/report",
- timeout=15,
- ) as resp:
- if resp.status_code in (200, 201, 401, 404):
- resp.success()
-
- # ---------- 报表查询类任务 ----------
-
- @task(4)
- def query_daily_report(self):
- """查询日报表"""
- date = (datetime.now() - timedelta(days=random.randint(0, 30))).strftime("%Y-%m-%d")
- self.client.get(
- f"/api/report/daily?date={date}",
- headers=self.headers,
- name="/api/report/daily",
- timeout=30,
- )
-
- @task(2)
- def query_monthly_report(self):
- """查询月报表"""
- year = datetime.now().year
- month = random.randint(1, 12)
- self.client.get(
- f"/api/report/monthly?year={year}&month={month}",
- headers=self.headers,
- name="/api/report/monthly",
- timeout=30,
- )
-
- @task(2)
- def export_report(self):
- """导出报表(较重量级操作)"""
- start = (datetime.now() - timedelta(days=30)).strftime("%Y-%m-%d")
- end = datetime.now().strftime("%Y-%m-%d")
- with self.client.get(
- f"/api/report/export?startDate={start}&endDate={end}&format=xlsx",
- headers=self.headers,
- catch_response=True,
- name="/api/report/export",
- timeout=60,
- ) as resp:
- if resp.status_code in (200, 401, 404):
- resp.success()
-
- # ---------- 应急调度类任务 ----------
-
- @task(1)
- def query_emergency_list(self):
- """查询应急事件列表"""
- self.client.get(
- "/api/emergency/list?page=1&size=20",
- headers=self.headers,
- name="/api/emergency/list",
- timeout=30,
- )
-
- @task(1)
- def simulate_pipe_burst(self):
- """模拟爆管事件(重量级操作)"""
- payload = {
- "lng": round(random.uniform(116.0, 117.0), 4),
- "lat": round(random.uniform(39.5, 40.5), 4),
- "pipeDiameter": random.choice(["DN50", "DN100", "DN200", "DN300"]),
- "operatorName": f"operator_{random.randint(1, 50)}",
- }
-
- with self.client.post(
- "/api/emergency/dispatch/quick-pipe-burst",
- json=payload,
- headers=self.headers,
- catch_response=True,
- name="/api/emergency/dispatch/quick-pipe-burst",
- timeout=60,
- ) as resp:
- if resp.status_code in (200, 201, 401, 404):
- resp.success()
-
- # ---------- GIS 空间查询 ----------
-
- @task(3)
- def gis_query_nearby(self):
- """GIS 空间查询 - 附近设备"""
- lng = round(random.uniform(116.0, 117.0), 6)
- lat = round(random.uniform(39.5, 40.5), 6)
- radius = random.choice([500, 1000, 2000, 5000])
- self.client.get(
- f"/api/gis/nearby?lng={lng}&lat={lat}&radius={radius}",
- headers=self.headers,
- name="/api/gis/nearby",
- timeout=30,
- )
-
- @task(2)
- def gis_query_area_stats(self):
- """GIS 区域统计"""
- area_id = random.randint(1, 50)
- self.client.get(
- f"/api/gis/area-stats?areaId={area_id}",
- headers=self.headers,
- name="/api/gis/area-stats",
- timeout=30,
- )
-
-
- # ==================== 梯度负载模型 ====================
-
- class StepLoadShape(LoadTestShape):
- """
- 梯度负载测试:10 → 50 → 100 → 500 → 1000 用户
- 每个阶段持续 3 分钟,总共约 15 分钟
- """
-
- stages = [
- {"duration": 180, "users": 10, "spawn_rate": 5}, # 预热 10 用户
- {"duration": 360, "users": 50, "spawn_rate": 10}, # 50 用户
- {"duration": 540, "users": 100, "spawn_rate": 20}, # 100 用户
- {"duration": 720, "users": 500, "spawn_rate": 50}, # 500 用户
- {"duration": 900, "users": 1000, "spawn_rate": 100}, # 1000 用户
- {"duration": 1080, "users": 0, "spawn_rate": 100}, # 逐步停止
- ]
-
- def tick(self):
- run_time = self.get_run_time()
-
- for stage in self.stages:
- if run_time < stage["duration"]:
- tick_data = (
- stage["users"],
- stage["spawn_rate"],
- )
- return tick_data
-
- return None
-
-
- # ==================== 事件钩子:自定义统计 ====================
-
- @events.test_stop.add_listener
- def on_test_stop(environment, **kwargs):
- """测试结束时输出摘要"""
- stats = environment.runner.stats
- print("\n" + "=" * 60)
- print("📊 REST API 压力测试报告")
- print("=" * 60)
- print(stats.serialize_current_response_times())
- print("=" * 60)
|