第45章 安全编程
学习目标
完成本章学习后,你将能够:
- 掌握密码学基础:对称加密、非对称加密、哈希算法、数字签名
- 实现安全认证:密码存储、JWT、OAuth2、多因素认证
- 防范常见攻击:SQL注入、XSS、CSRF、命令注入
- 进行安全编码:输入验证、输出编码、安全配置
- 实施安全审计:日志记录、入侵检测、漏洞扫描
- 进行渗透测试:安全测试、漏洞利用、修复建议
- 遵循OWASP指南:Top 10漏洞、安全最佳实践
- 实现安全通信:TLS/SSL、证书管理、安全API
45.1 密码学基础
45.1.1 加密与解密
python
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
from typing import Tuple, Optional
import base64
import os
import secrets
import hashlib
from dataclasses import dataclass
@dataclass
class EncryptedData:
ciphertext: bytes
iv: bytes
tag: Optional[bytes] = None
class SymmetricEncryption:
def __init__(self, key: bytes = None):
self.key = key or Fernet.generate_key()
self.fernet = Fernet(self.key)
@classmethod
def generate_key(cls) -> bytes:
return Fernet.generate_key()
def encrypt(self, plaintext: str) -> str:
ciphertext = self.fernet.encrypt(plaintext.encode())
return ciphertext.decode()
def decrypt(self, ciphertext: str) -> str:
plaintext = self.fernet.decrypt(ciphertext.encode())
return plaintext.decode()
def encrypt_dict(self, data: dict) -> str:
import json
json_str = json.dumps(data)
return self.encrypt(json_str)
def decrypt_dict(self, ciphertext: str) -> dict:
import json
json_str = self.decrypt(ciphertext)
return json.loads(json_str)
class AsymmetricEncryption:
def __init__(self, private_key=None, public_key=None):
self.private_key = private_key
self.public_key = public_key
@classmethod
def generate_key_pair(cls, key_size: int = 2048) -> "AsymmetricEncryption":
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
backend=default_backend()
)
public_key = private_key.public_key()
return cls(private_key=private_key, public_key=public_key)
def encrypt(self, plaintext: bytes) -> bytes:
ciphertext = self.public_key.encrypt(
plaintext,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
return ciphertext
def decrypt(self, ciphertext: bytes) -> bytes:
plaintext = self.private_key.decrypt(
ciphertext,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
return plaintext
def sign(self, message: bytes) -> bytes:
signature = self.private_key.sign(
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return signature
def verify(self, message: bytes, signature: bytes) -> bool:
try:
self.public_key.verify(
signature,
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return True
except Exception:
return False
def export_private_key(self, password: str = None) -> bytes:
if password:
encryption = serialization.BestAvailableEncryption(password.encode())
else:
encryption = serialization.NoEncryption()
return self.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=encryption
)
def export_public_key(self) -> bytes:
return self.public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
class HashUtils:
@staticmethod
def sha256(data: str) -> str:
return hashlib.sha256(data.encode()).hexdigest()
@staticmethod
def sha512(data: str) -> str:
return hashlib.sha512(data.encode()).hexdigest()
@staticmethod
def md5(data: str) -> str:
return hashlib.md5(data.encode()).hexdigest()
@staticmethod
def hmac_sha256(key: bytes, message: str) -> str:
import hmac
return hmac.new(key, message.encode(), hashlib.sha256).hexdigest()
class PasswordHasher:
def __init__(self, iterations: int = 100000):
self.iterations = iterations
def hash(self, password: str) -> str:
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=self.iterations,
backend=default_backend()
)
key = kdf.derive(password.encode())
stored = base64.b64encode(salt + key).decode()
return stored
def verify(self, password: str, stored_hash: str) -> bool:
decoded = base64.b64decode(stored_hash.encode())
salt = decoded[:16]
stored_key = decoded[16:]
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=self.iterations,
backend=default_backend()
)
try:
kdf.verify(password.encode(), stored_key)
return True
except Exception:
return False
def generate_random_password(self, length: int = 16) -> str:
import string
alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
return ''.join(secrets.choice(alphabet) for _ in range(length))
class SecureToken:
@staticmethod
def generate(length: int = 32) -> str:
return secrets.token_hex(length)
@staticmethod
def generate_url_safe(length: int = 32) -> str:
return secrets.token_urlsafe(length)
@staticmethod
def compare(a: str, b: str) -> bool:
return secrets.compare_digest(a, b)45.1.2 JWT认证
python
import jwt
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from dataclasses import dataclass
@dataclass
class TokenPayload:
user_id: str
username: str
roles: List[str]
exp: datetime
iat: datetime
class JWTManager:
def __init__(
self,
secret_key: str,
algorithm: str = "HS256",
access_token_expire: int = 30,
refresh_token_expire: int = 7
):
self.secret_key = secret_key
self.algorithm = algorithm
self.access_token_expire = access_token_expire
self.refresh_token_expire = refresh_token_expire
def create_access_token(
self,
user_id: str,
username: str,
roles: List[str] = None,
additional_claims: Dict = None
) -> str:
now = datetime.utcnow()
expire = now + timedelta(minutes=self.access_token_expire)
payload = {
"sub": user_id,
"username": username,
"roles": roles or [],
"type": "access",
"iat": now,
"exp": expire
}
if additional_claims:
payload.update(additional_claims)
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def create_refresh_token(self, user_id: str) -> str:
now = datetime.utcnow()
expire = now + timedelta(days=self.refresh_token_expire)
payload = {
"sub": user_id,
"type": "refresh",
"iat": now,
"exp": expire
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def decode_token(self, token: str) -> Optional[Dict]:
try:
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm]
)
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def verify_token(self, token: str, token_type: str = "access") -> bool:
payload = self.decode_token(token)
if not payload:
return False
return payload.get("type") == token_type
def refresh_access_token(self, refresh_token: str) -> Optional[str]:
payload = self.decode_token(refresh_token)
if not payload or payload.get("type") != "refresh":
return None
user_id = payload.get("sub")
return self.create_access_token(user_id, "", [])
def get_user_id(self, token: str) -> Optional[str]:
payload = self.decode_token(token)
if payload:
return payload.get("sub")
return None
def get_roles(self, token: str) -> List[str]:
payload = self.decode_token(token)
if payload:
return payload.get("roles", [])
return []
class APIKeyManager:
def __init__(self):
self._api_keys: Dict[str, Dict] = {}
def generate_api_key(
self,
user_id: str,
name: str,
permissions: List[str] = None,
expires_days: int = 365
) -> str:
api_key = f"pk_{secrets.token_urlsafe(32)}"
self._api_keys[api_key] = {
"user_id": user_id,
"name": name,
"permissions": permissions or [],
"created_at": datetime.utcnow(),
"expires_at": datetime.utcnow() + timedelta(days=expires_days)
}
return api_key
def verify_api_key(self, api_key: str) -> Optional[Dict]:
if api_key not in self._api_keys:
return None
key_info = self._api_keys[api_key]
if key_info["expires_at"] < datetime.utcnow():
del self._api_keys[api_key]
return None
return key_info
def revoke_api_key(self, api_key: str) -> bool:
if api_key in self._api_keys:
del self._api_keys[api_key]
return True
return False
def list_user_keys(self, user_id: str) -> List[Dict]:
return [
{"key": k, **v}
for k, v in self._api_keys.items()
if v["user_id"] == user_id
]45.2 安全编码实践
45.2.1 输入验证
python
import re
from typing import Any, List, Optional, Callable
from dataclasses import dataclass
from enum import Enum
class ValidationError(Exception):
def __init__(self, field: str, message: str):
self.field = field
self.message = message
super().__init__(f"{field}: {message}")
@dataclass
class ValidationResult:
is_valid: bool
errors: List[str] = None
def __bool__(self) -> bool:
return self.is_valid
class InputValidator:
@staticmethod
def validate_email(email: str) -> ValidationResult:
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if re.match(pattern, email):
return ValidationResult(is_valid=True)
return ValidationResult(is_valid=False, errors=["Invalid email format"])
@staticmethod
def validate_password(
password: str,
min_length: int = 8,
require_upper: bool = True,
require_lower: bool = True,
require_digit: bool = True,
require_special: bool = True
) -> ValidationResult:
errors = []
if len(password) < min_length:
errors.append(f"Password must be at least {min_length} characters")
if require_upper and not re.search(r'[A-Z]', password):
errors.append("Password must contain at least one uppercase letter")
if require_lower and not re.search(r'[a-z]', password):
errors.append("Password must contain at least one lowercase letter")
if require_digit and not re.search(r'\d', password):
errors.append("Password must contain at least one digit")
if require_special and not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
errors.append("Password must contain at least one special character")
return ValidationResult(is_valid=len(errors) == 0, errors=errors)
@staticmethod
def validate_username(username: str) -> ValidationResult:
errors = []
if len(username) < 3:
errors.append("Username must be at least 3 characters")
if len(username) > 20:
errors.append("Username must be at most 20 characters")
if not re.match(r'^[a-zA-Z0-9_]+$', username):
errors.append("Username can only contain letters, numbers, and underscores")
return ValidationResult(is_valid=len(errors) == 0, errors=errors)
@staticmethod
def validate_phone(phone: str) -> ValidationResult:
pattern = r'^\+?1?\d{9,15}$'
if re.match(pattern, phone):
return ValidationResult(is_valid=True)
return ValidationResult(is_valid=False, errors=["Invalid phone number format"])
@staticmethod
def validate_url(url: str) -> ValidationResult:
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
if re.match(pattern, url):
return ValidationResult(is_valid=True)
return ValidationResult(is_valid=False, errors=["Invalid URL format"])
@staticmethod
def sanitize_html(html: str) -> str:
import html as html_module
return html_module.escape(html)
@staticmethod
def sanitize_sql(value: str) -> str:
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "xp_", "sp_"]
sanitized = value
for char in dangerous_chars:
sanitized = sanitized.replace(char, "")
return sanitized
class SchemaValidator:
def __init__(self):
self._rules: Dict[str, List[Callable]] = {}
def field(self, name: str) -> "FieldValidator":
return FieldValidator(name, self)
def add_rule(self, field: str, rule: Callable) -> None:
if field not in self._rules:
self._rules[field] = []
self._rules[field].append(rule)
def validate(self, data: Dict) -> ValidationResult:
errors = []
for field, rules in self._rules.items():
value = data.get(field)
for rule in rules:
try:
rule(value)
except ValidationError as e:
errors.append(str(e))
return ValidationResult(is_valid=len(errors) == 0, errors=errors)
class FieldValidator:
def __init__(self, name: str, schema: SchemaValidator):
self.name = name
self.schema = schema
def required(self) -> "FieldValidator":
def rule(value):
if value is None or value == "":
raise ValidationError(self.name, "This field is required")
self.schema.add_rule(self.name, rule)
return self
def min_length(self, length: int) -> "FieldValidator":
def rule(value):
if value and len(str(value)) < length:
raise ValidationError(self.name, f"Minimum length is {length}")
self.schema.add_rule(self.name, rule)
return self
def max_length(self, length: int) -> "FieldValidator":
def rule(value):
if value and len(str(value)) > length:
raise ValidationError(self.name, f"Maximum length is {length}")
self.schema.add_rule(self.name, rule)
return self
def email(self) -> "FieldValidator":
def rule(value):
if value:
result = InputValidator.validate_email(value)
if not result.is_valid:
raise ValidationError(self.name, result.errors[0])
self.schema.add_rule(self.name, rule)
return self
def pattern(self, regex: str) -> "FieldValidator":
def rule(value):
if value and not re.match(regex, str(value)):
raise ValidationError(self.name, "Invalid format")
self.schema.add_rule(self.name, rule)
return self
def custom(self, validator: Callable) -> "FieldValidator":
self.schema.add_rule(self.name, validator)
return self45.2.2 SQL注入防护
python
import sqlite3
from typing import List, Dict, Any, Optional
from contextlib import contextmanager
class SecureDatabase:
def __init__(self, connection_string: str):
self.connection_string = connection_string
@contextmanager
def get_connection(self):
conn = sqlite3.connect(self.connection_string)
conn.row_factory = sqlite3.Row
try:
yield conn
finally:
conn.close()
def execute_safe(
self,
query: str,
params: tuple = None
) -> List[Dict]:
with self.get_connection() as conn:
cursor = conn.cursor()
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
return [dict(row) for row in cursor.fetchall()]
def insert(self, table: str, data: Dict) -> int:
columns = ", ".join(data.keys())
placeholders = ", ".join(["?" for _ in data])
query = f"INSERT INTO {table} ({columns}) VALUES ({placeholders})"
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(query, tuple(data.values()))
conn.commit()
return cursor.lastrowid
def update(
self,
table: str,
data: Dict,
where: str,
where_params: tuple = None
) -> int:
set_clause = ", ".join([f"{k} = ?" for k in data.keys()])
query = f"UPDATE {table} SET {set_clause} WHERE {where}"
params = tuple(data.values())
if where_params:
params = params + where_params
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(query, params)
conn.commit()
return cursor.rowcount
def delete(
self,
table: str,
where: str,
where_params: tuple = None
) -> int:
query = f"DELETE FROM {table} WHERE {where}"
with self.get_connection() as conn:
cursor = conn.cursor()
if where_params:
cursor.execute(query, where_params)
else:
cursor.execute(query)
conn.commit()
return cursor.rowcount
def select(
self,
table: str,
columns: List[str] = None,
where: str = None,
where_params: tuple = None,
order_by: str = None,
limit: int = None
) -> List[Dict]:
cols = ", ".join(columns) if columns else "*"
query = f"SELECT {cols} FROM {table}"
if where:
query += f" WHERE {where}"
if order_by:
query += f" ORDER BY {order_by}"
if limit:
query += f" LIMIT {limit}"
return self.execute_safe(query, where_params)
class SQLInjectionPrevention:
@staticmethod
def detect_injection(input_string: str) -> bool:
patterns = [
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|ALTER|CREATE|TRUNCATE)\b)",
r"(--|#|/\*|\*/)",
r"(\bOR\b|\bAND\b).*?=",
r"(\bunion\b.*?\bselect\b)",
r"(\bexec\b|\bexecute\b)",
r"(\bxp_|sp_)",
r"(;|\|)",
]
for pattern in patterns:
if re.search(pattern, input_string, re.IGNORECASE):
return True
return False
@staticmethod
def escape_identifier(identifier: str) -> str:
return f'"{identifier.replace("\"", "\"\"")}"'
@staticmethod
def build_safe_query(
table: str,
columns: List[str],
where_conditions: Dict[str, Any]
) -> tuple:
safe_table = SQLInjectionPrevention.escape_identifier(table)
safe_columns = ", ".join(
SQLInjectionPrevention.escape_identifier(col) for col in columns
)
where_parts = []
params = []
for col, value in where_conditions.items():
safe_col = SQLInjectionPrevention.escape_identifier(col)
where_parts.append(f"{safe_col} = ?")
params.append(value)
where_clause = " AND ".join(where_parts)
query = f"SELECT {safe_columns} FROM {safe_table} WHERE {where_clause}"
return query, tuple(params)45.3 OWASP Top 10防护
45.3.1 常见攻击防护
python
from dataclasses import dataclass
from typing import List, Dict, Optional
import re
import html
@dataclass
class SecurityHeaders:
content_type_options: str = "nosniff"
frame_options: str = "DENY"
xss_protection: str = "1; mode=block"
content_security_policy: str = "default-src 'self'"
strict_transport_security: str = "max-age=31536000; includeSubDomains"
def to_dict(self) -> Dict[str, str]:
return {
"X-Content-Type-Options": self.content_type_options,
"X-Frame-Options": self.frame_options,
"X-XSS-Protection": self.xss_protection,
"Content-Security-Policy": self.content_security_policy,
"Strict-Transport-Security": self.strict_transport_security
}
class XSSPrevention:
@staticmethod
def escape_html(text: str) -> str:
return html.escape(text, quote=True)
@staticmethod
def sanitize_input(text: str) -> str:
dangerous_patterns = [
(r'<script[^>]*>.*?</script>', '', re.IGNORECASE | re.DOTALL),
(r'javascript:', '', re.IGNORECASE),
(r'on\w+\s*=', '', re.IGNORECASE),
(r'<iframe[^>]*>.*?</iframe>', '', re.IGNORECASE | re.DOTALL),
]
sanitized = text
for pattern, replacement, flags in dangerous_patterns:
sanitized = re.sub(pattern, replacement, sanitized, flags=flags)
return sanitized
@staticmethod
def validate_content(content: str) -> bool:
xss_patterns = [
r'<script',
r'javascript:',
r'on\w+\s*=',
r'<iframe',
r'<object',
r'<embed',
r'<link',
r'<style',
]
for pattern in xss_patterns:
if re.search(pattern, content, re.IGNORECASE):
return False
return True
class CSRFProtection:
def __init__(self):
self._tokens: Dict[str, Dict] = {}
def generate_token(self, session_id: str) -> str:
token = secrets.token_urlsafe(32)
self._tokens[session_id] = {
"token": token,
"created_at": datetime.utcnow()
}
return token
def validate_token(self, session_id: str, token: str) -> bool:
if session_id not in self._tokens:
return False
stored = self._tokens[session_id]
if datetime.utcnow() - stored["created_at"] > timedelta(hours=1):
del self._tokens[session_id]
return False
return secrets.compare_digest(stored["token"], token)
def invalidate_token(self, session_id: str) -> None:
if session_id in self._tokens:
del self._tokens[session_id]
class CommandInjectionPrevention:
@staticmethod
def sanitize_command_arg(arg: str) -> str:
allowed_chars = r'[^a-zA-Z0-9_\-\.]'
return re.sub(allowed_chars, '', arg)
@staticmethod
def validate_command(command: str) -> bool:
dangerous_patterns = [
r'[;&|`$]',
r'\$\(',
r'`.*`',
r'\|\|',
r'&&',
r'>',
r'<',
]
for pattern in dangerous_patterns:
if re.search(pattern, command):
return False
return True
@staticmethod
def safe_subprocess_run(
command: List[str],
**kwargs
) -> Any:
import subprocess
safe_command = [
CommandInjectionPrevention.sanitize_command_arg(arg)
for arg in command
]
return subprocess.run(
safe_command,
shell=False,
**kwargs
)
class PathTraversalPrevention:
@staticmethod
def sanitize_path(path: str, base_dir: str) -> Optional[str]:
import os
path = path.replace('..', '')
path = path.lstrip('/\\')
full_path = os.path.normpath(os.path.join(base_dir, path))
if not full_path.startswith(os.path.abspath(base_dir)):
return None
return full_path
@staticmethod
def is_safe_path(path: str) -> bool:
dangerous_patterns = [
r'\.\.',
r'~',
r'/etc/',
r'/proc/',
r'/sys/',
]
for pattern in dangerous_patterns:
if re.search(pattern, path):
return False
return True
class SecurityAuditor:
def __init__(self):
self.vulnerabilities: List[Dict] = []
def check_password_strength(self, password: str) -> Dict:
issues = []
if len(password) < 8:
issues.append("Password too short")
if not re.search(r'[A-Z]', password):
issues.append("No uppercase letters")
if not re.search(r'[a-z]', password):
issues.append("No lowercase letters")
if not re.search(r'\d', password):
issues.append("No digits")
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
issues.append("No special characters")
return {
"is_secure": len(issues) == 0,
"issues": issues
}
def check_sql_injection_risk(self, query: str) -> Dict:
risks = []
if re.search(r'f["\'].*\{.*\}.*["\']', query):
risks.append("String formatting in SQL query")
if re.search(r'\+', query):
risks.append("String concatenation detected")
if not re.search(r'\?', query) and not re.search(r'%s', query):
risks.append("No parameterized query detected")
return {
"has_risk": len(risks) > 0,
"risks": risks
}
def check_xss_risk(self, html_content: str) -> Dict:
risks = []
if re.search(r'<script', html_content, re.IGNORECASE):
risks.append("Script tag detected")
if re.search(r'javascript:', html_content, re.IGNORECASE):
risks.append("JavaScript protocol detected")
if re.search(r'on\w+\s*=', html_content, re.IGNORECASE):
risks.append("Event handler detected")
return {
"has_risk": len(risks) > 0,
"risks": risks
}
def generate_report(self) -> Dict:
return {
"total_vulnerabilities": len(self.vulnerabilities),
"vulnerabilities": self.vulnerabilities
}45.4 知识图谱
45.4.1 安全编程体系
┌─────────────────────────────────────────────────────────────────────┐
│ 安全编程全景图 │
├─────────────────────────────────────────────────────────────────────┤
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 应用层安全 │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 认证授权 │ │ 会话管理 │ │ 访问控制 │ │ 审计日志 │ │ │
│ │ │ OAuth2 │ │ JWT │ │ RBAC │ │ Audit │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────┐ │
│ │ 数据安全 │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 加密存储 │ │ 传输加密 │ │ 密钥管理 │ │ 数据脱敏 │ │ │
│ │ │ AES │ │ TLS │ │ KMS │ │ Masking │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────┐ │
│ │ 输入安全 │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 输入验证 │ │ SQL注入 │ │ XSS防护 │ │ CSRF防护 │ │ │
│ │ │ Validate│ │ SQLi │ │ Filter │ │ Token │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────────────┴───────────────────────────────┐ │
│ │ 基础设施安全 │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 网络安全 │ │ 容器安全 │ │ 依赖安全 │ │ 配置安全 │ │ │
│ │ │ Firewall│ │ Docker │ │ SCA │ │ Secrets │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘45.4.2 OWASP Top 10 防护
┌─────────────────────────────────────────────────────────────────────┐
│ OWASP Top 10 安全风险 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A01: 访问控制失效 (Broken Access Control) │ │
│ │ 防护:最小权限原则、RBAC、ABAC、访问控制列表 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A02: 加密失败 (Cryptographic Failures) │ │
│ │ 防护:强加密算法、密钥管理、TLS传输、敏感数据加密 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A03: 注入攻击 (Injection) │ │
│ │ 防护:参数化查询、输入验证、输出编码、ORM框架 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A04: 不安全设计 (Insecure Design) │ │
│ │ 防护:安全设计模式、威胁建模、安全架构审查 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A05: 安全配置错误 (Security Misconfiguration) │ │
│ │ 防护:安全默认配置、配置审计、环境隔离 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A06: 易受攻击组件 (Vulnerable Components) │ │
│ │ 防护:依赖扫描、版本更新、SBOM管理 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A07: 身份认证失败 (Authentication Failures) │ │
│ │ 防护:多因素认证、密码策略、会话管理、登录限制 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A08: 软件完整性失败 (Software Integrity Failures) │ │
│ │ 防护:代码签名、CI/CD安全、依赖验证 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A09: 日志监控失败 (Logging Failures) │ │
│ │ 防护:安全日志、异常检测、告警机制 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ A10: 服务端请求伪造 (SSRF) │ │
│ │ 防护:URL白名单、输入验证、网络隔离 │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘45.5 技术选型指南
45.5.1 加密算法选型
| 用途 | 推荐算法 | 密钥长度 | 安全等级 | 备注 |
|---|---|---|---|---|
| 对称加密 | AES-GCM | 256位 | 高 | 推荐用于数据加密 |
| 非对称加密 | RSA-OAEP | 4096位 | 高 | 密钥交换、数字签名 |
| 哈希算法 | SHA-256 | - | 高 | 密码存储用bcrypt |
| 密码哈希 | bcrypt/Argon2 | - | 极高 | 内置盐值和工作因子 |
| 数字签名 | Ed25519 | - | 高 | 高性能、安全 |
45.5.2 认证方案选型
| 方案 | 适用场景 | 安全性 | 复杂度 | 无状态 |
|---|---|---|---|---|
| JWT | SPA、移动应用、微服务 | 高 | 中 | ✅ |
| OAuth2 | 第三方授权、社交登录 | 高 | 高 | ❌ |
| Session | 传统Web应用 | 中 | 低 | ❌ |
| API Key | 服务间调用 | 中 | 低 | ✅ |
| mTLS | 零信任网络 | 极高 | 高 | ✅ |
45.5.3 安全工具选型
| 工具类型 | 工具名称 | 功能 | 推荐指数 |
|---|---|---|---|
| 依赖扫描 | Safety, Snyk | 检测依赖漏洞 | ★★★★★ |
| 代码审计 | Bandit, Semgrep | 静态代码分析 | ★★★★★ |
| 密钥检测 | git-secrets, truffleHog | 检测泄露密钥 | ★★★★☆ |
| 容器扫描 | Trivy, Clair | 镜像漏洞扫描 | ★★★★★ |
| DAST | OWASP ZAP | 动态安全测试 | ★★★★☆ |
45.6 常见问题与解决方案
45.6.1 密码安全存储
python
import secrets
import hashlib
from typing import Tuple, Optional
from dataclasses import dataclass
@dataclass
class PasswordPolicy:
"""密码策略配置"""
min_length: int = 12
max_length: int = 128
require_uppercase: bool = True
require_lowercase: bool = True
require_digits: bool = True
require_special: bool = True
special_chars: str = "!@#$%^&*()_+-=[]{}|;:,.<>?"
max_age_days: int = 90
history_count: int = 5
class SecurePasswordManager:
"""安全密码管理器"""
def __init__(self, policy: PasswordPolicy = None):
self.policy = policy or PasswordPolicy()
self._password_history: dict = {}
def hash_password(self, password: str) -> str:
"""使用bcrypt安全哈希密码"""
import bcrypt
salt = bcrypt.gensalt(rounds=12)
return bcrypt.hashpw(password.encode(), salt).decode()
def verify_password(self, password: str, hashed: str) -> bool:
"""验证密码"""
import bcrypt
try:
return bcrypt.checkpw(password.encode(), hashed.encode())
except Exception:
return False
def validate_password(self, password: str) -> Tuple[bool, list]:
"""验证密码强度"""
errors = []
if len(password) < self.policy.min_length:
errors.append(f"密码长度至少{self.policy.min_length}位")
if len(password) > self.policy.max_length:
errors.append(f"密码长度不能超过{self.policy.max_length}位")
if self.policy.require_uppercase and not any(c.isupper() for c in password):
errors.append("密码必须包含大写字母")
if self.policy.require_lowercase and not any(c.islower() for c in password):
errors.append("密码必须包含小写字母")
if self.policy.require_digits and not any(c.isdigit() for c in password):
errors.append("密码必须包含数字")
if self.policy.require_special:
if not any(c in self.policy.special_chars for c in password):
errors.append("密码必须包含特殊字符")
common_passwords = [
"password", "123456", "qwerty", "admin", "letmein"
]
if password.lower() in common_passwords:
errors.append("密码过于常见,请使用更强的密码")
return len(errors) == 0, errors
def generate_secure_password(self, length: int = 16) -> str:
"""生成安全密码"""
import string
alphabet = string.ascii_letters + string.digits + self.policy.special_chars
while True:
password = ''.join(secrets.choice(alphabet) for _ in range(length))
valid, _ = self.validate_password(password)
if valid:
return password
def check_history(self, user_id: str, password: str) -> bool:
"""检查密码历史"""
if user_id not in self._password_history:
return True
for old_hash in self._password_history[user_id]:
if self.verify_password(password, old_hash):
return False
return True
def update_history(self, user_id: str, hashed_password: str):
"""更新密码历史"""
if user_id not in self._password_history:
self._password_history[user_id] = []
self._password_history[user_id].append(hashed_password)
if len(self._password_history[user_id]) > self.policy.history_count:
self._password_history[user_id].pop(0)
class SecureTokenManager:
"""安全令牌管理器"""
@staticmethod
def generate_token(length: int = 32) -> str:
"""生成安全随机令牌"""
return secrets.token_urlsafe(length)
@staticmethod
def generate_api_key(prefix: str = "sk") -> str:
"""生成API密钥"""
key = secrets.token_urlsafe(32)
return f"{prefix}_{key}"
@staticmethod
def hash_token(token: str) -> str:
"""哈希令牌(用于存储)"""
return hashlib.sha256(token.encode()).hexdigest()
@staticmethod
def constant_time_compare(a: str, b: str) -> bool:
"""常量时间比较(防止时序攻击)"""
return secrets.compare_digest(a, b)45.6.2 安全输入验证
python
import re
import html
from typing import Any, Dict, List, Optional, Callable
from dataclasses import dataclass, field
@dataclass
class ValidationResult:
"""验证结果"""
is_valid: bool
value: Any = None
errors: List[str] = field(default_factory=list)
class InputValidator:
"""安全输入验证器"""
EMAIL_PATTERN = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
PHONE_PATTERN = r'^1[3-9]\d{9}$'
USERNAME_PATTERN = r'^[a-zA-Z0-9_]{3,20}$'
@staticmethod
def sanitize_string(value: str, max_length: int = 1000) -> str:
"""清理字符串"""
if not isinstance(value, str):
return ""
value = value.strip()
value = value[:max_length]
value = html.escape(value)
return value
@staticmethod
def sanitize_html(value: str, allowed_tags: List[str] = None) -> str:
"""清理HTML"""
from html.parser import HTMLParser
class HTMLSanitizer(HTMLParser):
def __init__(self, allowed_tags):
super().__init__()
self.allowed_tags = allowed_tags or []
self.result = []
def handle_starttag(self, tag, attrs):
if tag in self.allowed_tags:
safe_attrs = []
for name, value in attrs:
if name in ['href', 'src'] and not value.startswith('javascript:'):
safe_attrs.append(f'{name}="{html.escape(value)}"')
self.result.append(f'<{tag} {" ".join(safe_attrs)}>')
def handle_endtag(self, tag):
if tag in self.allowed_tags:
self.result.append(f'</{tag}>')
def handle_data(self, data):
self.result.append(html.escape(data))
sanitizer = HTMLSanitizer(allowed_tags)
sanitizer.feed(value)
return ''.join(sanitizer.result)
@classmethod
def validate_email(cls, value: str) -> ValidationResult:
"""验证邮箱"""
errors = []
if not value:
errors.append("邮箱不能为空")
elif len(value) > 254:
errors.append("邮箱长度超出限制")
elif not re.match(cls.EMAIL_PATTERN, value):
errors.append("邮箱格式不正确")
return ValidationResult(
is_valid=len(errors) == 0,
value=value.lower() if not errors else None,
errors=errors
)
@classmethod
def validate_phone(cls, value: str) -> ValidationResult:
"""验证手机号"""
errors = []
if not value:
errors.append("手机号不能为空")
elif not re.match(cls.PHONE_PATTERN, value):
errors.append("手机号格式不正确")
return ValidationResult(
is_valid=len(errors) == 0,
value=value if not errors else None,
errors=errors
)
@classmethod
def validate_url(cls, value: str, allowed_schemes: List[str] = None) -> ValidationResult:
"""验证URL"""
from urllib.parse import urlparse
errors = []
allowed_schemes = allowed_schemes or ['http', 'https']
if not value:
errors.append("URL不能为空")
else:
try:
parsed = urlparse(value)
if parsed.scheme.lower() not in allowed_schemes:
errors.append(f"只允许 {allowed_schemes} 协议")
if not parsed.netloc:
errors.append("URL格式不正确")
except Exception:
errors.append("URL解析失败")
return ValidationResult(
is_valid=len(errors) == 0,
value=value if not errors else None,
errors=errors
)
@staticmethod
def validate_integer(value: Any, min_val: int = None, max_val: int = None) -> ValidationResult:
"""验证整数"""
errors = []
try:
int_value = int(value)
if min_val is not None and int_value < min_val:
errors.append(f"值不能小于{min_val}")
if max_val is not None and int_value > max_val:
errors.append(f"值不能大于{max_val}")
except (ValueError, TypeError):
errors.append("必须是有效的整数")
int_value = None
return ValidationResult(
is_valid=len(errors) == 0,
value=int_value if not errors else None,
errors=errors
)
class SQLInjectionProtector:
"""SQL注入防护器"""
DANGEROUS_PATTERNS = [
r"('\s*(OR|AND)\s*')",
r"(--|#|\/\*)",
r"(;\s*(DROP|DELETE|INSERT|UPDATE|EXEC))",
r"(UNION\s+SELECT)",
r"(xp_cmdshell)",
r"(CONCAT\s*\()",
]
@classmethod
def detect_injection(cls, input_string: str) -> bool:
"""检测SQL注入"""
if not input_string:
return False
upper_input = input_string.upper()
for pattern in cls.DANGEROUS_PATTERNS:
if re.search(pattern, upper_input, re.IGNORECASE):
return True
return False
@classmethod
def sanitize_identifier(cls, identifier: str) -> str:
"""清理标识符(表名、列名)"""
if not identifier:
return ""
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier):
raise ValueError(f"无效的标识符: {identifier}")
return identifier45.6.3 安全配置管理
python
import os
import json
from typing import Dict, Any, Optional
from dataclasses import dataclass
@dataclass
class SecurityConfig:
"""安全配置"""
secret_key: str
jwt_secret: str
database_url: str
redis_url: str
allowed_hosts: list
cors_origins: list
debug: bool = False
@classmethod
def from_env(cls) -> "SecurityConfig":
"""从环境变量加载配置"""
return cls(
secret_key=os.environ.get("SECRET_KEY", ""),
jwt_secret=os.environ.get("JWT_SECRET", ""),
database_url=os.environ.get("DATABASE_URL", ""),
redis_url=os.environ.get("REDIS_URL", ""),
allowed_hosts=os.environ.get("ALLOWED_HOSTS", "localhost").split(","),
cors_origins=os.environ.get("CORS_ORIGINS", "*").split(","),
debug=os.environ.get("DEBUG", "false").lower() == "true"
)
def validate(self) -> list:
"""验证配置"""
errors = []
if not self.secret_key or len(self.secret_key) < 32:
errors.append("SECRET_KEY必须设置且至少32位")
if not self.jwt_secret or len(self.jwt_secret) < 32:
errors.append("JWT_SECRET必须设置且至少32位")
if self.debug and "localhost" not in self.allowed_hosts:
errors.append("DEBUG模式下ALLOWED_HOSTS应限制为localhost")
return errors
class SecretsManager:
"""密钥管理器"""
def __init__(self):
self._secrets: Dict[str, str] = {}
def load_from_env(self, prefix: str = ""):
"""从环境变量加载密钥"""
for key, value in os.environ.items():
if prefix and key.startswith(prefix):
self._secrets[key[len(prefix):]] = value
elif not prefix:
self._secrets[key] = value
def load_from_file(self, filepath: str):
"""从文件加载密钥"""
with open(filepath, 'r') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
self._secrets[key.strip()] = value.strip()
def get(self, key: str, default: str = None) -> Optional[str]:
"""获取密钥"""
return self._secrets.get(key, default)
def require(self, key: str) -> str:
"""获取必需的密钥"""
value = self._secrets.get(key)
if value is None:
raise KeyError(f"Required secret not found: {key}")
return value
def mask_sensitive(self, data: Dict, keys: list = None) -> Dict:
"""脱敏敏感数据"""
keys = keys or ['password', 'secret', 'token', 'key', 'api_key']
masked = data.copy()
for key in masked:
if any(sensitive in key.lower() for sensitive in keys):
masked[key] = '***MASKED***'
return masked
class SecurityHeaders:
"""安全响应头"""
@staticmethod
def get_default_headers() -> Dict[str, str]:
"""获取默认安全头"""
return {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
"Content-Security-Policy": "default-src 'self'",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "geolocation=(), microphone=(), camera=()"
}
@staticmethod
def get_csp_header(
default_src: str = "'self'",
script_src: str = "'self'",
style_src: str = "'self' 'unsafe-inline'",
img_src: str = "'self' data:",
connect_src: str = "'self'",
font_src: str = "'self'",
frame_src: str = "'none'"
) -> str:
"""生成CSP头"""
return (
f"default-src {default_src}; "
f"script-src {script_src}; "
f"style-src {style_src}; "
f"img-src {img_src}; "
f"connect-src {connect_src}; "
f"font-src {font_src}; "
f"frame-src {frame_src}"
)45.7 本章小结
本章详细介绍了Python安全编程的核心概念和实践:
- 密码学基础:对称加密、非对称加密、哈希算法、数字签名
- 安全认证:密码存储、JWT、API密钥管理
- 输入验证:数据校验、模式匹配、安全编码
- SQL注入防护:参数化查询、输入过滤
- OWASP防护:XSS、CSRF、命令注入、路径遍历
练习题
- 实现一个完整的密码管理系统,支持加密存储和安全验证
- 开发一个JWT认证中间件,支持令牌刷新和权限验证
- 实现一个SQL注入检测器,分析代码中的安全风险
- 开发一个XSS过滤器,支持HTML内容的安全渲染
- 实现一个安全审计工具,自动检测代码中的安全漏洞