第41章 测试进阶
学习目标
完成本章学习后,你将能够:
- 掌握pytest高级特性:fixture、参数化、标记、插件
- 使用mock进行模拟测试:Mock对象、patch、副作用
- 实现测试夹具:fixture作用域、依赖注入、工厂模式
- 进行参数化测试:数据驱动测试、组合测试、条件跳过
- 测量测试覆盖率:coverage.py、覆盖率报告、分支覆盖
- 编写集成测试:数据库测试、API测试、端到端测试
- 实现测试替身:Stub、Spy、Fake、Dummy
- 构建测试流水线:CI/CD集成、测试报告、质量门禁
41.1 pytest高级特性
41.1.1 Fixture深入
python
import pytest
from typing import Any, Callable, Dict, Generator, List, Optional
from dataclasses import dataclass
from functools import wraps
import tempfile
import os
@dataclass
class Database:
name: str
connection_string: str
is_connected: bool = False
def connect(self) -> None:
self.is_connected = True
def disconnect(self) -> None:
self.is_connected = False
def execute(self, query: str) -> List[Dict]:
if not self.is_connected:
raise RuntimeError("Database not connected")
return []
def insert(self, table: str, data: Dict) -> int:
return 1
@pytest.fixture
def database():
db = Database(
name="test_db",
connection_string="sqlite:///:memory:"
)
db.connect()
yield db
db.disconnect()
@pytest.fixture(scope="session")
def session_database():
db = Database(
name="session_db",
connection_string="sqlite:///:memory:"
)
db.connect()
yield db
db.disconnect()
@pytest.fixture(scope="module")
def module_config():
return {"debug": True, "timeout": 30}
@pytest.fixture(scope="class")
def class_resource():
resource = {"data": [], "initialized": True}
yield resource
resource["data"].clear()
@pytest.fixture
def sample_users():
return [
{"id": 1, "name": "Alice", "email": "alice@example.com"},
{"id": 2, "name": "Bob", "email": "bob@example.com"},
{"id": 3, "name": "Charlie", "email": "charlie@example.com"}
]
@pytest.fixture
def temp_file():
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("test content")
temp_path = f.name
yield temp_path
if os.path.exists(temp_path):
os.unlink(temp_path)
@pytest.fixture
def temp_directory():
temp_dir = tempfile.mkdtemp()
yield temp_dir
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture(params=[1, 2, 3, 4, 5])
def number(request):
return request.param
@pytest.fixture(params=["admin", "user", "guest"])
def user_role(request):
return request.param
class FixtureFactory:
@staticmethod
@pytest.fixture
def create_user():
def _create_user(name: str, email: str, role: str = "user") -> Dict:
return {
"id": 1,
"name": name,
"email": email,
"role": role
}
return _create_user
@staticmethod
@pytest.fixture
def create_product():
def _create_product(name: str, price: float, stock: int = 0) -> Dict:
return {
"id": 1,
"name": name,
"price": price,
"stock": stock
}
return _create_product
@pytest.fixture
def mock_api_response():
def _mock_response(status_code: int = 200, data: Dict = None, error: str = None):
return {
"status_code": status_code,
"data": data or {},
"error": error
}
return _mock_response
def test_database_connection(database):
assert database.is_connected
result = database.execute("SELECT * FROM users")
assert isinstance(result, list)
def test_with_sample_users(sample_users):
assert len(sample_users) == 3
assert sample_users[0]["name"] == "Alice"
def test_with_temp_file(temp_file):
with open(temp_file, "r") as f:
content = f.read()
assert content == "test content"
def test_parametrized_fixture(number):
assert number in [1, 2, 3, 4, 5]
def test_user_roles(user_role):
assert user_role in ["admin", "user", "guest"]
def test_factory_fixture(create_user):
user = create_user("Test User", "test@example.com", "admin")
assert user["name"] == "Test User"
assert user["role"] == "admin"
class TestWithClassFixture:
def test_class_resource_1(self, class_resource):
class_resource["data"].append("item1")
assert len(class_resource["data"]) == 1
def test_class_resource_2(self, class_resource):
class_resource["data"].append("item2")
assert len(class_resource["data"]) == 241.1.2 参数化测试
python
import pytest
from typing import List, Tuple, Any
from dataclasses import dataclass
@dataclass
class TestCase:
input_value: Any
expected_output: Any
description: str = ""
@pytest.mark.parametrize("value,expected", [
(1, 2),
(2, 4),
(3, 6),
(4, 8),
(5, 10)
])
def test_double(value, expected):
assert value * 2 == expected
@pytest.mark.parametrize("a,b,expected", [
(1, 2, 3),
(5, 5, 10),
(-1, 1, 0),
(0, 0, 0),
(100, -50, 50)
])
def test_addition(a, b, expected):
assert a + b == expected
@pytest.mark.parametrize("input_str,expected_length", [
("hello", 5),
("", 0),
("a", 1),
("12345", 5),
("hello world", 11)
])
def test_string_length(input_str, expected_length):
assert len(input_str) == expected_length
@pytest.mark.parametrize("value,should_be_positive", [
(1, True),
(-1, False),
(0, False),
(100, True),
(-100, False)
])
def test_is_positive(value, should_be_positive):
result = value > 0
assert result == should_be_positive
class TestParametrized:
@pytest.mark.parametrize("x,y", [
(1, 1),
(2, 2),
(3, 3)
])
def test_equality(self, x, y):
assert x == y
@pytest.mark.parametrize("value", [1, 2, 3])
@pytest.mark.parametrize("multiplier", [1, 2, 3])
def test_multiplication(self, value, multiplier):
result = value * multiplier
assert result == value * multiplier
@pytest.mark.parametrize("test_case", [
TestCase(input_value=1, expected_output=2, description="positive"),
TestCase(input_value=-1, expected_output=-2, description="negative"),
TestCase(input_value=0, expected_output=0, description="zero")
])
def test_with_test_case(test_case):
assert test_case.input_value * 2 == test_case.expected_output
def generate_test_cases() -> List[Tuple]:
return [
(i, i ** 2) for i in range(1, 6)
]
@pytest.mark.parametrize("value,expected", generate_test_cases())
def test_square_generated(value, expected):
assert value ** 2 == expected
@pytest.mark.parametrize("value,expected", [
pytest.param(1, 2, id="positive"),
pytest.param(-1, -2, id="negative"),
pytest.param(0, 0, id="zero")
])
def test_with_ids(value, expected):
assert value * 2 == expected
@pytest.mark.parametrize("value,expected", [
(1, 2),
pytest.param(2, 4, marks=pytest.mark.skip("Not implemented")),
(3, 6)
])
def test_with_skip(value, expected):
assert value * 2 == expected
@pytest.mark.parametrize("value,expected", [
(1, 2),
pytest.param(2, 5, marks=pytest.mark.xfail(reason="Expected to fail")),
(3, 6)
])
def test_with_xfail(value, expected):
assert value * 2 == expected41.2 Mock与模拟测试
41.2.1 Mock对象
python
from unittest.mock import Mock, MagicMock, patch, call, ANY
from typing import Any, Callable, Dict, List, Optional
from dataclasses import dataclass
import pytest
class UserService:
def __init__(self, database, email_service):
self.database = database
self.email_service = email_service
def get_user(self, user_id: int) -> Optional[Dict]:
return self.database.find_user(user_id)
def create_user(self, name: str, email: str) -> Dict:
if self.database.user_exists(email):
raise ValueError("User already exists")
user = self.database.create_user(name, email)
self.email_service.send_welcome(email)
return user
def delete_user(self, user_id: int) -> bool:
user = self.get_user(user_id)
if not user:
return False
self.database.delete_user(user_id)
self.email_service.send_farewell(user["email"])
return True
class TestUserServiceWithMock:
def test_get_user_success(self):
mock_db = Mock()
mock_db.find_user.return_value = {"id": 1, "name": "Alice"}
service = UserService(mock_db, Mock())
result = service.get_user(1)
mock_db.find_user.assert_called_once_with(1)
assert result["name"] == "Alice"
def test_create_user_success(self):
mock_db = Mock()
mock_db.user_exists.return_value = False
mock_db.create_user.return_value = {"id": 1, "name": "Bob", "email": "bob@example.com"}
mock_email = Mock()
service = UserService(mock_db, mock_email)
result = service.create_user("Bob", "bob@example.com")
mock_db.user_exists.assert_called_once_with("bob@example.com")
mock_db.create_user.assert_called_once_with("Bob", "bob@example.com")
mock_email.send_welcome.assert_called_once_with("bob@example.com")
assert result["id"] == 1
def test_create_user_already_exists(self):
mock_db = Mock()
mock_db.user_exists.return_value = True
service = UserService(mock_db, Mock())
with pytest.raises(ValueError, match="User already exists"):
service.create_user("Bob", "bob@example.com")
def test_delete_user_success(self):
mock_db = Mock()
mock_db.find_user.return_value = {"id": 1, "name": "Alice", "email": "alice@example.com"}
mock_email = Mock()
service = UserService(mock_db, mock_email)
result = service.delete_user(1)
assert result is True
mock_db.delete_user.assert_called_once_with(1)
mock_email.send_farewell.assert_called_once_with("alice@example.com")
def test_delete_user_not_found(self):
mock_db = Mock()
mock_db.find_user.return_value = None
service = UserService(mock_db, Mock())
result = service.delete_user(999)
assert result is False
mock_db.delete_user.assert_not_called()
class TestMockAdvanced:
def test_mock_with_side_effect(self):
mock = Mock()
mock.side_effect = [1, 2, 3]
assert mock() == 1
assert mock() == 2
assert mock() == 3
def test_mock_with_exception_side_effect(self):
mock = Mock()
mock.side_effect = ValueError("Test error")
with pytest.raises(ValueError, match="Test error"):
mock()
def test_mock_with_function_side_effect(self):
mock = Mock()
mock.side_effect = lambda x: x * 2
assert mock(5) == 10
assert mock(3) == 6
def test_mock_attributes(self):
mock = Mock()
mock.name = "test"
mock.value = 42
assert mock.name == "test"
assert mock.value == 42
def test_mock_call_count(self):
mock = Mock()
mock(1)
mock(2)
mock(3)
assert mock.call_count == 3
def test_mock_call_args_list(self):
mock = Mock()
mock(1, 2)
mock(3, 4)
mock(5, 6)
assert mock.call_args_list == [call(1, 2), call(3, 4), call(5, 6)]
def test_mock_assert_called_with(self):
mock = Mock()
mock(1, 2, key="value")
mock.assert_called_with(1, 2, key="value")
def test_mock_assert_any_call(self):
mock = Mock()
mock(1)
mock(2)
mock(3)
mock.assert_any_call(2)
def test_mock_reset_mock(self):
mock = Mock()
mock(1)
mock(2)
mock.reset_mock()
assert mock.call_count == 0
assert mock.call_args_list == []
def test_magic_mock(self):
mock = MagicMock()
mock.__str__.return_value = "Custom String"
mock.__len__.return_value = 10
assert str(mock) == "Custom String"
assert len(mock) == 10
def test_mock_any_matcher(self):
mock = Mock()
mock.process(1, 2, 3)
mock.process.assert_called_with(ANY, ANY, 3)
class TestWithPatch:
@patch("builtins.open")
def test_patch_decorator(self, mock_open):
mock_open.return_value.__enter__.return_value.read.return_value = "test content"
with open("test.txt", "r") as f:
content = f.read()
assert content == "test content"
def test_patch_context_manager(self):
with patch("builtins.open") as mock_open:
mock_open.return_value.__enter__.return_value.read.return_value = "content"
with open("test.txt", "r") as f:
content = f.read()
assert content == "content"
@patch("os.path.exists")
def test_patch_os(self, mock_exists):
mock_exists.return_value = True
import os
result = os.path.exists("/some/path")
assert result is True
mock_exists.assert_called_once_with("/some/path")41.2.2 测试替身
python
from typing import Any, Callable, Dict, List, Optional
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
class Stub:
def __init__(self, responses: Dict[str, Any] = None):
self._responses = responses or {}
self._calls: List[Dict] = []
def get_response(self, method: str, *args, **kwargs) -> Any:
self._calls.append({
"method": method,
"args": args,
"kwargs": kwargs
})
return self._responses.get(method)
def get_calls(self) -> List[Dict]:
return self._calls.copy()
def was_called(self, method: str) -> bool:
return any(call["method"] == method for call in self._calls)
class Spy:
def __init__(self, target: Any = None):
self._target = target
self._calls: List[Dict] = []
self._return_values: Dict[str, Any] = {}
def __getattr__(self, name: str) -> Callable:
def method_wrapper(*args, **kwargs):
self._calls.append({
"method": name,
"args": args,
"kwargs": kwargs
})
if name in self._return_values:
return self._return_values[name]
if self._target and hasattr(self._target, name):
return getattr(self._target, name)(*args, **kwargs)
return None
return method_wrapper
def set_return_value(self, method: str, value: Any) -> None:
self._return_values[method] = value
def get_call_count(self, method: str) -> int:
return sum(1 for call in self._calls if call["method"] == method)
def was_called_with(self, method: str, *args, **kwargs) -> bool:
for call in self._calls:
if call["method"] == method:
if call["args"] == args and call["kwargs"] == kwargs:
return True
return False
class Fake:
def __init__(self):
self._data: Dict[str, List] = {}
def create(self, table: str, data: Dict) -> int:
if table not in self._data:
self._data[table] = []
data["id"] = len(self._data[table]) + 1
self._data[table].append(data)
return data["id"]
def read(self, table: str, id: int) -> Optional[Dict]:
if table not in self._data:
return None
for item in self._data[table]:
if item.get("id") == id:
return item.copy()
return None
def update(self, table: str, id: int, data: Dict) -> bool:
if table not in self._data:
return False
for i, item in enumerate(self._data[table]):
if item.get("id") == id:
self._data[table][i].update(data)
return True
return False
def delete(self, table: str, id: int) -> bool:
if table not in self._data:
return False
for i, item in enumerate(self._data[table]):
if item.get("id") == id:
del self._data[table][i]
return True
return False
def list_all(self, table: str) -> List[Dict]:
return self._data.get(table, []).copy()
class Dummy:
pass
class TestStubs:
def test_stub_responses(self):
stub = Stub({
"get_user": {"id": 1, "name": "Alice"},
"get_product": {"id": 1, "name": "Widget"}
})
user = stub.get_response("get_user")
product = stub.get_response("get_product")
assert user["name"] == "Alice"
assert product["name"] == "Widget"
def test_stub_records_calls(self):
stub = Stub({"process": "done"})
stub.get_response("process", 1, 2)
stub.get_response("process", 3, 4)
calls = stub.get_calls()
assert len(calls) == 2
assert stub.was_called("process")
class TestSpies:
def test_spy_records_calls(self):
spy = Spy()
spy.method1(1, 2, key="value")
spy.method2("test")
assert spy.get_call_count("method1") == 1
assert spy.get_call_count("method2") == 1
assert spy.was_called_with("method1", 1, 2, key="value")
def test_spy_with_return_values(self):
spy = Spy()
spy.set_return_value("get_value", 42)
result = spy.get_value()
assert result == 42
class TestFakes:
def test_fake_crud_operations(self):
fake_db = Fake()
user_id = fake_db.create("users", {"name": "Alice", "email": "alice@example.com"})
assert user_id == 1
user = fake_db.read("users", user_id)
assert user["name"] == "Alice"
fake_db.update("users", user_id, {"name": "Alice Updated"})
updated_user = fake_db.read("users", user_id)
assert updated_user["name"] == "Alice Updated"
deleted = fake_db.delete("users", user_id)
assert deleted is True
deleted_user = fake_db.read("users", user_id)
assert deleted_user is None
def test_fake_list_all(self):
fake_db = Fake()
fake_db.create("products", {"name": "Widget", "price": 10})
fake_db.create("products", {"name": "Gadget", "price": 20})
products = fake_db.list_all("products")
assert len(products) == 241.3 测试覆盖率
41.3.1 Coverage.py使用
python
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
@dataclass
class CoverageReport:
total_lines: int
covered_lines: int
missed_lines: int
coverage_percent: float
file_reports: Dict[str, Dict]
class CoverageAnalyzer:
def __init__(self):
self._covered_lines: Dict[str, set] = {}
self._total_lines: Dict[str, set] = {}
self._excluded_files: set = set()
def add_covered_line(self, filename: str, line_number: int) -> None:
if filename not in self._covered_lines:
self._covered_lines[filename] = set()
self._covered_lines[filename].add(line_number)
def add_total_line(self, filename: str, line_number: int) -> None:
if filename not in self._total_lines:
self._total_lines[filename] = set()
self._total_lines[filename].add(line_number)
def exclude_file(self, filename: str) -> None:
self._excluded_files.add(filename)
def get_file_coverage(self, filename: str) -> Dict:
if filename in self._excluded_files:
return {"coverage": 100.0, "covered": 0, "total": 0}
covered = len(self._covered_lines.get(filename, set()))
total = len(self._total_lines.get(filename, set()))
if total == 0:
return {"coverage": 100.0, "covered": 0, "total": 0}
return {
"coverage": (covered / total) * 100,
"covered": covered,
"total": total,
"missed": total - covered
}
def get_total_coverage(self) -> CoverageReport:
total_covered = 0
total_lines = 0
file_reports = {}
for filename in self._total_lines:
if filename in self._excluded_files:
continue
file_report = self.get_file_coverage(filename)
file_reports[filename] = file_report
total_covered += file_report["covered"]
total_lines += file_report["total"]
if total_lines == 0:
return CoverageReport(
total_lines=0,
covered_lines=0,
missed_lines=0,
coverage_percent=100.0,
file_reports=file_reports
)
return CoverageReport(
total_lines=total_lines,
covered_lines=total_covered,
missed_lines=total_lines - total_covered,
coverage_percent=(total_covered / total_lines) * 100,
file_reports=file_reports
)
def generate_report(self, output_format: str = "text") -> str:
report = self.get_total_coverage()
if output_format == "text":
lines = [
"Coverage Report",
"=" * 50,
f"Total Lines: {report.total_lines}",
f"Covered Lines: {report.covered_lines}",
f"Missed Lines: {report.missed_lines}",
f"Coverage: {report.coverage_percent:.2f}%",
"",
"File Details:",
"-" * 50
]
for filename, file_report in report.file_reports.items():
lines.append(
f"{filename}: {file_report['coverage']:.2f}% "
f"({file_report['covered']}/{file_report['total']})"
)
return "\n".join(lines)
return ""
class BranchCoverage:
def __init__(self):
self._branches: Dict[str, Dict[int, Dict[str, bool]]] = {}
def add_branch(self, filename: str, line: int, branch_id: str, taken: bool) -> None:
if filename not in self._branches:
self._branches[filename] = {}
if line not in self._branches[filename]:
self._branches[filename][line] = {}
self._branches[filename][line][branch_id] = taken
def get_branch_coverage(self, filename: str) -> Dict:
if filename not in self._branches:
return {"total": 0, "covered": 0, "coverage": 100.0}
total = 0
covered = 0
for line_branches in self._branches[filename].values():
for taken in line_branches.values():
total += 1
if taken:
covered += 1
coverage = (covered / total * 100) if total > 0 else 100.0
return {"total": total, "covered": covered, "coverage": coverage}41.4 集成测试
41.4.1 数据库测试
python
import pytest
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
import tempfile
import os
@dataclass
class TestDatabase:
path: str
tables: Dict[str, List[Dict]]
@classmethod
def create(cls) -> "TestDatabase":
return cls(
path=tempfile.mktemp(suffix=".db"),
tables={}
)
def insert(self, table: str, data: Dict) -> int:
if table not in self.tables:
self.tables[table] = []
data["id"] = len(self.tables[table]) + 1
self.tables[table].append(data)
return data["id"]
def find(self, table: str, id: int) -> Optional[Dict]:
if table not in self.tables:
return None
for item in self.tables[table]:
if item.get("id") == id:
return item.copy()
return None
def find_all(self, table: str) -> List[Dict]:
return self.tables.get(table, []).copy()
def delete(self, table: str, id: int) -> bool:
if table not in self.tables:
return False
for i, item in enumerate(self.tables[table]):
if item.get("id") == id:
del self.tables[table][i]
return True
return False
def cleanup(self) -> None:
self.tables.clear()
if os.path.exists(self.path):
os.unlink(self.path)
@pytest.fixture
def test_db():
db = TestDatabase.create()
yield db
db.cleanup()
class TestDatabaseIntegration:
def test_insert_and_find(self, test_db):
user_id = test_db.insert("users", {"name": "Alice", "email": "alice@example.com"})
user = test_db.find("users", user_id)
assert user is not None
assert user["name"] == "Alice"
assert user["email"] == "alice@example.com"
def test_find_all(self, test_db):
test_db.insert("users", {"name": "Alice"})
test_db.insert("users", {"name": "Bob"})
test_db.insert("users", {"name": "Charlie"})
users = test_db.find_all("users")
assert len(users) == 3
assert users[0]["name"] == "Alice"
assert users[1]["name"] == "Bob"
assert users[2]["name"] == "Charlie"
def test_delete(self, test_db):
user_id = test_db.insert("users", {"name": "Alice"})
deleted = test_db.delete("users", user_id)
assert deleted is True
user = test_db.find("users", user_id)
assert user is None
def test_find_nonexistent(self, test_db):
user = test_db.find("users", 999)
assert user is None
class TestAPIClient:
def __init__(self, base_url: str):
self.base_url = base_url
self._responses: Dict[str, Any] = {}
def set_response(self, endpoint: str, response: Any) -> None:
self._responses[endpoint] = response
def get(self, endpoint: str) -> Any:
return self._responses.get(endpoint)
def post(self, endpoint: str, data: Dict) -> Any:
response = self._responses.get(endpoint)
if callable(response):
return response(data)
return response
@pytest.fixture
def api_client():
return TestAPIClient("http://test-api")
class TestAPIIntegration:
def test_get_users(self, api_client):
api_client.set_response("/users", [
{"id": 1, "name": "Alice"},
{"id": 2, "name": "Bob"}
])
users = api_client.get("/users")
assert len(users) == 2
assert users[0]["name"] == "Alice"
def test_create_user(self, api_client):
def create_user(data):
return {"id": 1, **data}
api_client.set_response("/users", create_user)
result = api_client.post("/users", {"name": "Charlie", "email": "charlie@example.com"})
assert result["id"] == 1
assert result["name"] == "Charlie"41.5 知识图谱
41.5.1 测试金字塔架构
┌─────────────────────────────────────────────────────────────────────┐
│ 软件测试金字塔 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ /\ │
│ / \ │
│ / E2E\ 端到端测试 (10%) │
│ / Test \ - 用户流程 │
│ /--------\ - UI自动化 │
│ / \ │
│ / Integration\ 集成测试 (20%) │
│ / Test \ - API测试 │
│ /----------------\ - 数据库测试 │
│ / \ │
│ / Unit Test \ 单元测试 (70%) │
│ / \ - 函数测试 │
│ /------------------------\ - 类测试 │
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 测试工具链 │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ pytest │ │ unittest │ │ coverage │ │ mock │ │ │
│ │ │ 主框架 │ │ 标准库 │ │ 覆盖率 │ │ 模拟 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ pytest-cov│ │ pytest-mock│ │hypothesis│ │ faker │ │ │
│ │ │ 覆盖插件 │ │ mock插件 │ │ 属性测试 │ │ 测试数据 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘41.5.2 测试执行流程
┌─────────────────────────────────────────────────────────────────────┐
│ 测试执行生命周期 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 1. 测试发现 (Discovery) │ │
│ │ - 扫描测试文件 (test_*.py, *_test.py) │ │
│ │ - 收集测试类和测试函数 │ │
│ │ - 解析参数化测试用例 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 2. Fixture准备 (Setup) │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ session │ │ module │ │ class │ │ │
│ │ │ 会话级别 │ │ 模块级别 │ │ 类级别 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ ┌──────────┐ ┌──────────┐ │ │
│ │ │ function │ │ 依赖注入 │ │ │
│ │ │ 函数级别 │ │ 自动处理 │ │ │
│ │ └──────────┘ └──────────┘ │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 3. 测试执行 (Execution) │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 运行测试 │ │ 断言检查 │ │ 异常捕获 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ Mock处理 │ │ 参数化 │ │ 标记过滤 │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 4. Fixture清理 (Teardown) │ │
│ │ - 执行yield后的清理代码 │ │
│ │ - 释放资源、关闭连接 │ │
│ │ - 清理临时文件 │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ 5. 报告生成 (Reporting) │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ 测试结果 │ │ 覆盖率 │ │ 性能数据 │ │ │
│ │ │ Pass/Fail│ │ Coverage │ │ Duration │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ │ │
│ └──────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘41.6 技术选型指南
41.6.1 测试框架选型
| 框架 | 适用场景 | 特点 | 学习曲线 | 推荐指数 |
|---|---|---|---|---|
| pytest | 通用测试、项目首选 | 插件丰富、简洁 | 低 | ★★★★★ |
| unittest | 标准库、简单测试 | Python内置、稳定 | 低 | ★★★☆☆ |
| nose2 | unittest扩展 | 兼容性好 | 低 | ★★☆☆☆ |
| hypothesis | 属性测试 | 自动生成测试用例 | 中 | ★★★★☆ |
| robot | 验收测试、关键字驱动 | 非技术人员友好 | 中 | ★★★☆☆ |
41.6.2 Mock工具选型
| 工具 | 适用场景 | 特点 | 推荐指数 |
|---|---|---|---|
| unittest.mock | 标准Mock | Python内置、功能完整 | ★★★★★ |
| pytest-mock | pytest集成 | 简化Mock使用 | ★★★★★ |
| responses | HTTP Mock | 专门Mock HTTP请求 | ★★★★☆ |
| freezegun | 时间Mock | 冻结时间、测试时间相关 | ★★★★☆ |
| faker | 测试数据 | 生成假数据 | ★★★★★ |
41.6.3 测试覆盖率工具选型
| 工具 | 特点 | 输出格式 | 推荐指数 |
|---|---|---|---|
| coverage.py | 标准工具、功能完整 | HTML/XML/JSON | ★★★★★ |
| pytest-cov | pytest插件、集成方便 | 同coverage | ★★★★★ |
| pytest-testmon | 增量测试、只运行变更相关 | - | ★★★★☆ |
| mutmut | 变异测试 | HTML | ★★★☆☆ |
41.7 常见问题与解决方案
41.7.1 测试隔离问题
python
import pytest
from typing import Dict, Any
import tempfile
import os
import shutil
class TestIsolation:
"""测试隔离解决方案"""
@pytest.fixture
def isolated_env(self, monkeypatch):
"""隔离的环境变量"""
original_env = os.environ.copy()
monkeypatch.setenv("TEST_MODE", "true")
monkeypatch.setenv("DATABASE_URL", "sqlite:///:memory:")
yield
os.environ.clear()
os.environ.update(original_env)
@pytest.fixture
def isolated_filesystem(self):
"""隔离的文件系统"""
temp_dir = tempfile.mkdtemp()
original_cwd = os.getcwd()
os.chdir(temp_dir)
yield temp_dir
os.chdir(original_cwd)
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def isolated_database(self):
"""隔离的数据库"""
import sqlite3
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
name TEXT,
email TEXT
)
""")
conn.commit()
yield conn
conn.close()
def test_with_isolated_env(self, isolated_env):
assert os.environ.get("TEST_MODE") == "true"
def test_with_isolated_fs(self, isolated_filesystem):
test_file = os.path.join(isolated_filesystem, "test.txt")
with open(test_file, "w") as f:
f.write("test")
assert os.path.exists(test_file)
def test_with_isolated_db(self, isolated_database):
cursor = isolated_database.cursor()
cursor.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("Alice", "alice@test.com"))
isolated_database.commit()
cursor.execute("SELECT * FROM users")
assert len(cursor.fetchall()) == 1
class SharedStateCleaner:
"""共享状态清理器"""
def __init__(self):
self._original_states: Dict[str, Any] = {}
def save_state(self, obj, attr: str):
self._original_states[f"{id(obj)}.{attr}"] = getattr(obj, attr, None)
def restore_state(self, obj, attr: str):
key = f"{id(obj)}.{attr}"
if key in self._original_states:
setattr(obj, attr, self._original_states[key])
def restore_all(self):
for key, value in self._original_states.items():
obj_id, attr = key.split(".", 1)
pass
@pytest.fixture
def state_cleaner():
cleaner = SharedStateCleaner()
yield cleaner
cleaner.restore_all()41.7.2 异步代码测试
python
import pytest
import asyncio
from typing import AsyncGenerator, Callable
class AsyncTestHelper:
"""异步测试辅助工具"""
@staticmethod
@pytest.fixture
async def async_setup():
"""异步setup fixture"""
resource = await asyncio.sleep(0)
yield resource
await asyncio.sleep(0)
@staticmethod
def async_test(coro_func):
"""异步测试装饰器"""
def wrapper(*args, **kwargs):
return asyncio.run(coro_func(*args, **kwargs))
return wrapper
@staticmethod
async def wait_for_condition(
condition: Callable,
timeout: float = 5.0,
interval: float = 0.1
) -> bool:
"""等待条件满足"""
start = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start < timeout:
if await condition() if asyncio.iscoroutinefunction(condition) else condition():
return True
await asyncio.sleep(interval)
return False
class AsyncService:
"""示例异步服务"""
def __init__(self):
self._connected = False
self._data = []
async def connect(self):
await asyncio.sleep(0.1)
self._connected = True
async def disconnect(self):
await asyncio.sleep(0.1)
self._connected = False
async def fetch_data(self) -> list:
if not self._connected:
raise RuntimeError("Not connected")
await asyncio.sleep(0.1)
return self._data
async def add_data(self, item):
if not self._connected:
raise RuntimeError("Not connected")
await asyncio.sleep(0.1)
self._data.append(item)
@pytest.fixture
async def async_service():
"""异步服务fixture"""
service = AsyncService()
await service.connect()
yield service
await service.disconnect()
@pytest.mark.asyncio
class TestAsyncService:
"""异步服务测试"""
async def test_fetch_data(self, async_service):
data = await async_service.fetch_data()
assert isinstance(data, list)
async def test_add_data(self, async_service):
await async_service.add_data("item1")
data = await async_service.fetch_data()
assert "item1" in data
async def test_not_connected_error(self):
service = AsyncService()
with pytest.raises(RuntimeError, match="Not connected"):
await service.fetch_data()
class AsyncMockHelper:
"""异步Mock辅助工具"""
@staticmethod
def async_return(value):
"""创建返回指定值的异步函数"""
async def _async_return(*args, **kwargs):
return value
return _async_return
@staticmethod
def async_raise(exception):
"""创建抛出异常的异步函数"""
async def _async_raise(*args, **kwargs):
raise exception
return _async_raise
@staticmethod
def async_iterator(items):
"""创建异步迭代器"""
async def _async_iterator():
for item in items:
yield item
return _async_iterator()41.7.3 测试数据管理
python
import pytest
from typing import Dict, Any, List, Callable
from dataclasses import dataclass, field
from faker import Faker
import json
import random
fake = Faker("zh_CN")
@dataclass
class UserFactory:
"""用户数据工厂"""
@staticmethod
def create(**kwargs) -> Dict[str, Any]:
defaults = {
"id": fake.random_int(min=1, max=10000),
"name": fake.name(),
"email": fake.email(),
"phone": fake.phone_number(),
"address": fake.address(),
"created_at": fake.date_time_this_year().isoformat()
}
defaults.update(kwargs)
return defaults
@staticmethod
def create_batch(count: int, **kwargs) -> List[Dict]:
return [UserFactory.create(**kwargs) for _ in range(count)]
@dataclass
class ProductFactory:
"""产品数据工厂"""
@staticmethod
def create(**kwargs) -> Dict[str, Any]:
defaults = {
"id": fake.random_int(min=1, max=10000),
"name": fake.word().title(),
"price": round(fake.pyfloat(min_value=10, max_value=1000), 2),
"stock": fake.random_int(min=0, max=100),
"category": fake.random_element(["电子", "服装", "食品", "家居"]),
"description": fake.text(max_nb_chars=200)
}
defaults.update(kwargs)
return defaults
@staticmethod
def create_batch(count: int, **kwargs) -> List[Dict]:
return [ProductFactory.create(**kwargs) for _ in range(count)]
class TestDataBuilder:
"""测试数据构建器"""
def __init__(self):
self._data: Dict[str, Any] = {}
def with_user(self, **kwargs) -> "TestDataBuilder":
self._data["user"] = UserFactory.create(**kwargs)
return self
def with_products(self, count: int = 1, **kwargs) -> "TestDataBuilder":
self._data["products"] = ProductFactory.create_batch(count, **kwargs)
return self
def with_custom(self, key: str, value: Any) -> "TestDataBuilder":
self._data[key] = value
return self
def build(self) -> Dict[str, Any]:
return self._data.copy()
@pytest.fixture
def user_factory():
return UserFactory
@pytest.fixture
def product_factory():
return ProductFactory
@pytest.fixture
def test_data():
return TestDataBuilder()
class TestWithDataFactories:
"""使用数据工厂的测试"""
def test_user_creation(self, user_factory):
user = user_factory.create(name="测试用户")
assert user["name"] == "测试用户"
assert "@" in user["email"]
def test_user_batch(self, user_factory):
users = user_factory.create_batch(10)
assert len(users) == 10
def test_product_creation(self, product_factory):
product = product_factory.create(price=99.99)
assert product["price"] == 99.99
def test_with_builder(self, test_data):
data = test_data.with_user(name="Alice").with_products(3).build()
assert data["user"]["name"] == "Alice"
assert len(data["products"]) == 3
class TestDataManager:
"""测试数据管理器"""
def __init__(self, data_dir: str = "test_data"):
self.data_dir = data_dir
self._cache: Dict[str, Any] = {}
def load_json(self, filename: str) -> Dict:
if filename not in self._cache:
filepath = os.path.join(self.data_dir, filename)
with open(filepath, "r", encoding="utf-8") as f:
self._cache[filename] = json.load(f)
return self._cache[filename]
def save_json(self, filename: str, data: Dict):
filepath = os.path.join(self.data_dir, filename)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def get_test_case(self, name: str) -> Dict:
data = self.load_json("test_cases.json")
return data.get(name, {})
import os41.7.4 测试性能优化
python
import pytest
import time
from typing import Callable, List, Dict, Any
from functools import wraps
class TestPerformance:
"""测试性能优化"""
@staticmethod
def measure_time(func: Callable) -> Callable:
"""测量测试执行时间"""
@wraps(func)
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
print(f"\n{func.__name__} took {elapsed:.4f} seconds")
return result
return wrapper
@staticmethod
@pytest.fixture(scope="session")
def expensive_resource():
"""会话级别的昂贵资源"""
print("\n初始化昂贵资源...")
time.sleep(1)
resource = {"data": "expensive"}
yield resource
print("\n清理昂贵资源...")
@staticmethod
@pytest.fixture
def cached_resource(request, cache):
"""缓存测试结果"""
cache_key = f"resource_{request.node.name}"
if cache_key in cache:
return cache[cache_key]
resource = {"computed": True}
cache[cache_key] = resource
return resource
class TestOptimizer:
"""测试优化工具"""
def __init__(self):
self._slow_tests: List[str] = []
self._threshold = 1.0
def mark_slow(self, test_name: str, duration: float):
if duration > self._threshold:
self._slow_tests.append(test_name)
def get_slow_tests(self) -> List[str]:
return self._slow_tests
def suggest_optimizations(self) -> List[str]:
suggestions = []
for test in self._slow_tests:
suggestions.append(f"考虑优化测试: {test}")
return suggestions
@pytest.fixture
def test_optimizer():
return TestOptimizer()
def pytest_configure(config):
"""pytest配置钩子"""
config.addinivalue_line(
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
)
config.addinivalue_line(
"markers", "integration: marks tests as integration tests"
)
def pytest_collection_modifyitems(config, items):
"""修改测试收集"""
skip_slow = pytest.mark.skip(reason="Skipping slow tests")
for item in items:
if "slow" in item.keywords and not config.getoption("--runslow", default=False):
item.add_marker(skip_slow)41.8 本章小结
本章详细介绍了Python测试进阶的核心概念和实践:
- pytest高级特性:fixture作用域、参数化测试、标记
- Mock测试:Mock对象、patch、副作用、断言
- 测试替身:Stub、Spy、Fake、Dummy
- 测试覆盖率:Coverage.py、覆盖率报告、分支覆盖
- 集成测试:数据库测试、API测试
练习题
- 实现一个测试框架,支持fixture依赖注入和参数化
- 开发一个mock库,支持方法调用记录和返回值设置
- 实现一个覆盖率分析工具,生成HTML报告
- 开发一个测试数据生成器,支持多种数据类型
- 实现一个测试报告生成器,支持多种输出格式