Skip to content

第40章 API开发进阶

学习目标

完成本章学习后,你将能够:

  1. 设计RESTful API:资源设计、URL规范、HTTP方法、状态码
  2. 使用FastAPI:路由、依赖注入、请求验证、响应模型
  3. 使用Django REST Framework:序列化器、视图集、路由器、权限
  4. 实现API认证:JWT认证、OAuth2、API密钥、权限控制
  5. 实现GraphQL API:Schema设计、Query、Mutation、Subscription
  6. 编写API文档:OpenAPI规范、Swagger UI、ReDoc
  7. 实现API限流:限流策略、缓存、版本控制
  8. 处理API错误:错误处理、日志记录、监控告警

40.1 RESTful API设计

40.1.1 设计原则

python
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, TypeVar, Generic
from enum import Enum
from datetime import datetime
import json


T = TypeVar("T")


class HTTPMethod(Enum):
    GET = "GET"
    POST = "POST"
    PUT = "PUT"
    PATCH = "PATCH"
    DELETE = "DELETE"


class HTTPStatus(Enum):
    OK = 200
    CREATED = 201
    NO_CONTENT = 204
    BAD_REQUEST = 400
    UNAUTHORIZED = 401
    FORBIDDEN = 403
    NOT_FOUND = 404
    METHOD_NOT_ALLOWED = 405
    CONFLICT = 409
    UNPROCESSABLE_ENTITY = 422
    INTERNAL_SERVER_ERROR = 500
    SERVICE_UNAVAILABLE = 503


@dataclass
class APIResponse(Generic[T]):
    success: bool
    data: Optional[T] = None
    message: Optional[str] = None
    errors: Optional[List[Dict]] = None
    meta: Optional[Dict] = None

    def to_dict(self) -> Dict:
        result = {"success": self.success}
        if self.data is not None:
            result["data"] = self.data
        if self.message:
            result["message"] = self.message
        if self.errors:
            result["errors"] = self.errors
        if self.meta:
            result["meta"] = self.meta
        return result


@dataclass
class PaginationMeta:
    page: int
    per_page: int
    total: int
    total_pages: int
    has_next: bool
    has_prev: bool

    @classmethod
    def create(cls, page: int, per_page: int, total: int) -> "PaginationMeta":
        total_pages = (total + per_page - 1) // per_page
        return cls(
            page=page,
            per_page=per_page,
            total=total,
            total_pages=total_pages,
            has_next=page < total_pages,
            has_prev=page > 1
        )


@dataclass
class APIError:
    code: str
    message: str
    field: Optional[str] = None
    details: Optional[Dict] = None

    def to_dict(self) -> Dict:
        result = {"code": self.code, "message": self.message}
        if self.field:
            result["field"] = self.field
        if self.details:
            result["details"] = self.details
        return result


class ErrorHandler:
    @staticmethod
    def bad_request(message: str = "Bad request", errors: List[Dict] = None) -> APIResponse:
        return APIResponse(
            success=False,
            message=message,
            errors=errors,
            status_code=HTTPStatus.BAD_REQUEST.value
        )

    @staticmethod
    def unauthorized(message: str = "Unauthorized") -> APIResponse:
        return APIResponse(
            success=False,
            message=message,
            status_code=HTTPStatus.UNAUTHORIZED.value
        )

    @staticmethod
    def forbidden(message: str = "Forbidden") -> APIResponse:
        return APIResponse(
            success=False,
            message=message,
            status_code=HTTPStatus.FORBIDDEN.value
        )

    @staticmethod
    def not_found(message: str = "Resource not found") -> APIResponse:
        return APIResponse(
            success=False,
            message=message,
            status_code=HTTPStatus.NOT_FOUND.value
        )

    @staticmethod
    def validation_error(errors: List[Dict]) -> APIResponse:
        return APIResponse(
            success=False,
            message="Validation error",
            errors=errors,
            status_code=HTTPStatus.UNPROCESSABLE_ENTITY.value
        )

    @staticmethod
    def internal_error(message: str = "Internal server error") -> APIResponse:
        return APIResponse(
            success=False,
            message=message,
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value
        )


class RESTfulResource:
    def __init__(self, name: str, base_path: str = "/api/v1"):
        self.name = name
        self.base_path = base_path
        self._endpoints: Dict[str, Dict] = {}

    def add_endpoint(self, method: HTTPMethod, path: str, handler: callable, description: str = "") -> None:
        full_path = f"{self.base_path}/{self.name}{path}"
        self._endpoints[full_path] = {
            "method": method.value,
            "handler": handler,
            "description": description
        }

    def get_endpoints(self) -> Dict[str, Dict]:
        return self._endpoints.copy()

    def generate_docs(self) -> List[Dict]:
        docs = []
        for path, config in self._endpoints.items():
            docs.append({
                "path": path,
                "method": config["method"],
                "description": config["description"]
            })
        return docs


class APIVersioning:
    def __init__(self, current_version: str = "v1"):
        self.current_version = current_version
        self._versions: Dict[str, Dict] = {}

    def register_version(self, version: str, deprecated: bool = False, sunset_date: str = None) -> None:
        self._versions[version] = {
            "deprecated": deprecated,
            "sunset_date": sunset_date
        }

    def get_version_info(self, version: str) -> Optional[Dict]:
        return self._versions.get(version)

    def is_deprecated(self, version: str) -> bool:
        info = self._versions.get(version)
        return info.get("deprecated", False) if info else False

40.1.2 请求验证

python
from typing import Any, Callable, List, Optional, Type
from dataclasses import dataclass
import re


@dataclass
class ValidationResult:
    is_valid: bool
    errors: List[str]

    @classmethod
    def valid(cls) -> "ValidationResult":
        return cls(is_valid=True, errors=[])

    @classmethod
    def invalid(cls, errors: List[str]) -> "ValidationResult":
        return cls(is_valid=False, errors=errors)


class FieldValidator:
    @staticmethod
    def required(value: Any, field_name: str) -> ValidationResult:
        if value is None or (isinstance(value, str) and not value.strip()):
            return ValidationResult.invalid([f"{field_name} is required"])
        return ValidationResult.valid()

    @staticmethod
    def min_length(value: str, min_len: int, field_name: str) -> ValidationResult:
        if value and len(value) < min_len:
            return ValidationResult.invalid([f"{field_name} must be at least {min_len} characters"])
        return ValidationResult.valid()

    @staticmethod
    def max_length(value: str, max_len: int, field_name: str) -> ValidationResult:
        if value and len(value) > max_len:
            return ValidationResult.invalid([f"{field_name} must be at most {max_len} characters"])
        return ValidationResult.valid()

    @staticmethod
    def email(value: str, field_name: str) -> ValidationResult:
        pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        if value and not re.match(pattern, value):
            return ValidationResult.invalid([f"{field_name} must be a valid email"])
        return ValidationResult.valid()

    @staticmethod
    def url(value: str, field_name: str) -> ValidationResult:
        pattern = r'^https?://[^\s/$.?#].[^\s]*$'
        if value and not re.match(pattern, value):
            return ValidationResult.invalid([f"{field_name} must be a valid URL"])
        return ValidationResult.valid()

    @staticmethod
    def integer(value: Any, field_name: str) -> ValidationResult:
        if value is not None:
            try:
                int(value)
            except (ValueError, TypeError):
                return ValidationResult.invalid([f"{field_name} must be an integer"])
        return ValidationResult.valid()

    @staticmethod
    def float_val(value: Any, field_name: str) -> ValidationResult:
        if value is not None:
            try:
                float(value)
            except (ValueError, TypeError):
                return ValidationResult.invalid([f"{field_name} must be a number"])
        return ValidationResult.valid()

    @staticmethod
    def min_value(value: Any, min_val: float, field_name: str) -> ValidationResult:
        if value is not None:
            try:
                if float(value) < min_val:
                    return ValidationResult.invalid([f"{field_name} must be at least {min_val}"])
            except (ValueError, TypeError):
                pass
        return ValidationResult.valid()

    @staticmethod
    def max_value(value: Any, max_val: float, field_name: str) -> ValidationResult:
        if value is not None:
            try:
                if float(value) > max_val:
                    return ValidationResult.invalid([f"{field_name} must be at most {max_val}"])
            except (ValueError, TypeError):
                pass
        return ValidationResult.valid()

    @staticmethod
    def in_list(value: Any, allowed: List[Any], field_name: str) -> ValidationResult:
        if value is not None and value not in allowed:
            return ValidationResult.invalid([f"{field_name} must be one of: {allowed}"])
        return ValidationResult.valid()

    @staticmethod
    def regex(value: str, pattern: str, field_name: str) -> ValidationResult:
        if value and not re.match(pattern, value):
            return ValidationResult.invalid([f"{field_name} format is invalid"])
        return ValidationResult.valid()


class RequestValidator:
    def __init__(self):
        self._rules: Dict[str, List[Callable]] = {}

    def add_rule(self, field: str, validator: Callable) -> "RequestValidator":
        if field not in self._rules:
            self._rules[field] = []
        self._rules[field].append(validator)
        return self

    def validate(self, data: Dict) -> ValidationResult:
        all_errors = []

        for field, validators in self._rules.items():
            value = data.get(field)
            for validator in validators:
                result = validator(value, field)
                if not result.is_valid:
                    all_errors.extend(result.errors)

        if all_errors:
            return ValidationResult.invalid(all_errors)
        return ValidationResult.valid()


class SchemaValidator:
    def __init__(self, schema: Dict):
        self.schema = schema

    def validate(self, data: Dict) -> ValidationResult:
        errors = []

        for field, rules in self.schema.items():
            value = data.get(field)

            if rules.get("required", False):
                result = FieldValidator.required(value, field)
                if not result.is_valid:
                    errors.extend(result.errors)
                    continue

            if value is not None:
                if "type" in rules:
                    type_validators = {
                        "string": lambda v, f: ValidationResult.valid(),
                        "integer": FieldValidator.integer,
                        "float": FieldValidator.float_val,
                        "email": FieldValidator.email,
                        "url": FieldValidator.url
                    }
                    validator = type_validators.get(rules["type"])
                    if validator:
                        result = validator(value, field)
                        if not result.is_valid:
                            errors.extend(result.errors)

                if "min_length" in rules:
                    result = FieldValidator.min_length(value, rules["min_length"], field)
                    if not result.is_valid:
                        errors.extend(result.errors)

                if "max_length" in rules:
                    result = FieldValidator.max_length(value, rules["max_length"], field)
                    if not result.is_valid:
                        errors.extend(result.errors)

                if "min_value" in rules:
                    result = FieldValidator.min_value(value, rules["min_value"], field)
                    if not result.is_valid:
                        errors.extend(result.errors)

                if "max_value" in rules:
                    result = FieldValidator.max_value(value, rules["max_value"], field)
                    if not result.is_valid:
                        errors.extend(result.errors)

                if "enum" in rules:
                    result = FieldValidator.in_list(value, rules["enum"], field)
                    if not result.is_valid:
                        errors.extend(result.errors)

                if "pattern" in rules:
                    result = FieldValidator.regex(value, rules["pattern"], field)
                    if not result.is_valid:
                        errors.extend(result.errors)

        if errors:
            return ValidationResult.invalid(errors)
        return ValidationResult.valid()

40.2 FastAPI进阶

40.2.1 依赖注入

python
from typing import Optional, List, Dict, Any, Callable
from dataclasses import dataclass
from functools import wraps


@dataclass
class User:
    id: int
    username: str
    email: str
    role: str
    is_active: bool = True


@dataclass
class Request:
    method: str
    path: str
    headers: Dict[str, str]
    query_params: Dict[str, str]
    body: Optional[Dict] = None


class DependencyContainer:
    def __init__(self):
        self._dependencies: Dict[str, Any] = {}
        self._factories: Dict[str, Callable] = {}

    def register(self, name: str, instance: Any) -> None:
        self._dependencies[name] = instance

    def register_factory(self, name: str, factory: Callable) -> None:
        self._factories[name] = factory

    def get(self, name: str) -> Any:
        if name in self._dependencies:
            return self._dependencies[name]

        if name in self._factories:
            instance = self._factories[name]()
            self._dependencies[name] = instance
            return instance

        raise KeyError(f"Dependency '{name}' not found")

    def has(self, name: str) -> bool:
        return name in self._dependencies or name in self._factories


class Depends:
    def __init__(self, dependency: Callable, use_cache: bool = True):
        self.dependency = dependency
        self.use_cache = use_cache
        self._cache: Dict[int, Any] = {}

    def __call__(self, request: Request, container: DependencyContainer) -> Any:
        cache_key = id(request) if self.use_cache else None

        if self.use_cache and cache_key in self._cache:
            return self._cache[cache_key]

        result = self.dependency(request, container)

        if self.use_cache:
            self._cache[cache_key] = result

        return result


def get_current_user(request: Request, container: DependencyContainer) -> Optional[User]:
    auth_header = request.headers.get("Authorization", "")
    if not auth_header.startswith("Bearer "):
        return None

    token = auth_header[7:]
    user_service = container.get("user_service")
    return user_service.verify_token(token)


def require_auth(user: User = Depends(get_current_user)) -> User:
    if user is None:
        raise PermissionError("Authentication required")
    return user


def require_role(role: str):
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(user: User = Depends(get_current_user), *args, **kwargs):
            if user is None:
                raise PermissionError("Authentication required")
            if user.role != role:
                raise PermissionError(f"Role '{role}' required")
            return func(user=user, *args, **kwargs)
        return wrapper
    return decorator


class PaginationParams:
    def __init__(self, page: int = 1, per_page: int = 20, max_per_page: int = 100):
        self.page = max(1, page)
        self.per_page = min(max(1, per_page), max_per_page)
        self.offset = (self.page - 1) * self.per_page
        self.limit = self.per_page

    @classmethod
    def from_request(cls, request: Request) -> "PaginationParams":
        query = request.query_params
        page = int(query.get("page", 1))
        per_page = int(query.get("per_page", 20))
        return cls(page=page, per_page=per_page)


class FilterParams:
    def __init__(self):
        self._filters: Dict[str, Any] = {}
        self._sort: List[str] = []
        self._search: Optional[str] = None

    def add_filter(self, field: str, value: Any) -> None:
        self._filters[field] = value

    def add_sort(self, field: str, descending: bool = False) -> None:
        prefix = "-" if descending else ""
        self._sort.append(f"{prefix}{field}")

    def set_search(self, query: str) -> None:
        self._search = query

    def to_dict(self) -> Dict:
        result = {}
        if self._filters:
            result["filter"] = self._filters
        if self._sort:
            result["sort"] = self._sort
        if self._search:
            result["search"] = self._search
        return result

    @classmethod
    def from_request(cls, request: Request, allowed_fields: List[str] = None) -> "FilterParams":
        params = cls()
        query = request.query_params

        allowed_fields = allowed_fields or []

        for key, value in query.items():
            if key == "sort":
                params.add_sort(value.lstrip("-"), value.startswith("-"))
            elif key == "search":
                params.set_search(value)
            elif key in allowed_fields:
                params.add_filter(key, value)

        return params

40.2.2 响应模型

python
from typing import Generic, TypeVar, List, Optional, Dict, Any
from dataclasses import dataclass, asdict
from abc import ABC, abstractmethod
import json


T = TypeVar("T")


@dataclass
class ResponseModel(ABC):
    @abstractmethod
    def to_dict(self) -> Dict:
        pass

    def to_json(self) -> str:
        return json.dumps(self.to_dict())


@dataclass
class UserResponse(ResponseModel):
    id: int
    username: str
    email: str
    role: str
    created_at: str

    def to_dict(self) -> Dict:
        return {
            "id": self.id,
            "username": self.username,
            "email": self.email,
            "role": self.role,
            "created_at": self.created_at
        }


@dataclass
class UserListResponse(ResponseModel):
    users: List[UserResponse]
    pagination: PaginationMeta

    def to_dict(self) -> Dict:
        return {
            "users": [u.to_dict() for u in self.users],
            "pagination": {
                "page": self.pagination.page,
                "per_page": self.pagination.per_page,
                "total": self.pagination.total,
                "total_pages": self.pagination.total_pages
            }
        }


@dataclass
class ErrorResponse(ResponseModel):
    code: str
    message: str
    details: Optional[Dict] = None

    def to_dict(self) -> Dict:
        result = {"code": self.code, "message": self.message}
        if self.details:
            result["details"] = self.details
        return result


class ResponseBuilder:
    @staticmethod
    def success(data: Any, message: str = None, status: int = 200) -> Dict:
        response = {"success": True, "data": data}
        if message:
            response["message"] = message
        return response

    @staticmethod
    def error(code: str, message: str, details: Dict = None, status: int = 400) -> Dict:
        response = {"success": False, "error": {"code": code, "message": message}}
        if details:
            response["error"]["details"] = details
        return response

    @staticmethod
    def paginated(
        items: List[Any],
        page: int,
        per_page: int,
        total: int,
        item_key: str = "items"
    ) -> Dict:
        meta = PaginationMeta.create(page, per_page, total)
        return {
            "success": True,
            "data": {
                item_key: items,
                "pagination": {
                    "page": meta.page,
                    "per_page": meta.per_page,
                    "total": meta.total,
                    "total_pages": meta.total_pages,
                    "has_next": meta.has_next,
                    "has_prev": meta.has_prev
                }
            }
        }

    @staticmethod
    def created(data: Any, message: str = "Resource created successfully") -> Dict:
        return {"success": True, "data": data, "message": message}

    @staticmethod
    def no_content() -> Dict:
        return {"success": True}

    @staticmethod
    def validation_error(errors: List[Dict]) -> Dict:
        return {
            "success": False,
            "error": {
                "code": "VALIDATION_ERROR",
                "message": "Validation failed",
                "details": {"errors": errors}
            }
        }

40.3 Django REST Framework

40.3.1 序列化器

python
from typing import Dict, List, Any, Optional, Type
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import json


class Field(ABC):
    def __init__(
        self,
        required: bool = True,
        allow_null: bool = False,
        default: Any = None,
        help_text: str = ""
    ):
        self.required = required
        self.allow_null = allow_null
        self.default = default
        self.help_text = help_text
        self._name = None

    def bind(self, name: str) -> None:
        self._name = name

    @abstractmethod
    def to_internal_value(self, data: Any) -> Any:
        pass

    @abstractmethod
    def to_representation(self, value: Any) -> Any:
        pass

    def validate(self, value: Any) -> Any:
        if value is None:
            if self.required:
                raise ValueError(f"{self._name} is required")
            if not self.allow_null:
                raise ValueError(f"{self._name} cannot be null")
        return value


class CharField(Field):
    def __init__(self, max_length: int = None, min_length: int = None, **kwargs):
        super().__init__(**kwargs)
        self.max_length = max_length
        self.min_length = min_length

    def to_internal_value(self, data: Any) -> str:
        if data is None:
            return None
        value = str(data)
        if self.max_length and len(value) > self.max_length:
            raise ValueError(f"{self._name} exceeds max length of {self.max_length}")
        if self.min_length and len(value) < self.min_length:
            raise ValueError(f"{self._name} below min length of {self.min_length}")
        return value

    def to_representation(self, value: Any) -> str:
        return str(value) if value is not None else None


class IntegerField(Field):
    def __init__(self, min_value: int = None, max_value: int = None, **kwargs):
        super().__init__(**kwargs)
        self.min_value = min_value
        self.max_value = max_value

    def to_internal_value(self, data: Any) -> int:
        if data is None:
            return None
        try:
            value = int(data)
        except (ValueError, TypeError):
            raise ValueError(f"{self._name} must be an integer")

        if self.min_value is not None and value < self.min_value:
            raise ValueError(f"{self._name} must be >= {self.min_value}")
        if self.max_value is not None and value > self.max_value:
            raise ValueError(f"{self._name} must be <= {self.max_value}")

        return value

    def to_representation(self, value: Any) -> int:
        return int(value) if value is not None else None


class BooleanField(Field):
    def to_internal_value(self, data: Any) -> bool:
        if data is None:
            return None
        if isinstance(data, bool):
            return data
        if isinstance(data, str):
            return data.lower() in ("true", "1", "yes")
        return bool(data)

    def to_representation(self, value: Any) -> bool:
        return bool(value) if value is not None else None


class DateTimeField(Field):
    def __init__(self, format: str = "%Y-%m-%d %H:%M:%S", **kwargs):
        super().__init__(**kwargs)
        self.format = format

    def to_internal_value(self, data: Any) -> str:
        if data is None:
            return None
        from datetime import datetime
        if isinstance(data, datetime):
            return data.strftime(self.format)
        return str(data)

    def to_representation(self, value: Any) -> str:
        if value is None:
            return None
        return str(value)


class ListField(Field):
    def __init__(self, child: Field = None, **kwargs):
        super().__init__(**kwargs)
        self.child = child

    def to_internal_value(self, data: Any) -> List:
        if data is None:
            return None
        if not isinstance(data, (list, tuple)):
            raise ValueError(f"{self._name} must be a list")

        if self.child:
            return [self.child.to_internal_value(item) for item in data]
        return list(data)

    def to_representation(self, value: Any) -> List:
        if value is None:
            return None
        if self.child:
            return [self.child.to_representation(item) for item in value]
        return list(value)


class SerializerMeta(type):
    def __new__(mcs, name, bases, namespace):
        cls = super().__new__(mcs, name, bases, namespace)
        cls._fields = {}
        for name, value in namespace.items():
            if isinstance(value, Field):
                value.bind(name)
                cls._fields[name] = value
        return cls


class BaseSerializer(metaclass=SerializerMeta):
    def __init__(self, instance=None, data=None, many: bool = False):
        self.instance = instance
        self._data = data
        self.many = many
        self._errors: Dict[str, List[str]] = {}

    def to_internal_value(self, data: Dict) -> Dict:
        result = {}
        for name, field in self._fields.items():
            value = data.get(name, field.default)
            try:
                validated = field.validate(value)
                result[name] = field.to_internal_value(validated)
            except ValueError as e:
                if name not in self._errors:
                    self._errors[name] = []
                self._errors[name].append(str(e))
        return result

    def to_representation(self, instance) -> Dict:
        result = {}
        for name, field in self._fields.items():
            value = getattr(instance, name, None)
            result[name] = field.to_representation(value)
        return result

    def is_valid(self) -> bool:
        self._errors = {}
        if self._data is None:
            return False
        self._validated_data = self.to_internal_value(self._data)
        return len(self._errors) == 0

    @property
    def errors(self) -> Dict:
        return self._errors

    @property
    def validated_data(self) -> Dict:
        return getattr(self, "_validated_data", {})

    def data(self) -> Any:
        if self.many:
            return [self.to_representation(item) for item in self.instance]
        return self.to_representation(self.instance)

    def save(self) -> Any:
        raise NotImplementedError("Subclasses must implement save()")


class ModelSerializer(BaseSerializer):
    model_class = None

    def create(self, validated_data: Dict) -> Any:
        if self.model_class is None:
            raise NotImplementedError("model_class must be defined")
        instance = self.model_class(**validated_data)
        return instance

    def update(self, instance: Any, validated_data: Dict) -> Any:
        for name, value in validated_data.items():
            setattr(instance, name, value)
        return instance

    def save(self) -> Any:
        if self.instance is None:
            self.instance = self.create(self.validated_data)
        else:
            self.instance = self.update(self.instance, self.validated_data)
        return self.instance

40.3.2 视图集与路由

python
from typing import Dict, List, Any, Optional, Callable, Type
from dataclasses import dataclass
from abc import ABC, abstractmethod
from functools import wraps


@dataclass
class Route:
    path: str
    method: str
    handler: Callable
    name: str = ""


class ViewSet(ABC):
    serializer_class: Type[BaseSerializer] = None
    permission_classes: List[Callable] = []

    def __init__(self):
        self._routes: List[Route] = []

    def list(self, request: Request) -> Dict:
        raise NotImplementedError()

    def create(self, request: Request) -> Dict:
        raise NotImplementedError()

    def retrieve(self, request: Request, pk: int) -> Dict:
        raise NotImplementedError()

    def update(self, request: Request, pk: int) -> Dict:
        raise NotImplementedError()

    def partial_update(self, request: Request, pk: int) -> Dict:
        raise NotImplementedError()

    def destroy(self, request: Request, pk: int) -> Dict:
        raise NotImplementedError()


class ModelViewSet(ViewSet):
    model_class = None
    lookup_field = "id"

    def get_queryset(self):
        return self.model_class.query.all() if self.model_class else []

    def get_object(self, pk: int):
        return self.model_class.query.get(pk) if self.model_class else None

    def get_serializer(self, instance=None, data=None, many: bool = False):
        return self.serializer_class(instance=instance, data=data, many=many)

    def list(self, request: Request) -> Dict:
        queryset = self.get_queryset()
        serializer = self.get_serializer(queryset, many=True)
        return ResponseBuilder.success(serializer.data())

    def create(self, request: Request) -> Dict:
        serializer = self.get_serializer(data=request.body)
        if serializer.is_valid():
            instance = serializer.save()
            return ResponseBuilder.created(self.get_serializer(instance).data())
        return ResponseBuilder.validation_error(serializer.errors)

    def retrieve(self, request: Request, pk: int) -> Dict:
        instance = self.get_object(pk)
        if instance is None:
            return ResponseBuilder.error("NOT_FOUND", "Resource not found", status=404)
        serializer = self.get_serializer(instance)
        return ResponseBuilder.success(serializer.data())

    def update(self, request: Request, pk: int) -> Dict:
        instance = self.get_object(pk)
        if instance is None:
            return ResponseBuilder.error("NOT_FOUND", "Resource not found", status=404)

        serializer = self.get_serializer(instance, data=request.body)
        if serializer.is_valid():
            serializer.save()
            return ResponseBuilder.success(serializer.data())
        return ResponseBuilder.validation_error(serializer.errors)

    def destroy(self, request: Request, pk: int) -> Dict:
        instance = self.get_object(pk)
        if instance is None:
            return ResponseBuilder.error("NOT_FOUND", "Resource not found", status=404)

        instance.delete()
        return ResponseBuilder.no_content()


class Router:
    def __init__(self, prefix: str = ""):
        self.prefix = prefix
        self._routes: List[Route] = []

    def register(self, viewset: Type[ViewSet], basename: str) -> None:
        viewset_instance = viewset()

        self._routes.append(Route(
            path=f"{self.prefix}/{basename}",
            method="GET",
            handler=viewset_instance.list,
            name=f"{basename}-list"
        ))

        self._routes.append(Route(
            path=f"{self.prefix}/{basename}",
            method="POST",
            handler=viewset_instance.create,
            name=f"{basename}-create"
        ))

        self._routes.append(Route(
            path=f"{self.prefix}/{basename}/{{pk}}",
            method="GET",
            handler=viewset_instance.retrieve,
            name=f"{basename}-detail"
        ))

        self._routes.append(Route(
            path=f"{self.prefix}/{basename}/{{pk}}",
            method="PUT",
            handler=viewset_instance.update,
            name=f"{basename}-update"
        ))

        self._routes.append(Route(
            path=f"{self.prefix}/{basename}/{{pk}}",
            method="DELETE",
            handler=viewset_instance.destroy,
            name=f"{basename}-destroy"
        ))

    def get_routes(self) -> List[Route]:
        return self._routes

    def get_url_patterns(self) -> List[Dict]:
        return [
            {"path": route.path, "method": route.method, "name": route.name}
            for route in self._routes
        ]

40.4 API认证与授权

40.4.1 JWT认证

python
import hmac
import hashlib
import base64
import json
import time
from typing import Dict, Optional, Any
from dataclasses import dataclass


@dataclass
class JWTConfig:
    secret_key: str
    algorithm: str = "HS256"
    access_token_expire_minutes: int = 30
    refresh_token_expire_days: int = 7
    issuer: str = "myapp"
    audience: str = "myapp-users"


class JWTHandler:
    def __init__(self, config: JWTConfig):
        self.config = config

    def _base64url_encode(self, data: bytes) -> str:
        return base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8")

    def _base64url_decode(self, data: str) -> bytes:
        padding = 4 - len(data) % 4
        if padding != 4:
            data += "=" * padding
        return base64.urlsafe_b64decode(data)

    def _create_signature(self, header: str, payload: str) -> str:
        message = f"{header}.{payload}".encode("utf-8")
        signature = hmac.new(
            self.config.secret_key.encode("utf-8"),
            message,
            hashlib.sha256
        ).digest()
        return self._base64url_encode(signature)

    def encode(self, payload: Dict) -> str:
        header = {
            "alg": self.config.algorithm,
            "typ": "JWT"
        }

        header_encoded = self._base64url_encode(
            json.dumps(header).encode("utf-8")
        )

        now = int(time.time())
        payload["iat"] = now
        payload["iss"] = self.config.issuer
        payload["aud"] = self.config.audience

        if "exp" not in payload:
            payload["exp"] = now + self.config.access_token_expire_minutes * 60

        payload_encoded = self._base64url_encode(
            json.dumps(payload).encode("utf-8")
        )

        signature = self._create_signature(header_encoded, payload_encoded)

        return f"{header_encoded}.{payload_encoded}.{signature}"

    def decode(self, token: str) -> Optional[Dict]:
        try:
            parts = token.split(".")
            if len(parts) != 3:
                return None

            header_encoded, payload_encoded, signature = parts

            expected_signature = self._create_signature(header_encoded, payload_encoded)
            if not hmac.compare_digest(signature, expected_signature):
                return None

            payload = json.loads(
                self._base64url_decode(payload_encoded).decode("utf-8")
            )

            if payload.get("exp", 0) < time.time():
                return None

            if payload.get("iss") != self.config.issuer:
                return None

            if payload.get("aud") != self.config.audience:
                return None

            return payload

        except Exception:
            return None

    def create_access_token(self, user_id: int, extra_data: Dict = None) -> str:
        payload = {"sub": str(user_id), "type": "access"}
        if extra_data:
            payload.update(extra_data)
        return self.encode(payload)

    def create_refresh_token(self, user_id: int) -> str:
        payload = {
            "sub": str(user_id),
            "type": "refresh",
            "exp": int(time.time()) + self.config.refresh_token_expire_days * 86400
        }
        return self.encode(payload)

    def verify_token(self, token: str) -> Optional[Dict]:
        payload = self.decode(token)
        if payload is None:
            return None
        if payload.get("type") != "access":
            return None
        return payload


class AuthenticationMiddleware:
    def __init__(self, jwt_handler: JWTHandler):
        self.jwt_handler = jwt_handler

    def authenticate(self, request: Request) -> Optional[Dict]:
        auth_header = request.headers.get("Authorization", "")
        if not auth_header.startswith("Bearer "):
            return None

        token = auth_header[7:]
        return self.jwt_handler.verify_token(token)

    def __call__(self, request: Request, next_handler: Callable) -> Dict:
        user_data = self.authenticate(request)
        if user_data:
            request.user_id = user_data.get("sub")
            request.user_data = user_data

        return next_handler(request)

40.4.2 权限控制

python
from typing import List, Callable, Optional
from dataclasses import dataclass
from functools import wraps


@dataclass
class Permission:
    name: str
    description: str


class BasePermission:
    def has_permission(self, request: Request, view: Any) -> bool:
        return True

    def has_object_permission(self, request: Request, view: Any, obj: Any) -> bool:
        return True


class IsAuthenticated(BasePermission):
    def has_permission(self, request: Request, view: Any) -> bool:
        return hasattr(request, "user_id") and request.user_id is not None


class IsAdmin(BasePermission):
    def has_permission(self, request: Request, view: Any) -> bool:
        user_data = getattr(request, "user_data", {})
        return user_data.get("role") == "admin"


class IsOwner(BasePermission):
    def has_object_permission(self, request: Request, view: Any, obj: Any) -> bool:
        user_id = getattr(request, "user_id", None)
        owner_id = getattr(obj, "user_id", None)
        return user_id is not None and str(user_id) == str(owner_id)


class AllowAny(BasePermission):
    def has_permission(self, request: Request, view: Any) -> bool:
        return True


class PermissionChecker:
    def __init__(self, permissions: List[BasePermission]):
        self.permissions = permissions

    def check_permissions(self, request: Request, view: Any) -> bool:
        for permission in self.permissions:
            if not permission.has_permission(request, view):
                return False
        return True

    def check_object_permissions(self, request: Request, view: Any, obj: Any) -> bool:
        for permission in self.permissions:
            if not permission.has_object_permission(request, view, obj):
                return False
        return True


def permission_required(*permission_classes: BasePermission):
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(request: Request, *args, **kwargs):
            checker = PermissionChecker(list(permission_classes))
            if not checker.check_permissions(request, func):
                return ResponseBuilder.error(
                    "FORBIDDEN",
                    "You do not have permission to perform this action",
                    status=403
                )
            return func(request, *args, **kwargs)
        return wrapper
    return decorator


class RoleBasedAccessControl:
    def __init__(self):
        self._roles: Dict[str, List[str]] = {}
        self._permissions: Dict[str, Permission] = {}

    def add_permission(self, name: str, description: str = "") -> None:
        self._permissions[name] = Permission(name=name, description=description)

    def add_role(self, role_name: str, permissions: List[str]) -> None:
        self._roles[role_name] = permissions

    def get_role_permissions(self, role_name: str) -> List[str]:
        return self._roles.get(role_name, [])

    def has_permission(self, role_name: str, permission_name: str) -> bool:
        role_permissions = self.get_role_permissions(role_name)
        return permission_name in role_permissions

    def check_access(self, user_role: str, required_permission: str) -> bool:
        return self.has_permission(user_role, required_permission)

40.5 GraphQL API

40.5.1 Schema设计

python
from typing import Dict, List, Any, Optional, Callable, TypeVar, Generic
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from enum import Enum


class GraphQLType(Enum):
    STRING = "String"
    INT = "Int"
    FLOAT = "Float"
    BOOLEAN = "Boolean"
    ID = "ID"


@dataclass
class GraphQLField:
    name: str
    type: Any
    required: bool = False
    description: str = ""
    args: Dict[str, Any] = field(default_factory=dict)
    resolver: Callable = None


@dataclass
class GraphQLObjectType:
    name: str
    fields: Dict[str, GraphQLField]
    description: str = ""

    def to_schema(self) -> str:
        lines = [f"type {self.name} {{"]
        for name, field in self.fields.items():
            type_str = field.type if isinstance(field.type, str) else field.type.value
            if field.required:
                type_str = f"{type_str}!"
            lines.append(f"  {name}: {type_str}")
        lines.append("}")
        return "\n".join(lines)


@dataclass
class GraphQLInputObjectType:
    name: str
    fields: Dict[str, GraphQLField]
    description: str = ""

    def to_schema(self) -> str:
        lines = [f"input {self.name} {{"]
        for name, field in self.fields.items():
            type_str = field.type if isinstance(field.type, str) else field.type.value
            if field.required:
                type_str = f"{type_str}!"
            lines.append(f"  {name}: {type_str}")
        lines.append("}")
        return "\n".join(lines)


class GraphQLSchema:
    def __init__(self):
        self._types: Dict[str, GraphQLObjectType] = {}
        self._inputs: Dict[str, GraphQLInputObjectType] = {}
        self._queries: Dict[str, GraphQLField] = {}
        self._mutations: Dict[str, GraphQLField] = {}

    def add_type(self, graphql_type: GraphQLObjectType) -> None:
        self._types[graphql_type.name] = graphql_type

    def add_input(self, input_type: GraphQLInputObjectType) -> None:
        self._inputs[input_type.name] = input_type

    def add_query(self, field: GraphQLField) -> None:
        self._queries[field.name] = field

    def add_mutation(self, field: GraphQLField) -> None:
        self._mutations[field.name] = field

    def to_schema_string(self) -> str:
        parts = []

        for graphql_type in self._types.values():
            parts.append(graphql_type.to_schema())

        for input_type in self._inputs.values():
            parts.append(input_type.to_schema())

        if self._queries:
            lines = ["type Query {"]
            for name, field in self._queries.items():
                args_str = ""
                if field.args:
                    args_list = [f"{n}: {t}" for n, t in field.args.items()]
                    args_str = f"({', '.join(args_list)})"
                type_str = field.type if isinstance(field.type, str) else field.type.value
                if field.required:
                    type_str = f"{type_str}!"
                lines.append(f"  {name}{args_str}: {type_str}")
            lines.append("}")
            parts.append("\n".join(lines))

        if self._mutations:
            lines = ["type Mutation {"]
            for name, field in self._mutations.items():
                args_str = ""
                if field.args:
                    args_list = [f"{n}: {t}" for n, t in field.args.items()]
                    args_str = f"({', '.join(args_list)})"
                type_str = field.type if isinstance(field.type, str) else field.type.value
                if field.required:
                    type_str = f"{type_str}!"
                lines.append(f"  {name}{args_str}: {type_str}")
            lines.append("}")
            parts.append("\n".join(lines))

        return "\n\n".join(parts)


class GraphQLResolver:
    def __init__(self, schema: GraphQLSchema):
        self.schema = schema
        self._resolvers: Dict[str, Callable] = {}

    def register_resolver(self, type_name: str, field_name: str, resolver: Callable) -> None:
        key = f"{type_name}.{field_name}"
        self._resolvers[key] = resolver

    def resolve(self, type_name: str, field_name: str, obj: Any, args: Dict, context: Dict) -> Any:
        key = f"{type_name}.{field_name}"
        resolver = self._resolvers.get(key)
        if resolver:
            return resolver(obj, args, context)
        if obj and hasattr(obj, field_name):
            return getattr(obj, field_name)
        return None


def create_user_schema() -> GraphQLSchema:
    schema = GraphQLSchema()

    user_type = GraphQLObjectType(
        name="User",
        fields={
            "id": GraphQLField(name="id", type=GraphQLType.ID, required=True),
            "username": GraphQLField(name="username", type=GraphQLType.STRING, required=True),
            "email": GraphQLField(name="email", type=GraphQLType.STRING, required=True),
            "role": GraphQLField(name="role", type=GraphQLType.STRING),
            "createdAt": GraphQLField(name="createdAt", type=GraphQLType.STRING)
        }
    )
    schema.add_type(user_type)

    user_input = GraphQLInputObjectType(
        name="UserInput",
        fields={
            "username": GraphQLField(name="username", type=GraphQLType.STRING, required=True),
            "email": GraphQLField(name="email", type=GraphQLType.STRING, required=True),
            "password": GraphQLField(name="password", type=GraphQLType.STRING, required=True)
        }
    )
    schema.add_input(user_input)

    schema.add_query(GraphQLField(
        name="user",
        type="User",
        args={"id": "ID!"},
        resolver=lambda obj, args, ctx: {"id": args["id"], "username": "test", "email": "test@example.com"}
    ))

    schema.add_query(GraphQLField(
        name="users",
        type="[User]",
        resolver=lambda obj, args, ctx: []
    ))

    schema.add_mutation(GraphQLField(
        name="createUser",
        type="User",
        args={"input": "UserInput!"},
        resolver=lambda obj, args, ctx: {"id": "1", **args["input"]}
    ))

    return schema

40.6 知识图谱

40.6.1 API架构体系

┌─────────────────────────────────────────────────────────────────────┐
│                      API开发技术架构                                  │
├─────────────────────────────────────────────────────────────────────┤
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      客户端层 (Client)                        │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ Web应用  │ │ 移动应用  │ │ 桌面应用  │ │ IoT设备  │       │   │
│  │  └────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘       │   │
│  └───────┼────────────┼────────────┼────────────┼──────────────┘   │
│          └────────────┴────────────┴────────────┘                   │
│                                │ HTTP/HTTPS                         │
│                                ▼                                    │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      API网关层 (Gateway)                      │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 路由转发  │ │ 认证授权  │ │ 限流熔断  │ │ 日志监控  │       │   │
│  │  │ Kong     │ │ JWT/OAuth│ │ Rate Limit│ │ ELK      │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│          ┌─────────────────────┼─────────────────────┐             │
│          ▼                     ▼                     ▼             │
│  ┌──────────────┐      ┌──────────────┐      ┌──────────────┐     │
│  │  REST API    │      │  GraphQL API │      │  gRPC API    │     │
│  │  ┌────────┐  │      │  ┌────────┐  │      │  ┌────────┐  │     │
│  │  │FastAPI │  │      │  │Strawber│  │      │  │grpcio  │  │     │
│  │  │Django  │  │      │  │Ariadne │  │      │  │protobuf│  │     │
│  │  │Flask   │  │      │  │Graphene│  │      │  │        │  │     │
│  │  └────────┘  │      │  └────────┘  │      │  └────────┘  │     │
│  └───────┬──────┘      └───────┬──────┘      └───────┬──────┘     │
│          │                     │                     │             │
│          └─────────────────────┼─────────────────────┘             │
│                                ▼                                    │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      服务层 (Services)                        │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ 用户服务  │ │ 订单服务  │ │ 支付服务  │ │ 通知服务  │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│                                ▼                                    │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      数据层 (Data)                            │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │PostgreSQL│ │ MongoDB  │ │  Redis   │ │Elasticsrch│       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────┘

40.6.2 RESTful API请求流程

┌─────────────────────────────────────────────────────────────────────┐
│                      RESTful API请求处理流程                         │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│   ┌──────────┐                                                      │
│   │ 客户端   │                                                      │
│   │ Client  │                                                      │
│   └────┬─────┘                                                      │
│        │ 1. HTTP Request                                            │
│        ▼                                                            │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    中间件层 (Middleware)                   │     │
│   │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐    │     │
│   │  │ CORS     │ │ 认证     │ │ 限流     │ │ 日志     │    │     │
│   │  └──────────┘ └──────────┘ └──────────┘ └──────────┘    │     │
│   └──────────────────────────────────────────────────────────┘     │
│        │                                                            │
│        │ 2. 路由匹配                                                │
│        ▼                                                            │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    路由层 (Router)                         │     │
│   │  GET    /api/v1/users     → list_users()                 │     │
│   │  POST   /api/v1/users     → create_user()                │     │
│   │  GET    /api/v1/users/{id}→ get_user()                   │     │
│   │  PUT    /api/v1/users/{id}→ update_user()                │     │
│   │  DELETE /api/v1/users/{id}→ delete_user()                │     │
│   └──────────────────────────────────────────────────────────┘     │
│        │                                                            │
│        │ 3. 请求验证                                                │
│        ▼                                                            │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    验证层 (Validation)                     │     │
│   │  ┌──────────┐ ┌──────────┐ ┌──────────┐                 │     │
│   │  │ 参数验证  │ │ 类型检查  │ │ 业务规则  │                 │     │
│   │  │ Pydantic │ │ JSON Schm│ │ 自定义   │                 │     │
│   │  └──────────┘ └──────────┘ └──────────┘                 │     │
│   └──────────────────────────────────────────────────────────┘     │
│        │                                                            │
│        │ 4. 业务处理                                                │
│        ▼                                                            │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    业务层 (Business)                       │     │
│   │  ┌──────────┐ ┌──────────┐ ┌──────────┐                 │     │
│   │  │ 权限检查  │ │ 业务逻辑  │ │ 数据操作  │                 │     │
│   │  └──────────┘ └──────────┘ └──────────┘                 │     │
│   └──────────────────────────────────────────────────────────┘     │
│        │                                                            │
│        │ 5. 响应序列化                                              │
│        ▼                                                            │
│   ┌──────────────────────────────────────────────────────────┐     │
│   │                    响应层 (Response)                       │     │
│   │  ┌──────────┐ ┌──────────┐ ┌──────────┐                 │     │
│   │  │ 序列化   │ │ 格式化   │ │ 状态码   │                 │     │
│   │  │ JSON     │ │ 分页     │ │ HTTP     │                 │     │
│   │  └──────────┘ └──────────┘ └──────────┘                 │     │
│   └──────────────────────────────────────────────────────────┘     │
│        │                                                            │
│        │ 6. HTTP Response                                           │
│        ▼                                                            │
│   ┌──────────┐                                                      │
│   │ 客户端   │                                                      │
│   └──────────┘                                                      │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

40.7 技术选型指南

40.7.1 API框架选型

框架适用场景性能异步支持学习曲线推荐指数
FastAPI高性能API、微服务极高✅ 原生★★★★★
Django REST全功能Web应用★★★★★
Flask小型API、原型开发★★★★☆
Starlette轻量级异步API极高✅ 原生★★★★☆
Falcon高性能REST API极高★★★☆☆

40.7.2 API风格选型

风格适用场景优点缺点
RESTCRUD操作、资源导向简单直观、缓存友好多次请求、过度获取
GraphQL复杂数据需求、灵活查询单次请求、精确获取学习成本、缓存复杂
gRPC微服务通信、高性能场景高效、强类型、双向流浏览器不友好
WebSocket实时通信、推送场景双向通信、低延迟连接管理复杂

40.7.3 认证方案选型

方案适用场景安全性复杂度无状态
JWTSPA、移动应用、微服务
OAuth2第三方授权、社交登录
API Key服务间调用、简单认证
Session传统Web应用

40.8 常见问题与解决方案

40.8.1 API版本控制

python
from fastapi import FastAPI, APIRouter
from typing import Dict, Any, List

class APIVersionManager:
    """API版本管理器"""
    
    def __init__(self, app: FastAPI):
        self.app = app
        self._versions: Dict[str, APIRouter] = {}
    
    def register_version(self, version: str, router: APIRouter):
        """注册版本路由"""
        self._versions[version] = router
        self.app.include_router(router, prefix=f"/api/{version}")
    
    def deprecate_version(self, version: str, sunset_date: str):
        """标记版本废弃"""
        @self.app.middleware("http")
        async def add_deprecation_header(request, call_next):
            response = await call_next(request)
            if f"/api/{version}/" in request.url.path:
                response.headers["X-API-Deprecated"] = "true"
                response.headers["X-API-Sunset"] = sunset_date
                response.headers["Link"] = f'</api/v2{request.url.path[len(f"/api/{version}"):]>}; rel="successor-version"'
            return response


class VersionedAPI:
    """版本化API示例"""
    
    @staticmethod
    def create_v1_user(data: Dict) -> Dict:
        """V1版本用户创建(旧格式)"""
        return {
            "id": 1,
            "name": data.get("name"),
            "email": data.get("email")
        }
    
    @staticmethod
    def create_v2_user(data: Dict) -> Dict:
        """V2版本用户创建(新格式)"""
        return {
            "id": 1,
            "profile": {
                "name": data.get("name"),
                "email": data.get("email"),
                "phone": data.get("phone")
            },
            "settings": data.get("settings", {}),
            "created_at": "2026-01-01T00:00:00Z"
        }


class BackwardCompatibility:
    """向后兼容处理"""
    
    def __init__(self):
        self._field_mappings = {
            "v1_to_v2": {
                "name": "profile.name",
                "email": "profile.email"
            }
        }
    
    def transform_response(self, data: Dict, from_version: str, to_version: str) -> Dict:
        """转换响应格式"""
        if from_version == "v1" and to_version == "v2":
            return self._v1_to_v2(data)
        elif from_version == "v2" and to_version == "v1":
            return self._v2_to_v1(data)
        return data
    
    def _v1_to_v2(self, data: Dict) -> Dict:
        return {
            "profile": {
                "name": data.get("name"),
                "email": data.get("email")
            }
        }
    
    def _v2_to_v1(self, data: Dict) -> Dict:
        profile = data.get("profile", {})
        return {
            "name": profile.get("name"),
            "email": profile.get("email")
        }

40.8.2 API限流实现

python
import time
from typing import Dict, Optional
from dataclasses import dataclass
from collections import defaultdict

@dataclass
class RateLimitConfig:
    requests: int
    window: int  # seconds
    key_prefix: str = "rate_limit"


class TokenBucketLimiter:
    """令牌桶限流器"""
    
    def __init__(self, rate: int, capacity: int):
        self.rate = rate  # 每秒添加的令牌数
        self.capacity = capacity  # 桶容量
        self._tokens: Dict[str, Dict] = {}
    
    def _get_bucket(self, key: str) -> Dict:
        if key not in self._tokens:
            self._tokens[key] = {
                "tokens": self.capacity,
                "last_update": time.time()
            }
        return self._tokens[key]
    
    def _refill(self, bucket: Dict) -> None:
        now = time.time()
        elapsed = now - bucket["last_update"]
        new_tokens = elapsed * self.rate
        bucket["tokens"] = min(bucket["tokens"] + new_tokens, self.capacity)
        bucket["last_update"] = now
    
    def is_allowed(self, key: str) -> bool:
        bucket = self._get_bucket(key)
        self._refill(bucket)
        
        if bucket["tokens"] >= 1:
            bucket["tokens"] -= 1
            return True
        return False
    
    def get_remaining(self, key: str) -> int:
        bucket = self._get_bucket(key)
        self._refill(bucket)
        return int(bucket["tokens"])


class SlidingWindowLimiter:
    """滑动窗口限流器"""
    
    def __init__(self, max_requests: int, window_seconds: int):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self._requests: Dict[str, List[float]] = defaultdict(list)
    
    def is_allowed(self, key: str) -> bool:
        now = time.time()
        window_start = now - self.window_seconds
        
        self._requests[key] = [
            ts for ts in self._requests[key] if ts > window_start
        ]
        
        if len(self._requests[key]) < self.max_requests:
            self._requests[key].append(now)
            return True
        return False
    
    def get_remaining(self, key: str) -> int:
        now = time.time()
        window_start = now - self.window_seconds
        
        self._requests[key] = [
            ts for ts in self._requests[key] if ts > window_start
        ]
        
        return max(0, self.max_requests - len(self._requests[key]))
    
    def get_reset_time(self, key: str) -> float:
        if not self._requests[key]:
            return time.time()
        return self._requests[key][0] + self.window_seconds


class RateLimitMiddleware:
    """限流中间件"""
    
    def __init__(self, limiter, key_func=None):
        self.limiter = limiter
        self.key_func = key_func or (lambda request: request.client.host)
    
    async def __call__(self, request, call_next):
        key = self.key_func(request)
        
        if not self.limiter.is_allowed(key):
            return {
                "status_code": 429,
                "body": {
                    "error": "Too Many Requests",
                    "message": "Rate limit exceeded",
                    "retry_after": self.limiter.get_reset_time(key)
                }
            }
        
        response = await call_next(request)
        remaining = self.limiter.get_remaining(key)
        response.headers["X-RateLimit-Remaining"] = str(remaining)
        return response

40.8.3 API错误处理

python
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from enum import Enum
import traceback

class ErrorCode(Enum):
    VALIDATION_ERROR = "VALIDATION_ERROR"
    AUTHENTICATION_ERROR = "AUTHENTICATION_ERROR"
    AUTHORIZATION_ERROR = "AUTHORIZATION_ERROR"
    NOT_FOUND = "NOT_FOUND"
    CONFLICT = "CONFLICT"
    RATE_LIMIT = "RATE_LIMIT"
    INTERNAL_ERROR = "INTERNAL_ERROR"


@dataclass
class APIError:
    code: ErrorCode
    message: str
    details: Optional[Dict] = None
    field: Optional[str] = None


class APIErrorHandler:
    """API错误处理器"""
    
    @staticmethod
    def handle_validation_error(errors: List[Dict]) -> Dict:
        return {
            "success": False,
            "error": {
                "code": ErrorCode.VALIDATION_ERROR.value,
                "message": "请求参数验证失败",
                "details": {"errors": errors}
            }
        }
    
    @staticmethod
    def handle_authentication_error(message: str = "认证失败") -> Dict:
        return {
            "success": False,
            "error": {
                "code": ErrorCode.AUTHENTICATION_ERROR.value,
                "message": message
            }
        }
    
    @staticmethod
    def handle_authorization_error(message: str = "权限不足") -> Dict:
        return {
            "success": False,
            "error": {
                "code": ErrorCode.AUTHORIZATION_ERROR.value,
                "message": message
            }
        }
    
    @staticmethod
    def handle_not_found(resource: str = "资源") -> Dict:
        return {
            "success": False,
            "error": {
                "code": ErrorCode.NOT_FOUND.value,
                "message": f"{resource}不存在"
            }
        }
    
    @staticmethod
    def handle_conflict(message: str) -> Dict:
        return {
            "success": False,
            "error": {
                "code": ErrorCode.CONFLICT.value,
                "message": message
            }
        }
    
    @staticmethod
    def handle_internal_error(debug: bool = False) -> Dict:
        error = {
            "code": ErrorCode.INTERNAL_ERROR.value,
            "message": "服务器内部错误"
        }
        if debug:
            error["traceback"] = traceback.format_exc()
        
        return {
            "success": False,
            "error": error
        }


class ExceptionHandler:
    """统一异常处理"""
    
    def __init__(self, debug: bool = False):
        self.debug = debug
        self._handlers = {}
    
    def register(self, exception_class, handler):
        self._handlers[exception_class] = handler
    
    def handle(self, exception: Exception) -> Dict:
        for exc_class, handler in self._handlers.items():
            if isinstance(exception, exc_class):
                return handler(exception)
        
        return APIErrorHandler.handle_internal_error(self.debug)


def setup_exception_handlers(app):
    """配置异常处理器"""
    
    @app.exception_handler(ValueError)
    async def value_error_handler(request, exc):
        return APIErrorHandler.handle_validation_error([{"message": str(exc)}])
    
    @app.exception_handler(PermissionError)
    async def permission_error_handler(request, exc):
        return APIErrorHandler.handle_authorization_error(str(exc))
    
    @app.exception_handler(FileNotFoundError)
    async def not_found_handler(request, exc):
        return APIErrorHandler.handle_not_found()

40.8.4 API文档生成

python
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field

@dataclass
class OpenAPIEndpoint:
    path: str
    method: str
    summary: str = ""
    description: str = ""
    tags: List[str] = field(default_factory=list)
    parameters: List[Dict] = field(default_factory=list)
    request_body: Optional[Dict] = None
    responses: Dict[str, Dict] = field(default_factory=dict)


class OpenAPIGenerator:
    """OpenAPI文档生成器"""
    
    def __init__(self, title: str, version: str, description: str = ""):
        self.title = title
        self.version = version
        self.description = description
        self._endpoints: List[OpenAPIEndpoint] = []
        self._schemas: Dict[str, Dict] = {}
    
    def add_endpoint(self, endpoint: OpenAPIEndpoint):
        self._endpoints.append(endpoint)
    
    def add_schema(self, name: str, schema: Dict):
        self._schemas[name] = schema
    
    def generate(self) -> Dict:
        paths = {}
        
        for endpoint in self._endpoints:
            if endpoint.path not in paths:
                paths[endpoint.path] = {}
            
            path_item = {
                "summary": endpoint.summary,
                "description": endpoint.description,
                "tags": endpoint.tags,
                "parameters": endpoint.parameters,
                "responses": endpoint.responses
            }
            
            if endpoint.request_body:
                path_item["requestBody"] = endpoint.request_body
            
            paths[endpoint.path][endpoint.method.lower()] = path_item
        
        return {
            "openapi": "3.0.0",
            "info": {
                "title": self.title,
                "version": self.version,
                "description": self.description
            },
            "paths": paths,
            "components": {
                "schemas": self._schemas
            }
        }
    
    def generate_markdown(self) -> str:
        """生成Markdown格式文档"""
        lines = [
            f"# {self.title}",
            f"\n版本: {self.version}",
            f"\n{self.description}\n",
            "---\n",
            "## API端点\n"
        ]
        
        for endpoint in self._endpoints:
            lines.append(f"### {endpoint.method.upper()} {endpoint.path}")
            lines.append(f"\n{endpoint.summary}\n")
            
            if endpoint.parameters:
                lines.append("**参数:**\n")
                lines.append("| 名称 | 位置 | 类型 | 必需 | 描述 |")
                lines.append("|------|------|------|------|------|")
                for param in endpoint.parameters:
                    lines.append(
                        f"| {param.get('name')} | {param.get('in')} | "
                        f"{param.get('schema', {}).get('type', 'any')} | "
                        f"{'是' if param.get('required') else '否'} | "
                        f"{param.get('description', '')} |"
                    )
                lines.append("")
            
            if endpoint.responses:
                lines.append("**响应:**\n")
                for code, response in endpoint.responses.items():
                    lines.append(f"- `{code}`: {response.get('description', '')}")
                lines.append("")
        
        return "\n".join(lines)

40.9 本章小结

本章详细介绍了Python API开发进阶的核心概念和实践:

  1. RESTful设计:资源设计、URL规范、状态码、版本控制
  2. FastAPI:依赖注入、请求验证、响应模型
  3. Django REST Framework:序列化器、视图集、路由器
  4. 认证授权:JWT认证、OAuth2、权限控制
  5. GraphQL:Schema设计、Query、Mutation
  6. API文档:OpenAPI规范、Swagger UI

练习题

  1. 实现一个完整的RESTful API,包含CRUD操作和认证
  2. 开发一个GraphQL API,支持复杂查询和分页
  3. 实现一个API限流系统,支持多种限流策略
  4. 开发一个API文档自动生成工具
  5. 实现一个API网关,支持路由、认证和限流

扩展阅读

Python技术丛书 - 江苏省宿城中等专业学校