Skip to content

第45章 安全编程

学习目标

完成本章学习后,你将能够:

  1. 掌握密码学基础:对称加密、非对称加密、哈希算法、数字签名
  2. 实现安全认证:密码存储、JWT、OAuth2、多因素认证
  3. 防范常见攻击:SQL注入、XSS、CSRF、命令注入
  4. 进行安全编码:输入验证、输出编码、安全配置
  5. 实施安全审计:日志记录、入侵检测、漏洞扫描
  6. 进行渗透测试:安全测试、漏洞利用、修复建议
  7. 遵循OWASP指南:Top 10漏洞、安全最佳实践
  8. 实现安全通信:TLS/SSL、证书管理、安全API

45.1 密码学基础

45.1.1 加密与解密

python
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
from typing import Tuple, Optional
import base64
import os
import secrets
import hashlib
from dataclasses import dataclass


@dataclass
class EncryptedData:
    ciphertext: bytes
    iv: bytes
    tag: Optional[bytes] = None


class SymmetricEncryption:
    def __init__(self, key: bytes = None):
        self.key = key or Fernet.generate_key()
        self.fernet = Fernet(self.key)

    @classmethod
    def generate_key(cls) -> bytes:
        return Fernet.generate_key()

    def encrypt(self, plaintext: str) -> str:
        ciphertext = self.fernet.encrypt(plaintext.encode())
        return ciphertext.decode()

    def decrypt(self, ciphertext: str) -> str:
        plaintext = self.fernet.decrypt(ciphertext.encode())
        return plaintext.decode()

    def encrypt_dict(self, data: dict) -> str:
        import json
        json_str = json.dumps(data)
        return self.encrypt(json_str)

    def decrypt_dict(self, ciphertext: str) -> dict:
        import json
        json_str = self.decrypt(ciphertext)
        return json.loads(json_str)


class AsymmetricEncryption:
    def __init__(self, private_key=None, public_key=None):
        self.private_key = private_key
        self.public_key = public_key

    @classmethod
    def generate_key_pair(cls, key_size: int = 2048) -> "AsymmetricEncryption":
        private_key = rsa.generate_private_key(
            public_exponent=65537,
            key_size=key_size,
            backend=default_backend()
        )
        public_key = private_key.public_key()
        return cls(private_key=private_key, public_key=public_key)

    def encrypt(self, plaintext: bytes) -> bytes:
        ciphertext = self.public_key.encrypt(
            plaintext,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=None
            )
        )
        return ciphertext

    def decrypt(self, ciphertext: bytes) -> bytes:
        plaintext = self.private_key.decrypt(
            ciphertext,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=None
            )
        )
        return plaintext

    def sign(self, message: bytes) -> bytes:
        signature = self.private_key.sign(
            message,
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA256()
        )
        return signature

    def verify(self, message: bytes, signature: bytes) -> bool:
        try:
            self.public_key.verify(
                signature,
                message,
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA256()
            )
            return True
        except Exception:
            return False

    def export_private_key(self, password: str = None) -> bytes:
        if password:
            encryption = serialization.BestAvailableEncryption(password.encode())
        else:
            encryption = serialization.NoEncryption()

        return self.private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=encryption
        )

    def export_public_key(self) -> bytes:
        return self.public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )


class HashUtils:
    @staticmethod
    def sha256(data: str) -> str:
        return hashlib.sha256(data.encode()).hexdigest()

    @staticmethod
    def sha512(data: str) -> str:
        return hashlib.sha512(data.encode()).hexdigest()

    @staticmethod
    def md5(data: str) -> str:
        return hashlib.md5(data.encode()).hexdigest()

    @staticmethod
    def hmac_sha256(key: bytes, message: str) -> str:
        import hmac
        return hmac.new(key, message.encode(), hashlib.sha256).hexdigest()


class PasswordHasher:
    def __init__(self, iterations: int = 100000):
        self.iterations = iterations

    def hash(self, password: str) -> str:
        salt = os.urandom(16)
        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=self.iterations,
            backend=default_backend()
        )
        key = kdf.derive(password.encode())
        stored = base64.b64encode(salt + key).decode()
        return stored

    def verify(self, password: str, stored_hash: str) -> bool:
        decoded = base64.b64decode(stored_hash.encode())
        salt = decoded[:16]
        stored_key = decoded[16:]

        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=self.iterations,
            backend=default_backend()
        )
        try:
            kdf.verify(password.encode(), stored_key)
            return True
        except Exception:
            return False

    def generate_random_password(self, length: int = 16) -> str:
        import string
        alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
        return ''.join(secrets.choice(alphabet) for _ in range(length))


class SecureToken:
    @staticmethod
    def generate(length: int = 32) -> str:
        return secrets.token_hex(length)

    @staticmethod
    def generate_url_safe(length: int = 32) -> str:
        return secrets.token_urlsafe(length)

    @staticmethod
    def compare(a: str, b: str) -> bool:
        return secrets.compare_digest(a, b)

45.1.2 JWT认证

python
import jwt
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from dataclasses import dataclass


@dataclass
class TokenPayload:
    user_id: str
    username: str
    roles: List[str]
    exp: datetime
    iat: datetime


class JWTManager:
    def __init__(
        self,
        secret_key: str,
        algorithm: str = "HS256",
        access_token_expire: int = 30,
        refresh_token_expire: int = 7
    ):
        self.secret_key = secret_key
        self.algorithm = algorithm
        self.access_token_expire = access_token_expire
        self.refresh_token_expire = refresh_token_expire

    def create_access_token(
        self,
        user_id: str,
        username: str,
        roles: List[str] = None,
        additional_claims: Dict = None
    ) -> str:
        now = datetime.utcnow()
        expire = now + timedelta(minutes=self.access_token_expire)

        payload = {
            "sub": user_id,
            "username": username,
            "roles": roles or [],
            "type": "access",
            "iat": now,
            "exp": expire
        }

        if additional_claims:
            payload.update(additional_claims)

        return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)

    def create_refresh_token(self, user_id: str) -> str:
        now = datetime.utcnow()
        expire = now + timedelta(days=self.refresh_token_expire)

        payload = {
            "sub": user_id,
            "type": "refresh",
            "iat": now,
            "exp": expire
        }

        return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)

    def decode_token(self, token: str) -> Optional[Dict]:
        try:
            payload = jwt.decode(
                token,
                self.secret_key,
                algorithms=[self.algorithm]
            )
            return payload
        except jwt.ExpiredSignatureError:
            return None
        except jwt.InvalidTokenError:
            return None

    def verify_token(self, token: str, token_type: str = "access") -> bool:
        payload = self.decode_token(token)
        if not payload:
            return False
        return payload.get("type") == token_type

    def refresh_access_token(self, refresh_token: str) -> Optional[str]:
        payload = self.decode_token(refresh_token)
        if not payload or payload.get("type") != "refresh":
            return None

        user_id = payload.get("sub")
        return self.create_access_token(user_id, "", [])

    def get_user_id(self, token: str) -> Optional[str]:
        payload = self.decode_token(token)
        if payload:
            return payload.get("sub")
        return None

    def get_roles(self, token: str) -> List[str]:
        payload = self.decode_token(token)
        if payload:
            return payload.get("roles", [])
        return []


class APIKeyManager:
    def __init__(self):
        self._api_keys: Dict[str, Dict] = {}

    def generate_api_key(
        self,
        user_id: str,
        name: str,
        permissions: List[str] = None,
        expires_days: int = 365
    ) -> str:
        api_key = f"pk_{secrets.token_urlsafe(32)}"

        self._api_keys[api_key] = {
            "user_id": user_id,
            "name": name,
            "permissions": permissions or [],
            "created_at": datetime.utcnow(),
            "expires_at": datetime.utcnow() + timedelta(days=expires_days)
        }

        return api_key

    def verify_api_key(self, api_key: str) -> Optional[Dict]:
        if api_key not in self._api_keys:
            return None

        key_info = self._api_keys[api_key]

        if key_info["expires_at"] < datetime.utcnow():
            del self._api_keys[api_key]
            return None

        return key_info

    def revoke_api_key(self, api_key: str) -> bool:
        if api_key in self._api_keys:
            del self._api_keys[api_key]
            return True
        return False

    def list_user_keys(self, user_id: str) -> List[Dict]:
        return [
            {"key": k, **v}
            for k, v in self._api_keys.items()
            if v["user_id"] == user_id
        ]

45.2 安全编码实践

45.2.1 输入验证

python
import re
from typing import Any, List, Optional, Callable
from dataclasses import dataclass
from enum import Enum


class ValidationError(Exception):
    def __init__(self, field: str, message: str):
        self.field = field
        self.message = message
        super().__init__(f"{field}: {message}")


@dataclass
class ValidationResult:
    is_valid: bool
    errors: List[str] = None

    def __bool__(self) -> bool:
        return self.is_valid


class InputValidator:
    @staticmethod
    def validate_email(email: str) -> ValidationResult:
        pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        if re.match(pattern, email):
            return ValidationResult(is_valid=True)
        return ValidationResult(is_valid=False, errors=["Invalid email format"])

    @staticmethod
    def validate_password(
        password: str,
        min_length: int = 8,
        require_upper: bool = True,
        require_lower: bool = True,
        require_digit: bool = True,
        require_special: bool = True
    ) -> ValidationResult:
        errors = []

        if len(password) < min_length:
            errors.append(f"Password must be at least {min_length} characters")

        if require_upper and not re.search(r'[A-Z]', password):
            errors.append("Password must contain at least one uppercase letter")

        if require_lower and not re.search(r'[a-z]', password):
            errors.append("Password must contain at least one lowercase letter")

        if require_digit and not re.search(r'\d', password):
            errors.append("Password must contain at least one digit")

        if require_special and not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
            errors.append("Password must contain at least one special character")

        return ValidationResult(is_valid=len(errors) == 0, errors=errors)

    @staticmethod
    def validate_username(username: str) -> ValidationResult:
        errors = []

        if len(username) < 3:
            errors.append("Username must be at least 3 characters")
        if len(username) > 20:
            errors.append("Username must be at most 20 characters")
        if not re.match(r'^[a-zA-Z0-9_]+$', username):
            errors.append("Username can only contain letters, numbers, and underscores")

        return ValidationResult(is_valid=len(errors) == 0, errors=errors)

    @staticmethod
    def validate_phone(phone: str) -> ValidationResult:
        pattern = r'^\+?1?\d{9,15}$'
        if re.match(pattern, phone):
            return ValidationResult(is_valid=True)
        return ValidationResult(is_valid=False, errors=["Invalid phone number format"])

    @staticmethod
    def validate_url(url: str) -> ValidationResult:
        pattern = r'^https?://[^\s/$.?#].[^\s]*$'
        if re.match(pattern, url):
            return ValidationResult(is_valid=True)
        return ValidationResult(is_valid=False, errors=["Invalid URL format"])

    @staticmethod
    def sanitize_html(html: str) -> str:
        import html as html_module
        return html_module.escape(html)

    @staticmethod
    def sanitize_sql(value: str) -> str:
        dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
        sanitized = value
        for char in dangerous_chars:
            sanitized = sanitized.replace(char, "")
        return sanitized


class SchemaValidator:
    def __init__(self):
        self._rules: Dict[str, List[Callable]] = {}

    def field(self, name: str) -> "FieldValidator":
        return FieldValidator(name, self)

    def add_rule(self, field: str, rule: Callable) -> None:
        if field not in self._rules:
            self._rules[field] = []
        self._rules[field].append(rule)

    def validate(self, data: Dict) -> ValidationResult:
        errors = []

        for field, rules in self._rules.items():
            value = data.get(field)
            for rule in rules:
                try:
                    rule(value)
                except ValidationError as e:
                    errors.append(str(e))

        return ValidationResult(is_valid=len(errors) == 0, errors=errors)


class FieldValidator:
    def __init__(self, name: str, schema: SchemaValidator):
        self.name = name
        self.schema = schema

    def required(self) -> "FieldValidator":
        def rule(value):
            if value is None or value == "":
                raise ValidationError(self.name, "This field is required")
        self.schema.add_rule(self.name, rule)
        return self

    def min_length(self, length: int) -> "FieldValidator":
        def rule(value):
            if value and len(str(value)) < length:
                raise ValidationError(self.name, f"Minimum length is {length}")
        self.schema.add_rule(self.name, rule)
        return self

    def max_length(self, length: int) -> "FieldValidator":
        def rule(value):
            if value and len(str(value)) > length:
                raise ValidationError(self.name, f"Maximum length is {length}")
        self.schema.add_rule(self.name, rule)
        return self

    def email(self) -> "FieldValidator":
        def rule(value):
            if value:
                result = InputValidator.validate_email(value)
                if not result.is_valid:
                    raise ValidationError(self.name, result.errors[0])
        self.schema.add_rule(self.name, rule)
        return self

    def pattern(self, regex: str) -> "FieldValidator":
        def rule(value):
            if value and not re.match(regex, str(value)):
                raise ValidationError(self.name, "Invalid format")
        self.schema.add_rule(self.name, rule)
        return self

    def custom(self, validator: Callable) -> "FieldValidator":
        self.schema.add_rule(self.name, validator)
        return self

45.2.2 SQL注入防护

python
import sqlite3
from typing import List, Dict, Any, Optional
from contextlib import contextmanager


class SecureDatabase:
    def __init__(self, connection_string: str):
        self.connection_string = connection_string

    @contextmanager
    def get_connection(self):
        conn = sqlite3.connect(self.connection_string)
        conn.row_factory = sqlite3.Row
        try:
            yield conn
        finally:
            conn.close()

    def execute_safe(
        self,
        query: str,
        params: tuple = None
    ) -> List[Dict]:
        with self.get_connection() as conn:
            cursor = conn.cursor()
            if params:
                cursor.execute(query, params)
            else:
                cursor.execute(query)
            return [dict(row) for row in cursor.fetchall()]

    def insert(self, table: str, data: Dict) -> int:
        columns = ", ".join(data.keys())
        placeholders = ", ".join(["?" for _ in data])
        query = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"

        with self.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute(query, tuple(data.values()))
            conn.commit()
            return cursor.lastrowid

    def update(
        self,
        table: str,
        data: Dict,
        where: str,
        where_params: tuple = None
    ) -> int:
        set_clause = ", ".join([f"{k} = ?" for k in data.keys()])
        query = f"UPDATE {table} SET {set_clause} WHERE {where}"

        params = tuple(data.values())
        if where_params:
            params = params + where_params

        with self.get_connection() as conn:
            cursor = conn.cursor()
            cursor.execute(query, params)
            conn.commit()
            return cursor.rowcount

    def delete(
        self,
        table: str,
        where: str,
        where_params: tuple = None
    ) -> int:
        query = f"DELETE FROM {table} WHERE {where}"

        with self.get_connection() as conn:
            cursor = conn.cursor()
            if where_params:
                cursor.execute(query, where_params)
            else:
                cursor.execute(query)
            conn.commit()
            return cursor.rowcount

    def select(
        self,
        table: str,
        columns: List[str] = None,
        where: str = None,
        where_params: tuple = None,
        order_by: str = None,
        limit: int = None
    ) -> List[Dict]:
        cols = ", ".join(columns) if columns else "*"
        query = f"SELECT {cols} FROM {table}"

        if where:
            query += f" WHERE {where}"
        if order_by:
            query += f" ORDER BY {order_by}"
        if limit:
            query += f" LIMIT {limit}"

        return self.execute_safe(query, where_params)


class SQLInjectionPrevention:
    @staticmethod
    def detect_injection(input_string: str) -> bool:
        patterns = [
            r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|ALTER|CREATE|TRUNCATE)\b)",
            r"(--|#|/\*|\*/)",
            r"(\bOR\b|\bAND\b).*?=",
            r"(\bunion\b.*?\bselect\b)",
            r"(\bexec\b|\bexecute\b)",
            r"(\bxp_|sp_)",
            r"(;|\|)",
        ]

        for pattern in patterns:
            if re.search(pattern, input_string, re.IGNORECASE):
                return True
        return False

    @staticmethod
    def escape_identifier(identifier: str) -> str:
        return f'"{identifier.replace("\"", "\"\"")}"'

    @staticmethod
    def build_safe_query(
        table: str,
        columns: List[str],
        where_conditions: Dict[str, Any]
    ) -> tuple:
        safe_table = SQLInjectionPrevention.escape_identifier(table)
        safe_columns = ", ".join(
            SQLInjectionPrevention.escape_identifier(col) for col in columns
        )

        where_parts = []
        params = []

        for col, value in where_conditions.items():
            safe_col = SQLInjectionPrevention.escape_identifier(col)
            where_parts.append(f"{safe_col} = ?")
            params.append(value)

        where_clause = " AND ".join(where_parts)
        query = f"SELECT {safe_columns} FROM {safe_table} WHERE {where_clause}"

        return query, tuple(params)

45.3 OWASP Top 10防护

45.3.1 常见攻击防护

python
from dataclasses import dataclass
from typing import List, Dict, Optional
import re
import html


@dataclass
class SecurityHeaders:
    content_type_options: str = "nosniff"
    frame_options: str = "DENY"
    xss_protection: str = "1; mode=block"
    content_security_policy: str = "default-src 'self'"
    strict_transport_security: str = "max-age=31536000; includeSubDomains"

    def to_dict(self) -> Dict[str, str]:
        return {
            "X-Content-Type-Options": self.content_type_options,
            "X-Frame-Options": self.frame_options,
            "X-XSS-Protection": self.xss_protection,
            "Content-Security-Policy": self.content_security_policy,
            "Strict-Transport-Security": self.strict_transport_security
        }


class XSSPrevention:
    @staticmethod
    def escape_html(text: str) -> str:
        return html.escape(text, quote=True)

    @staticmethod
    def sanitize_input(text: str) -> str:
        dangerous_patterns = [
            (r'<script[^>]*>.*?</script>', '', re.IGNORECASE | re.DOTALL),
            (r'javascript:', '', re.IGNORECASE),
            (r'on\w+\s*=', '', re.IGNORECASE),
            (r'<iframe[^>]*>.*?</iframe>', '', re.IGNORECASE | re.DOTALL),
        ]

        sanitized = text
        for pattern, replacement, flags in dangerous_patterns:
            sanitized = re.sub(pattern, replacement, sanitized, flags=flags)

        return sanitized

    @staticmethod
    def validate_content(content: str) -> bool:
        xss_patterns = [
            r'<script',
            r'javascript:',
            r'on\w+\s*=',
            r'<iframe',
            r'<object',
            r'<embed',
            r'<link',
            r'<style',
        ]

        for pattern in xss_patterns:
            if re.search(pattern, content, re.IGNORECASE):
                return False
        return True


class CSRFProtection:
    def __init__(self):
        self._tokens: Dict[str, Dict] = {}

    def generate_token(self, session_id: str) -> str:
        token = secrets.token_urlsafe(32)
        self._tokens[session_id] = {
            "token": token,
            "created_at": datetime.utcnow()
        }
        return token

    def validate_token(self, session_id: str, token: str) -> bool:
        if session_id not in self._tokens:
            return False

        stored = self._tokens[session_id]

        if datetime.utcnow() - stored["created_at"] > timedelta(hours=1):
            del self._tokens[session_id]
            return False

        return secrets.compare_digest(stored["token"], token)

    def invalidate_token(self, session_id: str) -> None:
        if session_id in self._tokens:
            del self._tokens[session_id]


class CommandInjectionPrevention:
    @staticmethod
    def sanitize_command_arg(arg: str) -> str:
        allowed_chars = r'[^a-zA-Z0-9_\-\.]'
        return re.sub(allowed_chars, '', arg)

    @staticmethod
    def validate_command(command: str) -> bool:
        dangerous_patterns = [
            r'[;&|`$]',
            r'\$\(',
            r'`.*`',
            r'\|\|',
            r'&&',
            r'>',
            r'<',
        ]

        for pattern in dangerous_patterns:
            if re.search(pattern, command):
                return False
        return True

    @staticmethod
    def safe_subprocess_run(
        command: List[str],
        **kwargs
    ) -> Any:
        import subprocess

        safe_command = [
            CommandInjectionPrevention.sanitize_command_arg(arg)
            for arg in command
        ]

        return subprocess.run(
            safe_command,
            shell=False,
            **kwargs
        )


class PathTraversalPrevention:
    @staticmethod
    def sanitize_path(path: str, base_dir: str) -> Optional[str]:
        import os

        path = path.replace('..', '')
        path = path.lstrip('/\\')

        full_path = os.path.normpath(os.path.join(base_dir, path))

        if not full_path.startswith(os.path.abspath(base_dir)):
            return None

        return full_path

    @staticmethod
    def is_safe_path(path: str) -> bool:
        dangerous_patterns = [
            r'\.\.',
            r'~',
            r'/etc/',
            r'/proc/',
            r'/sys/',
        ]

        for pattern in dangerous_patterns:
            if re.search(pattern, path):
                return False
        return True


class SecurityAuditor:
    def __init__(self):
        self.vulnerabilities: List[Dict] = []

    def check_password_strength(self, password: str) -> Dict:
        issues = []

        if len(password) < 8:
            issues.append("Password too short")
        if not re.search(r'[A-Z]', password):
            issues.append("No uppercase letters")
        if not re.search(r'[a-z]', password):
            issues.append("No lowercase letters")
        if not re.search(r'\d', password):
            issues.append("No digits")
        if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
            issues.append("No special characters")

        return {
            "is_secure": len(issues) == 0,
            "issues": issues
        }

    def check_sql_injection_risk(self, query: str) -> Dict:
        risks = []

        if re.search(r'f["\'].*\{.*\}.*["\']', query):
            risks.append("String formatting in SQL query")

        if re.search(r'\+', query):
            risks.append("String concatenation detected")

        if not re.search(r'\?', query) and not re.search(r'%s', query):
            risks.append("No parameterized query detected")

        return {
            "has_risk": len(risks) > 0,
            "risks": risks
        }

    def check_xss_risk(self, html_content: str) -> Dict:
        risks = []

        if re.search(r'<script', html_content, re.IGNORECASE):
            risks.append("Script tag detected")

        if re.search(r'javascript:', html_content, re.IGNORECASE):
            risks.append("JavaScript protocol detected")

        if re.search(r'on\w+\s*=', html_content, re.IGNORECASE):
            risks.append("Event handler detected")

        return {
            "has_risk": len(risks) > 0,
            "risks": risks
        }

    def generate_report(self) -> Dict:
        return {
            "total_vulnerabilities": len(self.vulnerabilities),
            "vulnerabilities": self.vulnerabilities
        }

45.4 知识图谱

45.4.1 安全编程体系

┌─────────────────────────────────────────────────────────────────────┐
│                      安全编程全景图                                   │
├─────────────────────────────────────────────────────────────────────┤
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      应用层安全                               │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 认证授权  │ │ 会话管理  │ │ 访问控制  │ │ 审计日志  │       │   │
│  │  │ OAuth2  │ │ JWT     │ │ RBAC    │ │ Audit   │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│  ┌─────────────────────────────┴───────────────────────────────┐   │
│  │                      数据安全                                 │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 加密存储  │ │ 传输加密  │ │ 密钥管理  │ │ 数据脱敏  │       │   │
│  │  │ AES     │ │ TLS     │ │ KMS     │ │ Masking │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│  ┌─────────────────────────────┴───────────────────────────────┐   │
│  │                      输入安全                                 │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 输入验证  │ │ SQL注入  │ │ XSS防护  │ │ CSRF防护 │       │   │
│  │  │ Validate│ │ SQLi    │ │ Filter  │ │ Token   │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│  ┌─────────────────────────────┴───────────────────────────────┐   │
│  │                      基础设施安全                             │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 网络安全  │ │ 容器安全  │ │ 依赖安全  │ │ 配置安全  │       │   │
│  │  │ Firewall│ │ Docker  │ │ SCA     │ │ Secrets │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────┘

45.4.2 OWASP Top 10 防护

┌─────────────────────────────────────────────────────────────────────┐
│                      OWASP Top 10 安全风险                          │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A01: 访问控制失效 (Broken Access Control)                     │  │
│  │ 防护:最小权限原则、RBAC、ABAC、访问控制列表                    │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A02: 加密失败 (Cryptographic Failures)                        │  │
│  │ 防护:强加密算法、密钥管理、TLS传输、敏感数据加密               │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A03: 注入攻击 (Injection)                                     │  │
│  │ 防护:参数化查询、输入验证、输出编码、ORM框架                   │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A04: 不安全设计 (Insecure Design)                             │  │
│  │ 防护:安全设计模式、威胁建模、安全架构审查                      │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A05: 安全配置错误 (Security Misconfiguration)                 │  │
│  │ 防护:安全默认配置、配置审计、环境隔离                         │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A06: 易受攻击组件 (Vulnerable Components)                      │  │
│  │ 防护:依赖扫描、版本更新、SBOM管理                             │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A07: 身份认证失败 (Authentication Failures)                    │  │
│  │ 防护:多因素认证、密码策略、会话管理、登录限制                  │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A08: 软件完整性失败 (Software Integrity Failures)              │  │
│  │ 防护:代码签名、CI/CD安全、依赖验证                            │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A09: 日志监控失败 (Logging Failures)                           │  │
│  │ 防护:安全日志、异常检测、告警机制                              │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ A10: 服务端请求伪造 (SSRF)                                    │  │
│  │ 防护:URL白名单、输入验证、网络隔离                            │  │
│  └──────────────────────────────────────────────────────────────┘  │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

45.5 技术选型指南

45.5.1 加密算法选型

用途推荐算法密钥长度安全等级备注
对称加密AES-GCM256位推荐用于数据加密
非对称加密RSA-OAEP4096位密钥交换、数字签名
哈希算法SHA-256-密码存储用bcrypt
密码哈希bcrypt/Argon2-极高内置盐值和工作因子
数字签名Ed25519-高性能、安全

45.5.2 认证方案选型

方案适用场景安全性复杂度无状态
JWTSPA、移动应用、微服务
OAuth2第三方授权、社交登录
Session传统Web应用
API Key服务间调用
mTLS零信任网络极高

45.5.3 安全工具选型

工具类型工具名称功能推荐指数
依赖扫描Safety, Snyk检测依赖漏洞★★★★★
代码审计Bandit, Semgrep静态代码分析★★★★★
密钥检测git-secrets, truffleHog检测泄露密钥★★★★☆
容器扫描Trivy, Clair镜像漏洞扫描★★★★★
DASTOWASP ZAP动态安全测试★★★★☆

45.6 常见问题与解决方案

45.6.1 密码安全存储

python
import secrets
import hashlib
from typing import Tuple, Optional
from dataclasses import dataclass

@dataclass
class PasswordPolicy:
    """密码策略配置"""
    min_length: int = 12
    max_length: int = 128
    require_uppercase: bool = True
    require_lowercase: bool = True
    require_digits: bool = True
    require_special: bool = True
    special_chars: str = "!@#$%^&*()_+-=[]{}|;:,.<>?"
    max_age_days: int = 90
    history_count: int = 5


class SecurePasswordManager:
    """安全密码管理器"""
    
    def __init__(self, policy: PasswordPolicy = None):
        self.policy = policy or PasswordPolicy()
        self._password_history: dict = {}
    
    def hash_password(self, password: str) -> str:
        """使用bcrypt安全哈希密码"""
        import bcrypt
        salt = bcrypt.gensalt(rounds=12)
        return bcrypt.hashpw(password.encode(), salt).decode()
    
    def verify_password(self, password: str, hashed: str) -> bool:
        """验证密码"""
        import bcrypt
        try:
            return bcrypt.checkpw(password.encode(), hashed.encode())
        except Exception:
            return False
    
    def validate_password(self, password: str) -> Tuple[bool, list]:
        """验证密码强度"""
        errors = []
        
        if len(password) < self.policy.min_length:
            errors.append(f"密码长度至少{self.policy.min_length}位")
        
        if len(password) > self.policy.max_length:
            errors.append(f"密码长度不能超过{self.policy.max_length}位")
        
        if self.policy.require_uppercase and not any(c.isupper() for c in password):
            errors.append("密码必须包含大写字母")
        
        if self.policy.require_lowercase and not any(c.islower() for c in password):
            errors.append("密码必须包含小写字母")
        
        if self.policy.require_digits and not any(c.isdigit() for c in password):
            errors.append("密码必须包含数字")
        
        if self.policy.require_special:
            if not any(c in self.policy.special_chars for c in password):
                errors.append("密码必须包含特殊字符")
        
        common_passwords = [
            "password", "123456", "qwerty", "admin", "letmein"
        ]
        if password.lower() in common_passwords:
            errors.append("密码过于常见,请使用更强的密码")
        
        return len(errors) == 0, errors
    
    def generate_secure_password(self, length: int = 16) -> str:
        """生成安全密码"""
        import string
        
        alphabet = string.ascii_letters + string.digits + self.policy.special_chars
        
        while True:
            password = ''.join(secrets.choice(alphabet) for _ in range(length))
            valid, _ = self.validate_password(password)
            if valid:
                return password
    
    def check_history(self, user_id: str, password: str) -> bool:
        """检查密码历史"""
        if user_id not in self._password_history:
            return True
        
        for old_hash in self._password_history[user_id]:
            if self.verify_password(password, old_hash):
                return False
        return True
    
    def update_history(self, user_id: str, hashed_password: str):
        """更新密码历史"""
        if user_id not in self._password_history:
            self._password_history[user_id] = []
        
        self._password_history[user_id].append(hashed_password)
        
        if len(self._password_history[user_id]) > self.policy.history_count:
            self._password_history[user_id].pop(0)


class SecureTokenManager:
    """安全令牌管理器"""
    
    @staticmethod
    def generate_token(length: int = 32) -> str:
        """生成安全随机令牌"""
        return secrets.token_urlsafe(length)
    
    @staticmethod
    def generate_api_key(prefix: str = "sk") -> str:
        """生成API密钥"""
        key = secrets.token_urlsafe(32)
        return f"{prefix}_{key}"
    
    @staticmethod
    def hash_token(token: str) -> str:
        """哈希令牌(用于存储)"""
        return hashlib.sha256(token.encode()).hexdigest()
    
    @staticmethod
    def constant_time_compare(a: str, b: str) -> bool:
        """常量时间比较(防止时序攻击)"""
        return secrets.compare_digest(a, b)

45.6.2 安全输入验证

python
import re
import html
from typing import Any, Dict, List, Optional, Callable
from dataclasses import dataclass, field

@dataclass
class ValidationResult:
    """验证结果"""
    is_valid: bool
    value: Any = None
    errors: List[str] = field(default_factory=list)


class InputValidator:
    """安全输入验证器"""
    
    EMAIL_PATTERN = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
    PHONE_PATTERN = r'^1[3-9]\d{9}$'
    USERNAME_PATTERN = r'^[a-zA-Z0-9_]{3,20}$'
    
    @staticmethod
    def sanitize_string(value: str, max_length: int = 1000) -> str:
        """清理字符串"""
        if not isinstance(value, str):
            return ""
        
        value = value.strip()
        value = value[:max_length]
        value = html.escape(value)
        
        return value
    
    @staticmethod
    def sanitize_html(value: str, allowed_tags: List[str] = None) -> str:
        """清理HTML"""
        from html.parser import HTMLParser
        
        class HTMLSanitizer(HTMLParser):
            def __init__(self, allowed_tags):
                super().__init__()
                self.allowed_tags = allowed_tags or []
                self.result = []
            
            def handle_starttag(self, tag, attrs):
                if tag in self.allowed_tags:
                    safe_attrs = []
                    for name, value in attrs:
                        if name in ['href', 'src'] and not value.startswith('javascript:'):
                            safe_attrs.append(f'{name}="{html.escape(value)}"')
                    self.result.append(f'<{tag} {" ".join(safe_attrs)}>')
            
            def handle_endtag(self, tag):
                if tag in self.allowed_tags:
                    self.result.append(f'</{tag}>')
            
            def handle_data(self, data):
                self.result.append(html.escape(data))
        
        sanitizer = HTMLSanitizer(allowed_tags)
        sanitizer.feed(value)
        return ''.join(sanitizer.result)
    
    @classmethod
    def validate_email(cls, value: str) -> ValidationResult:
        """验证邮箱"""
        errors = []
        
        if not value:
            errors.append("邮箱不能为空")
        elif len(value) > 254:
            errors.append("邮箱长度超出限制")
        elif not re.match(cls.EMAIL_PATTERN, value):
            errors.append("邮箱格式不正确")
        
        return ValidationResult(
            is_valid=len(errors) == 0,
            value=value.lower() if not errors else None,
            errors=errors
        )
    
    @classmethod
    def validate_phone(cls, value: str) -> ValidationResult:
        """验证手机号"""
        errors = []
        
        if not value:
            errors.append("手机号不能为空")
        elif not re.match(cls.PHONE_PATTERN, value):
            errors.append("手机号格式不正确")
        
        return ValidationResult(
            is_valid=len(errors) == 0,
            value=value if not errors else None,
            errors=errors
        )
    
    @classmethod
    def validate_url(cls, value: str, allowed_schemes: List[str] = None) -> ValidationResult:
        """验证URL"""
        from urllib.parse import urlparse
        
        errors = []
        allowed_schemes = allowed_schemes or ['http', 'https']
        
        if not value:
            errors.append("URL不能为空")
        else:
            try:
                parsed = urlparse(value)
                if parsed.scheme.lower() not in allowed_schemes:
                    errors.append(f"只允许 {allowed_schemes} 协议")
                if not parsed.netloc:
                    errors.append("URL格式不正确")
            except Exception:
                errors.append("URL解析失败")
        
        return ValidationResult(
            is_valid=len(errors) == 0,
            value=value if not errors else None,
            errors=errors
        )
    
    @staticmethod
    def validate_integer(value: Any, min_val: int = None, max_val: int = None) -> ValidationResult:
        """验证整数"""
        errors = []
        
        try:
            int_value = int(value)
            if min_val is not None and int_value < min_val:
                errors.append(f"值不能小于{min_val}")
            if max_val is not None and int_value > max_val:
                errors.append(f"值不能大于{max_val}")
        except (ValueError, TypeError):
            errors.append("必须是有效的整数")
            int_value = None
        
        return ValidationResult(
            is_valid=len(errors) == 0,
            value=int_value if not errors else None,
            errors=errors
        )


class SQLInjectionProtector:
    """SQL注入防护器"""
    
    DANGEROUS_PATTERNS = [
        r"('\s*(OR|AND)\s*')",
        r"(--|#|\/\*)",
        r"(;\s*(DROP|DELETE|INSERT|UPDATE|EXEC))",
        r"(UNION\s+SELECT)",
        r"(xp_cmdshell)",
        r"(CONCAT\s*\()",
    ]
    
    @classmethod
    def detect_injection(cls, input_string: str) -> bool:
        """检测SQL注入"""
        if not input_string:
            return False
        
        upper_input = input_string.upper()
        
        for pattern in cls.DANGEROUS_PATTERNS:
            if re.search(pattern, upper_input, re.IGNORECASE):
                return True
        
        return False
    
    @classmethod
    def sanitize_identifier(cls, identifier: str) -> str:
        """清理标识符(表名、列名)"""
        if not identifier:
            return ""
        
        if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier):
            raise ValueError(f"无效的标识符: {identifier}")
        
        return identifier

45.6.3 安全配置管理

python
import os
import json
from typing import Dict, Any, Optional
from dataclasses import dataclass

@dataclass
class SecurityConfig:
    """安全配置"""
    secret_key: str
    jwt_secret: str
    database_url: str
    redis_url: str
    allowed_hosts: list
    cors_origins: list
    debug: bool = False
    
    @classmethod
    def from_env(cls) -> "SecurityConfig":
        """从环境变量加载配置"""
        return cls(
            secret_key=os.environ.get("SECRET_KEY", ""),
            jwt_secret=os.environ.get("JWT_SECRET", ""),
            database_url=os.environ.get("DATABASE_URL", ""),
            redis_url=os.environ.get("REDIS_URL", ""),
            allowed_hosts=os.environ.get("ALLOWED_HOSTS", "localhost").split(","),
            cors_origins=os.environ.get("CORS_ORIGINS", "*").split(","),
            debug=os.environ.get("DEBUG", "false").lower() == "true"
        )
    
    def validate(self) -> list:
        """验证配置"""
        errors = []
        
        if not self.secret_key or len(self.secret_key) < 32:
            errors.append("SECRET_KEY必须设置且至少32位")
        
        if not self.jwt_secret or len(self.jwt_secret) < 32:
            errors.append("JWT_SECRET必须设置且至少32位")
        
        if self.debug and "localhost" not in self.allowed_hosts:
            errors.append("DEBUG模式下ALLOWED_HOSTS应限制为localhost")
        
        return errors


class SecretsManager:
    """密钥管理器"""
    
    def __init__(self):
        self._secrets: Dict[str, str] = {}
    
    def load_from_env(self, prefix: str = ""):
        """从环境变量加载密钥"""
        for key, value in os.environ.items():
            if prefix and key.startswith(prefix):
                self._secrets[key[len(prefix):]] = value
            elif not prefix:
                self._secrets[key] = value
    
    def load_from_file(self, filepath: str):
        """从文件加载密钥"""
        with open(filepath, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#') and '=' in line:
                    key, value = line.split('=', 1)
                    self._secrets[key.strip()] = value.strip()
    
    def get(self, key: str, default: str = None) -> Optional[str]:
        """获取密钥"""
        return self._secrets.get(key, default)
    
    def require(self, key: str) -> str:
        """获取必需的密钥"""
        value = self._secrets.get(key)
        if value is None:
            raise KeyError(f"Required secret not found: {key}")
        return value
    
    def mask_sensitive(self, data: Dict, keys: list = None) -> Dict:
        """脱敏敏感数据"""
        keys = keys or ['password', 'secret', 'token', 'key', 'api_key']
        masked = data.copy()
        
        for key in masked:
            if any(sensitive in key.lower() for sensitive in keys):
                masked[key] = '***MASKED***'
        
        return masked


class SecurityHeaders:
    """安全响应头"""
    
    @staticmethod
    def get_default_headers() -> Dict[str, str]:
        """获取默认安全头"""
        return {
            "X-Content-Type-Options": "nosniff",
            "X-Frame-Options": "DENY",
            "X-XSS-Protection": "1; mode=block",
            "Strict-Transport-Security": "max-age=31536000; includeSubDomains",
            "Content-Security-Policy": "default-src 'self'",
            "Referrer-Policy": "strict-origin-when-cross-origin",
            "Permissions-Policy": "geolocation=(), microphone=(), camera=()"
        }
    
    @staticmethod
    def get_csp_header(
        default_src: str = "'self'",
        script_src: str = "'self'",
        style_src: str = "'self' 'unsafe-inline'",
        img_src: str = "'self' data:",
        connect_src: str = "'self'",
        font_src: str = "'self'",
        frame_src: str = "'none'"
    ) -> str:
        """生成CSP头"""
        return (
            f"default-src {default_src}; "
            f"script-src {script_src}; "
            f"style-src {style_src}; "
            f"img-src {img_src}; "
            f"connect-src {connect_src}; "
            f"font-src {font_src}; "
            f"frame-src {frame_src}"
        )

45.7 本章小结

本章详细介绍了Python安全编程的核心概念和实践:

  1. 密码学基础:对称加密、非对称加密、哈希算法、数字签名
  2. 安全认证:密码存储、JWT、API密钥管理
  3. 输入验证:数据校验、模式匹配、安全编码
  4. SQL注入防护:参数化查询、输入过滤
  5. OWASP防护:XSS、CSRF、命令注入、路径遍历

练习题

  1. 实现一个完整的密码管理系统,支持加密存储和安全验证
  2. 开发一个JWT认证中间件,支持令牌刷新和权限验证
  3. 实现一个SQL注入检测器,分析代码中的安全风险
  4. 开发一个XSS过滤器,支持HTML内容的安全渲染
  5. 实现一个安全审计工具,自动检测代码中的安全漏洞

扩展阅读

Python技术丛书 - 江苏省宿城中等专业学校