Skip to content

第41章 测试进阶

学习目标

完成本章学习后,你将能够:

  1. 掌握pytest高级特性:fixture、参数化、标记、插件
  2. 使用mock进行模拟测试:Mock对象、patch、副作用
  3. 实现测试夹具:fixture作用域、依赖注入、工厂模式
  4. 进行参数化测试:数据驱动测试、组合测试、条件跳过
  5. 测量测试覆盖率:coverage.py、覆盖率报告、分支覆盖
  6. 编写集成测试:数据库测试、API测试、端到端测试
  7. 实现测试替身:Stub、Spy、Fake、Dummy
  8. 构建测试流水线:CI/CD集成、测试报告、质量门禁

41.1 pytest高级特性

41.1.1 Fixture深入

python
import pytest
from typing import Any, Callable, Dict, Generator, List, Optional
from dataclasses import dataclass
from functools import wraps
import tempfile
import os


@dataclass
class Database:
    name: str
    connection_string: str
    is_connected: bool = False

    def connect(self) -> None:
        self.is_connected = True

    def disconnect(self) -> None:
        self.is_connected = False

    def execute(self, query: str) -> List[Dict]:
        if not self.is_connected:
            raise RuntimeError("Database not connected")
        return []

    def insert(self, table: str, data: Dict) -> int:
        return 1


@pytest.fixture
def database():
    db = Database(
        name="test_db",
        connection_string="sqlite:///:memory:"
    )
    db.connect()
    yield db
    db.disconnect()


@pytest.fixture(scope="session")
def session_database():
    db = Database(
        name="session_db",
        connection_string="sqlite:///:memory:"
    )
    db.connect()
    yield db
    db.disconnect()


@pytest.fixture(scope="module")
def module_config():
    return {"debug": True, "timeout": 30}


@pytest.fixture(scope="class")
def class_resource():
    resource = {"data": [], "initialized": True}
    yield resource
    resource["data"].clear()


@pytest.fixture
def sample_users():
    return [
        {"id": 1, "name": "Alice", "email": "alice@example.com"},
        {"id": 2, "name": "Bob", "email": "bob@example.com"},
        {"id": 3, "name": "Charlie", "email": "charlie@example.com"}
    ]


@pytest.fixture
def temp_file():
    with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
        f.write("test content")
        temp_path = f.name

    yield temp_path

    if os.path.exists(temp_path):
        os.unlink(temp_path)


@pytest.fixture
def temp_directory():
    temp_dir = tempfile.mkdtemp()
    yield temp_dir
    import shutil
    shutil.rmtree(temp_dir, ignore_errors=True)


@pytest.fixture(params=[1, 2, 3, 4, 5])
def number(request):
    return request.param


@pytest.fixture(params=["admin", "user", "guest"])
def user_role(request):
    return request.param


class FixtureFactory:
    @staticmethod
    @pytest.fixture
    def create_user():
        def _create_user(name: str, email: str, role: str = "user") -> Dict:
            return {
                "id": 1,
                "name": name,
                "email": email,
                "role": role
            }
        return _create_user

    @staticmethod
    @pytest.fixture
    def create_product():
        def _create_product(name: str, price: float, stock: int = 0) -> Dict:
            return {
                "id": 1,
                "name": name,
                "price": price,
                "stock": stock
            }
        return _create_product


@pytest.fixture
def mock_api_response():
    def _mock_response(status_code: int = 200, data: Dict = None, error: str = None):
        return {
            "status_code": status_code,
            "data": data or {},
            "error": error
        }
    return _mock_response


def test_database_connection(database):
    assert database.is_connected
    result = database.execute("SELECT * FROM users")
    assert isinstance(result, list)


def test_with_sample_users(sample_users):
    assert len(sample_users) == 3
    assert sample_users[0]["name"] == "Alice"


def test_with_temp_file(temp_file):
    with open(temp_file, "r") as f:
        content = f.read()
    assert content == "test content"


def test_parametrized_fixture(number):
    assert number in [1, 2, 3, 4, 5]


def test_user_roles(user_role):
    assert user_role in ["admin", "user", "guest"]


def test_factory_fixture(create_user):
    user = create_user("Test User", "test@example.com", "admin")
    assert user["name"] == "Test User"
    assert user["role"] == "admin"


class TestWithClassFixture:
    def test_class_resource_1(self, class_resource):
        class_resource["data"].append("item1")
        assert len(class_resource["data"]) == 1

    def test_class_resource_2(self, class_resource):
        class_resource["data"].append("item2")
        assert len(class_resource["data"]) == 2

41.1.2 参数化测试

python
import pytest
from typing import List, Tuple, Any
from dataclasses import dataclass


@dataclass
class TestCase:
    input_value: Any
    expected_output: Any
    description: str = ""


@pytest.mark.parametrize("value,expected", [
    (1, 2),
    (2, 4),
    (3, 6),
    (4, 8),
    (5, 10)
])
def test_double(value, expected):
    assert value * 2 == expected


@pytest.mark.parametrize("a,b,expected", [
    (1, 2, 3),
    (5, 5, 10),
    (-1, 1, 0),
    (0, 0, 0),
    (100, -50, 50)
])
def test_addition(a, b, expected):
    assert a + b == expected


@pytest.mark.parametrize("input_str,expected_length", [
    ("hello", 5),
    ("", 0),
    ("a", 1),
    ("12345", 5),
    ("hello world", 11)
])
def test_string_length(input_str, expected_length):
    assert len(input_str) == expected_length


@pytest.mark.parametrize("value,should_be_positive", [
    (1, True),
    (-1, False),
    (0, False),
    (100, True),
    (-100, False)
])
def test_is_positive(value, should_be_positive):
    result = value > 0
    assert result == should_be_positive


class TestParametrized:
    @pytest.mark.parametrize("x,y", [
        (1, 1),
        (2, 2),
        (3, 3)
    ])
    def test_equality(self, x, y):
        assert x == y

    @pytest.mark.parametrize("value", [1, 2, 3])
    @pytest.mark.parametrize("multiplier", [1, 2, 3])
    def test_multiplication(self, value, multiplier):
        result = value * multiplier
        assert result == value * multiplier


@pytest.mark.parametrize("test_case", [
    TestCase(input_value=1, expected_output=2, description="positive"),
    TestCase(input_value=-1, expected_output=-2, description="negative"),
    TestCase(input_value=0, expected_output=0, description="zero")
])
def test_with_test_case(test_case):
    assert test_case.input_value * 2 == test_case.expected_output


def generate_test_cases() -> List[Tuple]:
    return [
        (i, i ** 2) for i in range(1, 6)
    ]


@pytest.mark.parametrize("value,expected", generate_test_cases())
def test_square_generated(value, expected):
    assert value ** 2 == expected


@pytest.mark.parametrize("value,expected", [
    pytest.param(1, 2, id="positive"),
    pytest.param(-1, -2, id="negative"),
    pytest.param(0, 0, id="zero")
])
def test_with_ids(value, expected):
    assert value * 2 == expected


@pytest.mark.parametrize("value,expected", [
    (1, 2),
    pytest.param(2, 4, marks=pytest.mark.skip("Not implemented")),
    (3, 6)
])
def test_with_skip(value, expected):
    assert value * 2 == expected


@pytest.mark.parametrize("value,expected", [
    (1, 2),
    pytest.param(2, 5, marks=pytest.mark.xfail(reason="Expected to fail")),
    (3, 6)
])
def test_with_xfail(value, expected):
    assert value * 2 == expected

41.2 Mock与模拟测试

41.2.1 Mock对象

python
from unittest.mock import Mock, MagicMock, patch, call, ANY
from typing import Any, Callable, Dict, List, Optional
from dataclasses import dataclass
import pytest


class UserService:
    def __init__(self, database, email_service):
        self.database = database
        self.email_service = email_service

    def get_user(self, user_id: int) -> Optional[Dict]:
        return self.database.find_user(user_id)

    def create_user(self, name: str, email: str) -> Dict:
        if self.database.user_exists(email):
            raise ValueError("User already exists")

        user = self.database.create_user(name, email)
        self.email_service.send_welcome(email)
        return user

    def delete_user(self, user_id: int) -> bool:
        user = self.get_user(user_id)
        if not user:
            return False

        self.database.delete_user(user_id)
        self.email_service.send_farewell(user["email"])
        return True


class TestUserServiceWithMock:
    def test_get_user_success(self):
        mock_db = Mock()
        mock_db.find_user.return_value = {"id": 1, "name": "Alice"}

        service = UserService(mock_db, Mock())
        result = service.get_user(1)

        mock_db.find_user.assert_called_once_with(1)
        assert result["name"] == "Alice"

    def test_create_user_success(self):
        mock_db = Mock()
        mock_db.user_exists.return_value = False
        mock_db.create_user.return_value = {"id": 1, "name": "Bob", "email": "bob@example.com"}

        mock_email = Mock()

        service = UserService(mock_db, mock_email)
        result = service.create_user("Bob", "bob@example.com")

        mock_db.user_exists.assert_called_once_with("bob@example.com")
        mock_db.create_user.assert_called_once_with("Bob", "bob@example.com")
        mock_email.send_welcome.assert_called_once_with("bob@example.com")
        assert result["id"] == 1

    def test_create_user_already_exists(self):
        mock_db = Mock()
        mock_db.user_exists.return_value = True

        service = UserService(mock_db, Mock())

        with pytest.raises(ValueError, match="User already exists"):
            service.create_user("Bob", "bob@example.com")

    def test_delete_user_success(self):
        mock_db = Mock()
        mock_db.find_user.return_value = {"id": 1, "name": "Alice", "email": "alice@example.com"}

        mock_email = Mock()

        service = UserService(mock_db, mock_email)
        result = service.delete_user(1)

        assert result is True
        mock_db.delete_user.assert_called_once_with(1)
        mock_email.send_farewell.assert_called_once_with("alice@example.com")

    def test_delete_user_not_found(self):
        mock_db = Mock()
        mock_db.find_user.return_value = None

        service = UserService(mock_db, Mock())
        result = service.delete_user(999)

        assert result is False
        mock_db.delete_user.assert_not_called()


class TestMockAdvanced:
    def test_mock_with_side_effect(self):
        mock = Mock()
        mock.side_effect = [1, 2, 3]

        assert mock() == 1
        assert mock() == 2
        assert mock() == 3

    def test_mock_with_exception_side_effect(self):
        mock = Mock()
        mock.side_effect = ValueError("Test error")

        with pytest.raises(ValueError, match="Test error"):
            mock()

    def test_mock_with_function_side_effect(self):
        mock = Mock()
        mock.side_effect = lambda x: x * 2

        assert mock(5) == 10
        assert mock(3) == 6

    def test_mock_attributes(self):
        mock = Mock()
        mock.name = "test"
        mock.value = 42

        assert mock.name == "test"
        assert mock.value == 42

    def test_mock_call_count(self):
        mock = Mock()
        mock(1)
        mock(2)
        mock(3)

        assert mock.call_count == 3

    def test_mock_call_args_list(self):
        mock = Mock()
        mock(1, 2)
        mock(3, 4)
        mock(5, 6)

        assert mock.call_args_list == [call(1, 2), call(3, 4), call(5, 6)]

    def test_mock_assert_called_with(self):
        mock = Mock()
        mock(1, 2, key="value")

        mock.assert_called_with(1, 2, key="value")

    def test_mock_assert_any_call(self):
        mock = Mock()
        mock(1)
        mock(2)
        mock(3)

        mock.assert_any_call(2)

    def test_mock_reset_mock(self):
        mock = Mock()
        mock(1)
        mock(2)

        mock.reset_mock()

        assert mock.call_count == 0
        assert mock.call_args_list == []

    def test_magic_mock(self):
        mock = MagicMock()
        mock.__str__.return_value = "Custom String"
        mock.__len__.return_value = 10

        assert str(mock) == "Custom String"
        assert len(mock) == 10

    def test_mock_any_matcher(self):
        mock = Mock()
        mock.process(1, 2, 3)

        mock.process.assert_called_with(ANY, ANY, 3)


class TestWithPatch:
    @patch("builtins.open")
    def test_patch_decorator(self, mock_open):
        mock_open.return_value.__enter__.return_value.read.return_value = "test content"

        with open("test.txt", "r") as f:
            content = f.read()

        assert content == "test content"

    def test_patch_context_manager(self):
        with patch("builtins.open") as mock_open:
            mock_open.return_value.__enter__.return_value.read.return_value = "content"

            with open("test.txt", "r") as f:
                content = f.read()

            assert content == "content"

    @patch("os.path.exists")
    def test_patch_os(self, mock_exists):
        mock_exists.return_value = True

        import os
        result = os.path.exists("/some/path")

        assert result is True
        mock_exists.assert_called_once_with("/some/path")

41.2.2 测试替身

python
from typing import Any, Callable, Dict, List, Optional
from dataclasses import dataclass, field
from abc import ABC, abstractmethod


class Stub:
    def __init__(self, responses: Dict[str, Any] = None):
        self._responses = responses or {}
        self._calls: List[Dict] = []

    def get_response(self, method: str, *args, **kwargs) -> Any:
        self._calls.append({
            "method": method,
            "args": args,
            "kwargs": kwargs
        })
        return self._responses.get(method)

    def get_calls(self) -> List[Dict]:
        return self._calls.copy()

    def was_called(self, method: str) -> bool:
        return any(call["method"] == method for call in self._calls)


class Spy:
    def __init__(self, target: Any = None):
        self._target = target
        self._calls: List[Dict] = []
        self._return_values: Dict[str, Any] = {}

    def __getattr__(self, name: str) -> Callable:
        def method_wrapper(*args, **kwargs):
            self._calls.append({
                "method": name,
                "args": args,
                "kwargs": kwargs
            })

            if name in self._return_values:
                return self._return_values[name]

            if self._target and hasattr(self._target, name):
                return getattr(self._target, name)(*args, **kwargs)

            return None

        return method_wrapper

    def set_return_value(self, method: str, value: Any) -> None:
        self._return_values[method] = value

    def get_call_count(self, method: str) -> int:
        return sum(1 for call in self._calls if call["method"] == method)

    def was_called_with(self, method: str, *args, **kwargs) -> bool:
        for call in self._calls:
            if call["method"] == method:
                if call["args"] == args and call["kwargs"] == kwargs:
                    return True
        return False


class Fake:
    def __init__(self):
        self._data: Dict[str, List] = {}

    def create(self, table: str, data: Dict) -> int:
        if table not in self._data:
            self._data[table] = []

        data["id"] = len(self._data[table]) + 1
        self._data[table].append(data)
        return data["id"]

    def read(self, table: str, id: int) -> Optional[Dict]:
        if table not in self._data:
            return None

        for item in self._data[table]:
            if item.get("id") == id:
                return item.copy()
        return None

    def update(self, table: str, id: int, data: Dict) -> bool:
        if table not in self._data:
            return False

        for i, item in enumerate(self._data[table]):
            if item.get("id") == id:
                self._data[table][i].update(data)
                return True
        return False

    def delete(self, table: str, id: int) -> bool:
        if table not in self._data:
            return False

        for i, item in enumerate(self._data[table]):
            if item.get("id") == id:
                del self._data[table][i]
                return True
        return False

    def list_all(self, table: str) -> List[Dict]:
        return self._data.get(table, []).copy()


class Dummy:
    pass


class TestStubs:
    def test_stub_responses(self):
        stub = Stub({
            "get_user": {"id": 1, "name": "Alice"},
            "get_product": {"id": 1, "name": "Widget"}
        })

        user = stub.get_response("get_user")
        product = stub.get_response("get_product")

        assert user["name"] == "Alice"
        assert product["name"] == "Widget"

    def test_stub_records_calls(self):
        stub = Stub({"process": "done"})

        stub.get_response("process", 1, 2)
        stub.get_response("process", 3, 4)

        calls = stub.get_calls()
        assert len(calls) == 2
        assert stub.was_called("process")


class TestSpies:
    def test_spy_records_calls(self):
        spy = Spy()

        spy.method1(1, 2, key="value")
        spy.method2("test")

        assert spy.get_call_count("method1") == 1
        assert spy.get_call_count("method2") == 1
        assert spy.was_called_with("method1", 1, 2, key="value")

    def test_spy_with_return_values(self):
        spy = Spy()
        spy.set_return_value("get_value", 42)

        result = spy.get_value()

        assert result == 42


class TestFakes:
    def test_fake_crud_operations(self):
        fake_db = Fake()

        user_id = fake_db.create("users", {"name": "Alice", "email": "alice@example.com"})
        assert user_id == 1

        user = fake_db.read("users", user_id)
        assert user["name"] == "Alice"

        fake_db.update("users", user_id, {"name": "Alice Updated"})
        updated_user = fake_db.read("users", user_id)
        assert updated_user["name"] == "Alice Updated"

        deleted = fake_db.delete("users", user_id)
        assert deleted is True

        deleted_user = fake_db.read("users", user_id)
        assert deleted_user is None

    def test_fake_list_all(self):
        fake_db = Fake()

        fake_db.create("products", {"name": "Widget", "price": 10})
        fake_db.create("products", {"name": "Gadget", "price": 20})

        products = fake_db.list_all("products")
        assert len(products) == 2

41.3 测试覆盖率

41.3.1 Coverage.py使用

python
from typing import Any, Dict, List, Optional
from dataclasses import dataclass


@dataclass
class CoverageReport:
    total_lines: int
    covered_lines: int
    missed_lines: int
    coverage_percent: float
    file_reports: Dict[str, Dict]


class CoverageAnalyzer:
    def __init__(self):
        self._covered_lines: Dict[str, set] = {}
        self._total_lines: Dict[str, set] = {}
        self._excluded_files: set = set()

    def add_covered_line(self, filename: str, line_number: int) -> None:
        if filename not in self._covered_lines:
            self._covered_lines[filename] = set()
        self._covered_lines[filename].add(line_number)

    def add_total_line(self, filename: str, line_number: int) -> None:
        if filename not in self._total_lines:
            self._total_lines[filename] = set()
        self._total_lines[filename].add(line_number)

    def exclude_file(self, filename: str) -> None:
        self._excluded_files.add(filename)

    def get_file_coverage(self, filename: str) -> Dict:
        if filename in self._excluded_files:
            return {"coverage": 100.0, "covered": 0, "total": 0}

        covered = len(self._covered_lines.get(filename, set()))
        total = len(self._total_lines.get(filename, set()))

        if total == 0:
            return {"coverage": 100.0, "covered": 0, "total": 0}

        return {
            "coverage": (covered / total) * 100,
            "covered": covered,
            "total": total,
            "missed": total - covered
        }

    def get_total_coverage(self) -> CoverageReport:
        total_covered = 0
        total_lines = 0
        file_reports = {}

        for filename in self._total_lines:
            if filename in self._excluded_files:
                continue

            file_report = self.get_file_coverage(filename)
            file_reports[filename] = file_report
            total_covered += file_report["covered"]
            total_lines += file_report["total"]

        if total_lines == 0:
            return CoverageReport(
                total_lines=0,
                covered_lines=0,
                missed_lines=0,
                coverage_percent=100.0,
                file_reports=file_reports
            )

        return CoverageReport(
            total_lines=total_lines,
            covered_lines=total_covered,
            missed_lines=total_lines - total_covered,
            coverage_percent=(total_covered / total_lines) * 100,
            file_reports=file_reports
        )

    def generate_report(self, output_format: str = "text") -> str:
        report = self.get_total_coverage()

        if output_format == "text":
            lines = [
                "Coverage Report",
                "=" * 50,
                f"Total Lines: {report.total_lines}",
                f"Covered Lines: {report.covered_lines}",
                f"Missed Lines: {report.missed_lines}",
                f"Coverage: {report.coverage_percent:.2f}%",
                "",
                "File Details:",
                "-" * 50
            ]

            for filename, file_report in report.file_reports.items():
                lines.append(
                    f"{filename}: {file_report['coverage']:.2f}% "
                    f"({file_report['covered']}/{file_report['total']})"
                )

            return "\n".join(lines)

        return ""


class BranchCoverage:
    def __init__(self):
        self._branches: Dict[str, Dict[int, Dict[str, bool]]] = {}

    def add_branch(self, filename: str, line: int, branch_id: str, taken: bool) -> None:
        if filename not in self._branches:
            self._branches[filename] = {}
        if line not in self._branches[filename]:
            self._branches[filename][line] = {}
        self._branches[filename][line][branch_id] = taken

    def get_branch_coverage(self, filename: str) -> Dict:
        if filename not in self._branches:
            return {"total": 0, "covered": 0, "coverage": 100.0}

        total = 0
        covered = 0

        for line_branches in self._branches[filename].values():
            for taken in line_branches.values():
                total += 1
                if taken:
                    covered += 1

        coverage = (covered / total * 100) if total > 0 else 100.0

        return {"total": total, "covered": covered, "coverage": coverage}

41.4 集成测试

41.4.1 数据库测试

python
import pytest
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
import tempfile
import os


@dataclass
class TestDatabase:
    path: str
    tables: Dict[str, List[Dict]]

    @classmethod
    def create(cls) -> "TestDatabase":
        return cls(
            path=tempfile.mktemp(suffix=".db"),
            tables={}
        )

    def insert(self, table: str, data: Dict) -> int:
        if table not in self.tables:
            self.tables[table] = []

        data["id"] = len(self.tables[table]) + 1
        self.tables[table].append(data)
        return data["id"]

    def find(self, table: str, id: int) -> Optional[Dict]:
        if table not in self.tables:
            return None
        for item in self.tables[table]:
            if item.get("id") == id:
                return item.copy()
        return None

    def find_all(self, table: str) -> List[Dict]:
        return self.tables.get(table, []).copy()

    def delete(self, table: str, id: int) -> bool:
        if table not in self.tables:
            return False
        for i, item in enumerate(self.tables[table]):
            if item.get("id") == id:
                del self.tables[table][i]
                return True
        return False

    def cleanup(self) -> None:
        self.tables.clear()
        if os.path.exists(self.path):
            os.unlink(self.path)


@pytest.fixture
def test_db():
    db = TestDatabase.create()
    yield db
    db.cleanup()


class TestDatabaseIntegration:
    def test_insert_and_find(self, test_db):
        user_id = test_db.insert("users", {"name": "Alice", "email": "alice@example.com"})

        user = test_db.find("users", user_id)

        assert user is not None
        assert user["name"] == "Alice"
        assert user["email"] == "alice@example.com"

    def test_find_all(self, test_db):
        test_db.insert("users", {"name": "Alice"})
        test_db.insert("users", {"name": "Bob"})
        test_db.insert("users", {"name": "Charlie"})

        users = test_db.find_all("users")

        assert len(users) == 3
        assert users[0]["name"] == "Alice"
        assert users[1]["name"] == "Bob"
        assert users[2]["name"] == "Charlie"

    def test_delete(self, test_db):
        user_id = test_db.insert("users", {"name": "Alice"})

        deleted = test_db.delete("users", user_id)
        assert deleted is True

        user = test_db.find("users", user_id)
        assert user is None

    def test_find_nonexistent(self, test_db):
        user = test_db.find("users", 999)
        assert user is None


class TestAPIClient:
    def __init__(self, base_url: str):
        self.base_url = base_url
        self._responses: Dict[str, Any] = {}

    def set_response(self, endpoint: str, response: Any) -> None:
        self._responses[endpoint] = response

    def get(self, endpoint: str) -> Any:
        return self._responses.get(endpoint)

    def post(self, endpoint: str, data: Dict) -> Any:
        response = self._responses.get(endpoint)
        if callable(response):
            return response(data)
        return response


@pytest.fixture
def api_client():
    return TestAPIClient("http://test-api")


class TestAPIIntegration:
    def test_get_users(self, api_client):
        api_client.set_response("/users", [
            {"id": 1, "name": "Alice"},
            {"id": 2, "name": "Bob"}
        ])

        users = api_client.get("/users")

        assert len(users) == 2
        assert users[0]["name"] == "Alice"

    def test_create_user(self, api_client):
        def create_user(data):
            return {"id": 1, **data}

        api_client.set_response("/users", create_user)

        result = api_client.post("/users", {"name": "Charlie", "email": "charlie@example.com"})

        assert result["id"] == 1
        assert result["name"] == "Charlie"

41.5 知识图谱

41.5.1 测试金字塔架构

┌─────────────────────────────────────────────────────────────────────┐
│                      软件测试金字塔                                  │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│                          /\                                         │
│                         /  \                                        │
│                        / E2E\        端到端测试 (10%)               │
│                       / Test \       - 用户流程                      │
│                      /--------\      - UI自动化                     │
│                     /          \                                    │
│                    / Integration\    集成测试 (20%)                 │
│                   /    Test     \    - API测试                      │
│                  /----------------\  - 数据库测试                   │
│                 /                  \                                │
│                /    Unit Test       \ 单元测试 (70%)               │
│               /                      \ - 函数测试                   │
│              /------------------------\ - 类测试                    │
│                                                        │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      测试工具链                               │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ pytest   │ │ unittest │ │ coverage │ │ mock     │       │   │
│  │  │ 主框架   │ │ 标准库   │ │ 覆盖率   │ │ 模拟     │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ pytest-cov│ │ pytest-mock│ │hypothesis│ │ faker   │       │   │
│  │  │ 覆盖插件  │ │ mock插件  │ │ 属性测试 │ │ 测试数据 │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

41.5.2 测试执行流程

┌─────────────────────────────────────────────────────────────────────┐
│                        测试执行生命周期                              │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    1. 测试发现 (Discovery)                 │     │
│   │  - 扫描测试文件 (test_*.py, *_test.py)                    │     │
│   │  - 收集测试类和测试函数                                    │     │
│   │  - 解析参数化测试用例                                      │     │
│   └──────────────────────────────────────────────────────────┘     │
│                                │                                    │
│                                ▼                                    │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    2. Fixture准备 (Setup)                  │     │
│   │  ┌──────────┐ ┌──────────┐ ┌──────────┐                 │     │
│   │  │ session  │ │ module   │ │ class    │                 │     │
│   │  │ 会话级别 │ │ 模块级别 │ │ 类级别   │                 │     │
│   │  └──────────┘ └──────────┘ └──────────┘                 │     │
│   │  ┌──────────┐ ┌──────────┐                              │     │
│   │  │ function │ │ 依赖注入 │                              │     │
│   │  │ 函数级别 │ │ 自动处理 │                              │     │
│   │  └──────────┘ └──────────┘                              │     │
│   └──────────────────────────────────────────────────────────┘     │
│                                │                                    │
│                                ▼                                    │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    3. 测试执行 (Execution)                 │     │
│   │  ┌──────────┐ ┌──────────┐ ┌──────────┐                 │     │
│   │  │ 运行测试  │ │ 断言检查  │ │ 异常捕获  │                 │     │
│   │  └──────────┘ └──────────┘ └──────────┘                 │     │
│   │  ┌──────────┐ ┌──────────┐ ┌──────────┐                 │     │
│   │  │ Mock处理 │ │ 参数化   │ │ 标记过滤  │                 │     │
│   │  └──────────┘ └──────────┘ └──────────┘                 │     │
│   └──────────────────────────────────────────────────────────┘     │
│                                │                                    │
│                                ▼                                    │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    4. Fixture清理 (Teardown)               │     │
│   │  - 执行yield后的清理代码                                   │     │
│   │  - 释放资源、关闭连接                                      │     │
│   │  - 清理临时文件                                            │     │
│   └──────────────────────────────────────────────────────────┘     │
│                                │                                    │
│                                ▼                                    │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    5. 报告生成 (Reporting)                  │     │
│   │  ┌──────────┐ ┌──────────┐ ┌──────────┐                 │     │
│   │  │ 测试结果  │ │ 覆盖率   │ │ 性能数据  │                 │     │
│   │  │ Pass/Fail│ │ Coverage │ │ Duration │                 │     │
│   │  └──────────┘ └──────────┘ └──────────┘                 │     │
│   └──────────────────────────────────────────────────────────┘     │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

41.6 技术选型指南

41.6.1 测试框架选型

框架适用场景特点学习曲线推荐指数
pytest通用测试、项目首选插件丰富、简洁★★★★★
unittest标准库、简单测试Python内置、稳定★★★☆☆
nose2unittest扩展兼容性好★★☆☆☆
hypothesis属性测试自动生成测试用例★★★★☆
robot验收测试、关键字驱动非技术人员友好★★★☆☆

41.6.2 Mock工具选型

工具适用场景特点推荐指数
unittest.mock标准MockPython内置、功能完整★★★★★
pytest-mockpytest集成简化Mock使用★★★★★
responsesHTTP Mock专门Mock HTTP请求★★★★☆
freezegun时间Mock冻结时间、测试时间相关★★★★☆
faker测试数据生成假数据★★★★★

41.6.3 测试覆盖率工具选型

工具特点输出格式推荐指数
coverage.py标准工具、功能完整HTML/XML/JSON★★★★★
pytest-covpytest插件、集成方便同coverage★★★★★
pytest-testmon增量测试、只运行变更相关-★★★★☆
mutmut变异测试HTML★★★☆☆

41.7 常见问题与解决方案

41.7.1 测试隔离问题

python
import pytest
from typing import Dict, Any
import tempfile
import os
import shutil

class TestIsolation:
    """测试隔离解决方案"""
    
    @pytest.fixture
    def isolated_env(self, monkeypatch):
        """隔离的环境变量"""
        original_env = os.environ.copy()
        
        monkeypatch.setenv("TEST_MODE", "true")
        monkeypatch.setenv("DATABASE_URL", "sqlite:///:memory:")
        
        yield
        
        os.environ.clear()
        os.environ.update(original_env)
    
    @pytest.fixture
    def isolated_filesystem(self):
        """隔离的文件系统"""
        temp_dir = tempfile.mkdtemp()
        original_cwd = os.getcwd()
        
        os.chdir(temp_dir)
        
        yield temp_dir
        
        os.chdir(original_cwd)
        shutil.rmtree(temp_dir, ignore_errors=True)
    
    @pytest.fixture
    def isolated_database(self):
        """隔离的数据库"""
        import sqlite3
        
        conn = sqlite3.connect(":memory:")
        cursor = conn.cursor()
        
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS users (
                id INTEGER PRIMARY KEY,
                name TEXT,
                email TEXT
            )
        """)
        conn.commit()
        
        yield conn
        
        conn.close()
    
    def test_with_isolated_env(self, isolated_env):
        assert os.environ.get("TEST_MODE") == "true"
    
    def test_with_isolated_fs(self, isolated_filesystem):
        test_file = os.path.join(isolated_filesystem, "test.txt")
        with open(test_file, "w") as f:
            f.write("test")
        assert os.path.exists(test_file)
    
    def test_with_isolated_db(self, isolated_database):
        cursor = isolated_database.cursor()
        cursor.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("Alice", "alice@test.com"))
        isolated_database.commit()
        
        cursor.execute("SELECT * FROM users")
        assert len(cursor.fetchall()) == 1


class SharedStateCleaner:
    """共享状态清理器"""
    
    def __init__(self):
        self._original_states: Dict[str, Any] = {}
    
    def save_state(self, obj, attr: str):
        self._original_states[f"{id(obj)}.{attr}"] = getattr(obj, attr, None)
    
    def restore_state(self, obj, attr: str):
        key = f"{id(obj)}.{attr}"
        if key in self._original_states:
            setattr(obj, attr, self._original_states[key])
    
    def restore_all(self):
        for key, value in self._original_states.items():
            obj_id, attr = key.split(".", 1)
            pass


@pytest.fixture
def state_cleaner():
    cleaner = SharedStateCleaner()
    yield cleaner
    cleaner.restore_all()

41.7.2 异步代码测试

python
import pytest
import asyncio
from typing import AsyncGenerator, Callable

class AsyncTestHelper:
    """异步测试辅助工具"""
    
    @staticmethod
    @pytest.fixture
    async def async_setup():
        """异步setup fixture"""
        resource = await asyncio.sleep(0)
        yield resource
        await asyncio.sleep(0)
    
    @staticmethod
    def async_test(coro_func):
        """异步测试装饰器"""
        def wrapper(*args, **kwargs):
            return asyncio.run(coro_func(*args, **kwargs))
        return wrapper
    
    @staticmethod
    async def wait_for_condition(
        condition: Callable,
        timeout: float = 5.0,
        interval: float = 0.1
    ) -> bool:
        """等待条件满足"""
        start = asyncio.get_event_loop().time()
        while asyncio.get_event_loop().time() - start < timeout:
            if await condition() if asyncio.iscoroutinefunction(condition) else condition():
                return True
            await asyncio.sleep(interval)
        return False


class AsyncService:
    """示例异步服务"""
    
    def __init__(self):
        self._connected = False
        self._data = []
    
    async def connect(self):
        await asyncio.sleep(0.1)
        self._connected = True
    
    async def disconnect(self):
        await asyncio.sleep(0.1)
        self._connected = False
    
    async def fetch_data(self) -> list:
        if not self._connected:
            raise RuntimeError("Not connected")
        await asyncio.sleep(0.1)
        return self._data
    
    async def add_data(self, item):
        if not self._connected:
            raise RuntimeError("Not connected")
        await asyncio.sleep(0.1)
        self._data.append(item)


@pytest.fixture
async def async_service():
    """异步服务fixture"""
    service = AsyncService()
    await service.connect()
    yield service
    await service.disconnect()


@pytest.mark.asyncio
class TestAsyncService:
    """异步服务测试"""
    
    async def test_fetch_data(self, async_service):
        data = await async_service.fetch_data()
        assert isinstance(data, list)
    
    async def test_add_data(self, async_service):
        await async_service.add_data("item1")
        data = await async_service.fetch_data()
        assert "item1" in data
    
    async def test_not_connected_error(self):
        service = AsyncService()
        with pytest.raises(RuntimeError, match="Not connected"):
            await service.fetch_data()


class AsyncMockHelper:
    """异步Mock辅助工具"""
    
    @staticmethod
    def async_return(value):
        """创建返回指定值的异步函数"""
        async def _async_return(*args, **kwargs):
            return value
        return _async_return
    
    @staticmethod
    def async_raise(exception):
        """创建抛出异常的异步函数"""
        async def _async_raise(*args, **kwargs):
            raise exception
        return _async_raise
    
    @staticmethod
    def async_iterator(items):
        """创建异步迭代器"""
        async def _async_iterator():
            for item in items:
                yield item
        return _async_iterator()

41.7.3 测试数据管理

python
import pytest
from typing import Dict, Any, List, Callable
from dataclasses import dataclass, field
from faker import Faker
import json
import random

fake = Faker("zh_CN")


@dataclass
class UserFactory:
    """用户数据工厂"""
    
    @staticmethod
    def create(**kwargs) -> Dict[str, Any]:
        defaults = {
            "id": fake.random_int(min=1, max=10000),
            "name": fake.name(),
            "email": fake.email(),
            "phone": fake.phone_number(),
            "address": fake.address(),
            "created_at": fake.date_time_this_year().isoformat()
        }
        defaults.update(kwargs)
        return defaults
    
    @staticmethod
    def create_batch(count: int, **kwargs) -> List[Dict]:
        return [UserFactory.create(**kwargs) for _ in range(count)]


@dataclass
class ProductFactory:
    """产品数据工厂"""
    
    @staticmethod
    def create(**kwargs) -> Dict[str, Any]:
        defaults = {
            "id": fake.random_int(min=1, max=10000),
            "name": fake.word().title(),
            "price": round(fake.pyfloat(min_value=10, max_value=1000), 2),
            "stock": fake.random_int(min=0, max=100),
            "category": fake.random_element(["电子", "服装", "食品", "家居"]),
            "description": fake.text(max_nb_chars=200)
        }
        defaults.update(kwargs)
        return defaults
    
    @staticmethod
    def create_batch(count: int, **kwargs) -> List[Dict]:
        return [ProductFactory.create(**kwargs) for _ in range(count)]


class TestDataBuilder:
    """测试数据构建器"""
    
    def __init__(self):
        self._data: Dict[str, Any] = {}
    
    def with_user(self, **kwargs) -> "TestDataBuilder":
        self._data["user"] = UserFactory.create(**kwargs)
        return self
    
    def with_products(self, count: int = 1, **kwargs) -> "TestDataBuilder":
        self._data["products"] = ProductFactory.create_batch(count, **kwargs)
        return self
    
    def with_custom(self, key: str, value: Any) -> "TestDataBuilder":
        self._data[key] = value
        return self
    
    def build(self) -> Dict[str, Any]:
        return self._data.copy()


@pytest.fixture
def user_factory():
    return UserFactory


@pytest.fixture
def product_factory():
    return ProductFactory


@pytest.fixture
def test_data():
    return TestDataBuilder()


class TestWithDataFactories:
    """使用数据工厂的测试"""
    
    def test_user_creation(self, user_factory):
        user = user_factory.create(name="测试用户")
        assert user["name"] == "测试用户"
        assert "@" in user["email"]
    
    def test_user_batch(self, user_factory):
        users = user_factory.create_batch(10)
        assert len(users) == 10
    
    def test_product_creation(self, product_factory):
        product = product_factory.create(price=99.99)
        assert product["price"] == 99.99
    
    def test_with_builder(self, test_data):
        data = test_data.with_user(name="Alice").with_products(3).build()
        
        assert data["user"]["name"] == "Alice"
        assert len(data["products"]) == 3


class TestDataManager:
    """测试数据管理器"""
    
    def __init__(self, data_dir: str = "test_data"):
        self.data_dir = data_dir
        self._cache: Dict[str, Any] = {}
    
    def load_json(self, filename: str) -> Dict:
        if filename not in self._cache:
            filepath = os.path.join(self.data_dir, filename)
            with open(filepath, "r", encoding="utf-8") as f:
                self._cache[filename] = json.load(f)
        return self._cache[filename]
    
    def save_json(self, filename: str, data: Dict):
        filepath = os.path.join(self.data_dir, filename)
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
    
    def get_test_case(self, name: str) -> Dict:
        data = self.load_json("test_cases.json")
        return data.get(name, {})


import os

41.7.4 测试性能优化

python
import pytest
import time
from typing import Callable, List, Dict, Any
from functools import wraps

class TestPerformance:
    """测试性能优化"""
    
    @staticmethod
    def measure_time(func: Callable) -> Callable:
        """测量测试执行时间"""
        @wraps(func)
        def wrapper(*args, **kwargs):
            start = time.perf_counter()
            result = func(*args, **kwargs)
            elapsed = time.perf_counter() - start
            print(f"\n{func.__name__} took {elapsed:.4f} seconds")
            return result
        return wrapper
    
    @staticmethod
    @pytest.fixture(scope="session")
    def expensive_resource():
        """会话级别的昂贵资源"""
        print("\n初始化昂贵资源...")
        time.sleep(1)
        resource = {"data": "expensive"}
        yield resource
        print("\n清理昂贵资源...")
    
    @staticmethod
    @pytest.fixture
    def cached_resource(request, cache):
        """缓存测试结果"""
        cache_key = f"resource_{request.node.name}"
        if cache_key in cache:
            return cache[cache_key]
        
        resource = {"computed": True}
        cache[cache_key] = resource
        return resource


class TestOptimizer:
    """测试优化工具"""
    
    def __init__(self):
        self._slow_tests: List[str] = []
        self._threshold = 1.0
    
    def mark_slow(self, test_name: str, duration: float):
        if duration > self._threshold:
            self._slow_tests.append(test_name)
    
    def get_slow_tests(self) -> List[str]:
        return self._slow_tests
    
    def suggest_optimizations(self) -> List[str]:
        suggestions = []
        for test in self._slow_tests:
            suggestions.append(f"考虑优化测试: {test}")
        return suggestions


@pytest.fixture
def test_optimizer():
    return TestOptimizer()


def pytest_configure(config):
    """pytest配置钩子"""
    config.addinivalue_line(
        "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
    )
    config.addinivalue_line(
        "markers", "integration: marks tests as integration tests"
    )


def pytest_collection_modifyitems(config, items):
    """修改测试收集"""
    skip_slow = pytest.mark.skip(reason="Skipping slow tests")
    
    for item in items:
        if "slow" in item.keywords and not config.getoption("--runslow", default=False):
            item.add_marker(skip_slow)

41.8 本章小结

本章详细介绍了Python测试进阶的核心概念和实践:

  1. pytest高级特性:fixture作用域、参数化测试、标记
  2. Mock测试:Mock对象、patch、副作用、断言
  3. 测试替身:Stub、Spy、Fake、Dummy
  4. 测试覆盖率:Coverage.py、覆盖率报告、分支覆盖
  5. 集成测试:数据库测试、API测试

练习题

  1. 实现一个测试框架,支持fixture依赖注入和参数化
  2. 开发一个mock库,支持方法调用记录和返回值设置
  3. 实现一个覆盖率分析工具,生成HTML报告
  4. 开发一个测试数据生成器,支持多种数据类型
  5. 实现一个测试报告生成器,支持多种输出格式

扩展阅读

Python技术丛书 - 江苏省宿城中等专业学校