| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444 |
- """
- 安全模块单元测试
- 覆盖 JWT、权限、加密、输入验证、审计日志
- """
- import sys
- import os
- import json
- import time
- import unittest
- from datetime import datetime, timezone
-
- # 将项目根目录加入路径
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
-
- from src.security.auth import (
- Role, hash_password, verify_password,
- create_access_token, create_refresh_token,
- verify_access_token, verify_refresh_token,
- get_user_roles, check_permission,
- extract_token_from_header, validate_password_strength,
- )
- from src.security.encryption import (
- FieldEncryptor, encrypt_field, decrypt_field,
- mask_phone, mask_id_card, mask_email,
- mask_bank_card, mask_name, mask_ip,
- mask_response,
- )
- from src.security.input_validator import (
- detect_sql_injection, sanitize_sql_value,
- detect_xss, escape_html, sanitize_html,
- sanitize_string, sanitize_number,
- validate_email_format, validate_phone_format,
- validate_file_upload,
- )
- from src.security.audit import (
- AuditAction, AuditResult, AuditSeverity, AuditEvent,
- AuditLogger, AuditStorage, AuditAlertManager,
- )
- from src.security.middleware import (
- RateLimiter, generate_csrf_token, validate_csrf_token,
- get_client_ip,
- )
- from src.security.config import security_config
-
-
- class TestAuth(unittest.TestCase):
- """认证模块测试"""
-
- def test_password_hash(self):
- """密码哈希与验证"""
- password = "TestPassword123"
- hashed = hash_password(password)
- self.assertNotEqual(password, hashed)
- self.assertTrue(verify_password(password, hashed))
- self.assertFalse(verify_password("wrong", hashed))
-
- def test_create_access_token(self):
- """生成 access token"""
- token = create_access_token("user1", [Role.ADMIN, Role.OPERATOR])
- self.assertIsInstance(token, str)
- self.assertTrue(len(token) > 50)
-
- def test_verify_access_token(self):
- """验证 access token"""
- token = create_access_token("user1", [Role.ADMIN])
- payload = verify_access_token(token)
- self.assertIsNotNone(payload)
- self.assertEqual(payload["sub"], "user1")
- self.assertEqual(payload["type"], "access")
- self.assertIn("admin", payload["roles"])
-
- def test_expired_token(self):
- """过期 token 验证"""
- # 通过修改配置来创建过期 token
- import jwt as pyjwt
- from src.security.config import security_config as cfg
- payload = {
- "sub": "user1",
- "roles": ["admin"],
- "iat": int(time.time()) - 7200,
- "exp": int(time.time()) - 3600, # 已过期
- "iss": cfg.jwt.issuer,
- "type": "access",
- }
- token = pyjwt.encode(payload, cfg.jwt.secret_key, algorithm=cfg.jwt.algorithm)
- result = verify_access_token(token)
- self.assertIsNone(result)
-
- def test_refresh_token(self):
- """Refresh token 生成与验证"""
- token = create_refresh_token("user1")
- payload = verify_refresh_token(token)
- self.assertIsNotNone(payload)
- self.assertEqual(payload["sub"], "user1")
- self.assertEqual(payload["type"], "refresh")
-
- def test_token_type_mismatch(self):
- """Token 类型不匹配"""
- access_token = create_access_token("user1", [Role.ADMIN])
- # 用 verify_refresh_token 验证 access token 应该失败
- result = verify_refresh_token(access_token)
- self.assertIsNone(result)
-
- def test_get_user_roles(self):
- """从 payload 提取角色"""
- payload = {"roles": ["admin", "operator"]}
- roles = get_user_roles(payload)
- self.assertEqual(len(roles), 2)
- self.assertIn(Role.ADMIN, roles)
- self.assertIn(Role.OPERATOR, roles)
-
- def test_check_permission(self):
- """权限检查"""
- self.assertTrue(check_permission([Role.ADMIN], "system:manage"))
- self.assertTrue(check_permission([Role.OPERATOR], "data:write"))
- self.assertFalse(check_permission([Role.VIEWER], "data:write"))
- self.assertFalse(check_permission([Role.VIEWER], "system:manage"))
-
- def test_extract_token(self):
- """从 Authorization header 提取 token"""
- self.assertEqual(
- extract_token_from_header("Bearer abc123"),
- "abc123"
- )
- self.assertIsNone(extract_token_from_header("Basic abc123"))
- self.assertIsNone(extract_token_from_header(""))
- self.assertIsNone(extract_token_from_header(None))
-
- def test_password_strength(self):
- """密码强度验证"""
- ok, msg = validate_password_strength("Abcdef1!")
- self.assertTrue(ok)
-
- ok, msg = validate_password_strength("123")
- self.assertFalse(ok)
-
- ok, msg = validate_password_strength("abcdefgh")
- self.assertFalse(ok) # 缺少大写和数字
-
-
- class TestEncryption(unittest.TestCase):
- """加密模块测试"""
-
- def setUp(self):
- self.encryptor = FieldEncryptor()
-
- def test_encrypt_decrypt(self):
- """加密和解密"""
- plaintext = "敏感数据123"
- encrypted = self.encryptor.encrypt(plaintext)
- self.assertNotEqual(plaintext, encrypted)
-
- decrypted = self.encryptor.decrypt(encrypted)
- self.assertEqual(plaintext, decrypted)
-
- def test_encrypt_empty_string(self):
- """空字符串加密"""
- result = encrypt_field("")
- self.assertEqual(result, "")
-
- def test_encrypt_field_roundtrip(self):
- """字段加密解密往返"""
- original = "13812345678"
- encrypted = encrypt_field(original)
- decrypted = decrypt_field(encrypted)
- self.assertEqual(original, decrypted)
-
- def test_mask_phone(self):
- """手机号脱敏"""
- self.assertEqual(mask_phone("13812345678"), "138****5678")
- self.assertEqual(mask_phone("138"), "138")
- self.assertEqual(mask_phone(""), "")
-
- def test_mask_id_card(self):
- """身份证脱敏"""
- self.assertEqual(mask_id_card("110101199001011234"), "1101**********1234")
- self.assertEqual(mask_id_card(""), "")
-
- def test_mask_email(self):
- """邮箱脱敏"""
- self.assertEqual(mask_email("test@example.com"), "t***@example.com")
- self.assertEqual(mask_email("a@b.com"), "a***@b.com")
-
- def test_mask_bank_card(self):
- """银行卡脱敏"""
- self.assertEqual(mask_bank_card("6225880137073520"), "6225********3520")
-
- def test_mask_name(self):
- """姓名脱敏"""
- self.assertEqual(mask_name("张三"), "张*")
- self.assertEqual(mask_name("欧阳修"), "欧阳*")
-
- def test_mask_ip(self):
- """IP 脱敏"""
- self.assertEqual(mask_ip("192.168.1.100"), "192.168.*.*")
-
- def test_mask_response(self):
- """响应脱敏"""
- data = {"phone": "13812345678", "name": "张三", "email": "test@test.com"}
- rules = {"phone": "phone", "name": "name", "email": "email"}
- result = mask_response(data, rules)
- self.assertEqual(result["phone"], "138****5678")
- self.assertEqual(result["name"], "张*")
- self.assertIn("***", result["email"])
-
-
- class TestInputValidator(unittest.TestCase):
- """输入验证测试"""
-
- def test_sql_injection_detection(self):
- """SQL 注入检测"""
- is_injection, msg = detect_sql_injection("'; DROP TABLE users; --")
- self.assertTrue(is_injection)
-
- is_injection, msg = detect_sql_injection("1 OR 1=1")
- self.assertTrue(is_injection)
-
- is_injection, msg = detect_sql_injection("normal input value")
- self.assertFalse(is_injection)
-
- def test_sanitize_sql(self):
- """SQL 值清理"""
- result = sanitize_sql_value("'; DROP TABLE users;--")
- self.assertNotIn("'", result)
- self.assertNotIn(";", result)
- self.assertNotIn("--", result)
-
- def test_xss_detection(self):
- """XSS 检测"""
- is_xss, msg = detect_xss("<script>alert('xss')</script>")
- self.assertTrue(is_xss)
-
- is_xss, msg = detect_xss("javascript:alert(1)")
- self.assertTrue(is_xss)
-
- is_xss, msg = detect_xss("onclick=alert(1)")
- self.assertTrue(is_xss)
-
- is_xss, msg = detect_xss("normal text content")
- self.assertFalse(is_xss)
-
- def test_escape_html(self):
- """HTML 转义"""
- result = escape_html("<script>alert('xss')</script>")
- self.assertNotIn("<script>", result)
- self.assertIn("<", result)
-
- def test_sanitize_html(self):
- """清理 HTML"""
- result = sanitize_html("<p>Hello <b>World</b></p>")
- self.assertNotIn("<p>", result)
- self.assertNotIn("<b>", result)
-
- def test_sanitize_string(self):
- """字符串清理"""
- result = sanitize_string(" hello world ")
- self.assertEqual(result, "hello world")
-
- result = sanitize_string("hello\x00world")
- self.assertNotIn("\x00", result)
-
- result = sanitize_string("hello world", max_length=5)
- self.assertEqual(len(result), 5)
-
- def test_sanitize_number(self):
- """数字验证"""
- self.assertEqual(sanitize_number("123"), 123.0)
- self.assertEqual(sanitize_number("123.45"), 123.45)
- self.assertIsNone(sanitize_number("abc"))
- self.assertIsNone(sanitize_number("100", min_val=200))
- self.assertIsNone(sanitize_number("300", max_val=200))
-
- def test_validate_email(self):
- """邮箱格式验证"""
- self.assertTrue(validate_email_format("test@example.com"))
- self.assertFalse(validate_email_format("not-an-email"))
- self.assertFalse(validate_email_format("@example.com"))
-
- def test_validate_phone(self):
- """手机号格式验证"""
- self.assertTrue(validate_phone_format("13812345678"))
- self.assertFalse(validate_phone_format("12345678901"))
- self.assertFalse(validate_phone_format("138"))
-
- def test_file_upload_validation(self):
- """文件上传验证"""
- # 正常图片
- ok, msg = validate_file_upload("test.jpg", "image/jpeg", 1024)
- self.assertTrue(ok)
-
- # 危险文件类型
- ok, msg = validate_file_upload("hack.exe", "application/octet-stream", 1024)
- self.assertFalse(ok)
-
- # 超大文件
- ok, msg = validate_file_upload("big.pdf", "application/pdf", 100 * 1024 * 1024)
- self.assertFalse(ok)
-
-
- class TestAudit(unittest.TestCase):
- """审计日志测试"""
-
- def test_audit_event_creation(self):
- """审计事件创建"""
- event = AuditEvent(
- timestamp=datetime.now(timezone.utc),
- user_id="user1",
- action=AuditAction.LOGIN.value,
- resource="/auth/login",
- ip_address="127.0.0.1",
- result=AuditResult.SUCCESS,
- )
- self.assertTrue(len(event.event_id) > 0)
- self.assertEqual(event.user_id, "user1")
-
- def test_audit_event_to_dict(self):
- """审计事件转字典"""
- event = AuditEvent(
- timestamp=datetime.now(timezone.utc),
- user_id="user1",
- action=AuditAction.LOGIN.value,
- )
- d = event.to_dict()
- self.assertIn("timestamp", d)
- self.assertIn("user_id", d)
- self.assertIn("event_id", d)
-
- def test_audit_logger(self):
- """审计记录器"""
- logger = AuditLogger()
- logger.log_action(
- user_id="user1",
- action=AuditAction.LOGIN.value,
- ip_address="127.0.0.1",
- )
- logs = logger.query(user_id="user1")
- self.assertTrue(len(logs) > 0)
-
- def test_audit_query(self):
- """审计日志查询"""
- logger = AuditLogger()
- logger.log_action(user_id="user1", action=AuditAction.LOGIN.value)
- logger.log_action(user_id="user2", action=AuditAction.DATA_READ.value)
-
- logs = logger.query(action=AuditAction.LOGIN.value)
- self.assertTrue(all(l["action"] == AuditAction.LOGIN.value for l in logs))
-
- def test_audit_alert(self):
- """审计告警"""
- alert_mgr = AuditAlertManager()
- # 模拟多次登录失败
- for i in range(6):
- event = AuditEvent(
- timestamp=datetime.now(timezone.utc),
- user_id="attacker",
- action=AuditAction.LOGIN_FAILED.value,
- ip_address="10.0.0.1",
- result=AuditResult.FAILURE,
- )
- alert = alert_mgr.check_alert(event)
-
- # 应该触发告警
- alerts = alert_mgr.get_alerts()
- self.assertTrue(len(alerts) > 0)
-
- def test_audit_stats(self):
- """审计统计"""
- logger = AuditLogger()
- logger.log_action(user_id="user1", action=AuditAction.LOGIN.value)
- logger.log_action(user_id="user1", action=AuditAction.DATA_READ.value)
- stats = logger.get_stats(hours=1)
- self.assertIn("total_events", stats)
- self.assertIn("action_counts", stats)
-
-
- class TestMiddleware(unittest.TestCase):
- """中间件工具测试"""
-
- def test_rate_limiter(self):
- """限流器"""
- limiter = RateLimiter(max_requests=3, window_seconds=60)
-
- self.assertTrue(limiter.is_allowed("127.0.0.1"))
- self.assertTrue(limiter.is_allowed("127.0.0.1"))
- self.assertTrue(limiter.is_allowed("127.0.0.1"))
- self.assertFalse(limiter.is_allowed("127.0.0.1")) # 超过限制
-
- # 不同 IP 不受影响
- self.assertTrue(limiter.is_allowed("192.168.1.1"))
-
- def test_rate_limiter_remaining(self):
- """限流器剩余次数"""
- limiter = RateLimiter(max_requests=5, window_seconds=60)
- limiter.is_allowed("127.0.0.1")
- limiter.is_allowed("127.0.0.1")
- self.assertEqual(limiter.get_remaining("127.0.0.1"), 3)
-
- def test_csrf_token(self):
- """CSRF token 生成与验证"""
- session_id = "test-session-123"
- token = generate_csrf_token(session_id)
- self.assertTrue(validate_csrf_token(token, session_id))
- self.assertFalse(validate_csrf_token(token, "wrong-session"))
- self.assertFalse(validate_csrf_token("invalid-token", session_id))
-
- def test_get_client_ip(self):
- """获取客户端 IP"""
- # X-Forwarded-For
- ip = get_client_ip({"x-forwarded-for": "10.0.0.1, 10.0.0.2"}, "127.0.0.1")
- self.assertEqual(ip, "10.0.0.1")
-
- # X-Real-IP
- ip = get_client_ip({"x-real-ip": "10.0.0.3"}, "127.0.0.1")
- self.assertEqual(ip, "10.0.0.3")
-
- # 无代理头
- ip = get_client_ip({}, "127.0.0.1")
- self.assertEqual(ip, "127.0.0.1")
-
-
- class TestConfig(unittest.TestCase):
- """配置测试"""
-
- def test_default_config(self):
- """默认配置加载"""
- self.assertIsNotNone(security_config.jwt)
- self.assertIsNotNone(security_config.encryption)
- self.assertIsNotNone(security_config.rate_limit)
- self.assertIsNotNone(security_config.cors)
- self.assertIsNotNone(security_config.password_policy)
- self.assertIsNotNone(security_config.audit)
-
- def test_jwt_config(self):
- """JWT 配置"""
- self.assertEqual(security_config.jwt.algorithm, "HS256")
- self.assertTrue(security_config.jwt.access_token_expire_minutes > 0)
-
- def test_rate_limit_config(self):
- """限流配置"""
- self.assertTrue(security_config.rate_limit.max_requests_per_minute > 0)
-
-
- if __name__ == "__main__":
- unittest.main(verbosity=2)
|