"""
安全模块单元测试
覆盖 JWT、权限、加密、输入验证、审计日志
"""
import sys
import os
import json
import time
import unittest
from datetime import datetime, timezone
# 将项目根目录加入路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from src.security.auth import (
Role, hash_password, verify_password,
create_access_token, create_refresh_token,
verify_access_token, verify_refresh_token,
get_user_roles, check_permission,
extract_token_from_header, validate_password_strength,
)
from src.security.encryption import (
FieldEncryptor, encrypt_field, decrypt_field,
mask_phone, mask_id_card, mask_email,
mask_bank_card, mask_name, mask_ip,
mask_response,
)
from src.security.input_validator import (
detect_sql_injection, sanitize_sql_value,
detect_xss, escape_html, sanitize_html,
sanitize_string, sanitize_number,
validate_email_format, validate_phone_format,
validate_file_upload,
)
from src.security.audit import (
AuditAction, AuditResult, AuditSeverity, AuditEvent,
AuditLogger, AuditStorage, AuditAlertManager,
)
from src.security.middleware import (
RateLimiter, generate_csrf_token, validate_csrf_token,
get_client_ip,
)
from src.security.config import security_config
class TestAuth(unittest.TestCase):
"""认证模块测试"""
def test_password_hash(self):
"""密码哈希与验证"""
password = "TestPassword123"
hashed = hash_password(password)
self.assertNotEqual(password, hashed)
self.assertTrue(verify_password(password, hashed))
self.assertFalse(verify_password("wrong", hashed))
def test_create_access_token(self):
"""生成 access token"""
token = create_access_token("user1", [Role.ADMIN, Role.OPERATOR])
self.assertIsInstance(token, str)
self.assertTrue(len(token) > 50)
def test_verify_access_token(self):
"""验证 access token"""
token = create_access_token("user1", [Role.ADMIN])
payload = verify_access_token(token)
self.assertIsNotNone(payload)
self.assertEqual(payload["sub"], "user1")
self.assertEqual(payload["type"], "access")
self.assertIn("admin", payload["roles"])
def test_expired_token(self):
"""过期 token 验证"""
# 通过修改配置来创建过期 token
import jwt as pyjwt
from src.security.config import security_config as cfg
payload = {
"sub": "user1",
"roles": ["admin"],
"iat": int(time.time()) - 7200,
"exp": int(time.time()) - 3600, # 已过期
"iss": cfg.jwt.issuer,
"type": "access",
}
token = pyjwt.encode(payload, cfg.jwt.secret_key, algorithm=cfg.jwt.algorithm)
result = verify_access_token(token)
self.assertIsNone(result)
def test_refresh_token(self):
"""Refresh token 生成与验证"""
token = create_refresh_token("user1")
payload = verify_refresh_token(token)
self.assertIsNotNone(payload)
self.assertEqual(payload["sub"], "user1")
self.assertEqual(payload["type"], "refresh")
def test_token_type_mismatch(self):
"""Token 类型不匹配"""
access_token = create_access_token("user1", [Role.ADMIN])
# 用 verify_refresh_token 验证 access token 应该失败
result = verify_refresh_token(access_token)
self.assertIsNone(result)
def test_get_user_roles(self):
"""从 payload 提取角色"""
payload = {"roles": ["admin", "operator"]}
roles = get_user_roles(payload)
self.assertEqual(len(roles), 2)
self.assertIn(Role.ADMIN, roles)
self.assertIn(Role.OPERATOR, roles)
def test_check_permission(self):
"""权限检查"""
self.assertTrue(check_permission([Role.ADMIN], "system:manage"))
self.assertTrue(check_permission([Role.OPERATOR], "data:write"))
self.assertFalse(check_permission([Role.VIEWER], "data:write"))
self.assertFalse(check_permission([Role.VIEWER], "system:manage"))
def test_extract_token(self):
"""从 Authorization header 提取 token"""
self.assertEqual(
extract_token_from_header("Bearer abc123"),
"abc123"
)
self.assertIsNone(extract_token_from_header("Basic abc123"))
self.assertIsNone(extract_token_from_header(""))
self.assertIsNone(extract_token_from_header(None))
def test_password_strength(self):
"""密码强度验证"""
ok, msg = validate_password_strength("Abcdef1!")
self.assertTrue(ok)
ok, msg = validate_password_strength("123")
self.assertFalse(ok)
ok, msg = validate_password_strength("abcdefgh")
self.assertFalse(ok) # 缺少大写和数字
class TestEncryption(unittest.TestCase):
"""加密模块测试"""
def setUp(self):
self.encryptor = FieldEncryptor()
def test_encrypt_decrypt(self):
"""加密和解密"""
plaintext = "敏感数据123"
encrypted = self.encryptor.encrypt(plaintext)
self.assertNotEqual(plaintext, encrypted)
decrypted = self.encryptor.decrypt(encrypted)
self.assertEqual(plaintext, decrypted)
def test_encrypt_empty_string(self):
"""空字符串加密"""
result = encrypt_field("")
self.assertEqual(result, "")
def test_encrypt_field_roundtrip(self):
"""字段加密解密往返"""
original = "13812345678"
encrypted = encrypt_field(original)
decrypted = decrypt_field(encrypted)
self.assertEqual(original, decrypted)
def test_mask_phone(self):
"""手机号脱敏"""
self.assertEqual(mask_phone("13812345678"), "138****5678")
self.assertEqual(mask_phone("138"), "138")
self.assertEqual(mask_phone(""), "")
def test_mask_id_card(self):
"""身份证脱敏"""
self.assertEqual(mask_id_card("110101199001011234"), "1101**********1234")
self.assertEqual(mask_id_card(""), "")
def test_mask_email(self):
"""邮箱脱敏"""
self.assertEqual(mask_email("test@example.com"), "t***@example.com")
self.assertEqual(mask_email("a@b.com"), "a***@b.com")
def test_mask_bank_card(self):
"""银行卡脱敏"""
self.assertEqual(mask_bank_card("6225880137073520"), "6225********3520")
def test_mask_name(self):
"""姓名脱敏"""
self.assertEqual(mask_name("张三"), "张*")
self.assertEqual(mask_name("欧阳修"), "欧阳*")
def test_mask_ip(self):
"""IP 脱敏"""
self.assertEqual(mask_ip("192.168.1.100"), "192.168.*.*")
def test_mask_response(self):
"""响应脱敏"""
data = {"phone": "13812345678", "name": "张三", "email": "test@test.com"}
rules = {"phone": "phone", "name": "name", "email": "email"}
result = mask_response(data, rules)
self.assertEqual(result["phone"], "138****5678")
self.assertEqual(result["name"], "张*")
self.assertIn("***", result["email"])
class TestInputValidator(unittest.TestCase):
"""输入验证测试"""
def test_sql_injection_detection(self):
"""SQL 注入检测"""
is_injection, msg = detect_sql_injection("'; DROP TABLE users; --")
self.assertTrue(is_injection)
is_injection, msg = detect_sql_injection("1 OR 1=1")
self.assertTrue(is_injection)
is_injection, msg = detect_sql_injection("normal input value")
self.assertFalse(is_injection)
def test_sanitize_sql(self):
"""SQL 值清理"""
result = sanitize_sql_value("'; DROP TABLE users;--")
self.assertNotIn("'", result)
self.assertNotIn(";", result)
self.assertNotIn("--", result)
def test_xss_detection(self):
"""XSS 检测"""
is_xss, msg = detect_xss("")
self.assertTrue(is_xss)
is_xss, msg = detect_xss("javascript:alert(1)")
self.assertTrue(is_xss)
is_xss, msg = detect_xss("onclick=alert(1)")
self.assertTrue(is_xss)
is_xss, msg = detect_xss("normal text content")
self.assertFalse(is_xss)
def test_escape_html(self):
"""HTML 转义"""
result = escape_html("")
self.assertNotIn("