Parcourir la source

feat: 添加安全模块 - RBAC/JWT/加密/审计 (#95)

- src/security/auth.py: RBAC 角色权限 (ADMIN/OPERATOR/VIEWER/DEVICE) + JWT 认证
- src/security/middleware.py: 安全中间件 (JWT认证/CORS/限流/CSRF/安全Headers)
- src/security/encryption.py: AES-256-GCM 数据加密 + 响应脱敏
- src/security/input_validator.py: SQL注入/XSS防护 + 文件上传验证
- src/security/audit.py: 操作审计日志 (事件记录/告警/查询/双写存储)
- src/security/config.py: 安全配置 (从环境变量读取)
- src/api/rest_api.py: 认证端点 + 审计查询 + 权限示例
- tests/test_security.py: 43 个单元测试全部通过
- docs/SECURITY.md: 安全架构文档
- requirements.txt: 安全依赖 (PyJWT/passlib/bcrypt/cryptography)
bot_dev2 il y a 3 jours
Parent
révision
81ca9d6e74

+ 309
- 0
docs/SECURITY.md Voir le fichier

1
+# 安全架构文档
2
+
3
+## 概述
4
+
5
+水务管理系统安全模块提供完整的身份认证、权限控制、数据加密和安全审计能力,保障系统在生产环境中的安全性。
6
+
7
+## 安全架构
8
+
9
+```
10
+┌─────────────────────────────────────────────────────────┐
11
+│                     客户端请求                            │
12
+└─────────────────┬───────────────────────────────────────┘
13
+                  │
14
+                  ▼
15
+┌─────────────────────────────────────────────────────────┐
16
+│  SecurityMiddleware(安全中间件层)                        │
17
+│  ├─ Rate Limiting (限流: 60 req/min/IP)                  │
18
+│  ├─ Security Headers (HSTS, CSP, X-Frame-Options...)    │
19
+│  ├─ JWT Authentication (Token 验证)                      │
20
+│  ├─ CORS Protection (跨域限制)                           │
21
+│  └─ Audit Logging (审计记录)                              │
22
+└─────────────────┬───────────────────────────────────────┘
23
+                  │
24
+                  ▼
25
+┌─────────────────────────────────────────────────────────┐
26
+│  路由层 + 权限装饰器                                      │
27
+│  ├─ @require_role(ADMIN, OPERATOR)                       │
28
+│  ├─ @require_permission("data:read")                     │
29
+│  └─ 输入验证 (SQL注入/XSS 防护)                           │
30
+└─────────────────┬───────────────────────────────────────┘
31
+                  │
32
+                  ▼
33
+┌─────────────────────────────────────────────────────────┐
34
+│  业务逻辑层                                              │
35
+│  ├─ 数据加密 (AES-256-GCM)                               │
36
+│  ├─ 响应脱敏 (手机号/身份证/邮箱)                          │
37
+│  └─ 审计事件触发                                         │
38
+└─────────────────────────────────────────────────────────┘
39
+```
40
+
41
+## JWT 认证流程
42
+
43
+### 登录流程
44
+
45
+```
46
+Client                    Server
47
+  │                         │
48
+  │  POST /auth/login       │
49
+  │  {username, password}   │
50
+  │ ───────────────────────>│
51
+  │                         │  1. 验证用户名/密码
52
+  │                         │  2. 检查账户状态
53
+  │                         │  3. 生成 JWT tokens
54
+  │  {access_token,         │
55
+  │   refresh_token}        │
56
+  │ <───────────────────────│
57
+  │                         │
58
+```
59
+
60
+### Token 使用
61
+
62
+```
63
+Client                    Server
64
+  │                         │
65
+  │  GET /api/resource      │
66
+  │  Authorization: Bearer  │
67
+  │  <access_token>         │
68
+  │ ───────────────────────>│
69
+  │                         │  1. 提取 Bearer token
70
+  │                         │  2. 验证签名和有效期
71
+  │                         │  3. 提取用户信息和角色
72
+  │                         │  4. 检查权限
73
+  │  200 OK                 │
74
+  │ <───────────────────────│
75
+```
76
+
77
+### Token 刷新
78
+
79
+```
80
+Client                    Server
81
+  │                         │
82
+  │  POST /auth/refresh     │
83
+  │  {refresh_token}        │
84
+  │ ───────────────────────>│
85
+  │                         │  1. 验证 refresh_token
86
+  │                         │  2. 生成新 access_token
87
+  │  {access_token}         │
88
+  │ <───────────────────────│
89
+```
90
+
91
+### Token 说明
92
+
93
+| 类型 | 有效期 | 用途 |
94
+|------|--------|------|
95
+| Access Token | 30 分钟 | API 请求认证 |
96
+| Refresh Token | 7 天 | 刷新 Access Token |
97
+
98
+## RBAC 角色说明
99
+
100
+### 角色定义
101
+
102
+| 角色 | 说明 | 典型用户 |
103
+|------|------|----------|
104
+| ADMIN | 系统管理员,拥有所有权限 | IT 管理员 |
105
+| OPERATOR | 操作员,可读写业务数据 | 值班操作员 |
106
+| VIEWER | 只读用户,只能查看数据 | 管理层查看 |
107
+| DEVICE | IoT 设备专用 | 传感器/网关 |
108
+
109
+### 权限矩阵
110
+
111
+| 权限 | ADMIN | OPERATOR | VIEWER | DEVICE |
112
+|------|-------|----------|--------|--------|
113
+| system:manage | ✅ | ❌ | ❌ | ❌ |
114
+| users:manage | ✅ | ❌ | ❌ | ❌ |
115
+| roles:manage | ✅ | ❌ | ❌ | ❌ |
116
+| devices:manage | ✅ | ✅ | ❌ | ❌ |
117
+| data:read | ✅ | ✅ | ✅ | ✅ |
118
+| data:write | ✅ | ✅ | ❌ | ✅ |
119
+| data:delete | ✅ | ❌ | ❌ | ❌ |
120
+| audit:read | ✅ | ❌ | ❌ | ❌ |
121
+| config:manage | ✅ | ❌ | ❌ | ❌ |
122
+| reports:read | ✅ | ✅ | ✅ | ❌ |
123
+| reports:export | ✅ | ✅ | ❌ | ❌ |
124
+
125
+## 加密方案
126
+
127
+### AES-256-GCM
128
+
129
+系统使用 AES-256-GCM 模式加密敏感数据:
130
+
131
+- **密钥**: 从环境变量 `AES_ENCRYPTION_KEY` 读取
132
+- **IV**: 每次加密自动生成 96-bit 随机 nonce
133
+- **认证**: GCM 模式自带完整性验证
134
+- **编码**: 密文以 Base64 格式存储
135
+
136
+### 适用场景
137
+
138
+- 数据库敏感字段加密(手机号、身份证号等)
139
+- 配置文件加密存储
140
+- API 响应数据脱敏前的安全存储
141
+
142
+### 密钥管理
143
+
144
+```bash
145
+# 生成 32 字节密钥
146
+python -c "import secrets; print(secrets.token_hex(16))"
147
+
148
+# 设置环境变量
149
+export AES_ENCRYPTION_KEY="your-hex-key-here"
150
+export JWT_SECRET_KEY="your-jwt-secret-here"
151
+```
152
+
153
+## API 安全最佳实践
154
+
155
+### 1. 始终使用 HTTPS
156
+
157
+生产环境必须使用 HTTPS,禁止明文 HTTP 传输敏感数据。
158
+
159
+### 2. 请求限流
160
+
161
+系统默认限制每 IP 每分钟 60 次请求,可通过环境变量调整:
162
+
163
+```bash
164
+export RATE_LIMIT_PER_MINUTE=100
165
+```
166
+
167
+### 3. 输入验证
168
+
169
+所有用户输入都经过以下检查:
170
+
171
+- **SQL 注入防护**: 检测常见注入模式,使用参数化查询
172
+- **XSS 防护**: HTML 转义所有输出内容
173
+- **文件上传验证**: 类型、大小、内容三重检查
174
+
175
+### 4. 安全 Headers
176
+
177
+系统自动添加以下安全响应头:
178
+
179
+| Header | 值 | 说明 |
180
+|--------|------|------|
181
+| X-Content-Type-Options | nosniff | 禁止 MIME 嗅探 |
182
+| X-Frame-Options | DENY | 禁止嵌套框架 |
183
+| X-XSS-Protection | 1; mode=block | XSS 过滤器 |
184
+| Strict-Transport-Security | max-age=31536000 | 强制 HTTPS |
185
+| Content-Security-Policy | default-src 'self' | CSP 策略 |
186
+| Referrer-Policy | strict-origin-when-cross-origin | 引用来源策略 |
187
+
188
+### 5. 响应脱敏
189
+
190
+敏感数据在返回客户端前自动脱敏:
191
+
192
+```python
193
+from src.security import mask_phone, mask_id_card, mask_email
194
+
195
+mask_phone("13812345678")     # → "138****5678"
196
+mask_id_card("110101199001011234")  # → "1101**********1234"
197
+mask_email("test@example.com")    # → "t***@example.com"
198
+```
199
+
200
+## 安全审计
201
+
202
+### 审计事件类型
203
+
204
+| 事件 | 说明 | 严重程度 |
205
+|------|------|----------|
206
+| login | 用户登录成功 | INFO |
207
+| login_failed | 登录失败 | WARNING |
208
+| logout | 用户登出 | INFO |
209
+| user_create | 创建用户 | INFO |
210
+| role_change | 角色变更 | WARNING |
211
+| permission_change | 权限变更 | WARNING |
212
+| data_delete | 数据删除 | WARNING |
213
+| config_change | 配置变更 | CRITICAL |
214
+| data_export | 数据导出 | INFO |
215
+
216
+### 告警规则
217
+
218
+| 规则 | 触发条件 | 严重程度 |
219
+|------|----------|----------|
220
+| 多次登录失败 | 15 分钟内 5 次失败 | CRITICAL |
221
+| 频繁权限变更 | 1 小时内 10 次变更 | WARNING |
222
+| 批量数据删除 | 10 分钟内 20 次删除 | CRITICAL |
223
+
224
+### 审计日志查询
225
+
226
+```bash
227
+# 查询最近审计日志
228
+curl -H "Authorization: Bearer <token>" \
229
+  http://localhost:8000/audit/logs
230
+
231
+# 按用户过滤
232
+curl -H "Authorization: Bearer <token>" \
233
+  "http://localhost:8000/audit/logs?user_id=admin"
234
+
235
+# 查看安全告警
236
+curl -H "Authorization: Bearer <token>" \
237
+  http://localhost:8000/audit/alerts
238
+```
239
+
240
+## 安全事件响应流程
241
+
242
+### 1. 检测阶段
243
+
244
+- 审计日志自动记录所有操作
245
+- 高危操作自动触发告警
246
+- 异常登录行为(多次失败、异地登录)标红
247
+
248
+### 2. 响应阶段
249
+
250
+1. **确认事件**: 查看审计日志确认事件详情
251
+2. **隔离影响**: 必要时禁用相关账户
252
+3. **修复漏洞**: 定位并修复安全问题
253
+4. **通知相关方**: 如涉及数据泄露,通知受影响用户
254
+
255
+### 3. 复盘阶段
256
+
257
+1. 记录事件处理过程
258
+2. 分析根本原因
259
+3. 更新安全策略
260
+4. 优化告警规则
261
+
262
+## 环境变量配置
263
+
264
+```bash
265
+# JWT 配置
266
+JWT_SECRET_KEY=your-secret-key-here
267
+JWT_ALGORITHM=HS256
268
+JWT_ACCESS_TOKEN_EXPIRE=30
269
+JWT_REFRESH_TOKEN_EXPIRE=7
270
+
271
+# 加密配置
272
+AES_ENCRYPTION_KEY=your-32-char-hex-key
273
+
274
+# 限流配置
275
+RATE_LIMIT_PER_MINUTE=60
276
+MAX_LOGIN_ATTEMPTS=5
277
+LOCKOUT_DURATION_MINUTES=15
278
+
279
+# CORS 配置
280
+CORS_ALLOWED_ORIGINS=https://yourdomain.com,https://admin.yourdomain.com
281
+
282
+# 审计配置
283
+AUDIT_ENABLED=true
284
+AUDIT_LOG_FILE=logs/audit.log
285
+AUDIT_DB_STORAGE=true
286
+AUDIT_ALERT_FAILED_LOGINS=5
287
+
288
+# 密码策略
289
+PASSWORD_MIN_LENGTH=8
290
+PASSWORD_REQUIRE_UPPERCASE=true
291
+PASSWORD_REQUIRE_LOWERCASE=true
292
+PASSWORD_REQUIRE_DIGIT=true
293
+PASSWORD_REQUIRE_SPECIAL=false
294
+```
295
+
296
+## 依赖版本
297
+
298
+| 包 | 版本 | 用途 |
299
+|-----|------|------|
300
+| PyJWT | ≥2.8.0 | JWT token 生成与验证 |
301
+| passlib | ≥1.7.4 | 密码哈希(bcrypt) |
302
+| bcrypt | ≥4.0.0 | 密码哈希算法 |
303
+| cryptography | ≥41.0.0 | AES-256-GCM 加密 |
304
+| FastAPI | ≥0.100.0 | Web 框架 |
305
+| starlette | ≥0.27.0 | ASGI 中间件基础 |
306
+
307
+---
308
+
309
+*最后更新: 2026-06-16*

+ 21
- 0
requirements.txt Voir le fichier

1
+# 核心依赖
2
+fastapi>=0.100.0
3
+uvicorn[standard]>=0.23.0
4
+python-multipart>=0.0.6
5
+
6
+# 安全模块依赖
7
+PyJWT>=2.8.0
8
+passlib[bcrypt]>=1.7.4
9
+bcrypt>=4.0.0
10
+cryptography>=41.0.0
11
+
12
+# 可选依赖
13
+slowapi>=0.1.9
14
+starlette>=0.27.0
15
+httpx>=0.24.0
16
+
17
+# 数据处理
18
+pandas>=2.0.0
19
+
20
+# WebSocket
21
+websockets>=11.0

+ 3
- 0
src/__init__.py Voir le fichier

1
+"""
2
+src 模块
3
+"""

+ 3
- 0
src/api/__init__.py Voir le fichier

1
+"""
2
+API 模块
3
+"""

+ 306
- 0
src/api/rest_api.py Voir le fichier

1
+"""
2
+REST API 路由(安全模块集成)
3
+包含认证端点、审计日志查询端点、权限示例
4
+"""
5
+import sys
6
+import os
7
+
8
+# 将项目根目录加入 Python 路径
9
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
10
+
11
+try:
12
+    from fastapi import FastAPI, HTTPException, Depends, Header, Request
13
+    from fastapi.responses import JSONResponse
14
+except ImportError:
15
+    FastAPI = None
16
+
17
+from src.security.auth import (
18
+    Role, hash_password, verify_password,
19
+    create_access_token, create_refresh_token,
20
+    verify_access_token, verify_refresh_token,
21
+    extract_token_from_header, validate_password_strength,
22
+    require_role,
23
+)
24
+from src.security.audit import AuditAction, AuditResult, AuditSeverity, get_audit_logger
25
+from src.security.middleware import SecurityMiddleware, cors_middleware, get_rate_limiter
26
+from src.security.encryption import mask_phone, mask_response
27
+from src.security.input_validator import sanitize_string, detect_sql_injection, detect_xss
28
+from src.security.config import security_config
29
+
30
+
31
+# 内存用户存储(演示用,生产环境应使用数据库)
32
+_users_db = {
33
+    "admin": {
34
+        "user_id": "admin",
35
+        "username": "admin",
36
+        "password_hash": hash_password("Admin@123") if True else "",
37
+        "roles": [Role.ADMIN.value],
38
+        "phone": "13800138000",
39
+    },
40
+    "operator": {
41
+        "user_id": "operator01",
42
+        "username": "operator",
43
+        "password_hash": hash_password("Operator@123") if True else "",
44
+        "roles": [Role.OPERATOR.value],
45
+        "phone": "13900139000",
46
+    },
47
+}
48
+
49
+
50
+def create_app() -> "FastAPI":
51
+    """创建并配置 FastAPI 应用"""
52
+    if FastAPI is None:
53
+        raise RuntimeError("FastAPI 未安装,请运行: pip install fastapi uvicorn")
54
+
55
+    app = FastAPI(
56
+        title="水务管理系统 API",
57
+        version="2.0.0",
58
+        description="集成安全模块的 REST API",
59
+    )
60
+
61
+    # 注册安全中间件
62
+    app.add_middleware(SecurityMiddleware, excluded_paths=[
63
+        "/auth/login", "/auth/refresh", "/health",
64
+        "/docs", "/openapi.json", "/redoc",
65
+    ])
66
+
67
+    # 注册 CORS 中间件
68
+    cors_middleware(app)
69
+
70
+    # ============ 认证端点 ============
71
+
72
+    @app.post("/auth/login")
73
+    async def login(request: Request):
74
+        """用户登录"""
75
+        try:
76
+            body = await request.json()
77
+        except Exception:
78
+            raise HTTPException(status_code=400, detail="请求体格式错误")
79
+
80
+        username = body.get("username", "")
81
+        password = body.get("password", "")
82
+
83
+        # 输入验证
84
+        username = sanitize_string(username, max_length=50)
85
+        if not username or not password:
86
+            raise HTTPException(status_code=400, detail="用户名和密码不能为空")
87
+
88
+        # SQL 注入检测
89
+        is_injection, msg = detect_sql_injection(username)
90
+        if is_injection:
91
+            raise HTTPException(status_code=400, detail="输入包含非法字符")
92
+
93
+        user = _users_db.get(username)
94
+        audit_logger = get_audit_logger()
95
+        client_ip = request.client.host if request.client else ""
96
+
97
+        if not user or not verify_password(password, user["password_hash"]):
98
+            audit_logger.log_action(
99
+                user_id=username,
100
+                action=AuditAction.LOGIN_FAILED.value,
101
+                ip_address=client_ip,
102
+                result=AuditResult.FAILURE,
103
+                severity=AuditSeverity.WARNING,
104
+                details={"reason": "用户名或密码错误"},
105
+            )
106
+            raise HTTPException(status_code=401, detail="用户名或密码错误")
107
+
108
+        # 生成 token
109
+        roles = [Role(r) for r in user["roles"]]
110
+        access_token = create_access_token(user["user_id"], roles)
111
+        refresh_token = create_refresh_token(user["user_id"])
112
+
113
+        audit_logger.log_action(
114
+            user_id=user["user_id"],
115
+            action=AuditAction.LOGIN.value,
116
+            ip_address=client_ip,
117
+            result=AuditResult.SUCCESS,
118
+        )
119
+
120
+        return {
121
+            "access_token": access_token,
122
+            "refresh_token": refresh_token,
123
+            "token_type": "bearer",
124
+            "user": {
125
+                "user_id": user["user_id"],
126
+                "username": user["username"],
127
+                "roles": user["roles"],
128
+            },
129
+        }
130
+
131
+    @app.post("/auth/refresh")
132
+    async def refresh_token(request: Request):
133
+        """刷新 access token"""
134
+        try:
135
+            body = await request.json()
136
+        except Exception:
137
+            raise HTTPException(status_code=400, detail="请求体格式错误")
138
+
139
+        refresh_token = body.get("refresh_token", "")
140
+        if not refresh_token:
141
+            raise HTTPException(status_code=400, detail="refresh_token 不能为空")
142
+
143
+        payload = verify_refresh_token(refresh_token)
144
+        if payload is None:
145
+            raise HTTPException(status_code=401, detail="refresh_token 无效或已过期")
146
+
147
+        user_id = payload.get("sub", "")
148
+        user = next((u for u in _users_db.values() if u["user_id"] == user_id), None)
149
+        if not user:
150
+            raise HTTPException(status_code=401, detail="用户不存在")
151
+
152
+        roles = [Role(r) for r in user["roles"]]
153
+        new_access_token = create_access_token(user_id, roles)
154
+
155
+        return {
156
+            "access_token": new_access_token,
157
+            "token_type": "bearer",
158
+        }
159
+
160
+    @app.post("/auth/logout")
161
+    async def logout(request: Request):
162
+        """用户登出"""
163
+        audit_logger = get_audit_logger()
164
+        user = getattr(request.scope, "user", {})
165
+        user_id = user.get("sub", "unknown") if isinstance(user, dict) else "unknown"
166
+        client_ip = request.client.host if request.client else ""
167
+
168
+        audit_logger.log_action(
169
+            user_id=user_id,
170
+            action=AuditAction.LOGOUT.value,
171
+            ip_address=client_ip,
172
+        )
173
+
174
+        return {"message": "已登出"}
175
+
176
+    # ============ 审计日志端点 ============
177
+
178
+    @app.get("/audit/logs")
179
+    async def get_audit_logs(
180
+        user_id: str = None,
181
+        action: str = None,
182
+        result: str = None,
183
+        limit: int = 100,
184
+        offset: int = 0,
185
+        current_user: dict = None,
186
+    ):
187
+        """查询审计日志(仅管理员)"""
188
+        # 权限检查
189
+        if not current_user:
190
+            raise HTTPException(status_code=401, detail="未认证")
191
+
192
+        user_roles = current_user.get("roles", [])
193
+        if Role.ADMIN.value not in user_roles:
194
+            raise HTTPException(status_code=403, detail="仅管理员可查询审计日志")
195
+
196
+        audit_logger = get_audit_logger()
197
+        logs = audit_logger.query(
198
+            user_id=user_id,
199
+            action=action,
200
+            result=result,
201
+            limit=limit,
202
+            offset=offset,
203
+        )
204
+        return {"logs": logs, "total": len(logs)}
205
+
206
+    @app.get("/audit/stats")
207
+    async def get_audit_stats(hours: int = 24, current_user: dict = None):
208
+        """获取审计统计(仅管理员)"""
209
+        if not current_user:
210
+            raise HTTPException(status_code=401, detail="未认证")
211
+
212
+        user_roles = current_user.get("roles", [])
213
+        if Role.ADMIN.value not in user_roles:
214
+            raise HTTPException(status_code=403, detail="仅管理员可查看统计")
215
+
216
+        audit_logger = get_audit_logger()
217
+        stats = audit_logger.get_stats(hours)
218
+        return stats
219
+
220
+    @app.get("/audit/alerts")
221
+    async def get_audit_alerts(limit: int = 50, current_user: dict = None):
222
+        """获取安全告警(仅管理员)"""
223
+        if not current_user:
224
+            raise HTTPException(status_code=401, detail="未认证")
225
+
226
+        user_roles = current_user.get("roles", [])
227
+        if Role.ADMIN.value not in user_roles:
228
+            raise HTTPException(status_code=403, detail="仅管理员可查看告警")
229
+
230
+        audit_logger = get_audit_logger()
231
+        alerts = audit_logger.get_alerts(limit)
232
+        return {"alerts": alerts}
233
+
234
+    # ============ 业务路由示例(带权限控制) ============
235
+
236
+    @app.get("/health")
237
+    async def health_check():
238
+        """健康检查(无需认证)"""
239
+        return {"status": "ok", "version": "2.0.0"}
240
+
241
+    @app.get("/api/devices")
242
+    async def list_devices(current_user: dict = None):
243
+        """设备列表(需要 data:read 权限)"""
244
+        if not current_user:
245
+            raise HTTPException(status_code=401, detail="未认证")
246
+        return {"devices": [], "message": "设备列表"}
247
+
248
+    @app.post("/api/devices")
249
+    async def create_device(request: Request, current_user: dict = None):
250
+        """创建设备(需要 OPERATOR 或 ADMIN 角色)"""
251
+        if not current_user:
252
+            raise HTTPException(status_code=401, detail="未认证")
253
+
254
+        user_roles = current_user.get("roles", [])
255
+        if Role.ADMIN.value not in user_roles and Role.OPERATOR.value not in user_roles:
256
+            raise HTTPException(status_code=403, detail="权限不足")
257
+
258
+        return {"message": "设备创建成功"}
259
+
260
+    @app.delete("/api/devices/{device_id}")
261
+    async def delete_device(device_id: str, current_user: dict = None):
262
+        """删除设备(仅管理员)"""
263
+        if not current_user:
264
+            raise HTTPException(status_code=401, detail="未认证")
265
+
266
+        user_roles = current_user.get("roles", [])
267
+        if Role.ADMIN.value not in user_roles:
268
+            raise HTTPException(status_code=403, detail="仅管理员可删除设备")
269
+
270
+        audit_logger = get_audit_logger()
271
+        audit_logger.log_action(
272
+            user_id=current_user.get("sub", ""),
273
+            action=AuditAction.DATA_DELETE.value,
274
+            resource=f"device:{device_id}",
275
+            severity=AuditSeverity.WARNING,
276
+        )
277
+
278
+        return {"message": f"设备 {device_id} 已删除"}
279
+
280
+    @app.get("/api/users/me")
281
+    async def get_current_user_info(current_user: dict = None):
282
+        """获取当前用户信息"""
283
+        if not current_user:
284
+            raise HTTPException(status_code=401, detail="未认证")
285
+
286
+        user_id = current_user.get("sub", "")
287
+        user = next((u for u in _users_db.values() if u["user_id"] == user_id), None)
288
+        if not user:
289
+            raise HTTPException(status_code=404, detail="用户不存在")
290
+
291
+        # 脱敏手机号
292
+        return {
293
+            "user_id": user["user_id"],
294
+            "username": user["username"],
295
+            "roles": user["roles"],
296
+            "phone": mask_phone(user.get("phone", "")),
297
+        }
298
+
299
+    return app
300
+
301
+
302
+# 导出 app 实例(供 uvicorn 使用)
303
+try:
304
+    app = create_app()
305
+except RuntimeError:
306
+    app = None

+ 137
- 0
src/security/__init__.py Voir le fichier

1
+"""
2
+安全模块 (src/security)
3
+统一导出所有安全组件
4
+
5
+包含:
6
+- auth: RBAC 角色权限 + JWT 认证
7
+- middleware: 安全中间件(JWT 认证、CORS、限流、CSRF、安全 Headers)
8
+- encryption: 数据加密(AES-256-GCM)+ 响应脱敏
9
+- input_validator: 输入验证 + 注入防护
10
+- audit: 操作审计日志
11
+- config: 安全配置
12
+"""
13
+
14
+# 认证与权限
15
+from .auth import (
16
+    Role,
17
+    ROLE_PERMISSIONS,
18
+    hash_password,
19
+    verify_password,
20
+    create_access_token,
21
+    create_refresh_token,
22
+    decode_token,
23
+    verify_access_token,
24
+    verify_refresh_token,
25
+    get_user_roles,
26
+    check_permission,
27
+    require_role,
28
+    require_permission,
29
+    extract_token_from_header,
30
+    validate_password_strength,
31
+)
32
+
33
+# 中间件
34
+from .middleware import (
35
+    SecurityMiddleware,
36
+    RateLimiter,
37
+    get_rate_limiter,
38
+    generate_csrf_token,
39
+    validate_csrf_token,
40
+    get_client_ip,
41
+    cors_middleware,
42
+)
43
+
44
+# 加密
45
+from .encryption import (
46
+    FieldEncryptor,
47
+    get_encryptor,
48
+    encrypt_field,
49
+    decrypt_field,
50
+    encrypt_dict,
51
+    decrypt_dict,
52
+    encrypt_config_file,
53
+    decrypt_config_file,
54
+    mask_phone,
55
+    mask_id_card,
56
+    mask_email,
57
+    mask_bank_card,
58
+    mask_name,
59
+    mask_ip,
60
+    mask_response,
61
+)
62
+
63
+# 输入验证
64
+from .input_validator import (
65
+    detect_sql_injection,
66
+    sanitize_sql_value,
67
+    detect_xss,
68
+    escape_html,
69
+    sanitize_html,
70
+    FileUploadValidator,
71
+    file_validator,
72
+    validate_file_upload,
73
+    sanitize_string,
74
+    sanitize_number,
75
+    validate_email_format,
76
+    validate_phone_format,
77
+    validate_id_card_format,
78
+)
79
+
80
+# 审计
81
+from .audit import (
82
+    AuditAction,
83
+    AuditResult,
84
+    AuditSeverity,
85
+    AuditEvent,
86
+    AuditLogger,
87
+    AuditStorage,
88
+    AuditAlertManager,
89
+    get_audit_logger,
90
+)
91
+
92
+# 配置
93
+from .config import (
94
+    security_config,
95
+    SecurityConfig,
96
+    JWTConfig,
97
+    EncryptionConfig,
98
+    RateLimitConfig,
99
+    CORSConfig,
100
+    PasswordPolicyConfig,
101
+    AuditConfig,
102
+)
103
+
104
+__all__ = [
105
+    # Auth
106
+    "Role", "ROLE_PERMISSIONS",
107
+    "hash_password", "verify_password",
108
+    "create_access_token", "create_refresh_token",
109
+    "decode_token", "verify_access_token", "verify_refresh_token",
110
+    "get_user_roles", "check_permission",
111
+    "require_role", "require_permission",
112
+    "extract_token_from_header", "validate_password_strength",
113
+    # Middleware
114
+    "SecurityMiddleware", "RateLimiter", "get_rate_limiter",
115
+    "generate_csrf_token", "validate_csrf_token",
116
+    "get_client_ip", "cors_middleware",
117
+    # Encryption
118
+    "FieldEncryptor", "get_encryptor",
119
+    "encrypt_field", "decrypt_field",
120
+    "encrypt_dict", "decrypt_dict",
121
+    "encrypt_config_file", "decrypt_config_file",
122
+    "mask_phone", "mask_id_card", "mask_email",
123
+    "mask_bank_card", "mask_name", "mask_ip", "mask_response",
124
+    # Input Validator
125
+    "detect_sql_injection", "sanitize_sql_value",
126
+    "detect_xss", "escape_html", "sanitize_html",
127
+    "FileUploadValidator", "file_validator", "validate_file_upload",
128
+    "sanitize_string", "sanitize_number",
129
+    "validate_email_format", "validate_phone_format", "validate_id_card_format",
130
+    # Audit
131
+    "AuditAction", "AuditResult", "AuditSeverity", "AuditEvent",
132
+    "AuditLogger", "AuditStorage", "AuditAlertManager", "get_audit_logger",
133
+    # Config
134
+    "security_config", "SecurityConfig",
135
+    "JWTConfig", "EncryptionConfig", "RateLimitConfig",
136
+    "CORSConfig", "PasswordPolicyConfig", "AuditConfig",
137
+]

+ 367
- 0
src/security/audit.py Voir le fichier

1
+"""
2
+操作审计日志模块
3
+审计事件模型、记录器、查询 API、高危告警、双写存储
4
+"""
5
+import os
6
+import json
7
+import time
8
+import logging
9
+from enum import Enum
10
+from typing import Optional, List, Dict, Any
11
+from datetime import datetime, timezone, timedelta
12
+from dataclasses import dataclass, field, asdict
13
+from collections import defaultdict
14
+
15
+from .config import security_config
16
+
17
+logger = logging.getLogger(__name__)
18
+
19
+
20
+class AuditAction(str, Enum):
21
+    """审计操作类型"""
22
+    LOGIN = "login"
23
+    LOGOUT = "logout"
24
+    LOGIN_FAILED = "login_failed"
25
+    USER_CREATE = "user_create"
26
+    USER_UPDATE = "user_update"
27
+    USER_DELETE = "user_delete"
28
+    ROLE_CHANGE = "role_change"
29
+    PERMISSION_CHANGE = "permission_change"
30
+    PASSWORD_CHANGE = "password_change"
31
+    PASSWORD_RESET = "password_reset"
32
+    DATA_READ = "data_read"
33
+    DATA_WRITE = "data_write"
34
+    DATA_DELETE = "data_delete"
35
+    DATA_EXPORT = "data_export"
36
+    CONFIG_CHANGE = "config_change"
37
+    SYSTEM_RESTART = "system_restart"
38
+    DEVICE_REGISTER = "device_register"
39
+    DEVICE_DELETE = "device_delete"
40
+    API_ACCESS = "api_access"
41
+    FILE_UPLOAD = "file_upload"
42
+    FILE_DOWNLOAD = "file_download"
43
+
44
+
45
+class AuditResult(str, Enum):
46
+    """审计结果"""
47
+    SUCCESS = "success"
48
+    FAILURE = "failure"
49
+    DENIED = "denied"
50
+    ERROR = "error"
51
+
52
+
53
+class AuditSeverity(str, Enum):
54
+    """审计事件严重程度"""
55
+    INFO = "info"
56
+    WARNING = "warning"
57
+    CRITICAL = "critical"
58
+
59
+
60
+@dataclass
61
+class AuditEvent:
62
+    """审计事件模型"""
63
+    timestamp: datetime
64
+    user_id: str
65
+    action: str
66
+    resource: str = ""
67
+    ip_address: str = ""
68
+    user_agent: str = ""
69
+    result: AuditResult = AuditResult.SUCCESS
70
+    severity: AuditSeverity = AuditSeverity.INFO
71
+    details: Dict[str, Any] = field(default_factory=dict)
72
+    event_id: str = ""
73
+
74
+    def __post_init__(self):
75
+        if not self.event_id:
76
+            import secrets
77
+            self.event_id = secrets.token_hex(16)
78
+        if isinstance(self.result, str):
79
+            self.result = AuditResult(self.result)
80
+        if isinstance(self.severity, str):
81
+            self.severity = AuditSeverity(self.severity)
82
+
83
+    def to_dict(self) -> dict:
84
+        """转换为字典"""
85
+        d = asdict(self)
86
+        d["timestamp"] = self.timestamp.isoformat()
87
+        d["result"] = self.result.value
88
+        d["severity"] = self.severity.value
89
+        return d
90
+
91
+    def to_json(self) -> str:
92
+        """转换为 JSON 字符串"""
93
+        return json.dumps(self.to_dict(), ensure_ascii=False)
94
+
95
+
96
+class AlertRule:
97
+    """告警规则"""
98
+
99
+    def __init__(
100
+        self,
101
+        name: str,
102
+        action: str,
103
+        max_count: int,
104
+        window_seconds: int,
105
+        severity: AuditSeverity = AuditSeverity.WARNING,
106
+    ):
107
+        self.name = name
108
+        self.action = action
109
+        self.max_count = max_count
110
+        self.window_seconds = window_seconds
111
+        self.severity = severity
112
+
113
+
114
+class AuditAlertManager:
115
+    """审计告警管理器"""
116
+
117
+    def __init__(self):
118
+        self._event_counts: Dict[str, List[float]] = defaultdict(list)
119
+        self.rules: List[AlertRule] = [
120
+            AlertRule(
121
+                name="多次登录失败",
122
+                action=AuditAction.LOGIN_FAILED.value,
123
+                max_count=security_config.audit.alert_on_failed_logins,
124
+                window_seconds=900,  # 15 分钟
125
+                severity=AuditSeverity.CRITICAL,
126
+            ),
127
+            AlertRule(
128
+                name="频繁权限变更",
129
+                action=AuditAction.PERMISSION_CHANGE.value,
130
+                max_count=10,
131
+                window_seconds=3600,  # 1 小时
132
+                severity=AuditSeverity.WARNING,
133
+            ),
134
+            AlertRule(
135
+                name="批量数据删除",
136
+                action=AuditAction.DATA_DELETE.value,
137
+                max_count=20,
138
+                window_seconds=600,  # 10 分钟
139
+                severity=AuditSeverity.CRITICAL,
140
+            ),
141
+        ]
142
+        self._alerts: List[Dict[str, Any]] = []
143
+
144
+    def check_alert(self, event: AuditEvent) -> Optional[Dict[str, Any]]:
145
+        """检查事件是否触发告警"""
146
+        now = time.time()
147
+
148
+        for rule in self.rules:
149
+            if event.action == rule.action:
150
+                key = f"{rule.name}:{event.user_id}"
151
+                cutoff = now - rule.window_seconds
152
+
153
+                # 清理过期记录
154
+                self._event_counts[key] = [
155
+                    t for t in self._event_counts[key] if t > cutoff
156
+                ]
157
+                self._event_counts[key].append(now)
158
+
159
+                if len(self._event_counts[key]) >= rule.max_count:
160
+                    alert = {
161
+                        "rule_name": rule.name,
162
+                        "user_id": event.user_id,
163
+                        "action": event.action,
164
+                        "count": len(self._event_counts[key]),
165
+                        "window_seconds": rule.window_seconds,
166
+                        "severity": rule.severity.value,
167
+                        "timestamp": datetime.now(timezone.utc).isoformat(),
168
+                        "ip_address": event.ip_address,
169
+                    }
170
+                    self._alerts.append(alert)
171
+                    logger.warning(f"审计告警: {json.dumps(alert, ensure_ascii=False)}")
172
+                    return alert
173
+
174
+        return None
175
+
176
+    def get_alerts(self, limit: int = 50) -> List[Dict[str, Any]]:
177
+        """获取最近的告警列表"""
178
+        return self._alerts[-limit:]
179
+
180
+
181
+class AuditStorage:
182
+    """审计日志存储(文件 + 内存双写)"""
183
+
184
+    def __init__(self, log_file_path: Optional[str] = None):
185
+        self.log_file_path = log_file_path or security_config.audit.log_file_path
186
+        self._memory_store: List[Dict[str, Any]] = []
187
+        self._max_memory = 10000  # 内存中最多保留 10000 条
188
+
189
+        # 确保日志目录存在
190
+        log_dir = os.path.dirname(self.log_file_path)
191
+        if log_dir:
192
+            os.makedirs(log_dir, exist_ok=True)
193
+
194
+    def write(self, event: AuditEvent) -> None:
195
+        """写入审计日志"""
196
+        event_dict = event.to_dict()
197
+
198
+        # 写入内存
199
+        self._memory_store.append(event_dict)
200
+        if len(self._memory_store) > self._max_memory:
201
+            self._memory_store = self._memory_store[-self._max_memory:]
202
+
203
+        # 写入文件
204
+        try:
205
+            with open(self.log_file_path, "a", encoding="utf-8") as f:
206
+                f.write(json.dumps(event_dict, ensure_ascii=False) + "\n")
207
+        except Exception as e:
208
+            logger.error(f"审计日志写入文件失败: {e}")
209
+
210
+    def query(
211
+        self,
212
+        user_id: Optional[str] = None,
213
+        action: Optional[str] = None,
214
+        result: Optional[str] = None,
215
+        start_time: Optional[datetime] = None,
216
+        end_time: Optional[datetime] = None,
217
+        severity: Optional[str] = None,
218
+        limit: int = 100,
219
+        offset: int = 0,
220
+    ) -> List[Dict[str, Any]]:
221
+        """
222
+        从内存存储查询审计日志
223
+
224
+        Args:
225
+            user_id: 按用户 ID 过滤
226
+            action: 按操作类型过滤
227
+            result: 按结果过滤
228
+            start_time: 开始时间
229
+            end_time: 结束时间
230
+            severity: 按严重程度过滤
231
+            limit: 返回条数限制
232
+            offset: 偏移量
233
+
234
+        Returns:
235
+            审计日志列表
236
+        """
237
+        results = self._memory_store.copy()
238
+
239
+        if user_id:
240
+            results = [r for r in results if r.get("user_id") == user_id]
241
+        if action:
242
+            results = [r for r in results if r.get("action") == action]
243
+        if result:
244
+            results = [r for r in results if r.get("result") == result]
245
+        if severity:
246
+            results = [r for r in results if r.get("severity") == severity]
247
+        if start_time:
248
+            start_iso = start_time.isoformat()
249
+            results = [r for r in results if r.get("timestamp", "") >= start_iso]
250
+        if end_time:
251
+            end_iso = end_time.isoformat()
252
+            results = [r for r in results if r.get("timestamp", "") <= end_iso]
253
+
254
+        # 按时间倒序
255
+        results.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
256
+
257
+        return results[offset:offset + limit]
258
+
259
+    def get_stats(self, hours: int = 24) -> Dict[str, Any]:
260
+        """获取统计信息"""
261
+        cutoff = (datetime.now(timezone.utc) - timedelta(hours=hours)).isoformat()
262
+        recent = [r for r in self._memory_store if r.get("timestamp", "") >= cutoff]
263
+
264
+        action_counts = defaultdict(int)
265
+        result_counts = defaultdict(int)
266
+        user_counts = defaultdict(int)
267
+
268
+        for r in recent:
269
+            action_counts[r.get("action", "unknown")] += 1
270
+            result_counts[r.get("result", "unknown")] += 1
271
+            user_counts[r.get("user_id", "unknown")] += 1
272
+
273
+        return {
274
+            "period_hours": hours,
275
+            "total_events": len(recent),
276
+            "action_counts": dict(action_counts),
277
+            "result_counts": dict(result_counts),
278
+            "top_users": dict(sorted(user_counts.items(), key=lambda x: x[1], reverse=True)[:10]),
279
+        }
280
+
281
+
282
+class AuditLogger:
283
+    """审计日志记录器(主入口)"""
284
+
285
+    def __init__(self):
286
+        self.storage = AuditStorage()
287
+        self.alert_manager = AuditAlertManager()
288
+        self._enabled = security_config.audit.enabled
289
+
290
+    def log(self, event: AuditEvent) -> None:
291
+        """记录审计事件"""
292
+        if not self._enabled:
293
+            return
294
+
295
+        # 存储日志
296
+        self.storage.write(event)
297
+
298
+        # 检查告警
299
+        self.alert_manager.check_alert(event)
300
+
301
+    def log_action(
302
+        self,
303
+        user_id: str,
304
+        action: str,
305
+        resource: str = "",
306
+        ip_address: str = "",
307
+        user_agent: str = "",
308
+        result: AuditResult = AuditResult.SUCCESS,
309
+        severity: AuditSeverity = AuditSeverity.INFO,
310
+        details: Optional[Dict[str, Any]] = None,
311
+    ) -> None:
312
+        """便捷方法:记录审计操作"""
313
+        event = AuditEvent(
314
+            timestamp=datetime.now(timezone.utc),
315
+            user_id=user_id,
316
+            action=action,
317
+            resource=resource,
318
+            ip_address=ip_address,
319
+            user_agent=user_agent,
320
+            result=result,
321
+            severity=severity,
322
+            details=details or {},
323
+        )
324
+        self.log(event)
325
+
326
+    def query(
327
+        self,
328
+        user_id: Optional[str] = None,
329
+        action: Optional[str] = None,
330
+        result: Optional[str] = None,
331
+        start_time: Optional[datetime] = None,
332
+        end_time: Optional[datetime] = None,
333
+        severity: Optional[str] = None,
334
+        limit: int = 100,
335
+        offset: int = 0,
336
+    ) -> List[Dict[str, Any]]:
337
+        """查询审计日志"""
338
+        return self.storage.query(
339
+            user_id=user_id,
340
+            action=action,
341
+            result=result,
342
+            start_time=start_time,
343
+            end_time=end_time,
344
+            severity=severity,
345
+            limit=limit,
346
+            offset=offset,
347
+        )
348
+
349
+    def get_stats(self, hours: int = 24) -> Dict[str, Any]:
350
+        """获取审计统计"""
351
+        return self.storage.get_stats(hours)
352
+
353
+    def get_alerts(self, limit: int = 50) -> List[Dict[str, Any]]:
354
+        """获取告警列表"""
355
+        return self.alert_manager.get_alerts(limit)
356
+
357
+
358
+# 全局审计记录器实例
359
+_audit_logger: Optional[AuditLogger] = None
360
+
361
+
362
+def get_audit_logger() -> AuditLogger:
363
+    """获取全局审计记录器"""
364
+    global _audit_logger
365
+    if _audit_logger is None:
366
+        _audit_logger = AuditLogger()
367
+    return _audit_logger

+ 286
- 0
src/security/auth.py Voir le fichier

1
+"""
2
+RBAC 角色权限 + JWT 认证模块
3
+提供角色枚举、JWT token 生成/验证、权限装饰器、密码哈希
4
+"""
5
+import os
6
+import time
7
+import secrets
8
+from enum import Enum
9
+from typing import List, Optional, Set, Callable, Any
10
+from functools import wraps
11
+from datetime import datetime, timedelta, timezone
12
+
13
+try:
14
+    import jwt
15
+except ImportError:
16
+    jwt = None
17
+
18
+try:
19
+    from passlib.context import CryptContext
20
+except ImportError:
21
+    CryptContext = None
22
+
23
+from .config import security_config
24
+
25
+
26
+class Role(str, Enum):
27
+    """系统角色枚举"""
28
+    ADMIN = "admin"          # 管理员:完全权限
29
+    OPERATOR = "operator"    # 操作员:可读写,不能管理系统
30
+    VIEWER = "viewer"        # 只读用户
31
+    DEVICE = "device"        # IoT 设备专用
32
+
33
+
34
+# 角色权限映射
35
+ROLE_PERMISSIONS = {
36
+    Role.ADMIN: {
37
+        "system:manage", "users:manage", "roles:manage",
38
+        "devices:manage", "data:read", "data:write", "data:delete",
39
+        "audit:read", "config:manage", "reports:read", "reports:export",
40
+    },
41
+    Role.OPERATOR: {
42
+        "devices:manage", "data:read", "data:write",
43
+        "reports:read", "reports:export",
44
+    },
45
+    Role.VIEWER: {
46
+        "data:read", "reports:read",
47
+    },
48
+    Role.DEVICE: {
49
+        "data:write", "data:read",
50
+    },
51
+}
52
+
53
+
54
+# 密码哈希上下文
55
+pwd_context = None
56
+if CryptContext:
57
+    pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
58
+
59
+
60
+def hash_password(password: str) -> str:
61
+    """对密码进行哈希处理"""
62
+    if pwd_context is None:
63
+        raise RuntimeError("passlib 未安装,请运行: pip install passlib[bcrypt]")
64
+    return pwd_context.hash(password)
65
+
66
+
67
+def verify_password(plain_password: str, hashed_password: str) -> bool:
68
+    """验证明文密码与哈希密码"""
69
+    if pwd_context is None:
70
+        raise RuntimeError("passlib 未安装,请运行: pip install passlib[bcrypt]")
71
+    return pwd_context.verify(plain_password, hashed_password)
72
+
73
+
74
+def create_access_token(
75
+    subject: str,
76
+    roles: List[Role],
77
+    extra_claims: Optional[dict] = None,
78
+) -> str:
79
+    """
80
+    生成 JWT access token
81
+
82
+    Args:
83
+        subject: 用户标识(user_id 或 username)
84
+        roles: 用户角色列表
85
+        extra_claims: 额外的 claims
86
+
87
+    Returns:
88
+        JWT token 字符串
89
+    """
90
+    if jwt is None:
91
+        raise RuntimeError("PyJWT 未安装,请运行: pip install PyJWT")
92
+
93
+    cfg = security_config.jwt
94
+    now = datetime.now(timezone.utc)
95
+    expire = now + timedelta(minutes=cfg.access_token_expire_minutes)
96
+
97
+    payload = {
98
+        "sub": subject,
99
+        "roles": [r.value for r in roles],
100
+        "iat": int(now.timestamp()),
101
+        "exp": int(expire.timestamp()),
102
+        "iss": cfg.issuer,
103
+        "type": "access",
104
+        "jti": secrets.token_hex(16),
105
+    }
106
+    if extra_claims:
107
+        payload.update(extra_claims)
108
+
109
+    return jwt.encode(payload, cfg.secret_key, algorithm=cfg.algorithm)
110
+
111
+
112
+def create_refresh_token(subject: str) -> str:
113
+    """生成 JWT refresh token"""
114
+    if jwt is None:
115
+        raise RuntimeError("PyJWT 未安装,请运行: pip install PyJWT")
116
+
117
+    cfg = security_config.jwt
118
+    now = datetime.now(timezone.utc)
119
+    expire = now + timedelta(days=cfg.refresh_token_expire_days)
120
+
121
+    payload = {
122
+        "sub": subject,
123
+        "iat": int(now.timestamp()),
124
+        "exp": int(expire.timestamp()),
125
+        "iss": cfg.issuer,
126
+        "type": "refresh",
127
+        "jti": secrets.token_hex(16),
128
+    }
129
+    return jwt.encode(payload, cfg.secret_key, algorithm=cfg.algorithm)
130
+
131
+
132
+def decode_token(token: str) -> dict:
133
+    """
134
+    解码并验证 JWT token
135
+
136
+    Returns:
137
+        解码后的 payload 字典
138
+
139
+    Raises:
140
+        jwt.ExpiredSignatureError: token 已过期
141
+        jwt.InvalidTokenError: token 无效
142
+    """
143
+    if jwt is None:
144
+        raise RuntimeError("PyJWT 未安装,请运行: pip install PyJWT")
145
+
146
+    cfg = security_config.jwt
147
+    return jwt.decode(
148
+        token,
149
+        cfg.secret_key,
150
+        algorithms=[cfg.algorithm],
151
+        issuer=cfg.issuer,
152
+    )
153
+
154
+
155
+def verify_access_token(token: str) -> Optional[dict]:
156
+    """验证 access token,返回 payload 或 None"""
157
+    try:
158
+        payload = decode_token(token)
159
+        if payload.get("type") != "access":
160
+            return None
161
+        return payload
162
+    except Exception:
163
+        return None
164
+
165
+
166
+def verify_refresh_token(token: str) -> Optional[dict]:
167
+    """验证 refresh token,返回 payload 或 None"""
168
+    try:
169
+        payload = decode_token(token)
170
+        if payload.get("type") != "refresh":
171
+            return None
172
+        return payload
173
+    except Exception:
174
+        return None
175
+
176
+
177
+def get_user_roles(payload: dict) -> List[Role]:
178
+    """从 token payload 提取角色列表"""
179
+    roles_str = payload.get("roles", [])
180
+    roles = []
181
+    for r in roles_str:
182
+        try:
183
+            roles.append(Role(r))
184
+        except ValueError:
185
+            continue
186
+    return roles
187
+
188
+
189
+def check_permission(roles: List[Role], permission: str) -> bool:
190
+    """检查角色列表是否拥有指定权限"""
191
+    user_permissions: Set[str] = set()
192
+    for role in roles:
193
+        user_permissions.update(ROLE_PERMISSIONS.get(role, set()))
194
+    return permission in user_permissions
195
+
196
+
197
+def require_role(*allowed_roles: Role):
198
+    """
199
+    FastAPI 路由权限装饰器
200
+
201
+    用法:
202
+        @app.get("/admin/users")
203
+        @require_role(Role.ADMIN)
204
+        async def get_users(current_user: dict = Depends(get_current_user)):
205
+            ...
206
+    """
207
+    def decorator(func: Callable) -> Callable:
208
+        @wraps(func)
209
+        async def wrapper(*args, **kwargs):
210
+            # 从 kwargs 中提取 current_user
211
+            current_user = kwargs.get("current_user")
212
+            if current_user is None:
213
+                raise PermissionError("未提供用户信息")
214
+
215
+            user_roles = current_user.get("roles", [])
216
+            if not any(r in [role.value for role in allowed_roles] for r in user_roles):
217
+                raise PermissionError(
218
+                    f"权限不足,需要角色: {[r.value for r in allowed_roles]}"
219
+                )
220
+            return await func(*args, **kwargs)
221
+        return wrapper
222
+    return decorator
223
+
224
+
225
+def require_permission(*permissions: str):
226
+    """
227
+    FastAPI 路由权限装饰器(基于权限名称)
228
+
229
+    用法:
230
+        @app.get("/data")
231
+        @require_permission("data:read")
232
+        async def get_data(current_user: dict = Depends(get_current_user)):
233
+            ...
234
+    """
235
+    def decorator(func: Callable) -> Callable:
236
+        @wraps(func)
237
+        async def wrapper(*args, **kwargs):
238
+            current_user = kwargs.get("current_user")
239
+            if current_user is None:
240
+                raise PermissionError("未提供用户信息")
241
+
242
+            user_roles = get_user_roles(current_user)
243
+            for perm in permissions:
244
+                if not check_permission(user_roles, perm):
245
+                    raise PermissionError(f"权限不足,缺少权限: {perm}")
246
+            return await func(*args, **kwargs)
247
+        return wrapper
248
+    return decorator
249
+
250
+
251
+def extract_token_from_header(authorization: Optional[str]) -> Optional[str]:
252
+    """从 Authorization header 提取 Bearer token"""
253
+    if not authorization:
254
+        return None
255
+    parts = authorization.split()
256
+    if len(parts) != 2 or parts[0].lower() != "bearer":
257
+        return None
258
+    return parts[1]
259
+
260
+
261
+def validate_password_strength(password: str) -> tuple[bool, str]:
262
+    """
263
+    验证密码强度是否符合策略
264
+
265
+    Returns:
266
+        (是否通过, 错误信息)
267
+    """
268
+    policy = security_config.password_policy
269
+    errors = []
270
+
271
+    if len(password) < policy.min_length:
272
+        errors.append(f"密码长度至少 {policy.min_length} 位")
273
+    if policy.require_uppercase and not any(c.isupper() for c in password):
274
+        errors.append("密码需包含大写字母")
275
+    if policy.require_lowercase and not any(c.islower() for c in password):
276
+        errors.append("密码需包含小写字母")
277
+    if policy.require_digit and not any(c.isdigit() for c in password):
278
+        errors.append("密码需包含数字")
279
+    if policy.require_special:
280
+        special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
281
+        if not any(c in special_chars for c in password):
282
+            errors.append("密码需包含特殊字符")
283
+
284
+    if errors:
285
+        return False, "; ".join(errors)
286
+    return True, ""

+ 79
- 0
src/security/config.py Voir le fichier

1
+"""
2
+安全配置模块
3
+从环境变量读取所有安全相关配置,提供默认值
4
+"""
5
+import os
6
+from dataclasses import dataclass, field
7
+from typing import List
8
+
9
+
10
+@dataclass
11
+class JWTConfig:
12
+    """JWT 配置"""
13
+    secret_key: str = os.getenv("JWT_SECRET_KEY", "change-me-in-production-use-a-real-secret")
14
+    algorithm: str = os.getenv("JWT_ALGORITHM", "HS256")
15
+    access_token_expire_minutes: int = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRE", "30"))
16
+    refresh_token_expire_days: int = int(os.getenv("JWT_REFRESH_TOKEN_EXPIRE", "7"))
17
+    issuer: str = os.getenv("JWT_ISSUER", "water-management-system")
18
+
19
+
20
+@dataclass
21
+class EncryptionConfig:
22
+    """加密配置"""
23
+    aes_key: str = os.getenv("AES_ENCRYPTION_KEY", "0123456789abcdef0123456789abcdef")
24
+    aes_iv: str = os.getenv("AES_IV", "")  # 留空则自动生成随机 IV
25
+
26
+
27
+@dataclass
28
+class RateLimitConfig:
29
+    """限流配置"""
30
+    max_requests_per_minute: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60"))
31
+    max_login_attempts: int = int(os.getenv("MAX_LOGIN_ATTEMPTS", "5"))
32
+    lockout_duration_minutes: int = int(os.getenv("LOCKOUT_DURATION_MINUTES", "15"))
33
+
34
+
35
+@dataclass
36
+class CORSConfig:
37
+    """CORS 配置"""
38
+    allowed_origins: List[str] = field(default_factory=lambda: [
39
+        origin.strip()
40
+        for origin in os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:8080").split(",")
41
+    ])
42
+    allowed_methods: List[str] = field(default_factory=lambda: ["GET", "POST", "PUT", "DELETE", "OPTIONS"])
43
+    allowed_headers: List[str] = field(default_factory=lambda: ["Authorization", "Content-Type", "X-CSRF-Token"])
44
+    allow_credentials: bool = True
45
+    max_age: int = 3600
46
+
47
+
48
+@dataclass
49
+class PasswordPolicyConfig:
50
+    """密码策略配置"""
51
+    min_length: int = int(os.getenv("PASSWORD_MIN_LENGTH", "8"))
52
+    require_uppercase: bool = os.getenv("PASSWORD_REQUIRE_UPPERCASE", "true").lower() == "true"
53
+    require_lowercase: bool = os.getenv("PASSWORD_REQUIRE_LOWERCASE", "true").lower() == "true"
54
+    require_digit: bool = os.getenv("PASSWORD_REQUIRE_DIGIT", "true").lower() == "true"
55
+    require_special: bool = os.getenv("PASSWORD_REQUIRE_SPECIAL", "false").lower() == "true"
56
+
57
+
58
+@dataclass
59
+class AuditConfig:
60
+    """审计日志配置"""
61
+    enabled: bool = os.getenv("AUDIT_ENABLED", "true").lower() == "true"
62
+    log_file_path: str = os.getenv("AUDIT_LOG_FILE", "logs/audit.log")
63
+    db_storage_enabled: bool = os.getenv("AUDIT_DB_STORAGE", "true").lower() == "true"
64
+    alert_on_failed_logins: int = int(os.getenv("AUDIT_ALERT_FAILED_LOGINS", "5"))
65
+
66
+
67
+@dataclass
68
+class SecurityConfig:
69
+    """统一安全配置"""
70
+    jwt: JWTConfig = field(default_factory=JWTConfig)
71
+    encryption: EncryptionConfig = field(default_factory=EncryptionConfig)
72
+    rate_limit: RateLimitConfig = field(default_factory=RateLimitConfig)
73
+    cors: CORSConfig = field(default_factory=CORSConfig)
74
+    password_policy: PasswordPolicyConfig = field(default_factory=PasswordPolicyConfig)
75
+    audit: AuditConfig = field(default_factory=AuditConfig)
76
+
77
+
78
+# 全局配置实例
79
+security_config = SecurityConfig()

+ 270
- 0
src/security/encryption.py Voir le fichier

1
+"""
2
+数据加密模块
3
+AES-256-GCM 加密、字段级加密、配置加密、响应脱敏
4
+"""
5
+import os
6
+import re
7
+import base64
8
+import json
9
+from typing import Optional, Any, Dict
10
+
11
+try:
12
+    from cryptography.hazmat.primitives.ciphers.aead import AESGCM
13
+    from cryptography.hazmat.primitives import hashes
14
+    from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
15
+except ImportError:
16
+    AESGCM = None
17
+
18
+from .config import security_config
19
+
20
+
21
+class FieldEncryptor:
22
+    """AES-256-GCM 字段加密器"""
23
+
24
+    def __init__(self, key: Optional[bytes] = None):
25
+        if AESGCM is None:
26
+            raise RuntimeError("cryptography 未安装,请运行: pip install cryptography")
27
+
28
+        if key is None:
29
+            key_hex = security_config.encryption.aes_key
30
+            if len(key_hex) == 32:
31
+                # 32 字节 hex = 16 bytes, 需要扩展到 32 bytes
32
+                key = self._derive_key(key_hex.encode(), b"water-mgmt-salt")
33
+            else:
34
+                key = key_hex.encode()[:32].ljust(32, b"\0")
35
+
36
+        self.aesgcm = AESGCM(key)
37
+
38
+    @staticmethod
39
+    def _derive_key(password: bytes, salt: bytes) -> bytes:
40
+        """使用 PBKDF2 派生 256 位密钥"""
41
+        kdf = PBKDF2HMAC(
42
+            algorithm=hashes.SHA256(),
43
+            length=32,
44
+            salt=salt,
45
+            iterations=100000,
46
+        )
47
+        return kdf.derive(password)
48
+
49
+    def encrypt(self, plaintext: str) -> str:
50
+        """
51
+        加密字符串
52
+
53
+        Args:
54
+            plaintext: 明文
55
+
56
+        Returns:
57
+            Base64 编码的密文(nonce + ciphertext)
58
+        """
59
+        nonce = os.urandom(12)  # 96-bit nonce for AES-GCM
60
+        ciphertext = self.aesgcm.encrypt(nonce, plaintext.encode("utf-8"), None)
61
+        # 拼接 nonce + ciphertext 并 base64 编码
62
+        combined = nonce + ciphertext
63
+        return base64.b64encode(combined).decode("ascii")
64
+
65
+    def decrypt(self, encrypted: str) -> str:
66
+        """
67
+        解密字符串
68
+
69
+        Args:
70
+            encrypted: Base64 编码的密文
71
+
72
+        Returns:
73
+            明文字符串
74
+        """
75
+        combined = base64.b64decode(encrypted)
76
+        nonce = combined[:12]
77
+        ciphertext = combined[12:]
78
+        plaintext = self.aesgcm.decrypt(nonce, ciphertext, None)
79
+        return plaintext.decode("utf-8")
80
+
81
+
82
+# 全局加密器实例(延迟初始化)
83
+_encryptor: Optional[FieldEncryptor] = None
84
+
85
+
86
+def get_encryptor() -> FieldEncryptor:
87
+    """获取全局加密器实例"""
88
+    global _encryptor
89
+    if _encryptor is None:
90
+        _encryptor = FieldEncryptor()
91
+    return _encryptor
92
+
93
+
94
+def encrypt_field(value: str) -> str:
95
+    """加密单个字段"""
96
+    if not value:
97
+        return value
98
+    return get_encryptor().encrypt(value)
99
+
100
+
101
+def decrypt_field(value: str) -> str:
102
+    """解密单个字段"""
103
+    if not value:
104
+        return value
105
+    try:
106
+        return get_encryptor().decrypt(value)
107
+    except Exception:
108
+        return value  # 解密失败返回原值(兼容未加密数据)
109
+
110
+
111
+def encrypt_dict(data: Dict[str, Any], fields: list) -> Dict[str, Any]:
112
+    """
113
+    加密字典中的指定字段
114
+
115
+    Args:
116
+        data: 原始字典
117
+        fields: 需要加密的字段列表
118
+
119
+    Returns:
120
+        加密后的字典
121
+    """
122
+    result = data.copy()
123
+    for field in fields:
124
+        if field in result and isinstance(result[field], str):
125
+            result[field] = encrypt_field(result[field])
126
+    return result
127
+
128
+
129
+def decrypt_dict(data: Dict[str, Any], fields: list) -> Dict[str, Any]:
130
+    """
131
+    解密字典中的指定字段
132
+
133
+    Args:
134
+        data: 加密后的字典
135
+        fields: 需要解密的字段列表
136
+
137
+    Returns:
138
+        解密后的字典
139
+    """
140
+    result = data.copy()
141
+    for field in fields:
142
+        if field in result and isinstance(result[field], str):
143
+            result[field] = decrypt_field(result[field])
144
+    return result
145
+
146
+
147
+def encrypt_config_file(input_path: str, output_path: str) -> None:
148
+    """
149
+    加密配置文件
150
+
151
+    Args:
152
+        input_path: 输入文件路径(JSON)
153
+        output_path: 输出文件路径(加密后)
154
+    """
155
+    with open(input_path, "r", encoding="utf-8") as f:
156
+        content = f.read()
157
+
158
+    encrypted = get_encryptor().encrypt(content)
159
+    with open(output_path, "w", encoding="utf-8") as f:
160
+        f.write(encrypted)
161
+
162
+
163
+def decrypt_config_file(input_path: str, output_path: Optional[str] = None) -> dict:
164
+    """
165
+    解密配置文件
166
+
167
+    Args:
168
+        input_path: 加密文件路径
169
+        output_path: 输出文件路径(可选)
170
+
171
+    Returns:
172
+        解密后的 JSON 字典
173
+    """
174
+    with open(input_path, "r", encoding="utf-8") as f:
175
+        encrypted = f.read()
176
+
177
+    decrypted = get_encryptor().decrypt(encrypted)
178
+    data = json.loads(decrypted)
179
+
180
+    if output_path:
181
+        with open(output_path, "w", encoding="utf-8") as f:
182
+            json.dump(data, f, indent=2, ensure_ascii=False)
183
+
184
+    return data
185
+
186
+
187
+# ============ 响应脱敏 ============
188
+
189
+def mask_phone(phone: str) -> str:
190
+    """手机号脱敏: 13812345678 -> 138****5678"""
191
+    if not phone or len(phone) < 7:
192
+        return phone
193
+    return phone[:3] + "****" + phone[-4:]
194
+
195
+
196
+def mask_id_card(id_card: str) -> str:
197
+    """身份证号脱敏: 110101199001011234 -> 1101**********1234"""
198
+    if not id_card or len(id_card) < 8:
199
+        return id_card
200
+    return id_card[:4] + "*" * (len(id_card) - 8) + id_card[-4:]
201
+
202
+
203
+def mask_email(email: str) -> str:
204
+    """邮箱脱敏: test@example.com -> t***@example.com"""
205
+    if not email or "@" not in email:
206
+        return email
207
+    local, domain = email.split("@", 1)
208
+    if len(local) <= 1:
209
+        masked_local = local + "***"
210
+    else:
211
+        masked_local = local[0] + "***"
212
+    return f"{masked_local}@{domain}"
213
+
214
+
215
+def mask_bank_card(card: str) -> str:
216
+    """银行卡号脱敏: 6225880137073520 -> 6225********3520"""
217
+    if not card or len(card) < 8:
218
+        return card
219
+    return card[:4] + "*" * (len(card) - 8) + card[-4:]
220
+
221
+
222
+def mask_name(name: str) -> str:
223
+    """姓名脱敏: 张三 -> 张*, 欧阳修 -> 欧阳*"""
224
+    if not name:
225
+        return name
226
+    if len(name) <= 1:
227
+        return name
228
+    return name[:-1] + "*"
229
+
230
+
231
+def mask_ip(ip: str) -> str:
232
+    """IP 脱敏: 192.168.1.100 -> 192.168.*.*"""
233
+    if not ip:
234
+        return ip
235
+    parts = ip.split(".")
236
+    if len(parts) == 4:
237
+        return f"{parts[0]}.{parts[1]}.*.*"
238
+    return ip
239
+
240
+
241
+# 脱敏函数映射
242
+MASK_FUNCTIONS = {
243
+    "phone": mask_phone,
244
+    "id_card": mask_id_card,
245
+    "email": mask_email,
246
+    "bank_card": mask_bank_card,
247
+    "name": mask_name,
248
+    "ip": mask_ip,
249
+}
250
+
251
+
252
+def mask_response(data: Dict[str, Any], mask_rules: Dict[str, str]) -> Dict[str, Any]:
253
+    """
254
+    对响应数据进行脱敏
255
+
256
+    Args:
257
+        data: 原始响应数据
258
+        mask_rules: 脱敏规则 {field_name: mask_type}
259
+                   mask_type: phone, id_card, email, bank_card, name, ip
260
+
261
+    Returns:
262
+        脱敏后的数据
263
+    """
264
+    result = data.copy()
265
+    for field, mask_type in mask_rules.items():
266
+        if field in result and result[field]:
267
+            mask_func = MASK_FUNCTIONS.get(mask_type)
268
+            if mask_func:
269
+                result[field] = mask_func(str(result[field]))
270
+    return result

+ 306
- 0
src/security/input_validator.py Voir le fichier

1
+"""
2
+输入验证 + 注入防护模块
3
+SQL 注入防护、XSS 防护、文件上传验证、通用输入清理
4
+"""
5
+import re
6
+import html
7
+import os
8
+import mimetypes
9
+from typing import Optional, Tuple, List, Dict, Any
10
+
11
+
12
+# ============ SQL 注入防护 ============
13
+
14
+# SQL 注入关键词模式
15
+SQL_INJECTION_PATTERNS = [
16
+    re.compile(r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|ALTER|CREATE|EXEC)\b.*\b(FROM|INTO|TABLE|WHERE|SET)\b)", re.IGNORECASE),
17
+    re.compile(r"(--|#|/\*|\*/|;)", re.IGNORECASE),
18
+    re.compile(r"(\bOR\b\s+\d+\s*=\s*\d+)", re.IGNORECASE),
19
+    re.compile(r"('\s*(OR|AND)\s+')", re.IGNORECASE),
20
+    re.compile(r"(WAITFOR\s+DELAY)", re.IGNORECASE),
21
+    re.compile(r"(BENCHMARK\s*\()", re.IGNORECASE),
22
+    re.compile(r"(SLEEP\s*\()", re.IGNORECASE),
23
+]
24
+
25
+
26
+def detect_sql_injection(value: str) -> Tuple[bool, str]:
27
+    """
28
+    检测字符串是否包含 SQL 注入模式
29
+
30
+    Returns:
31
+        (是否检测到注入, 描述信息)
32
+    """
33
+    if not isinstance(value, str):
34
+        return False, ""
35
+
36
+    for pattern in SQL_INJECTION_PATTERNS:
37
+        match = pattern.search(value)
38
+        if match:
39
+            return True, f"检测到疑似 SQL 注入: {match.group()}"
40
+
41
+    return False, ""
42
+
43
+
44
+def sanitize_sql_value(value: str) -> str:
45
+    """
46
+    清理可能包含 SQL 注入的输入值
47
+    注意:这是辅助手段,主要防护应使用参数化查询
48
+    """
49
+    if not isinstance(value, str):
50
+        return value
51
+
52
+    # 移除危险字符
53
+    dangerous_chars = ["'", '"', ";", "--", "#", "/*", "*/", "\\", "\x00"]
54
+    result = value
55
+    for char in dangerous_chars:
56
+        result = result.replace(char, "")
57
+
58
+    return result.strip()
59
+
60
+
61
+def parameterized_query(query_template: str, params: Dict[str, Any]) -> str:
62
+    """
63
+    生成参数化查询的安全说明
64
+    实际使用时应直接传入参数给数据库驱动
65
+
66
+    示例:
67
+        # 正确做法(使用 ORM 或数据库驱动的参数化)
68
+        cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
69
+
70
+        # 不要拼接 SQL
71
+        # cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
72
+    """
73
+    return query_template
74
+
75
+
76
+# ============ XSS 防护 ============
77
+
78
+# XSS 攻击模式
79
+XSS_PATTERNS = [
80
+    re.compile(r"<\s*script", re.IGNORECASE),
81
+    re.compile(r"javascript\s*:", re.IGNORECASE),
82
+    re.compile(r"on\w+\s*=", re.IGNORECASE),  # onclick, onerror, etc.
83
+    re.compile(r"<\s*iframe", re.IGNORECASE),
84
+    re.compile(r"<\s*object", re.IGNORECASE),
85
+    re.compile(r"<\s*embed", re.IGNORECASE),
86
+    re.compile(r"<\s*img[^>]+onerror", re.IGNORECASE),
87
+    re.compile(r"expression\s*\(", re.IGNORECASE),
88
+    re.compile(r"url\s*\(", re.IGNORECASE),
89
+    re.compile(r"data\s*:\s*text/html", re.IGNORECASE),
90
+]
91
+
92
+
93
+def detect_xss(value: str) -> Tuple[bool, str]:
94
+    """
95
+    检测字符串是否包含 XSS 攻击模式
96
+
97
+    Returns:
98
+        (是否检测到 XSS, 描述信息)
99
+    """
100
+    if not isinstance(value, str):
101
+        return False, ""
102
+
103
+    for pattern in XSS_PATTERNS:
104
+        match = pattern.search(value)
105
+        if match:
106
+            return True, f"检测到疑似 XSS 攻击: {match.group()}"
107
+
108
+    return False, ""
109
+
110
+
111
+def escape_html(value: str) -> str:
112
+    """HTML 转义,防止 XSS"""
113
+    if not isinstance(value, str):
114
+        return value
115
+    return html.escape(value, quote=True)
116
+
117
+
118
+def sanitize_html(value: str) -> str:
119
+    """
120
+    清理 HTML 内容,移除所有标签
121
+    用于需要纯文本的场景
122
+    """
123
+    if not isinstance(value, str):
124
+        return value
125
+
126
+    # 移除所有 HTML 标签
127
+    clean = re.sub(r"<[^>]+>", "", value)
128
+    # 转义剩余的特殊字符
129
+    return html.escape(clean, quote=True)
130
+
131
+
132
+# ============ 文件上传验证 ============
133
+
134
+# 允许的文件类型
135
+ALLOWED_MIME_TYPES = {
136
+    "image/jpeg": [".jpg", ".jpeg"],
137
+    "image/png": [".png"],
138
+    "image/gif": [".gif"],
139
+    "image/webp": [".webp"],
140
+    "application/pdf": [".pdf"],
141
+    "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": [".xlsx"],
142
+    "application/vnd.ms-excel": [".xls"],
143
+    "text/csv": [".csv"],
144
+    "application/json": [".json"],
145
+    "text/plain": [".txt", ".log"],
146
+}
147
+
148
+# 最大文件大小 (10 MB)
149
+MAX_FILE_SIZE = 10 * 1024 * 1024
150
+
151
+# 危险文件扩展名
152
+DANGEROUS_EXTENSIONS = {
153
+    ".exe", ".bat", ".cmd", ".sh", ".bash", ".csh",
154
+    ".py", ".pyw", ".php", ".jsp", ".asp", ".aspx",
155
+    ".html", ".htm", ".js", ".vbs", ".ps1",
156
+}
157
+
158
+
159
+class FileUploadValidator:
160
+    """文件上传验证器"""
161
+
162
+    def __init__(
163
+        self,
164
+        allowed_types: Optional[Dict[str, List[str]]] = None,
165
+        max_size: int = MAX_FILE_SIZE,
166
+    ):
167
+        self.allowed_types = allowed_types or ALLOWED_MIME_TYPES
168
+        self.max_size = max_size
169
+
170
+    def validate(
171
+        self,
172
+        filename: str,
173
+        content_type: Optional[str],
174
+        size: int,
175
+        content: Optional[bytes] = None,
176
+    ) -> Tuple[bool, str]:
177
+        """
178
+        验证上传文件
179
+
180
+        Args:
181
+            filename: 文件名
182
+            content_type: MIME 类型
183
+            size: 文件大小(字节)
184
+            content: 文件内容(可选,用于内容检查)
185
+
186
+        Returns:
187
+            (是否通过, 错误信息)
188
+        """
189
+        errors = []
190
+
191
+        # 1. 检查文件大小
192
+        if size > self.max_size:
193
+            max_mb = self.max_size / (1024 * 1024)
194
+            errors.append(f"文件大小超过限制(最大 {max_mb:.1f} MB)")
195
+
196
+        # 2. 检查文件扩展名
197
+        _, ext = os.path.splitext(filename.lower())
198
+        if ext in DANGEROUS_EXTENSIONS:
199
+            errors.append(f"不允许上传 {ext} 类型的文件")
200
+
201
+        # 3. 检查 MIME 类型
202
+        if content_type and content_type not in self.allowed_types:
203
+            errors.append(f"不支持的文件类型: {content_type}")
204
+
205
+        # 4. 检查扩展名与 MIME 类型匹配
206
+        if content_type and content_type in self.allowed_types:
207
+            allowed_exts = self.allowed_types[content_type]
208
+            if ext and ext not in allowed_exts:
209
+                errors.append(f"文件扩展名 {ext} 与类型 {content_type} 不匹配")
210
+
211
+        # 5. 内容检查(如果提供了内容)
212
+        if content:
213
+            # 检查文件头(magic bytes)
214
+            if content[:4] == b"\x89PNG" and content_type != "image/png":
215
+                errors.append("文件内容与声明的类型不匹配")
216
+            elif content[:2] == b"\xff\xd8" and content_type != "image/jpeg":
217
+                errors.append("文件内容与声明的类型不匹配")
218
+            elif content[:4] == b"%PDF" and content_type != "application/pdf":
219
+                errors.append("文件内容与声明的类型不匹配")
220
+
221
+            # 检查是否包含可执行代码
222
+            if b"<?php" in content or b"<%" in content:
223
+                errors.append("文件包含可疑内容")
224
+
225
+        if errors:
226
+            return False, "; ".join(errors)
227
+        return True, ""
228
+
229
+
230
+# 全局验证器实例
231
+file_validator = FileUploadValidator()
232
+
233
+
234
+def validate_file_upload(
235
+    filename: str,
236
+    content_type: Optional[str],
237
+    size: int,
238
+    content: Optional[bytes] = None,
239
+) -> Tuple[bool, str]:
240
+    """便捷函数:验证文件上传"""
241
+    return file_validator.validate(filename, content_type, size, content)
242
+
243
+
244
+# ============ 通用输入清理 ============
245
+
246
+def sanitize_string(value: str, max_length: int = 0) -> str:
247
+    """
248
+    通用字符串清理
249
+
250
+    Args:
251
+        value: 输入字符串
252
+        max_length: 最大长度(0 表示不限制)
253
+
254
+    Returns:
255
+        清理后的字符串
256
+    """
257
+    if not isinstance(value, str):
258
+        return str(value) if value is not None else ""
259
+
260
+    # 移除 null 字节
261
+    result = value.replace("\x00", "")
262
+
263
+    # 移除控制字符(保留换行和制表符)
264
+    result = re.sub(r"[\x01-\x08\x0b\x0c\x0e-\x1f\x7f]", "", result)
265
+
266
+    # 限制长度
267
+    if max_length > 0:
268
+        result = result[:max_length]
269
+
270
+    return result.strip()
271
+
272
+
273
+def sanitize_number(value: Any, min_val: Optional[float] = None, max_val: Optional[float] = None) -> Optional[float]:
274
+    """
275
+    验证和清理数字输入
276
+
277
+    Returns:
278
+        清理后的数字,无效时返回 None
279
+    """
280
+    try:
281
+        num = float(value)
282
+        if min_val is not None and num < min_val:
283
+            return None
284
+        if max_val is not None and num > max_val:
285
+            return None
286
+        return num
287
+    except (TypeError, ValueError):
288
+        return None
289
+
290
+
291
+def validate_email_format(email: str) -> bool:
292
+    """验证邮箱格式"""
293
+    pattern = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
294
+    return bool(pattern.match(email))
295
+
296
+
297
+def validate_phone_format(phone: str) -> bool:
298
+    """验证中国手机号格式"""
299
+    pattern = re.compile(r"^1[3-9]\d{9}$")
300
+    return bool(pattern.match(phone))
301
+
302
+
303
+def validate_id_card_format(id_card: str) -> bool:
304
+    """验证身份证号格式(简单校验)"""
305
+    pattern = re.compile(r"^\d{17}[\dXx]$")
306
+    return bool(pattern.match(id_card))

+ 291
- 0
src/security/middleware.py Voir le fichier

1
+"""
2
+安全中间件模块
3
+JWT 认证、CORS、限流、审计日志、CSRF、安全 Headers
4
+"""
5
+import os
6
+import re
7
+import time
8
+import secrets
9
+import hashlib
10
+import logging
11
+from typing import Optional, Dict, List, Callable
12
+from collections import defaultdict
13
+from datetime import datetime, timezone
14
+
15
+from .config import security_config
16
+from .auth import verify_access_token, extract_token_from_header
17
+from .audit import AuditLogger, AuditEvent, AuditResult
18
+
19
+logger = logging.getLogger(__name__)
20
+
21
+
22
+class RateLimiter:
23
+    """基于内存的 IP 限流器(令牌桶算法)"""
24
+
25
+    def __init__(self, max_requests: int = 60, window_seconds: int = 60):
26
+        self.max_requests = max_requests
27
+        self.window_seconds = window_seconds
28
+        self._requests: Dict[str, List[float]] = defaultdict(list)
29
+
30
+    def is_allowed(self, ip: str) -> bool:
31
+        """检查 IP 是否超过限流"""
32
+        now = time.time()
33
+        cutoff = now - self.window_seconds
34
+
35
+        # 清理过期记录
36
+        self._requests[ip] = [
37
+            t for t in self._requests[ip] if t > cutoff
38
+        ]
39
+
40
+        if len(self._requests[ip]) >= self.max_requests:
41
+            return False
42
+
43
+        self._requests[ip].append(now)
44
+        return True
45
+
46
+    def get_remaining(self, ip: str) -> int:
47
+        """获取剩余请求数"""
48
+        now = time.time()
49
+        cutoff = now - self.window_seconds
50
+        self._requests[ip] = [
51
+            t for t in self._requests[ip] if t > cutoff
52
+        ]
53
+        return max(0, self.max_requests - len(self._requests[ip]))
54
+
55
+    def cleanup(self):
56
+        """清理所有过期记录"""
57
+        now = time.time()
58
+        cutoff = now - self.window_seconds
59
+        expired_ips = []
60
+        for ip, timestamps in self._requests.items():
61
+            self._requests[ip] = [t for t in timestamps if t > cutoff]
62
+            if not self._requests[ip]:
63
+                expired_ips.append(ip)
64
+        for ip in expired_ips:
65
+            del self._requests[ip]
66
+
67
+
68
+# 全局限流器实例
69
+_rate_limiter = RateLimiter(
70
+    max_requests=security_config.rate_limit.max_requests_per_minute,
71
+    window_seconds=60,
72
+)
73
+
74
+
75
+def get_rate_limiter() -> RateLimiter:
76
+    """获取全局限流器"""
77
+    return _rate_limiter
78
+
79
+
80
+def generate_csrf_token(session_id: str) -> str:
81
+    """生成 CSRF token"""
82
+    secret = security_config.jwt.secret_key
83
+    timestamp = str(int(time.time()))
84
+    nonce = secrets.token_hex(8)
85
+    raw = f"{session_id}:{timestamp}:{nonce}:{secret}"
86
+    token = hashlib.sha256(raw.encode()).hexdigest()
87
+    return f"{timestamp}:{nonce}:{token}"
88
+
89
+
90
+def validate_csrf_token(token: str, session_id: str) -> bool:
91
+    """验证 CSRF token"""
92
+    try:
93
+        parts = token.split(":")
94
+        if len(parts) != 3:
95
+            return False
96
+        timestamp, nonce, hash_value = parts
97
+
98
+        # 检查时间戳(1小时内有效)
99
+        if abs(time.time() - int(timestamp)) > 3600:
100
+            return False
101
+
102
+        secret = security_config.jwt.secret_key
103
+        raw = f"{session_id}:{timestamp}:{nonce}:{secret}"
104
+        expected = hashlib.sha256(raw.encode()).hexdigest()
105
+        return secrets.compare_digest(hash_value, expected)
106
+    except Exception:
107
+        return False
108
+
109
+
110
+def get_client_ip(headers: dict, remote_addr: str = "") -> str:
111
+    """从请求头中提取客户端真实 IP"""
112
+    forwarded_for = headers.get("x-forwarded-for", "")
113
+    if forwarded_for:
114
+        return forwarded_for.split(",")[0].strip()
115
+
116
+    real_ip = headers.get("x-real-ip", "")
117
+    if real_ip:
118
+        return real_ip
119
+
120
+    return remote_addr
121
+
122
+
123
+class SecurityMiddleware:
124
+    """
125
+    FastAPI 安全中间件
126
+
127
+    用法:
128
+        app.add_middleware(SecurityMiddleware)
129
+    """
130
+
131
+    def __init__(
132
+        self,
133
+        app,
134
+        excluded_paths: Optional[List[str]] = None,
135
+    ):
136
+        self.app = app
137
+        self.excluded_paths = excluded_paths or [
138
+            "/auth/login",
139
+            "/auth/refresh",
140
+            "/health",
141
+            "/docs",
142
+            "/openapi.json",
143
+            "/redoc",
144
+        ]
145
+        self.audit_logger = AuditLogger()
146
+
147
+    async def __call__(self, scope, receive, send):
148
+        if scope["type"] != "http":
149
+            await self.app(scope, receive, send)
150
+            return
151
+
152
+        path = scope.get("path", "")
153
+
154
+        # 跳过排除路径
155
+        if any(path.startswith(excluded) for excluded in self.excluded_paths):
156
+            await self.app(scope, receive, send)
157
+            return
158
+
159
+        headers = dict(scope.get("headers", []))
160
+        headers = {
161
+            k.decode() if isinstance(k, bytes) else k:
162
+            v.decode() if isinstance(v, bytes) else v
163
+            for k, v in headers.items()
164
+        }
165
+
166
+        method = scope.get("method", "GET")
167
+        client_ip = get_client_ip(headers, scope.get("client", ("", ""))[0])
168
+
169
+        # 1. Rate Limiting
170
+        rate_limiter = get_rate_limiter()
171
+        if not rate_limiter.is_allowed(client_ip):
172
+            response_body = '{"detail": "请求过于频繁,请稍后再试"}'.encode("utf-8")
173
+            await send({
174
+                "type": "http.response.start",
175
+                "status": 429,
176
+                "headers": [
177
+                    (b"content-type", b"application/json"),
178
+                    (b"retry-after", b"60"),
179
+                ],
180
+            })
181
+            await send({
182
+                "type": "http.response.body",
183
+                "body": response_body,
184
+            })
185
+            return
186
+
187
+        # 2. Security Headers
188
+        original_send = send
189
+
190
+        async def send_with_security_headers(message):
191
+            if message["type"] == "http.response.start":
192
+                extra_headers = [
193
+                    (b"x-content-type-options", b"nosniff"),
194
+                    (b"x-frame-options", b"DENY"),
195
+                    (b"x-xss-protection", b"1; mode=block"),
196
+                    (b"strict-transport-security", b"max-age=31536000; includeSubDomains"),
197
+                    (b"content-security-policy", b"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'"),
198
+                    (b"referrer-policy", b"strict-origin-when-cross-origin"),
199
+                    (b"permissions-policy", b"camera=(), microphone=(), geolocation=()"),
200
+                ]
201
+                message["headers"] = list(message.get("headers", [])) + extra_headers
202
+            await original_send(message)
203
+
204
+        # 3. JWT 认证(对需要认证的路径)
205
+        auth_required = not any(
206
+            path.startswith(p) for p in self.excluded_paths
207
+        )
208
+
209
+        if auth_required and method not in ("OPTIONS",):
210
+            auth_header = headers.get("authorization", "")
211
+            token = extract_token_from_header(auth_header)
212
+
213
+            if not token:
214
+                response_body = '{"detail": "未提供认证凭据"}'.encode("utf-8")
215
+                await send({
216
+                    "type": "http.response.start",
217
+                    "status": 401,
218
+                    "headers": [(b"content-type", b"application/json")],
219
+                })
220
+                await send({
221
+                    "type": "http.response.body",
222
+                    "body": response_body,
223
+                })
224
+                return
225
+
226
+            payload = verify_access_token(token)
227
+            if payload is None:
228
+                response_body = '{"detail": "认证凭据无效或已过期"}'.encode("utf-8")
229
+                await send({
230
+                    "type": "http.response.start",
231
+                    "status": 401,
232
+                    "headers": [(b"content-type", b"application/json")],
233
+                })
234
+                await send({
235
+                    "type": "http.response.body",
236
+                    "body": response_body,
237
+                })
238
+                return
239
+
240
+            # 将用户信息注入 scope
241
+            scope["user"] = payload
242
+
243
+        # 4. 审计日志
244
+        start_time = time.time()
245
+        audit_event = AuditEvent(
246
+            timestamp=datetime.now(timezone.utc),
247
+            user_id=scope.get("user", {}).get("sub", "anonymous"),
248
+            action=f"{method} {path}",
249
+            resource=path,
250
+            ip_address=client_ip,
251
+            user_agent=headers.get("user-agent", ""),
252
+            result=AuditResult.SUCCESS,
253
+        )
254
+
255
+        await self.app(scope, receive, send_with_security_headers)
256
+
257
+        # 记录审计日志
258
+        duration = time.time() - start_time
259
+        audit_event.details = {"duration_ms": round(duration * 1000, 2)}
260
+        self.audit_logger.log(audit_event)
261
+
262
+
263
+def cors_middleware(app, config=None):
264
+    """
265
+    CORS 中间件包装器
266
+
267
+    用法:
268
+        from starlette.middleware.cors import CORSMiddleware
269
+        app.add_middleware(
270
+            CORSMiddleware,
271
+            allow_origins=security_config.cors.allowed_origins,
272
+            allow_credentials=security_config.cors.allow_credentials,
273
+            allow_methods=security_config.cors.allowed_methods,
274
+            allow_headers=security_config.cors.allowed_headers,
275
+            max_age=security_config.cors.max_age,
276
+        )
277
+    """
278
+    cfg = config or security_config.cors
279
+    try:
280
+        from starlette.middleware.cors import CORSMiddleware
281
+        app.add_middleware(
282
+            CORSMiddleware,
283
+            allow_origins=cfg.allowed_origins,
284
+            allow_credentials=cfg.allow_credentials,
285
+            allow_methods=cfg.allowed_methods,
286
+            allow_headers=cfg.allowed_headers,
287
+            max_age=cfg.max_age,
288
+        )
289
+    except ImportError:
290
+        logger.warning("starlette 未安装,CORS 中间件未启用")
291
+    return app

+ 444
- 0
tests/test_security.py Voir le fichier

1
+"""
2
+安全模块单元测试
3
+覆盖 JWT、权限、加密、输入验证、审计日志
4
+"""
5
+import sys
6
+import os
7
+import json
8
+import time
9
+import unittest
10
+from datetime import datetime, timezone
11
+
12
+# 将项目根目录加入路径
13
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
14
+
15
+from src.security.auth import (
16
+    Role, hash_password, verify_password,
17
+    create_access_token, create_refresh_token,
18
+    verify_access_token, verify_refresh_token,
19
+    get_user_roles, check_permission,
20
+    extract_token_from_header, validate_password_strength,
21
+)
22
+from src.security.encryption import (
23
+    FieldEncryptor, encrypt_field, decrypt_field,
24
+    mask_phone, mask_id_card, mask_email,
25
+    mask_bank_card, mask_name, mask_ip,
26
+    mask_response,
27
+)
28
+from src.security.input_validator import (
29
+    detect_sql_injection, sanitize_sql_value,
30
+    detect_xss, escape_html, sanitize_html,
31
+    sanitize_string, sanitize_number,
32
+    validate_email_format, validate_phone_format,
33
+    validate_file_upload,
34
+)
35
+from src.security.audit import (
36
+    AuditAction, AuditResult, AuditSeverity, AuditEvent,
37
+    AuditLogger, AuditStorage, AuditAlertManager,
38
+)
39
+from src.security.middleware import (
40
+    RateLimiter, generate_csrf_token, validate_csrf_token,
41
+    get_client_ip,
42
+)
43
+from src.security.config import security_config
44
+
45
+
46
+class TestAuth(unittest.TestCase):
47
+    """认证模块测试"""
48
+
49
+    def test_password_hash(self):
50
+        """密码哈希与验证"""
51
+        password = "TestPassword123"
52
+        hashed = hash_password(password)
53
+        self.assertNotEqual(password, hashed)
54
+        self.assertTrue(verify_password(password, hashed))
55
+        self.assertFalse(verify_password("wrong", hashed))
56
+
57
+    def test_create_access_token(self):
58
+        """生成 access token"""
59
+        token = create_access_token("user1", [Role.ADMIN, Role.OPERATOR])
60
+        self.assertIsInstance(token, str)
61
+        self.assertTrue(len(token) > 50)
62
+
63
+    def test_verify_access_token(self):
64
+        """验证 access token"""
65
+        token = create_access_token("user1", [Role.ADMIN])
66
+        payload = verify_access_token(token)
67
+        self.assertIsNotNone(payload)
68
+        self.assertEqual(payload["sub"], "user1")
69
+        self.assertEqual(payload["type"], "access")
70
+        self.assertIn("admin", payload["roles"])
71
+
72
+    def test_expired_token(self):
73
+        """过期 token 验证"""
74
+        # 通过修改配置来创建过期 token
75
+        import jwt as pyjwt
76
+        from src.security.config import security_config as cfg
77
+        payload = {
78
+            "sub": "user1",
79
+            "roles": ["admin"],
80
+            "iat": int(time.time()) - 7200,
81
+            "exp": int(time.time()) - 3600,  # 已过期
82
+            "iss": cfg.jwt.issuer,
83
+            "type": "access",
84
+        }
85
+        token = pyjwt.encode(payload, cfg.jwt.secret_key, algorithm=cfg.jwt.algorithm)
86
+        result = verify_access_token(token)
87
+        self.assertIsNone(result)
88
+
89
+    def test_refresh_token(self):
90
+        """Refresh token 生成与验证"""
91
+        token = create_refresh_token("user1")
92
+        payload = verify_refresh_token(token)
93
+        self.assertIsNotNone(payload)
94
+        self.assertEqual(payload["sub"], "user1")
95
+        self.assertEqual(payload["type"], "refresh")
96
+
97
+    def test_token_type_mismatch(self):
98
+        """Token 类型不匹配"""
99
+        access_token = create_access_token("user1", [Role.ADMIN])
100
+        # 用 verify_refresh_token 验证 access token 应该失败
101
+        result = verify_refresh_token(access_token)
102
+        self.assertIsNone(result)
103
+
104
+    def test_get_user_roles(self):
105
+        """从 payload 提取角色"""
106
+        payload = {"roles": ["admin", "operator"]}
107
+        roles = get_user_roles(payload)
108
+        self.assertEqual(len(roles), 2)
109
+        self.assertIn(Role.ADMIN, roles)
110
+        self.assertIn(Role.OPERATOR, roles)
111
+
112
+    def test_check_permission(self):
113
+        """权限检查"""
114
+        self.assertTrue(check_permission([Role.ADMIN], "system:manage"))
115
+        self.assertTrue(check_permission([Role.OPERATOR], "data:write"))
116
+        self.assertFalse(check_permission([Role.VIEWER], "data:write"))
117
+        self.assertFalse(check_permission([Role.VIEWER], "system:manage"))
118
+
119
+    def test_extract_token(self):
120
+        """从 Authorization header 提取 token"""
121
+        self.assertEqual(
122
+            extract_token_from_header("Bearer abc123"),
123
+            "abc123"
124
+        )
125
+        self.assertIsNone(extract_token_from_header("Basic abc123"))
126
+        self.assertIsNone(extract_token_from_header(""))
127
+        self.assertIsNone(extract_token_from_header(None))
128
+
129
+    def test_password_strength(self):
130
+        """密码强度验证"""
131
+        ok, msg = validate_password_strength("Abcdef1!")
132
+        self.assertTrue(ok)
133
+
134
+        ok, msg = validate_password_strength("123")
135
+        self.assertFalse(ok)
136
+
137
+        ok, msg = validate_password_strength("abcdefgh")
138
+        self.assertFalse(ok)  # 缺少大写和数字
139
+
140
+
141
+class TestEncryption(unittest.TestCase):
142
+    """加密模块测试"""
143
+
144
+    def setUp(self):
145
+        self.encryptor = FieldEncryptor()
146
+
147
+    def test_encrypt_decrypt(self):
148
+        """加密和解密"""
149
+        plaintext = "敏感数据123"
150
+        encrypted = self.encryptor.encrypt(plaintext)
151
+        self.assertNotEqual(plaintext, encrypted)
152
+
153
+        decrypted = self.encryptor.decrypt(encrypted)
154
+        self.assertEqual(plaintext, decrypted)
155
+
156
+    def test_encrypt_empty_string(self):
157
+        """空字符串加密"""
158
+        result = encrypt_field("")
159
+        self.assertEqual(result, "")
160
+
161
+    def test_encrypt_field_roundtrip(self):
162
+        """字段加密解密往返"""
163
+        original = "13812345678"
164
+        encrypted = encrypt_field(original)
165
+        decrypted = decrypt_field(encrypted)
166
+        self.assertEqual(original, decrypted)
167
+
168
+    def test_mask_phone(self):
169
+        """手机号脱敏"""
170
+        self.assertEqual(mask_phone("13812345678"), "138****5678")
171
+        self.assertEqual(mask_phone("138"), "138")
172
+        self.assertEqual(mask_phone(""), "")
173
+
174
+    def test_mask_id_card(self):
175
+        """身份证脱敏"""
176
+        self.assertEqual(mask_id_card("110101199001011234"), "1101**********1234")
177
+        self.assertEqual(mask_id_card(""), "")
178
+
179
+    def test_mask_email(self):
180
+        """邮箱脱敏"""
181
+        self.assertEqual(mask_email("test@example.com"), "t***@example.com")
182
+        self.assertEqual(mask_email("a@b.com"), "a***@b.com")
183
+
184
+    def test_mask_bank_card(self):
185
+        """银行卡脱敏"""
186
+        self.assertEqual(mask_bank_card("6225880137073520"), "6225********3520")
187
+
188
+    def test_mask_name(self):
189
+        """姓名脱敏"""
190
+        self.assertEqual(mask_name("张三"), "张*")
191
+        self.assertEqual(mask_name("欧阳修"), "欧阳*")
192
+
193
+    def test_mask_ip(self):
194
+        """IP 脱敏"""
195
+        self.assertEqual(mask_ip("192.168.1.100"), "192.168.*.*")
196
+
197
+    def test_mask_response(self):
198
+        """响应脱敏"""
199
+        data = {"phone": "13812345678", "name": "张三", "email": "test@test.com"}
200
+        rules = {"phone": "phone", "name": "name", "email": "email"}
201
+        result = mask_response(data, rules)
202
+        self.assertEqual(result["phone"], "138****5678")
203
+        self.assertEqual(result["name"], "张*")
204
+        self.assertIn("***", result["email"])
205
+
206
+
207
+class TestInputValidator(unittest.TestCase):
208
+    """输入验证测试"""
209
+
210
+    def test_sql_injection_detection(self):
211
+        """SQL 注入检测"""
212
+        is_injection, msg = detect_sql_injection("'; DROP TABLE users; --")
213
+        self.assertTrue(is_injection)
214
+
215
+        is_injection, msg = detect_sql_injection("1 OR 1=1")
216
+        self.assertTrue(is_injection)
217
+
218
+        is_injection, msg = detect_sql_injection("normal input value")
219
+        self.assertFalse(is_injection)
220
+
221
+    def test_sanitize_sql(self):
222
+        """SQL 值清理"""
223
+        result = sanitize_sql_value("'; DROP TABLE users;--")
224
+        self.assertNotIn("'", result)
225
+        self.assertNotIn(";", result)
226
+        self.assertNotIn("--", result)
227
+
228
+    def test_xss_detection(self):
229
+        """XSS 检测"""
230
+        is_xss, msg = detect_xss("<script>alert('xss')</script>")
231
+        self.assertTrue(is_xss)
232
+
233
+        is_xss, msg = detect_xss("javascript:alert(1)")
234
+        self.assertTrue(is_xss)
235
+
236
+        is_xss, msg = detect_xss("onclick=alert(1)")
237
+        self.assertTrue(is_xss)
238
+
239
+        is_xss, msg = detect_xss("normal text content")
240
+        self.assertFalse(is_xss)
241
+
242
+    def test_escape_html(self):
243
+        """HTML 转义"""
244
+        result = escape_html("<script>alert('xss')</script>")
245
+        self.assertNotIn("<script>", result)
246
+        self.assertIn("&lt;", result)
247
+
248
+    def test_sanitize_html(self):
249
+        """清理 HTML"""
250
+        result = sanitize_html("<p>Hello <b>World</b></p>")
251
+        self.assertNotIn("<p>", result)
252
+        self.assertNotIn("<b>", result)
253
+
254
+    def test_sanitize_string(self):
255
+        """字符串清理"""
256
+        result = sanitize_string("  hello world  ")
257
+        self.assertEqual(result, "hello world")
258
+
259
+        result = sanitize_string("hello\x00world")
260
+        self.assertNotIn("\x00", result)
261
+
262
+        result = sanitize_string("hello world", max_length=5)
263
+        self.assertEqual(len(result), 5)
264
+
265
+    def test_sanitize_number(self):
266
+        """数字验证"""
267
+        self.assertEqual(sanitize_number("123"), 123.0)
268
+        self.assertEqual(sanitize_number("123.45"), 123.45)
269
+        self.assertIsNone(sanitize_number("abc"))
270
+        self.assertIsNone(sanitize_number("100", min_val=200))
271
+        self.assertIsNone(sanitize_number("300", max_val=200))
272
+
273
+    def test_validate_email(self):
274
+        """邮箱格式验证"""
275
+        self.assertTrue(validate_email_format("test@example.com"))
276
+        self.assertFalse(validate_email_format("not-an-email"))
277
+        self.assertFalse(validate_email_format("@example.com"))
278
+
279
+    def test_validate_phone(self):
280
+        """手机号格式验证"""
281
+        self.assertTrue(validate_phone_format("13812345678"))
282
+        self.assertFalse(validate_phone_format("12345678901"))
283
+        self.assertFalse(validate_phone_format("138"))
284
+
285
+    def test_file_upload_validation(self):
286
+        """文件上传验证"""
287
+        # 正常图片
288
+        ok, msg = validate_file_upload("test.jpg", "image/jpeg", 1024)
289
+        self.assertTrue(ok)
290
+
291
+        # 危险文件类型
292
+        ok, msg = validate_file_upload("hack.exe", "application/octet-stream", 1024)
293
+        self.assertFalse(ok)
294
+
295
+        # 超大文件
296
+        ok, msg = validate_file_upload("big.pdf", "application/pdf", 100 * 1024 * 1024)
297
+        self.assertFalse(ok)
298
+
299
+
300
+class TestAudit(unittest.TestCase):
301
+    """审计日志测试"""
302
+
303
+    def test_audit_event_creation(self):
304
+        """审计事件创建"""
305
+        event = AuditEvent(
306
+            timestamp=datetime.now(timezone.utc),
307
+            user_id="user1",
308
+            action=AuditAction.LOGIN.value,
309
+            resource="/auth/login",
310
+            ip_address="127.0.0.1",
311
+            result=AuditResult.SUCCESS,
312
+        )
313
+        self.assertTrue(len(event.event_id) > 0)
314
+        self.assertEqual(event.user_id, "user1")
315
+
316
+    def test_audit_event_to_dict(self):
317
+        """审计事件转字典"""
318
+        event = AuditEvent(
319
+            timestamp=datetime.now(timezone.utc),
320
+            user_id="user1",
321
+            action=AuditAction.LOGIN.value,
322
+        )
323
+        d = event.to_dict()
324
+        self.assertIn("timestamp", d)
325
+        self.assertIn("user_id", d)
326
+        self.assertIn("event_id", d)
327
+
328
+    def test_audit_logger(self):
329
+        """审计记录器"""
330
+        logger = AuditLogger()
331
+        logger.log_action(
332
+            user_id="user1",
333
+            action=AuditAction.LOGIN.value,
334
+            ip_address="127.0.0.1",
335
+        )
336
+        logs = logger.query(user_id="user1")
337
+        self.assertTrue(len(logs) > 0)
338
+
339
+    def test_audit_query(self):
340
+        """审计日志查询"""
341
+        logger = AuditLogger()
342
+        logger.log_action(user_id="user1", action=AuditAction.LOGIN.value)
343
+        logger.log_action(user_id="user2", action=AuditAction.DATA_READ.value)
344
+
345
+        logs = logger.query(action=AuditAction.LOGIN.value)
346
+        self.assertTrue(all(l["action"] == AuditAction.LOGIN.value for l in logs))
347
+
348
+    def test_audit_alert(self):
349
+        """审计告警"""
350
+        alert_mgr = AuditAlertManager()
351
+        # 模拟多次登录失败
352
+        for i in range(6):
353
+            event = AuditEvent(
354
+                timestamp=datetime.now(timezone.utc),
355
+                user_id="attacker",
356
+                action=AuditAction.LOGIN_FAILED.value,
357
+                ip_address="10.0.0.1",
358
+                result=AuditResult.FAILURE,
359
+            )
360
+            alert = alert_mgr.check_alert(event)
361
+
362
+        # 应该触发告警
363
+        alerts = alert_mgr.get_alerts()
364
+        self.assertTrue(len(alerts) > 0)
365
+
366
+    def test_audit_stats(self):
367
+        """审计统计"""
368
+        logger = AuditLogger()
369
+        logger.log_action(user_id="user1", action=AuditAction.LOGIN.value)
370
+        logger.log_action(user_id="user1", action=AuditAction.DATA_READ.value)
371
+        stats = logger.get_stats(hours=1)
372
+        self.assertIn("total_events", stats)
373
+        self.assertIn("action_counts", stats)
374
+
375
+
376
+class TestMiddleware(unittest.TestCase):
377
+    """中间件工具测试"""
378
+
379
+    def test_rate_limiter(self):
380
+        """限流器"""
381
+        limiter = RateLimiter(max_requests=3, window_seconds=60)
382
+
383
+        self.assertTrue(limiter.is_allowed("127.0.0.1"))
384
+        self.assertTrue(limiter.is_allowed("127.0.0.1"))
385
+        self.assertTrue(limiter.is_allowed("127.0.0.1"))
386
+        self.assertFalse(limiter.is_allowed("127.0.0.1"))  # 超过限制
387
+
388
+        # 不同 IP 不受影响
389
+        self.assertTrue(limiter.is_allowed("192.168.1.1"))
390
+
391
+    def test_rate_limiter_remaining(self):
392
+        """限流器剩余次数"""
393
+        limiter = RateLimiter(max_requests=5, window_seconds=60)
394
+        limiter.is_allowed("127.0.0.1")
395
+        limiter.is_allowed("127.0.0.1")
396
+        self.assertEqual(limiter.get_remaining("127.0.0.1"), 3)
397
+
398
+    def test_csrf_token(self):
399
+        """CSRF token 生成与验证"""
400
+        session_id = "test-session-123"
401
+        token = generate_csrf_token(session_id)
402
+        self.assertTrue(validate_csrf_token(token, session_id))
403
+        self.assertFalse(validate_csrf_token(token, "wrong-session"))
404
+        self.assertFalse(validate_csrf_token("invalid-token", session_id))
405
+
406
+    def test_get_client_ip(self):
407
+        """获取客户端 IP"""
408
+        # X-Forwarded-For
409
+        ip = get_client_ip({"x-forwarded-for": "10.0.0.1, 10.0.0.2"}, "127.0.0.1")
410
+        self.assertEqual(ip, "10.0.0.1")
411
+
412
+        # X-Real-IP
413
+        ip = get_client_ip({"x-real-ip": "10.0.0.3"}, "127.0.0.1")
414
+        self.assertEqual(ip, "10.0.0.3")
415
+
416
+        # 无代理头
417
+        ip = get_client_ip({}, "127.0.0.1")
418
+        self.assertEqual(ip, "127.0.0.1")
419
+
420
+
421
+class TestConfig(unittest.TestCase):
422
+    """配置测试"""
423
+
424
+    def test_default_config(self):
425
+        """默认配置加载"""
426
+        self.assertIsNotNone(security_config.jwt)
427
+        self.assertIsNotNone(security_config.encryption)
428
+        self.assertIsNotNone(security_config.rate_limit)
429
+        self.assertIsNotNone(security_config.cors)
430
+        self.assertIsNotNone(security_config.password_policy)
431
+        self.assertIsNotNone(security_config.audit)
432
+
433
+    def test_jwt_config(self):
434
+        """JWT 配置"""
435
+        self.assertEqual(security_config.jwt.algorithm, "HS256")
436
+        self.assertTrue(security_config.jwt.access_token_expire_minutes > 0)
437
+
438
+    def test_rate_limit_config(self):
439
+        """限流配置"""
440
+        self.assertTrue(security_config.rate_limit.max_requests_per_minute > 0)
441
+
442
+
443
+if __name__ == "__main__":
444
+    unittest.main(verbosity=2)