#!/usr/bin/env python3 """ 大数据量数据库查询压力测试 测试百万级记录下各类查询的性能: - 简单查询(单表 WHERE) - 聚合查询(GROUP BY / SUM / AVG) - JOIN 查询(多表关联) - GIS 空间查询(PostGIS ST_DWithin 等) - 对比有无索引的性能差异 运行方式: python db_query_stress.py [--host HOST] [--port PORT] [--db DB] [--user USER] [--password PASS] 示例: python db_query_stress.py --host localhost --db water_management --generate-data 1000000 """ import argparse import json import time import random import statistics import sys from datetime import datetime, timedelta from contextlib import contextmanager try: import psycopg2 from psycopg2 import pool except ImportError: print("请先安装 psycopg2: pip install psycopg2-binary") sys.exit(1) # ==================== 配置 ==================== DEFAULT_HOST = "localhost" DEFAULT_PORT = 5432 DEFAULT_DB = "water_management" DEFAULT_USER = "postgres" DEFAULT_PASSWORD = "postgres" # 测试查询配置 QUERY_ROUNDS = 5 # 每个查询重复次数取平均值 # ==================== 数据库连接 ==================== class DBConnectionPool: """数据库连接池管理""" def __init__(self, host, port, db, user, password, min_conn=2, max_conn=10): self.pool = pool.ThreadedConnectionPool( min_conn, max_conn, host=host, port=port, dbname=db, user=user, password=password, connect_timeout=10, ) @contextmanager def get_connection(self): conn = self.pool.getconn() try: yield conn finally: self.pool.putconn(conn) def close(self): self.pool.closeall() # ==================== 测试数据生成 ==================== def generate_test_data(db_pool, num_records): """生成测试数据""" print(f"\n📦 生成 {num_records} 条测试数据...") with db_pool.get_connection() as conn: cur = conn.cursor() # 创建传感器数据表(如不存在) cur.execute(""" CREATE TABLE IF NOT EXISTS perf_sensor_data ( id BIGSERIAL PRIMARY KEY, sensor_id INTEGER NOT NULL, device_id VARCHAR(32) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(), pressure REAL, flow REAL, temperature REAL, ph REAL, turbidity REAL, chlorine REAL, battery REAL, signal_strength INTEGER, lng DOUBLE PRECISION, lat DOUBLE PRECISION, area_id INTEGER, created_at TIMESTAMPTZ DEFAULT NOW() ) """) # 创建告警表 cur.execute(""" CREATE TABLE IF NOT EXISTS perf_alarms ( id BIGSERIAL PRIMARY KEY, sensor_id INTEGER NOT NULL, alarm_type VARCHAR(32) NOT NULL, severity VARCHAR(16) NOT NULL, value REAL, threshold REAL, description TEXT, acknowledged BOOLEAN DEFAULT FALSE, timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(), created_at TIMESTAMPTZ DEFAULT NOW() ) """) # 创建设备表 cur.execute(""" CREATE TABLE IF NOT EXISTS perf_devices ( id SERIAL PRIMARY KEY, device_id VARCHAR(32) UNIQUE NOT NULL, name VARCHAR(128), device_type VARCHAR(32), area_id INTEGER, lng DOUBLE PRECISION, lat DOUBLE PRECISION, status VARCHAR(16) DEFAULT 'active', install_date DATE, created_at TIMESTAMPTZ DEFAULT NOW() ) """) conn.commit() # 清空旧测试数据 print(" 清空旧数据...") cur.execute("TRUNCATE TABLE perf_sensor_data, perf_alarms, perf_devices") conn.commit() # 生成设备数据 print(" 生成设备数据...") device_count = min(5000, num_records // 100) devices = [] for i in range(1, device_count + 1): device_id = f"DEV{i:06d}" devices.append(( device_id, f"传感器-{i}", random.choice(["pressure", "flow", "quality", "level", "valve"]), random.randint(1, 50), round(random.uniform(116.0, 117.0), 6), round(random.uniform(39.5, 40.5), 6), random.choice(["active", "inactive", "maintenance"]), (datetime.now() - timedelta(days=random.randint(0, 730))).date(), )) cur.executemany(""" INSERT INTO perf_devices (device_id, name, device_type, area_id, lng, lat, status, install_date) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """, devices) conn.commit() print(f" ✅ 已生成 {device_count} 个设备") # 批量插入传感器数据 print(f" 生成 {num_records} 条传感器数据...") batch_size = 10000 total_inserted = 0 base_time = datetime.now() - timedelta(days=365) alarm_types = ["PRESSURE_HIGH", "PRESSURE_LOW", "FLOW_ABNORMAL", "QUALITY_WARNING", "LEAK_DETECTED", "BATTERY_LOW"] severities = ["INFO", "WARNING", "CRITICAL"] while total_inserted < num_records: batch = [] for _ in range(min(batch_size, num_records - total_inserted)): sensor_id = random.randint(1, device_count) ts = base_time + timedelta( seconds=random.randint(0, 365 * 24 * 3600) ) batch.append(( sensor_id, f"DEV{sensor_id:06d}", ts, round(random.uniform(0.1, 0.8), 3), # pressure round(random.uniform(10, 500), 2), # flow round(random.uniform(5, 35), 1), # temperature round(random.uniform(6.5, 8.5), 2), # ph round(random.uniform(0, 5), 2), # turbidity round(random.uniform(0.1, 0.8), 3), # chlorine round(random.uniform(20, 100), 1), # battery random.randint(-90, -30), # signal round(random.uniform(116.0, 117.0), 6), # lng round(random.uniform(39.5, 40.5), 6), # lat random.randint(1, 50), # area_id )) cur.executemany(""" INSERT INTO perf_sensor_data (sensor_id, device_id, timestamp, pressure, flow, temperature, ph, turbidity, chlorine, battery, signal_strength, lng, lat, area_id) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """, batch) conn.commit() total_inserted += len(batch) if total_inserted % 100000 == 0: print(f" 进度: {total_inserted}/{num_records} ({total_inserted*100//num_records}%)") print(f" ✅ 已生成 {total_inserted} 条传感器数据") # 生成告警数据(传感器数据的 5%) alarm_count = num_records // 20 print(f" 生成 {alarm_count} 条告警数据...") alarms = [] for _ in range(alarm_count): sensor_id = random.randint(1, device_count) alarms.append(( sensor_id, random.choice(alarm_types), random.choice(severities), round(random.uniform(0, 100), 2), round(random.uniform(50, 100), 2), f"自动告警描述 - {random.choice(alarm_types)}", random.choice([True, False]), base_time + timedelta(seconds=random.randint(0, 365 * 24 * 3600)), )) cur.executemany(""" INSERT INTO perf_alarms (sensor_id, alarm_type, severity, value, threshold, description, acknowledged, timestamp) VALUES (%s,%s,%s,%s,%s,%s,%s,%s) """, alarms) conn.commit() print(f" ✅ 已生成 {alarm_count} 条告警数据") cur.close() print("✅ 测试数据生成完成\n") # ==================== 索引管理 ==================== def create_indexes(db_pool): """创建性能测试用索引""" with db_pool.get_connection() as conn: cur = conn.cursor() indexes = [ "CREATE INDEX IF NOT EXISTS idx_sensor_data_sensor_id ON perf_sensor_data(sensor_id)", "CREATE INDEX IF NOT EXISTS idx_sensor_data_timestamp ON perf_sensor_data(timestamp)", "CREATE INDEX IF NOT EXISTS idx_sensor_data_area_id ON perf_sensor_data(area_id)", "CREATE INDEX IF NOT EXISTS idx_sensor_data_device_id ON perf_sensor_data(device_id)", "CREATE INDEX IF NOT EXISTS idx_sensor_data_composite ON perf_sensor_data(sensor_id, timestamp)", "CREATE INDEX IF NOT EXISTS idx_alarms_sensor_id ON perf_alarms(sensor_id)", "CREATE INDEX IF NOT EXISTS idx_alarms_timestamp ON perf_alarms(timestamp)", "CREATE INDEX IF NOT EXISTS idx_alarms_type ON perf_alarms(alarm_type)", "CREATE INDEX IF NOT EXISTS idx_devices_area_id ON perf_devices(area_id)", "CREATE INDEX IF NOT EXISTS idx_devices_type ON perf_devices(device_type)", ] for idx in indexes: cur.execute(idx) conn.commit() cur.close() print("✅ 索引已创建") def drop_indexes(db_pool): """删除性能测试用索引""" with db_pool.get_connection() as conn: cur = conn.cursor() indexes = [ "idx_sensor_data_sensor_id", "idx_sensor_data_timestamp", "idx_sensor_data_area_id", "idx_sensor_data_device_id", "idx_sensor_data_composite", "idx_alarms_sensor_id", "idx_alarms_timestamp", "idx_alarms_type", "idx_devices_area_id", "idx_devices_type", ] for idx in indexes: cur.execute(f"DROP INDEX IF EXISTS {idx}") conn.commit() cur.close() print("🗑️ 索引已删除") # ==================== 查询测试 ==================== class QueryBenchmark: """查询性能基准测试""" def __init__(self, db_pool): self.db_pool = db_pool self.results = [] def run_query(self, name, sql, params=None, rounds=QUERY_ROUNDS): """执行查询并统计耗时""" times = [] row_count = 0 for i in range(rounds): with self.db_pool.get_connection() as conn: cur = conn.cursor() start = time.time() try: cur.execute(sql, params) rows = cur.fetchall() row_count = len(rows) except Exception as e: print(f" ⚠️ 查询错误: {e}") conn.rollback() row_count = 0 times.append(-1) cur.close() continue elapsed_ms = (time.time() - start) * 1000 times.append(elapsed_ms) cur.close() valid_times = [t for t in times if t >= 0] if valid_times: result = { "name": name, "sql": sql[:100] + ("..." if len(sql) > 100 else ""), "rounds": rounds, "success_rounds": len(valid_times), "rows_returned": row_count, "min_ms": round(min(valid_times), 2), "max_ms": round(max(valid_times), 2), "avg_ms": round(statistics.mean(valid_times), 2), "median_ms": round(statistics.median(valid_times), 2), } else: result = { "name": name, "sql": sql[:100] + ("..." if len(sql) > 100 else ""), "rounds": rounds, "success_rounds": 0, "error": "all rounds failed", } self.results.append(result) status = "✅" if valid_times else "❌" avg = result.get("avg_ms", "N/A") print(f" {status} {name}: avg={avg}ms, rows={row_count}") return result def run_all_benchmarks(self): """运行所有基准测试""" print("\n🔍 运行查询性能测试...\n") # 1. 简单查询 print(" --- 简单查询 ---") self.run_query( "简单查询 - 按传感器ID查询最近24小时", """SELECT * FROM perf_sensor_data WHERE sensor_id = %s AND timestamp > NOW() - INTERVAL '24 hours' ORDER BY timestamp DESC LIMIT 100""", (random.randint(1, 5000),), ) self.run_query( "简单查询 - 按区域ID查询", """SELECT * FROM perf_sensor_data WHERE area_id = %s ORDER BY timestamp DESC LIMIT 100""", (random.randint(1, 50),), ) self.run_query( "简单查询 - 时间范围查询", """SELECT * FROM perf_sensor_data WHERE timestamp BETWEEN %s AND %s ORDER BY timestamp DESC LIMIT 1000""", ( (datetime.now() - timedelta(days=7)).isoformat(), datetime.now().isoformat(), ), ) # 2. 聚合查询 print("\n --- 聚合查询 ---") self.run_query( "聚合查询 - 每小时平均压力", """SELECT date_trunc('hour', timestamp) AS hour, AVG(pressure) AS avg_pressure, MAX(pressure) AS max_pressure, MIN(pressure) AS min_pressure, COUNT(*) AS readings FROM perf_sensor_data WHERE sensor_id = %s AND timestamp > NOW() - INTERVAL '7 days' GROUP BY hour ORDER BY hour""", (random.randint(1, 5000),), ) self.run_query( "聚合查询 - 各区域统计", """SELECT area_id, COUNT(*) AS total_readings, AVG(pressure) AS avg_pressure, AVG(flow) AS avg_flow, AVG(temperature) AS avg_temp FROM perf_sensor_data WHERE timestamp > NOW() - INTERVAL '1 day' GROUP BY area_id ORDER BY total_readings DESC""", ) self.run_query( "聚合查询 - 日统计报表", """SELECT date_trunc('day', timestamp) AS day, COUNT(*) AS readings, AVG(pressure) AS avg_pressure, AVG(flow) AS avg_flow, STDDEV(pressure) AS pressure_stddev FROM perf_sensor_data WHERE timestamp > NOW() - INTERVAL '30 days' GROUP BY day ORDER BY day""", ) # 3. JOIN 查询 print("\n --- JOIN 查询 ---") self.run_query( "JOIN 查询 - 传感器数据 + 设备信息", """SELECT s.timestamp, s.pressure, s.flow, s.temperature, d.name AS device_name, d.device_type, d.area_id, d.status FROM perf_sensor_data s JOIN perf_devices d ON s.device_id = d.device_id WHERE s.sensor_id = %s ORDER BY s.timestamp DESC LIMIT 100""", (random.randint(1, 5000),), ) self.run_query( "JOIN 查询 - 告警 + 设备信息", """SELECT a.timestamp, a.alarm_type, a.severity, a.value, a.threshold, d.name AS device_name, d.device_type, d.lng, d.lat FROM perf_alarms a JOIN perf_devices d ON a.sensor_id = d.id WHERE a.timestamp > NOW() - INTERVAL '7 days' ORDER BY a.timestamp DESC LIMIT 200""", ) self.run_query( "JOIN 查询 - 三表关联统计", """SELECT d.area_id, COUNT(DISTINCT d.id) AS device_count, COUNT(s.id) AS reading_count, COUNT(a.id) AS alarm_count, AVG(s.pressure) AS avg_pressure FROM perf_devices d LEFT JOIN perf_sensor_data s ON s.device_id = d.device_id AND s.timestamp > NOW() - INTERVAL '1 day' LEFT JOIN perf_alarms a ON a.sensor_id = d.id AND a.timestamp > NOW() - INTERVAL '1 day' GROUP BY d.area_id ORDER BY reading_count DESC""", ) # 4. GIS 空间查询 print("\n --- GIS 空间查询 ---") # 先检查 PostGIS 是否可用 postgis_available = False try: with self.db_pool.get_connection() as conn: cur = conn.cursor() cur.execute("SELECT PostGIS_Version()") postgis_available = True cur.close() except Exception: pass if postgis_available: self.run_query( "GIS 查询 - 圆形范围内设备 (1km)", """SELECT d.id, d.device_id, d.name, d.device_type, ST_Distance( ST_SetSRID(ST_MakePoint(d.lng, d.lat), 4326)::geography, ST_SetSRID(ST_MakePoint(%s, %s), 4326)::geography ) AS distance_m FROM perf_devices d WHERE ST_DWithin( ST_SetSRID(ST_MakePoint(d.lng, d.lat), 4326)::geography, ST_SetSRID(ST_MakePoint(%s, %s), 4326)::geography, 1000 ) ORDER BY distance_m""", (116.4, 39.9, 116.4, 39.9), ) else: # 无 PostGIS 时用距离公式近似 self.run_query( "空间查询 - 距离近似计算 (无PostGIS)", """SELECT id, device_id, name, device_type, SQRT(POW(lng - %s, 2) + POW(lat - %s, 2)) * 111000 AS approx_distance_m FROM perf_devices WHERE ABS(lng - %s) < 0.01 AND ABS(lat - %s) < 0.01 ORDER BY approx_distance_m LIMIT 100""", (116.4, 39.9, 116.4, 39.9), ) self.run_query( "空间查询 - 区域设备统计", """SELECT area_id, COUNT(*) AS device_count, AVG(lng) AS center_lng, AVG(lat) AS center_lat, COUNT(CASE WHEN status = 'active' THEN 1 END) AS active_count FROM perf_devices GROUP BY area_id ORDER BY device_count DESC""", ) # 5. 复杂查询 print("\n --- 复杂查询 ---") self.run_query( "复杂查询 - 异常检测 (压力突变)", """WITH sensor_stats AS ( SELECT sensor_id, AVG(pressure) AS avg_p, STDDEV(pressure) AS std_p FROM perf_sensor_data WHERE timestamp > NOW() - INTERVAL '7 days' GROUP BY sensor_id HAVING COUNT(*) > 10 ) SELECT s.sensor_id, s.timestamp, s.pressure, ss.avg_p, ss.std_p, ABS(s.pressure - ss.avg_p) / NULLIF(ss.std_p, 0) AS z_score FROM perf_sensor_data s JOIN sensor_stats ss ON s.sensor_id = ss.sensor_id WHERE ABS(s.pressure - ss.avg_p) > 3 * NULLIF(ss.std_p, 0) AND s.timestamp > NOW() - INTERVAL '1 day' ORDER BY z_score DESC LIMIT 50""", ) self.run_query( "复杂查询 - 窗口函数 (滑动平均)", """SELECT sensor_id, timestamp, pressure, AVG(pressure) OVER ( PARTITION BY sensor_id ORDER BY timestamp ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING ) AS moving_avg FROM perf_sensor_data WHERE sensor_id = %s AND timestamp > NOW() - INTERVAL '1 day' ORDER BY timestamp LIMIT 500""", (random.randint(1, 5000),), ) return self.results # ==================== 索引对比测试 ==================== def run_index_comparison(db_pool): """对比有无索引的查询性能""" print("\n" + "=" * 60) print("📊 索引性能对比测试") print("=" * 60) # 无索引测试 print("\n--- 无索引状态 ---") drop_indexes(db_pool) bench_no_idx = QueryBenchmark(db_pool) results_no_idx = bench_no_idx.run_all_benchmarks() # 有索引测试 print("\n--- 有索引状态 ---") create_indexes(db_pool) # 先 ANALYZE 更新统计 with db_pool.get_connection() as conn: cur = conn.cursor() cur.execute("ANALYZE perf_sensor_data, perf_alarms, perf_devices") conn.commit() cur.close() bench_with_idx = QueryBenchmark(db_pool) results_with_idx = bench_with_idx.run_all_benchmarks() # 对比 print("\n" + "=" * 60) print("📊 索引性能对比结果") print("=" * 60) print(f"{'查询名称':<40} {'无索引(ms)':>12} {'有索引(ms)':>12} {'提升':>10}") print("-" * 74) comparison = [] for no_idx, with_idx in zip(results_no_idx, results_with_idx): no_avg = no_idx.get("avg_ms", 0) with_avg = with_idx.get("avg_ms", 0) if no_avg > 0: improvement = f"{(1 - with_avg / no_avg) * 100:.1f}%" else: improvement = "N/A" print(f"{no_idx['name']:<40} {no_avg:>12.2f} {with_avg:>12.2f} {improvement:>10}") comparison.append({ "query": no_idx["name"], "no_index_ms": no_avg, "with_index_ms": with_avg, }) return comparison # ==================== 入口 ==================== def main(): parser = argparse.ArgumentParser(description="数据库查询压力测试") parser.add_argument("--host", default=DEFAULT_HOST, help="数据库主机") parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="数据库端口") parser.add_argument("--db", default=DEFAULT_DB, help="数据库名") parser.add_argument("--user", default=DEFAULT_USER, help="数据库用户") parser.add_argument("--password", default=DEFAULT_PASSWORD, help="数据库密码") parser.add_argument("--generate-data", type=int, default=0, help="生成指定条数的测试数据 (如 1000000)") parser.add_argument("--skip-data-gen", action="store_true", help="跳过数据生成") parser.add_argument("--index-compare", action="store_true", help="运行索引对比测试") parser.add_argument("--output", default=None, help="结果输出 JSON 文件路径") args = parser.parse_args() print(f"\n{'=' * 60}") print(f"🗄️ 数据库查询压力测试") print(f"{'=' * 60}") print(f"数据库: {args.host}:{args.port}/{args.db}") print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"{'=' * 60}") # 连接数据库 try: db_pool = DBConnectionPool( args.host, args.port, args.db, args.user, args.password ) except Exception as e: print(f"❌ 数据库连接失败: {e}") sys.exit(1) all_results = {} try: # 生成测试数据 if args.generate_data > 0 and not args.skip_data_gen: generate_test_data(db_pool, args.generate_data) create_indexes(db_pool) # ANALYZE with db_pool.get_connection() as conn: cur = conn.cursor() cur.execute("ANALYZE perf_sensor_data, perf_alarms, perf_devices") conn.commit() cur.close() # 运行基准查询测试 bench = QueryBenchmark(db_pool) results = bench.run_all_benchmarks() all_results["query_benchmark"] = results # 索引对比测试 if args.index_compare: comparison = run_index_comparison(db_pool) all_results["index_comparison"] = comparison finally: db_pool.close() # 输出汇总 print(f"\n{'=' * 60}") print("📊 查询测试汇总") print(f"{'=' * 60}") for r in all_results.get("query_benchmark", []): avg = r.get("avg_ms", "ERR") print(f" {r['name']}: {avg}ms ({r.get('rows_returned', 0)} rows)") if args.output: with open(args.output, "w") as f: json.dump(all_results, f, indent=2, ensure_ascii=False, default=str) print(f"\n结果已保存到: {args.output}") print(f"\n{'=' * 60}") return all_results if __name__ == "__main__": main()