智慧水务管理系统 - 精河县供水工程综合管理平台

test_security.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. """
  2. 安全模块单元测试
  3. 覆盖 JWT、权限、加密、输入验证、审计日志
  4. """
  5. import sys
  6. import os
  7. import json
  8. import time
  9. import unittest
  10. from datetime import datetime, timezone
  11. # 将项目根目录加入路径
  12. sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
  13. from src.security.auth import (
  14. Role, hash_password, verify_password,
  15. create_access_token, create_refresh_token,
  16. verify_access_token, verify_refresh_token,
  17. get_user_roles, check_permission,
  18. extract_token_from_header, validate_password_strength,
  19. )
  20. from src.security.encryption import (
  21. FieldEncryptor, encrypt_field, decrypt_field,
  22. mask_phone, mask_id_card, mask_email,
  23. mask_bank_card, mask_name, mask_ip,
  24. mask_response,
  25. )
  26. from src.security.input_validator import (
  27. detect_sql_injection, sanitize_sql_value,
  28. detect_xss, escape_html, sanitize_html,
  29. sanitize_string, sanitize_number,
  30. validate_email_format, validate_phone_format,
  31. validate_file_upload,
  32. )
  33. from src.security.audit import (
  34. AuditAction, AuditResult, AuditSeverity, AuditEvent,
  35. AuditLogger, AuditStorage, AuditAlertManager,
  36. )
  37. from src.security.middleware import (
  38. RateLimiter, generate_csrf_token, validate_csrf_token,
  39. get_client_ip,
  40. )
  41. from src.security.config import security_config
  42. class TestAuth(unittest.TestCase):
  43. """认证模块测试"""
  44. def test_password_hash(self):
  45. """密码哈希与验证"""
  46. password = "TestPassword123"
  47. hashed = hash_password(password)
  48. self.assertNotEqual(password, hashed)
  49. self.assertTrue(verify_password(password, hashed))
  50. self.assertFalse(verify_password("wrong", hashed))
  51. def test_create_access_token(self):
  52. """生成 access token"""
  53. token = create_access_token("user1", [Role.ADMIN, Role.OPERATOR])
  54. self.assertIsInstance(token, str)
  55. self.assertTrue(len(token) > 50)
  56. def test_verify_access_token(self):
  57. """验证 access token"""
  58. token = create_access_token("user1", [Role.ADMIN])
  59. payload = verify_access_token(token)
  60. self.assertIsNotNone(payload)
  61. self.assertEqual(payload["sub"], "user1")
  62. self.assertEqual(payload["type"], "access")
  63. self.assertIn("admin", payload["roles"])
  64. def test_expired_token(self):
  65. """过期 token 验证"""
  66. # 通过修改配置来创建过期 token
  67. import jwt as pyjwt
  68. from src.security.config import security_config as cfg
  69. payload = {
  70. "sub": "user1",
  71. "roles": ["admin"],
  72. "iat": int(time.time()) - 7200,
  73. "exp": int(time.time()) - 3600, # 已过期
  74. "iss": cfg.jwt.issuer,
  75. "type": "access",
  76. }
  77. token = pyjwt.encode(payload, cfg.jwt.secret_key, algorithm=cfg.jwt.algorithm)
  78. result = verify_access_token(token)
  79. self.assertIsNone(result)
  80. def test_refresh_token(self):
  81. """Refresh token 生成与验证"""
  82. token = create_refresh_token("user1")
  83. payload = verify_refresh_token(token)
  84. self.assertIsNotNone(payload)
  85. self.assertEqual(payload["sub"], "user1")
  86. self.assertEqual(payload["type"], "refresh")
  87. def test_token_type_mismatch(self):
  88. """Token 类型不匹配"""
  89. access_token = create_access_token("user1", [Role.ADMIN])
  90. # 用 verify_refresh_token 验证 access token 应该失败
  91. result = verify_refresh_token(access_token)
  92. self.assertIsNone(result)
  93. def test_get_user_roles(self):
  94. """从 payload 提取角色"""
  95. payload = {"roles": ["admin", "operator"]}
  96. roles = get_user_roles(payload)
  97. self.assertEqual(len(roles), 2)
  98. self.assertIn(Role.ADMIN, roles)
  99. self.assertIn(Role.OPERATOR, roles)
  100. def test_check_permission(self):
  101. """权限检查"""
  102. self.assertTrue(check_permission([Role.ADMIN], "system:manage"))
  103. self.assertTrue(check_permission([Role.OPERATOR], "data:write"))
  104. self.assertFalse(check_permission([Role.VIEWER], "data:write"))
  105. self.assertFalse(check_permission([Role.VIEWER], "system:manage"))
  106. def test_extract_token(self):
  107. """从 Authorization header 提取 token"""
  108. self.assertEqual(
  109. extract_token_from_header("Bearer abc123"),
  110. "abc123"
  111. )
  112. self.assertIsNone(extract_token_from_header("Basic abc123"))
  113. self.assertIsNone(extract_token_from_header(""))
  114. self.assertIsNone(extract_token_from_header(None))
  115. def test_password_strength(self):
  116. """密码强度验证"""
  117. ok, msg = validate_password_strength("Abcdef1!")
  118. self.assertTrue(ok)
  119. ok, msg = validate_password_strength("123")
  120. self.assertFalse(ok)
  121. ok, msg = validate_password_strength("abcdefgh")
  122. self.assertFalse(ok) # 缺少大写和数字
  123. class TestEncryption(unittest.TestCase):
  124. """加密模块测试"""
  125. def setUp(self):
  126. self.encryptor = FieldEncryptor()
  127. def test_encrypt_decrypt(self):
  128. """加密和解密"""
  129. plaintext = "敏感数据123"
  130. encrypted = self.encryptor.encrypt(plaintext)
  131. self.assertNotEqual(plaintext, encrypted)
  132. decrypted = self.encryptor.decrypt(encrypted)
  133. self.assertEqual(plaintext, decrypted)
  134. def test_encrypt_empty_string(self):
  135. """空字符串加密"""
  136. result = encrypt_field("")
  137. self.assertEqual(result, "")
  138. def test_encrypt_field_roundtrip(self):
  139. """字段加密解密往返"""
  140. original = "13812345678"
  141. encrypted = encrypt_field(original)
  142. decrypted = decrypt_field(encrypted)
  143. self.assertEqual(original, decrypted)
  144. def test_mask_phone(self):
  145. """手机号脱敏"""
  146. self.assertEqual(mask_phone("13812345678"), "138****5678")
  147. self.assertEqual(mask_phone("138"), "138")
  148. self.assertEqual(mask_phone(""), "")
  149. def test_mask_id_card(self):
  150. """身份证脱敏"""
  151. self.assertEqual(mask_id_card("110101199001011234"), "1101**********1234")
  152. self.assertEqual(mask_id_card(""), "")
  153. def test_mask_email(self):
  154. """邮箱脱敏"""
  155. self.assertEqual(mask_email("test@example.com"), "t***@example.com")
  156. self.assertEqual(mask_email("a@b.com"), "a***@b.com")
  157. def test_mask_bank_card(self):
  158. """银行卡脱敏"""
  159. self.assertEqual(mask_bank_card("6225880137073520"), "6225********3520")
  160. def test_mask_name(self):
  161. """姓名脱敏"""
  162. self.assertEqual(mask_name("张三"), "张*")
  163. self.assertEqual(mask_name("欧阳修"), "欧阳*")
  164. def test_mask_ip(self):
  165. """IP 脱敏"""
  166. self.assertEqual(mask_ip("192.168.1.100"), "192.168.*.*")
  167. def test_mask_response(self):
  168. """响应脱敏"""
  169. data = {"phone": "13812345678", "name": "张三", "email": "test@test.com"}
  170. rules = {"phone": "phone", "name": "name", "email": "email"}
  171. result = mask_response(data, rules)
  172. self.assertEqual(result["phone"], "138****5678")
  173. self.assertEqual(result["name"], "张*")
  174. self.assertIn("***", result["email"])
  175. class TestInputValidator(unittest.TestCase):
  176. """输入验证测试"""
  177. def test_sql_injection_detection(self):
  178. """SQL 注入检测"""
  179. is_injection, msg = detect_sql_injection("'; DROP TABLE users; --")
  180. self.assertTrue(is_injection)
  181. is_injection, msg = detect_sql_injection("1 OR 1=1")
  182. self.assertTrue(is_injection)
  183. is_injection, msg = detect_sql_injection("normal input value")
  184. self.assertFalse(is_injection)
  185. def test_sanitize_sql(self):
  186. """SQL 值清理"""
  187. result = sanitize_sql_value("'; DROP TABLE users;--")
  188. self.assertNotIn("'", result)
  189. self.assertNotIn(";", result)
  190. self.assertNotIn("--", result)
  191. def test_xss_detection(self):
  192. """XSS 检测"""
  193. is_xss, msg = detect_xss("<script>alert('xss')</script>")
  194. self.assertTrue(is_xss)
  195. is_xss, msg = detect_xss("javascript:alert(1)")
  196. self.assertTrue(is_xss)
  197. is_xss, msg = detect_xss("onclick=alert(1)")
  198. self.assertTrue(is_xss)
  199. is_xss, msg = detect_xss("normal text content")
  200. self.assertFalse(is_xss)
  201. def test_escape_html(self):
  202. """HTML 转义"""
  203. result = escape_html("<script>alert('xss')</script>")
  204. self.assertNotIn("<script>", result)
  205. self.assertIn("&lt;", result)
  206. def test_sanitize_html(self):
  207. """清理 HTML"""
  208. result = sanitize_html("<p>Hello <b>World</b></p>")
  209. self.assertNotIn("<p>", result)
  210. self.assertNotIn("<b>", result)
  211. def test_sanitize_string(self):
  212. """字符串清理"""
  213. result = sanitize_string(" hello world ")
  214. self.assertEqual(result, "hello world")
  215. result = sanitize_string("hello\x00world")
  216. self.assertNotIn("\x00", result)
  217. result = sanitize_string("hello world", max_length=5)
  218. self.assertEqual(len(result), 5)
  219. def test_sanitize_number(self):
  220. """数字验证"""
  221. self.assertEqual(sanitize_number("123"), 123.0)
  222. self.assertEqual(sanitize_number("123.45"), 123.45)
  223. self.assertIsNone(sanitize_number("abc"))
  224. self.assertIsNone(sanitize_number("100", min_val=200))
  225. self.assertIsNone(sanitize_number("300", max_val=200))
  226. def test_validate_email(self):
  227. """邮箱格式验证"""
  228. self.assertTrue(validate_email_format("test@example.com"))
  229. self.assertFalse(validate_email_format("not-an-email"))
  230. self.assertFalse(validate_email_format("@example.com"))
  231. def test_validate_phone(self):
  232. """手机号格式验证"""
  233. self.assertTrue(validate_phone_format("13812345678"))
  234. self.assertFalse(validate_phone_format("12345678901"))
  235. self.assertFalse(validate_phone_format("138"))
  236. def test_file_upload_validation(self):
  237. """文件上传验证"""
  238. # 正常图片
  239. ok, msg = validate_file_upload("test.jpg", "image/jpeg", 1024)
  240. self.assertTrue(ok)
  241. # 危险文件类型
  242. ok, msg = validate_file_upload("hack.exe", "application/octet-stream", 1024)
  243. self.assertFalse(ok)
  244. # 超大文件
  245. ok, msg = validate_file_upload("big.pdf", "application/pdf", 100 * 1024 * 1024)
  246. self.assertFalse(ok)
  247. class TestAudit(unittest.TestCase):
  248. """审计日志测试"""
  249. def test_audit_event_creation(self):
  250. """审计事件创建"""
  251. event = AuditEvent(
  252. timestamp=datetime.now(timezone.utc),
  253. user_id="user1",
  254. action=AuditAction.LOGIN.value,
  255. resource="/auth/login",
  256. ip_address="127.0.0.1",
  257. result=AuditResult.SUCCESS,
  258. )
  259. self.assertTrue(len(event.event_id) > 0)
  260. self.assertEqual(event.user_id, "user1")
  261. def test_audit_event_to_dict(self):
  262. """审计事件转字典"""
  263. event = AuditEvent(
  264. timestamp=datetime.now(timezone.utc),
  265. user_id="user1",
  266. action=AuditAction.LOGIN.value,
  267. )
  268. d = event.to_dict()
  269. self.assertIn("timestamp", d)
  270. self.assertIn("user_id", d)
  271. self.assertIn("event_id", d)
  272. def test_audit_logger(self):
  273. """审计记录器"""
  274. logger = AuditLogger()
  275. logger.log_action(
  276. user_id="user1",
  277. action=AuditAction.LOGIN.value,
  278. ip_address="127.0.0.1",
  279. )
  280. logs = logger.query(user_id="user1")
  281. self.assertTrue(len(logs) > 0)
  282. def test_audit_query(self):
  283. """审计日志查询"""
  284. logger = AuditLogger()
  285. logger.log_action(user_id="user1", action=AuditAction.LOGIN.value)
  286. logger.log_action(user_id="user2", action=AuditAction.DATA_READ.value)
  287. logs = logger.query(action=AuditAction.LOGIN.value)
  288. self.assertTrue(all(l["action"] == AuditAction.LOGIN.value for l in logs))
  289. def test_audit_alert(self):
  290. """审计告警"""
  291. alert_mgr = AuditAlertManager()
  292. # 模拟多次登录失败
  293. for i in range(6):
  294. event = AuditEvent(
  295. timestamp=datetime.now(timezone.utc),
  296. user_id="attacker",
  297. action=AuditAction.LOGIN_FAILED.value,
  298. ip_address="10.0.0.1",
  299. result=AuditResult.FAILURE,
  300. )
  301. alert = alert_mgr.check_alert(event)
  302. # 应该触发告警
  303. alerts = alert_mgr.get_alerts()
  304. self.assertTrue(len(alerts) > 0)
  305. def test_audit_stats(self):
  306. """审计统计"""
  307. logger = AuditLogger()
  308. logger.log_action(user_id="user1", action=AuditAction.LOGIN.value)
  309. logger.log_action(user_id="user1", action=AuditAction.DATA_READ.value)
  310. stats = logger.get_stats(hours=1)
  311. self.assertIn("total_events", stats)
  312. self.assertIn("action_counts", stats)
  313. class TestMiddleware(unittest.TestCase):
  314. """中间件工具测试"""
  315. def test_rate_limiter(self):
  316. """限流器"""
  317. limiter = RateLimiter(max_requests=3, window_seconds=60)
  318. self.assertTrue(limiter.is_allowed("127.0.0.1"))
  319. self.assertTrue(limiter.is_allowed("127.0.0.1"))
  320. self.assertTrue(limiter.is_allowed("127.0.0.1"))
  321. self.assertFalse(limiter.is_allowed("127.0.0.1")) # 超过限制
  322. # 不同 IP 不受影响
  323. self.assertTrue(limiter.is_allowed("192.168.1.1"))
  324. def test_rate_limiter_remaining(self):
  325. """限流器剩余次数"""
  326. limiter = RateLimiter(max_requests=5, window_seconds=60)
  327. limiter.is_allowed("127.0.0.1")
  328. limiter.is_allowed("127.0.0.1")
  329. self.assertEqual(limiter.get_remaining("127.0.0.1"), 3)
  330. def test_csrf_token(self):
  331. """CSRF token 生成与验证"""
  332. session_id = "test-session-123"
  333. token = generate_csrf_token(session_id)
  334. self.assertTrue(validate_csrf_token(token, session_id))
  335. self.assertFalse(validate_csrf_token(token, "wrong-session"))
  336. self.assertFalse(validate_csrf_token("invalid-token", session_id))
  337. def test_get_client_ip(self):
  338. """获取客户端 IP"""
  339. # X-Forwarded-For
  340. ip = get_client_ip({"x-forwarded-for": "10.0.0.1, 10.0.0.2"}, "127.0.0.1")
  341. self.assertEqual(ip, "10.0.0.1")
  342. # X-Real-IP
  343. ip = get_client_ip({"x-real-ip": "10.0.0.3"}, "127.0.0.1")
  344. self.assertEqual(ip, "10.0.0.3")
  345. # 无代理头
  346. ip = get_client_ip({}, "127.0.0.1")
  347. self.assertEqual(ip, "127.0.0.1")
  348. class TestConfig(unittest.TestCase):
  349. """配置测试"""
  350. def test_default_config(self):
  351. """默认配置加载"""
  352. self.assertIsNotNone(security_config.jwt)
  353. self.assertIsNotNone(security_config.encryption)
  354. self.assertIsNotNone(security_config.rate_limit)
  355. self.assertIsNotNone(security_config.cors)
  356. self.assertIsNotNone(security_config.password_policy)
  357. self.assertIsNotNone(security_config.audit)
  358. def test_jwt_config(self):
  359. """JWT 配置"""
  360. self.assertEqual(security_config.jwt.algorithm, "HS256")
  361. self.assertTrue(security_config.jwt.access_token_expire_minutes > 0)
  362. def test_rate_limit_config(self):
  363. """限流配置"""
  364. self.assertTrue(security_config.rate_limit.max_requests_per_minute > 0)
  365. if __name__ == "__main__":
  366. unittest.main(verbosity=2)