第22章 测试驱动开发
学习目标
完成本章学习后,读者应能够:
- 理解TDD方法论:掌握红-绿-重构循环与测试金字塔
- 精通pytest框架:灵活运用夹具、参数化、标记与插件
- 掌握unittest体系:使用标准库构建完整测试套件
- 实现Mock与桩代码:隔离外部依赖,编写独立可重复的单元测试
- 掌握集成测试:测试数据库、API与Web应用
- 运用性能测试:使用cProfile与pytest-benchmark进行性能分析
- 精通调试技术:使用pdb、日志与性能分析定位问题
22.1 测试驱动开发方法论
22.1.1 TDD核心循环
测试驱动开发(Test-Driven Development)遵循"红-绿-重构"循环:
┌─────────────────────────────────────────┐
│ │
│ ┌─────────┐ ┌─────────┐ │
│ │ 🔴 RED │───▶│ 🟢 GREEN│───┐ │
│ │ 编写失败 │ │ 快速通过 │ │ │
│ │ 的测试 │ │ 的实现 │ │ │
│ └─────────┘ └─────────┘ │ │
│ ▲ │ │
│ │ ┌────────▼───┐ │
│ │ │ 🔵 REFACTOR│ │
│ └──────────────│ 重构优化 │◀─┘
│ └────────────┘
└─────────────────────────────────────────┘22.1.2 测试金字塔
╱╲
╱ ╲ E2E测试(少量)
╱ E2E╲ - 完整用户流程
╱──────╲ - 执行慢、脆弱
╱ ╲
╱ 集成测试 ╲ 集成测试(适量)
╱────────────╲ - 组件交互
╱ ╲ - 数据库/API
╱ 单元测试 ╲ 单元测试(大量)
╱──────────────────╲ - 快速、独立
╱ ╲- 覆盖率高
╱──────────────────────╲22.1.3 TDD实践示例
以开发一个购物车系统为例:
python
class ShoppingCart:
def __init__(self):
self._items: dict[str, dict] = {}
def add_item(self, name: str, price: float, quantity: int = 1):
if name in self._items:
self._items[name]["quantity"] += quantity
else:
self._items[name] = {"price": price, "quantity": quantity}
def remove_item(self, name: str, quantity: int | None = None):
if name not in self._items:
raise KeyError(f"商品 '{name}' 不在购物车中")
if quantity is None or quantity >= self._items[name]["quantity"]:
del self._items[name]
else:
self._items[name]["quantity"] -= quantity
def get_total(self) -> float:
return sum(item["price"] * item["quantity"] for item in self._items.values())
def get_item_count(self) -> int:
return sum(item["quantity"] for item in self._items.values())
def clear(self):
self._items.clear()
@property
def is_empty(self) -> bool:
return len(self._items) == 022.2 pytest框架
22.2.1 pytest核心特性
python
import pytest
from decimal import Decimal
class TestShoppingCart:
def test_add_item(self):
cart = ShoppingCart()
cart.add_item("Python书", 89.00, 2)
assert not cart.is_empty
assert cart.get_item_count() == 2
def test_add_existing_item(self):
cart = ShoppingCart()
cart.add_item("Python书", 89.00, 1)
cart.add_item("Python书", 89.00, 2)
assert cart.get_item_count() == 3
def test_remove_item(self):
cart = ShoppingCart()
cart.add_item("Python书", 89.00, 3)
cart.remove_item("Python书", 1)
assert cart.get_item_count() == 2
def test_remove_item_completely(self):
cart = ShoppingCart()
cart.add_item("Python书", 89.00, 1)
cart.remove_item("Python书")
assert cart.is_empty
def test_remove_nonexistent_item(self):
cart = ShoppingCart()
with pytest.raises(KeyError, match="不在购物车中"):
cart.remove_item("不存在")
def test_get_total(self):
cart = ShoppingCart()
cart.add_item("Python书", 89.00, 2)
cart.add_item("键盘", 399.00, 1)
assert cart.get_total() == pytest.approx(577.00)
def test_clear(self):
cart = ShoppingCart()
cart.add_item("Python书", 89.00)
cart.clear()
assert cart.is_empty
def test_empty_cart_total(self):
cart = ShoppingCart()
assert cart.get_total() == 0.022.2.2 夹具系统
python
import pytest
from pathlib import Path
import json
import tempfile
@pytest.fixture
def cart():
return ShoppingCart()
@pytest.fixture
def cart_with_items(cart):
cart.add_item("Python书", 89.00, 2)
cart.add_item("键盘", 399.00, 1)
return cart
@pytest.fixture
def temp_json_file():
data = {"items": {"Python书": {"price": 89.00, "quantity": 2}}}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(data, f)
filepath = f.name
yield Path(filepath)
Path(filepath).unlink(missing_ok=True)
class TestWithFixtures:
def test_empty_cart(self, cart):
assert cart.is_empty
def test_cart_with_items_total(self, cart_with_items):
assert cart_with_items.get_total() == pytest.approx(577.00)
def test_cart_with_items_count(self, cart_with_items):
assert cart_with_items.get_item_count() == 3
def test_json_file_load(self, temp_json_file):
data = json.loads(temp_json_file.read_text())
assert "items" in data
assert data["items"]["Python书"]["price"] == 89.00
@pytest.fixture(scope="session")
def db_connection():
import sqlite3
conn = sqlite3.connect(":memory:")
conn.execute("""
CREATE TABLE products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
price REAL NOT NULL
)
""")
conn.execute("INSERT INTO products (name, price) VALUES ('Python书', 89.00)")
conn.commit()
yield conn
conn.close()
def test_db_query(db_connection):
cursor = db_connection.execute("SELECT name, price FROM products")
row = cursor.fetchone()
assert row == ("Python书", 89.00)22.2.3 参数化与标记
python
import pytest
@pytest.mark.parametrize("price, quantity, expected_total", [
(89.00, 1, 89.00),
(89.00, 2, 178.00),
(0.00, 5, 0.00),
(99.99, 3, 299.97),
])
def test_item_total(price, quantity, expected_total):
cart = ShoppingCart()
cart.add_item("商品", price, quantity)
assert cart.get_total() == pytest.approx(expected_total)
@pytest.mark.parametrize("items, expected_count, expected_total", [
([], 0, 0.0),
([("A", 10.0, 1)], 1, 10.0),
([("A", 10.0, 2), ("B", 20.0, 1)], 3, 40.0),
([("A", 100.0, 1), ("B", 200.0, 2), ("C", 50.0, 3)], 6, 650.0),
])
def test_cart_scenarios(items, expected_count, expected_total):
cart = ShoppingCart()
for name, price, qty in items:
cart.add_item(name, price, qty)
assert cart.get_item_count() == expected_count
assert cart.get_total() == pytest.approx(expected_total)
@pytest.mark.slow
def test_large_cart():
cart = ShoppingCart()
for i in range(10000):
cart.add_item(f"商品{i}", 10.0, 1)
assert cart.get_item_count() == 10000
@pytest.mark.skipif(
sys.platform == "win32",
reason="Unix-specific test",
)
def test_unix_feature():
pass
@pytest.mark.xfail(reason="Known bug: #123")
def test_known_bug():
assert 1 == 2pytest配置文件 pytest.ini:
ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short --strict-markers
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
integration: marks integration tests22.3 Mock与桩代码
22.3.1 unittest.mock
python
from unittest.mock import Mock, patch, MagicMock, call
import pytest
class EmailService:
def send_email(self, to: str, subject: str, body: str) -> bool:
import smtplib
server = smtplib.SMTP("smtp.example.com", 587)
server.sendmail("noreply@example.com", to, f"Subject: {subject}\n\n{body}")
server.quit()
return True
class UserService:
def __init__(self, email_service: EmailService):
self.email_service = email_service
self.users: dict[str, dict] = {}
def register(self, username: str, email: str) -> dict:
if username in self.users:
raise ValueError(f"用户 '{username}' 已存在")
user = {"username": username, "email": email}
self.users[username] = user
self.email_service.send_email(
to=email,
subject="欢迎注册",
body=f"你好 {username},欢迎加入!",
)
return user
class TestUserServiceWithMock:
def test_register_success(self):
mock_email = Mock()
mock_email.send_email.return_value = True
service = UserService(mock_email)
user = service.register("alice", "alice@example.com")
assert user["username"] == "alice"
assert user["email"] == "alice@example.com"
mock_email.send_email.assert_called_once_with(
to="alice@example.com",
subject="欢迎注册",
body="你好 alice,欢迎加入!",
)
def test_register_duplicate_user(self):
mock_email = Mock()
service = UserService(mock_email)
service.register("alice", "alice@example.com")
with pytest.raises(ValueError, match="已存在"):
service.register("alice", "alice2@example.com")
mock_email.send_email.assert_called_once()
def test_register_email_failure(self):
mock_email = Mock()
mock_email.send_email.side_effect = ConnectionError("SMTP连接失败")
service = UserService(mock_email)
with pytest.raises(ConnectionError):
service.register("alice", "alice@example.com")
class TestWithPatch:
@patch("smtplib.SMTP")
def test_send_email(self, mock_smtp_class):
mock_instance = MagicMock()
mock_smtp_class.return_value = mock_instance
service = EmailService()
result = service.send_email("test@example.com", "测试", "内容")
assert result is True
mock_smtp_class.assert_called_once_with("smtp.example.com", 587)
mock_instance.sendmail.assert_called_once()
mock_instance.quit.assert_called_once()
@patch("builtins.open", create=True)
def test_file_read(self, mock_open):
mock_open.return_value.__enter__.return_value.read.return_value = "test content"
with open("test.txt") as f:
content = f.read()
assert content == "test content"
class TestCallAssertions:
def test_multiple_calls(self):
mock = Mock()
mock(1)
mock(2)
mock(3)
assert mock.call_count == 3
assert mock.call_args_list == [call(1), call(2), call(3)]
mock.assert_has_calls([call(1), call(3)], any_order=True)22.3.2 pytest-mock
python
def test_with_mocker(mocker):
mock_requests = mocker.patch("requests.get")
mock_requests.return_value.json.return_value = {"status": "ok"}
import requests
response = requests.get("https://api.example.com/status")
assert response.json() == {"status": "ok"}
mock_requests.assert_called_once_with("https://api.example.com/status")
def test_spy_with_mocker(mocker):
original_list = [1, 2, 3]
spy = mocker.spy(original_list, "append")
original_list.append(4)
spy.assert_called_once_with(4)
assert original_list == [1, 2, 3, 4]22.4 集成测试
22.4.1 Flask应用测试
python
import pytest
from flask import Flask, jsonify, request
def create_app():
app = Flask(__name__)
@app.route("/api/users", methods=["GET"])
def get_users():
return jsonify({"users": [{"id": 1, "name": "Alice"}]})
@app.route("/api/users", methods=["POST"])
def create_user():
data = request.get_json()
if not data or "name" not in data:
return jsonify({"error": "name is required"}), 400
return jsonify({"id": 2, "name": data["name"]}), 201
return app
@pytest.fixture
def app():
app = create_app()
app.config["TESTING"] = True
return app
@pytest.fixture
def client(app):
return app.test_client()
class TestUserAPI:
def test_get_users(self, client):
response = client.get("/api/users")
assert response.status_code == 200
data = response.get_json()
assert "users" in data
assert len(data["users"]) > 0
def test_create_user(self, client):
response = client.post(
"/api/users",
json={"name": "Bob"},
content_type="application/json",
)
assert response.status_code == 201
data = response.get_json()
assert data["name"] == "Bob"
def test_create_user_missing_name(self, client):
response = client.post(
"/api/users",
json={},
content_type="application/json",
)
assert response.status_code == 40022.4.2 数据库集成测试
python
import pytest
import sqlite3
from contextlib import contextmanager
class UserRepository:
def __init__(self, conn: sqlite3.Connection):
self.conn = conn
def create_table(self):
self.conn.execute("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
email TEXT UNIQUE NOT NULL
)
""")
self.conn.commit()
def add_user(self, username: str, email: str) -> int:
cursor = self.conn.execute(
"INSERT INTO users (username, email) VALUES (?, ?)",
(username, email),
)
self.conn.commit()
return cursor.lastrowid
def get_user(self, user_id: int) -> dict | None:
cursor = self.conn.execute(
"SELECT id, username, email FROM users WHERE id = ?", (user_id,),
)
row = cursor.fetchone()
if row:
return {"id": row[0], "username": row[1], "email": row[2]}
return None
def get_all_users(self) -> list[dict]:
cursor = self.conn.execute("SELECT id, username, email FROM users")
return [{"id": r[0], "username": r[1], "email": r[2]} for r in cursor.fetchall()]
@pytest.fixture
def db_conn():
conn = sqlite3.connect(":memory:")
conn.row_factory = None
yield conn
conn.close()
@pytest.fixture
def user_repo(db_conn):
repo = UserRepository(db_conn)
repo.create_table()
return repo
class TestUserRepository:
def test_add_user(self, user_repo):
user_id = user_repo.add_user("alice", "alice@example.com")
assert user_id == 1
def test_get_user(self, user_repo):
user_id = user_repo.add_user("alice", "alice@example.com")
user = user_repo.get_user(user_id)
assert user is not None
assert user["username"] == "alice"
assert user["email"] == "alice@example.com"
def test_get_nonexistent_user(self, user_repo):
user = user_repo.get_user(999)
assert user is None
def test_get_all_users(self, user_repo):
user_repo.add_user("alice", "alice@example.com")
user_repo.add_user("bob", "bob@example.com")
users = user_repo.get_all_users()
assert len(users) == 2
assert users[0]["username"] == "alice"
assert users[1]["username"] == "bob"
def test_unique_username_constraint(self, user_repo):
user_repo.add_user("alice", "alice@example.com")
with pytest.raises(sqlite3.IntegrityError):
user_repo.add_user("alice", "alice2@example.com")22.5 性能测试
22.5.1 pytest-benchmark
python
def fibonacci_recursive(n: int) -> int:
if n <= 1:
return n
return fibonacci_recursive(n - 1) + fibonacci_recursive(n - 2)
def fibonacci_memo(n: int, cache: dict = None) -> int:
if cache is None:
cache = {0: 0, 1: 1}
if n not in cache:
cache[n] = fibonacci_memo(n - 1, cache) + fibonacci_memo(n - 2, cache)
return cache[n]
def fibonacci_iterative(n: int) -> int:
if n <= 1:
return n
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return b
def test_fibonacci_correctness():
assert fibonacci_iterative(10) == 55
assert fibonacci_memo(10) == 55
def test_fibonacci_benchmark(benchmark):
benchmark(fibonacci_iterative, 30)
def test_fibonacci_compare(benchmark):
benchmark(fibonacci_memo, 30)22.5.2 cProfile分析
python
import cProfile
import pstats
import io
def profile_function(func, *args, **kwargs):
profiler = cProfile.Profile()
profiler.enable()
result = func(*args, **kwargs)
profiler.disable()
stream = io.StringIO()
stats = pstats.Stats(profiler, stream=stream)
stats.sort_stats("cumulative")
stats.print_stats(20)
print(stream.getvalue())
return result
def analyze_data():
data = list(range(100000))
_ = [x ** 2 for x in data]
_ = sum(data)
_ = sorted(data, reverse=True)
if __name__ == "__main__":
profile_function(analyze_data)22.6 调试技术
22.6.1 pdb调试器
python
import pdb
def calculate_discount(price: float, discount_rate: float, min_price: float = 0.0) -> float:
discounted = price * (1 - discount_rate)
result = max(discounted, min_price)
return result
def process_order(order: dict) -> dict:
subtotal = 0.0
for item in order["items"]:
price = item["price"]
qty = item["quantity"]
line_total = price * qty
subtotal += line_total
discount_rate = order.get("discount", 0.0)
total = calculate_discount(subtotal, discount_rate)
return {"subtotal": subtotal, "total": total}
order = {
"items": [
{"name": "Python书", "price": 89.00, "quantity": 2},
{"name": "键盘", "price": 399.00, "quantity": 1},
],
"discount": 0.1,
}
result = process_order(order)
print(result)pdb常用命令:
| 命令 | 缩写 | 说明 |
|---|---|---|
next | n | 执行下一行(不进入函数) |
step | s | 单步执行(进入函数) |
continue | c | 继续执行到下一个断点 |
break | b | 设置断点 |
print | p | 打印变量值 |
pp | 美化打印 | |
list | l | 显示源代码 |
where | w | 显示调用栈 |
up/down | u/d | 在调用栈中上下移动 |
quit | q | 退出调试 |
22.6.2 结构化日志
python
import logging
import logging.config
from functools import wraps
from datetime import datetime
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"detailed": {
"format": "%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
"simple": {
"format": "[%(levelname)s] %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "simple",
"stream": "ext://sys.stdout",
},
"file": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "detailed",
"filename": "app.log",
"maxBytes": 10485760,
"backupCount": 5,
},
},
"loggers": {
"app": {
"level": "DEBUG",
"handlers": ["console", "file"],
"propagate": False,
},
},
"root": {
"level": "WARNING",
"handlers": ["console"],
},
}
logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger("app")
def log_execution(func):
@wraps(func)
def wrapper(*args, **kwargs):
func_name = func.__qualname__
logger.debug(f"→ {func_name} args={args[1:]} kwargs={kwargs}")
start = datetime.now()
try:
result = func(*args, **kwargs)
elapsed = (datetime.now() - start).total_seconds()
logger.debug(f"← {func_name} returned in {elapsed:.4f}s")
return result
except Exception as e:
elapsed = (datetime.now() - start).total_seconds()
logger.error(f"✗ {func_name} failed in {elapsed:.4f}s: {e}")
raise
return wrapper
@log_execution
def process_payment(amount: float, currency: str = "CNY") -> dict:
if amount <= 0:
raise ValueError(f"无效金额: {amount}")
logger.info(f"处理支付: {amount} {currency}")
return {"status": "success", "amount": amount, "currency": currency}22.7 测试覆盖率
22.7.1 coverage.py
bash
pip install pytest-cov
pytest --cov=src --cov-report=html --cov-report=term-missingini
[tool.coverage.run]
source = ["src"]
omit = ["tests/*", "*/migrations/*"]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"if __name__ == .__main__.:",
"raise NotImplementedError",
"pass",
]
fail_under = 8022.7.2 覆盖率实践
python
class Calculator:
def add(self, a: float, b: float) -> float:
return a + b
def divide(self, a: float, b: float) -> float:
if b == 0:
raise ZeroDivisionError("除数不能为零")
return a / b
def factorial(self, n: int) -> int:
if n < 0:
raise ValueError("负数没有阶乘")
if n <= 1:
return 1
return n * self.factorial(n - 1)
class TestCalculator:
def test_add(self):
calc = Calculator()
assert calc.add(1, 2) == 3
def test_divide(self):
calc = Calculator()
assert calc.divide(6, 3) == 2.0
def test_divide_by_zero(self):
calc = Calculator()
with pytest.raises(ZeroDivisionError):
calc.divide(1, 0)
def test_factorial(self):
calc = Calculator()
assert calc.factorial(5) == 120
assert calc.factorial(0) == 1
def test_factorial_negative(self):
calc = Calculator()
with pytest.raises(ValueError):
calc.factorial(-1)22.8 前沿技术动态
22.8.1 现代测试工具
- pytest-asyncio:异步代码测试
- hypothesis:基于属性的测试(Property-Based Testing)
- mutmut:变异测试(Mutation Testing)
- pytest-cov:覆盖率集成
- allure-pytest:测试报告生成
22.8.2 基于属性的测试
python
from hypothesis import given, strategies as st, settings
@given(st.integers(), st.integers())
def test_add_commutative(a, b):
assert Calculator().add(a, b) == Calculator().add(b, a)
@given(st.integers(min_value=0, max_value=20))
def test_factorial_positive(n):
result = Calculator().factorial(n)
assert result >= 1
@given(st.integers(min_value=1, max_value=100), st.integers(min_value=1, max_value=100))
def test_divide_multiplication_inverse(a, b):
calc = Calculator()
assert calc.divide(a * b, b) == pytest.approx(float(a))22.9 本章小结
本章系统阐述了测试驱动开发的核心知识体系:
- TDD方法论:红-绿-重构循环与测试金字塔
- pytest框架:夹具系统、参数化、标记与配置
- Mock与桩代码:unittest.mock、pytest-mock与依赖隔离
- 集成测试:Flask应用测试与数据库测试
- 性能测试:pytest-benchmark与cProfile分析
- 调试技术:pdb调试器与结构化日志
- 测试覆盖率:coverage.py与覆盖率最佳实践
- 前沿工具:基于属性的测试与变异测试
22.10 习题与项目练习
基础题
使用TDD方式开发一个
Stack类,要求先写测试再写实现,覆盖push、pop、peek、is_empty、size等方法。为一个
StringCalculator类编写pytest测试,要求支持:空字符串返回0、单个数字、逗号分隔、换行分隔、自定义分隔符。使用Mock对象测试一个发送HTTP请求的函数,不实际发送网络请求。
进阶题
为Flask应用编写完整的测试套件,包含单元测试、集成测试和API端点测试,使用夹具管理测试数据库。
使用hypothesis编写基于属性的测试,验证排序算法的正确性(幂等性、稳定性、长度保持)。
实现一个日志装饰器,支持函数调用追踪、异常捕获和性能计时,编写测试验证其行为。
综合项目
电商系统测试套件:为一个电商系统编写完整的测试套件,包含:
- 用户注册/登录的单元测试
- 购物车逻辑的参数化测试
- 订单处理的集成测试(含数据库)
- 支付服务的Mock测试
- API端点的端到端测试
- 覆盖率目标 ≥ 90%
CI/CD测试流水线:配置一个完整的测试流水线,包含:
- pytest配置与conftest.py
- 单元测试/集成测试/端到端测试分层
- 覆盖率报告与阈值检查
- GitHub Actions工作流配置
- 测试报告生成(Allure)
思考题
在TDD中,如何决定测试的粒度?过度测试与测试不足各有什么问题?请结合测试金字塔理论分析。
Mock对象在测试中可能带来哪些问题?如何避免"Mock过度"导致的测试脆弱性?请讨论何时应该使用真实依赖而非Mock。
22.11 延伸阅读
22.11.1 测试理论
- 《Test Driven Development》 (Kent Beck) — TDD经典著作
- 《xUnit Test Patterns》 (Gerard Meszaros) — 单元测试模式
- 《Growing Object-Oriented Software, Guided by Tests》 — 测试驱动开发实践
22.11.2 pytest生态
- pytest官方文档 (https://docs.pytest.org/) — pytest权威指南
- pytest插件索引 (https://docs.pytest.org/en/latest/reference/plugin_list.html) — 官方插件列表
- hypothesis (https://hypothesis.readthedocs.io/) — 基于属性的测试
22.11.3 测试工具
- coverage.py (https://coverage.readthedocs.io/) — 代码覆盖率
- pytest-cov (https://pytest-cov.readthedocs.io/) — pytest覆盖率插件
- mutmut (https://mutmut.readthedocs.io/) — 变异测试
- Allure Report (https://docs.qameta.io/allure/) — 测试报告
22.11.4 调试与日志
- Python调试器 (https://docs.python.org/3/library/pdb.html) — pdb文档
- logging模块 (https://docs.python.org/3/library/logging.html) — 日志模块
- structlog (https://www.structlog.org/) — 结构化日志
下一章:第23章 版本控制与协作