""" 安全模块单元测试 覆盖 JWT、权限、加密、输入验证、审计日志 """ import sys import os import json import time import unittest from datetime import datetime, timezone # 将项目根目录加入路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from src.security.auth import ( Role, hash_password, verify_password, create_access_token, create_refresh_token, verify_access_token, verify_refresh_token, get_user_roles, check_permission, extract_token_from_header, validate_password_strength, ) from src.security.encryption import ( FieldEncryptor, encrypt_field, decrypt_field, mask_phone, mask_id_card, mask_email, mask_bank_card, mask_name, mask_ip, mask_response, ) from src.security.input_validator import ( detect_sql_injection, sanitize_sql_value, detect_xss, escape_html, sanitize_html, sanitize_string, sanitize_number, validate_email_format, validate_phone_format, validate_file_upload, ) from src.security.audit import ( AuditAction, AuditResult, AuditSeverity, AuditEvent, AuditLogger, AuditStorage, AuditAlertManager, ) from src.security.middleware import ( RateLimiter, generate_csrf_token, validate_csrf_token, get_client_ip, ) from src.security.config import security_config class TestAuth(unittest.TestCase): """认证模块测试""" def test_password_hash(self): """密码哈希与验证""" password = "TestPassword123" hashed = hash_password(password) self.assertNotEqual(password, hashed) self.assertTrue(verify_password(password, hashed)) self.assertFalse(verify_password("wrong", hashed)) def test_create_access_token(self): """生成 access token""" token = create_access_token("user1", [Role.ADMIN, Role.OPERATOR]) self.assertIsInstance(token, str) self.assertTrue(len(token) > 50) def test_verify_access_token(self): """验证 access token""" token = create_access_token("user1", [Role.ADMIN]) payload = verify_access_token(token) self.assertIsNotNone(payload) self.assertEqual(payload["sub"], "user1") self.assertEqual(payload["type"], "access") self.assertIn("admin", payload["roles"]) def test_expired_token(self): """过期 token 验证""" # 通过修改配置来创建过期 token import jwt as pyjwt from src.security.config import security_config as cfg payload = { "sub": "user1", "roles": ["admin"], "iat": int(time.time()) - 7200, "exp": int(time.time()) - 3600, # 已过期 "iss": cfg.jwt.issuer, "type": "access", } token = pyjwt.encode(payload, cfg.jwt.secret_key, algorithm=cfg.jwt.algorithm) result = verify_access_token(token) self.assertIsNone(result) def test_refresh_token(self): """Refresh token 生成与验证""" token = create_refresh_token("user1") payload = verify_refresh_token(token) self.assertIsNotNone(payload) self.assertEqual(payload["sub"], "user1") self.assertEqual(payload["type"], "refresh") def test_token_type_mismatch(self): """Token 类型不匹配""" access_token = create_access_token("user1", [Role.ADMIN]) # 用 verify_refresh_token 验证 access token 应该失败 result = verify_refresh_token(access_token) self.assertIsNone(result) def test_get_user_roles(self): """从 payload 提取角色""" payload = {"roles": ["admin", "operator"]} roles = get_user_roles(payload) self.assertEqual(len(roles), 2) self.assertIn(Role.ADMIN, roles) self.assertIn(Role.OPERATOR, roles) def test_check_permission(self): """权限检查""" self.assertTrue(check_permission([Role.ADMIN], "system:manage")) self.assertTrue(check_permission([Role.OPERATOR], "data:write")) self.assertFalse(check_permission([Role.VIEWER], "data:write")) self.assertFalse(check_permission([Role.VIEWER], "system:manage")) def test_extract_token(self): """从 Authorization header 提取 token""" self.assertEqual( extract_token_from_header("Bearer abc123"), "abc123" ) self.assertIsNone(extract_token_from_header("Basic abc123")) self.assertIsNone(extract_token_from_header("")) self.assertIsNone(extract_token_from_header(None)) def test_password_strength(self): """密码强度验证""" ok, msg = validate_password_strength("Abcdef1!") self.assertTrue(ok) ok, msg = validate_password_strength("123") self.assertFalse(ok) ok, msg = validate_password_strength("abcdefgh") self.assertFalse(ok) # 缺少大写和数字 class TestEncryption(unittest.TestCase): """加密模块测试""" def setUp(self): self.encryptor = FieldEncryptor() def test_encrypt_decrypt(self): """加密和解密""" plaintext = "敏感数据123" encrypted = self.encryptor.encrypt(plaintext) self.assertNotEqual(plaintext, encrypted) decrypted = self.encryptor.decrypt(encrypted) self.assertEqual(plaintext, decrypted) def test_encrypt_empty_string(self): """空字符串加密""" result = encrypt_field("") self.assertEqual(result, "") def test_encrypt_field_roundtrip(self): """字段加密解密往返""" original = "13812345678" encrypted = encrypt_field(original) decrypted = decrypt_field(encrypted) self.assertEqual(original, decrypted) def test_mask_phone(self): """手机号脱敏""" self.assertEqual(mask_phone("13812345678"), "138****5678") self.assertEqual(mask_phone("138"), "138") self.assertEqual(mask_phone(""), "") def test_mask_id_card(self): """身份证脱敏""" self.assertEqual(mask_id_card("110101199001011234"), "1101**********1234") self.assertEqual(mask_id_card(""), "") def test_mask_email(self): """邮箱脱敏""" self.assertEqual(mask_email("test@example.com"), "t***@example.com") self.assertEqual(mask_email("a@b.com"), "a***@b.com") def test_mask_bank_card(self): """银行卡脱敏""" self.assertEqual(mask_bank_card("6225880137073520"), "6225********3520") def test_mask_name(self): """姓名脱敏""" self.assertEqual(mask_name("张三"), "张*") self.assertEqual(mask_name("欧阳修"), "欧阳*") def test_mask_ip(self): """IP 脱敏""" self.assertEqual(mask_ip("192.168.1.100"), "192.168.*.*") def test_mask_response(self): """响应脱敏""" data = {"phone": "13812345678", "name": "张三", "email": "test@test.com"} rules = {"phone": "phone", "name": "name", "email": "email"} result = mask_response(data, rules) self.assertEqual(result["phone"], "138****5678") self.assertEqual(result["name"], "张*") self.assertIn("***", result["email"]) class TestInputValidator(unittest.TestCase): """输入验证测试""" def test_sql_injection_detection(self): """SQL 注入检测""" is_injection, msg = detect_sql_injection("'; DROP TABLE users; --") self.assertTrue(is_injection) is_injection, msg = detect_sql_injection("1 OR 1=1") self.assertTrue(is_injection) is_injection, msg = detect_sql_injection("normal input value") self.assertFalse(is_injection) def test_sanitize_sql(self): """SQL 值清理""" result = sanitize_sql_value("'; DROP TABLE users;--") self.assertNotIn("'", result) self.assertNotIn(";", result) self.assertNotIn("--", result) def test_xss_detection(self): """XSS 检测""" is_xss, msg = detect_xss("") self.assertTrue(is_xss) is_xss, msg = detect_xss("javascript:alert(1)") self.assertTrue(is_xss) is_xss, msg = detect_xss("onclick=alert(1)") self.assertTrue(is_xss) is_xss, msg = detect_xss("normal text content") self.assertFalse(is_xss) def test_escape_html(self): """HTML 转义""" result = escape_html("") self.assertNotIn("