Skip to content

第28章 数据库编程

学习目标

完成本章学习后,你将能够:

  1. 理解数据库基础概念:关系型数据库原理、SQL语言、数据库设计范式
  2. 掌握SQLite数据库操作:创建数据库、执行SQL语句、事务管理
  3. 使用SQLAlchemy ORM:模型定义、查询构建、关系映射
  4. 实现数据库连接池:连接管理、性能优化、资源释放
  5. 处理数据库迁移:Alembic迁移工具、版本控制、回滚策略
  6. 掌握高级查询技术:复杂查询、聚合函数、子查询、连接查询
  7. 实现数据库安全实践:参数化查询、SQL注入防护、权限管理
  8. 构建数据库应用架构:Repository模式、Unit of Work、数据访问层设计

28.1 数据库基础概念

28.1.1 关系型数据库概述

关系型数据库(Relational Database)是基于关系模型的数据库,数据以表格形式存储,表格之间通过关系连接。

┌─────────────────────────────────────────────────────────────────────┐
│                      关系型数据库结构                                 │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  ┌─────────────────┐         ┌─────────────────┐                   │
│  │    users 表      │         │   orders 表      │                   │
│  ├─────────────────┤         ├─────────────────┤                   │
│  │ id (PK)         │◄────────│ user_id (FK)    │                   │
│  │ username        │    1:N  │ id (PK)         │                   │
│  │ email           │         │ order_date      │                   │
│  │ created_at      │         │ total_amount    │                   │
│  └─────────────────┘         │ status          │                   │
│                              └─────────────────┘                   │
│                                      │                              │
│                                      │ N:1                          │
│                                      ▼                              │
│                              ┌─────────────────┐                   │
│                              │  products 表     │                   │
│                              ├─────────────────┤                   │
│                              │ id (PK)         │                   │
│                              │ name            │                   │
│                              │ price           │                   │
│                              │ category        │                   │
│                              └─────────────────┘                   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

PK = Primary Key (主键)
FK = Foreign Key (外键)

28.1.2 SQL语言基础

SQL(Structured Query Language)是操作关系型数据库的标准语言:

类别命令用途
DDLCREATE, ALTER, DROP定义数据库结构
DMLINSERT, UPDATE, DELETE, SELECT操作数据
DCLGRANT, REVOKE控制访问权限
TCLCOMMIT, ROLLBACK, SAVEPOINT事务控制

28.1.3 数据库设计范式

python
from dataclasses import dataclass
from datetime import datetime
from typing import Optional


@dataclass
class User:
    id: int
    username: str
    email: str
    created_at: datetime


@dataclass
class Product:
    id: int
    name: str
    price: float
    category: str
    stock: int


@dataclass
class Order:
    id: int
    user_id: int
    order_date: datetime
    status: str
    total_amount: float


@dataclass
class OrderItem:
    id: int
    order_id: int
    product_id: int
    quantity: int
    unit_price: float

28.2 SQLite数据库操作

28.2.1 SQLite简介

SQLite是一个轻量级的嵌入式数据库,无需服务器,适合小型应用和原型开发。

python
import sqlite3
from pathlib import Path
from typing import Any, Optional
import logging

logger = logging.getLogger(__name__)


class SQLiteManager:
    def __init__(self, db_path: str | Path = ":memory:"):
        self.db_path = Path(db_path) if db_path != ":memory:" else db_path
        self._connection: Optional[sqlite3.Connection] = None
        self._cursor: Optional[sqlite3.Cursor] = None

    @property
    def connection(self) -> sqlite3.Connection:
        if self._connection is None:
            self._connection = sqlite3.connect(
                self.db_path,
                detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
            )
            self._connection.row_factory = sqlite3.Row
            self._connection.execute("PRAGMA foreign_keys = ON")
            logger.info(f"Connected to database: {self.db_path}")
        return self._connection

    @property
    def cursor(self) -> sqlite3.Cursor:
        if self._cursor is None:
            self._cursor = self.connection.cursor()
        return self._cursor

    def execute(
        self,
        query: str,
        params: tuple | dict = (),
        commit: bool = False
    ) -> sqlite3.Cursor:
        cursor = self.cursor.execute(query, params)
        if commit:
            self.connection.commit()
        return cursor

    def executemany(
        self,
        query: str,
        params_list: list[tuple | dict],
        commit: bool = False
    ) -> sqlite3.Cursor:
        cursor = self.cursor.executemany(query, params_list)
        if commit:
            self.connection.commit()
        return cursor

    def fetchone(self, query: str, params: tuple | dict = ()) -> Optional[sqlite3.Row]:
        cursor = self.execute(query, params)
        return cursor.fetchone()

    def fetchall(self, query: str, params: tuple | dict = ()) -> list[sqlite3.Row]:
        cursor = self.execute(query, params)
        return cursor.fetchall()

    def fetchvalue(self, query: str, params: tuple | dict = ()) -> Any:
        row = self.fetchone(query, params)
        return row[0] if row else None

    def close(self) -> None:
        if self._cursor:
            self._cursor.close()
            self._cursor = None
        if self._connection:
            self._connection.close()
            self._connection = None
            logger.info("Database connection closed")

    def __enter__(self) -> "SQLiteManager":
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        if exc_type:
            self.connection.rollback()
        self.close()

    def create_tables(self) -> None:
        self.execute("""
            CREATE TABLE IF NOT EXISTS users (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                username TEXT NOT NULL UNIQUE,
                email TEXT NOT NULL UNIQUE,
                password_hash TEXT NOT NULL,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        self.execute("""
            CREATE TABLE IF NOT EXISTS products (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                name TEXT NOT NULL,
                description TEXT,
                price REAL NOT NULL CHECK(price >= 0),
                category TEXT NOT NULL,
                stock INTEGER DEFAULT 0 CHECK(stock >= 0),
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        
        self.execute("""
            CREATE TABLE IF NOT EXISTS orders (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                user_id INTEGER NOT NULL,
                order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                status TEXT DEFAULT 'pending',
                total_amount REAL DEFAULT 0,
                shipping_address TEXT,
                FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
            )
        """)
        
        self.execute("""
            CREATE TABLE IF NOT EXISTS order_items (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                order_id INTEGER NOT NULL,
                product_id INTEGER NOT NULL,
                quantity INTEGER NOT NULL CHECK(quantity > 0),
                unit_price REAL NOT NULL CHECK(unit_price >= 0),
                FOREIGN KEY (order_id) REFERENCES orders(id) ON DELETE CASCADE,
                FOREIGN KEY (product_id) REFERENCES products(id) ON DELETE RESTRICT
            )
        """)
        
        self.connection.commit()
        logger.info("Tables created successfully")

28.2.2 CRUD操作实现

python
from dataclasses import asdict
from datetime import datetime
import hashlib
import secrets


class UserRepository:
    def __init__(self, db: SQLiteManager):
        self.db = db

    def create(self, username: str, email: str, password: str) -> int:
        password_hash = self._hash_password(password)
        cursor = self.db.execute(
            """
            INSERT INTO users (username, email, password_hash)
            VALUES (?, ?, ?)
            """,
            (username, email, password_hash),
            commit=True
        )
        return cursor.lastrowid

    def find_by_id(self, user_id: int) -> Optional[dict]:
        row = self.db.fetchone(
            "SELECT id, username, email, created_at FROM users WHERE id = ?",
            (user_id,)
        )
        return dict(row) if row else None

    def find_by_username(self, username: str) -> Optional[dict]:
        row = self.db.fetchone(
            "SELECT * FROM users WHERE username = ?",
            (username,)
        )
        return dict(row) if row else None

    def find_all(self, limit: int = 100, offset: int = 0) -> list[dict]:
        rows = self.db.fetchall(
            "SELECT id, username, email, created_at FROM users ORDER BY created_at DESC LIMIT ? OFFSET ?",
            (limit, offset)
        )
        return [dict(row) for row in rows]

    def update(self, user_id: int, **fields) -> bool:
        allowed_fields = {"username", "email"}
        updates = {k: v for k, v in fields.items() if k in allowed_fields}
        
        if not updates:
            return False
        
        updates["updated_at"] = datetime.now()
        set_clause = ", ".join(f"{k} = ?" for k in updates.keys())
        values = list(updates.values()) + [user_id]
        
        cursor = self.db.execute(
            f"UPDATE users SET {set_clause} WHERE id = ?",
            tuple(values),
            commit=True
        )
        return cursor.rowcount > 0

    def delete(self, user_id: int) -> bool:
        cursor = self.db.execute(
            "DELETE FROM users WHERE id = ?",
            (user_id,),
            commit=True
        )
        return cursor.rowcount > 0

    def verify_password(self, username: str, password: str) -> bool:
        row = self.db.fetchone(
            "SELECT password_hash FROM users WHERE username = ?",
            (username,)
        )
        if not row:
            return False
        return self._verify_password(password, row["password_hash"])

    @staticmethod
    def _hash_password(password: str) -> str:
        salt = secrets.token_hex(16)
        hash_value = hashlib.pbkdf2_hmac(
            "sha256",
            password.encode("utf-8"),
            salt.encode("utf-8"),
            100000
        )
        return f"{salt}:{hash_value.hex()}"

    @staticmethod
    def _verify_password(password: str, stored_hash: str) -> bool:
        salt, hash_value = stored_hash.split(":")
        new_hash = hashlib.pbkdf2_hmac(
            "sha256",
            password.encode("utf-8"),
            salt.encode("utf-8"),
            100000
        )
        return secrets.compare_digest(new_hash.hex(), hash_value)


class ProductRepository:
    def __init__(self, db: SQLiteManager):
        self.db = db

    def create(
        self,
        name: str,
        price: float,
        category: str,
        description: str = "",
        stock: int = 0
    ) -> int:
        cursor = self.db.execute(
            """
            INSERT INTO products (name, description, price, category, stock)
            VALUES (?, ?, ?, ?, ?)
            """,
            (name, description, price, category, stock),
            commit=True
        )
        return cursor.lastrowid

    def find_by_id(self, product_id: int) -> Optional[dict]:
        row = self.db.fetchone(
            "SELECT * FROM products WHERE id = ?",
            (product_id,)
        )
        return dict(row) if row else None

    def find_by_category(self, category: str) -> list[dict]:
        rows = self.db.fetchall(
            "SELECT * FROM products WHERE category = ? ORDER BY name",
            (category,)
        )
        return [dict(row) for row in rows]

    def search(self, keyword: str) -> list[dict]:
        rows = self.db.fetchall(
            """
            SELECT * FROM products 
            WHERE name LIKE ? OR description LIKE ?
            ORDER BY name
            """,
            (f"%{keyword}%", f"%{keyword}%")
        )
        return [dict(row) for row in rows]

    def update_stock(self, product_id: int, quantity: int) -> bool:
        cursor = self.db.execute(
            """
            UPDATE products SET stock = stock + ? 
            WHERE id = ? AND stock + ? >= 0
            """,
            (quantity, product_id, quantity),
            commit=True
        )
        return cursor.rowcount > 0

    def list_all(
        self,
        category: Optional[str] = None,
        min_price: Optional[float] = None,
        max_price: Optional[float] = None,
        in_stock_only: bool = False,
        limit: int = 100,
        offset: int = 0
    ) -> list[dict]:
        conditions = []
        params = []

        if category:
            conditions.append("category = ?")
            params.append(category)
        if min_price is not None:
            conditions.append("price >= ?")
            params.append(min_price)
        if max_price is not None:
            conditions.append("price <= ?")
            params.append(max_price)
        if in_stock_only:
            conditions.append("stock > 0")

        where_clause = " AND ".join(conditions) if conditions else "1=1"
        params.extend([limit, offset])

        rows = self.db.fetchall(
            f"""
            SELECT * FROM products 
            WHERE {where_clause}
            ORDER BY created_at DESC
            LIMIT ? OFFSET ?
            """,
            tuple(params)
        )
        return [dict(row) for row in rows]

28.2.3 事务管理

python
from contextlib import contextmanager
from typing import Generator, Callable, TypeVar, ParamSpec

P = ParamSpec("P")
T = TypeVar("T")


class TransactionManager:
    def __init__(self, db: SQLiteManager):
        self.db = db

    @contextmanager
    def transaction(self) -> Generator[sqlite3.Connection, None, None]:
        conn = self.db.connection
        try:
            yield conn
            conn.commit()
        except Exception as e:
            conn.rollback()
            logger.error(f"Transaction rolled back: {e}")
            raise

    def execute_in_transaction(
        self,
        func: Callable[P, T],
        *args: P.args,
        **kwargs: P.kwargs
    ) -> T:
        with self.transaction():
            return func(*args, **kwargs)


class OrderService:
    def __init__(self, db: SQLiteManager):
        self.db = db
        self.tx = TransactionManager(db)
        self.user_repo = UserRepository(db)
        self.product_repo = ProductRepository(db)

    def create_order(
        self,
        user_id: int,
        items: list[dict],
        shipping_address: str
    ) -> int:
        with self.db.connection:
            user = self.user_repo.find_by_id(user_id)
            if not user:
                raise ValueError(f"User {user_id} not found")

            total_amount = 0.0
            order_items = []

            for item in items:
                product = self.product_repo.find_by_id(item["product_id"])
                if not product:
                    raise ValueError(f"Product {item['product_id']} not found")
                
                if product["stock"] < item["quantity"]:
                    raise ValueError(
                        f"Insufficient stock for {product['name']}. "
                        f"Available: {product['stock']}, Requested: {item['quantity']}"
                    )

                unit_price = product["price"]
                subtotal = unit_price * item["quantity"]
                total_amount += subtotal
                order_items.append({
                    "product_id": item["product_id"],
                    "quantity": item["quantity"],
                    "unit_price": unit_price
                })

            cursor = self.db.execute(
                """
                INSERT INTO orders (user_id, total_amount, shipping_address, status)
                VALUES (?, ?, ?, 'pending')
                """,
                (user_id, total_amount, shipping_address)
            )
            order_id = cursor.lastrowid

            for item in order_items:
                self.db.execute(
                    """
                    INSERT INTO order_items (order_id, product_id, quantity, unit_price)
                    VALUES (?, ?, ?, ?)
                    """,
                    (order_id, item["product_id"], item["quantity"], item["unit_price"])
                )

                self.product_repo.update_stock(item["product_id"], -item["quantity"])

            self.db.connection.commit()
            logger.info(f"Order {order_id} created for user {user_id}")
            return order_id

    def cancel_order(self, order_id: int) -> bool:
        with self.db.connection:
            order = self.db.fetchone(
                "SELECT * FROM orders WHERE id = ?",
                (order_id,)
            )
            if not order:
                raise ValueError(f"Order {order_id} not found")

            if order["status"] == "cancelled":
                return True

            if order["status"] == "delivered":
                raise ValueError("Cannot cancel a delivered order")

            items = self.db.fetchall(
                "SELECT product_id, quantity FROM order_items WHERE order_id = ?",
                (order_id,)
            )

            for item in items:
                self.product_repo.update_stock(item["product_id"], item["quantity"])

            self.db.execute(
                "UPDATE orders SET status = 'cancelled' WHERE id = ?",
                (order_id,)
            )

            self.db.connection.commit()
            logger.info(f"Order {order_id} cancelled")
            return True

    def get_order_details(self, order_id: int) -> Optional[dict]:
        order = self.db.fetchone(
            """
            SELECT o.*, u.username, u.email
            FROM orders o
            JOIN users u ON o.user_id = u.id
            WHERE o.id = ?
            """,
            (order_id,)
        )
        if not order:
            return None

        items = self.db.fetchall(
            """
            SELECT oi.*, p.name as product_name, p.category
            FROM order_items oi
            JOIN products p ON oi.product_id = p.id
            WHERE oi.order_id = ?
            """,
            (order_id,)
        )

        return {
            "order": dict(order),
            "items": [dict(item) for item in items]
        }

28.3 SQLAlchemy ORM

28.3.1 SQLAlchemy简介

SQLAlchemy是Python最流行的ORM框架,提供了高级的ORM和低级的SQL表达式两种使用方式。

python
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, ForeignKey, Text, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship, sessionmaker
from datetime import datetime
from typing import Optional, List

Base = declarative_base()


class User(Base):
    __tablename__ = "users"

    id = Column(Integer, primary_key=True, autoincrement=True)
    username = Column(String(50), unique=True, nullable=False, index=True)
    email = Column(String(100), unique=True, nullable=False, index=True)
    password_hash = Column(String(128), nullable=False)
    is_active = Column(Boolean, default=True)
    is_admin = Column(Boolean, default=False)
    created_at = Column(DateTime, default=datetime.utcnow)
    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

    orders = relationship("Order", back_populates="user", cascade="all, delete-orphan")

    def __repr__(self) -> str:
        return f"<User(id={self.id}, username='{self.username}')>"


class Product(Base):
    __tablename__ = "products"

    id = Column(Integer, primary_key=True, autoincrement=True)
    name = Column(String(100), nullable=False)
    description = Column(Text)
    price = Column(Float, nullable=False)
    category = Column(String(50), nullable=False, index=True)
    stock = Column(Integer, default=0)
    is_available = Column(Boolean, default=True)
    created_at = Column(DateTime, default=datetime.utcnow)

    order_items = relationship("OrderItem", back_populates="product")

    def __repr__(self) -> str:
        return f"<Product(id={self.id}, name='{self.name}', price={self.price})>"


class Order(Base):
    __tablename__ = "orders"

    id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
    order_date = Column(DateTime, default=datetime.utcnow)
    status = Column(String(20), default="pending")
    total_amount = Column(Float, default=0.0)
    shipping_address = Column(Text)

    user = relationship("User", back_populates="orders")
    items = relationship("OrderItem", back_populates="order", cascade="all, delete-orphan")

    def __repr__(self) -> str:
        return f"<Order(id={self.id}, user_id={self.user_id}, status='{self.status}')>"

    def calculate_total(self) -> float:
        return sum(item.quantity * item.unit_price for item in self.items)


class OrderItem(Base):
    __tablename__ = "order_items"

    id = Column(Integer, primary_key=True, autoincrement=True)
    order_id = Column(Integer, ForeignKey("orders.id", ondelete="CASCADE"), nullable=False)
    product_id = Column(Integer, ForeignKey("products.id", ondelete="RESTRICT"), nullable=False)
    quantity = Column(Integer, nullable=False)
    unit_price = Column(Float, nullable=False)

    order = relationship("Order", back_populates="items")
    product = relationship("Product", back_populates="order_items")

    def __repr__(self) -> str:
        return f"<OrderItem(order_id={self.order_id}, product_id={self.product_id}, qty={self.quantity})>"


class Database:
    def __init__(self, database_url: str = "sqlite:///app.db", echo: bool = False):
        self.engine = create_engine(database_url, echo=echo)
        self.SessionLocal = sessionmaker(
            autocommit=False,
            autoflush=False,
            bind=self.engine
        )

    def create_tables(self) -> None:
        Base.metadata.create_all(self.engine)

    def drop_tables(self) -> None:
        Base.metadata.drop_all(self.engine)

    def get_session(self) -> Session:
        return self.SessionLocal()

    @staticmethod
    def session_context(session: Session):
        try:
            yield session
            session.commit()
        except Exception:
            session.rollback()
            raise
        finally:
            session.close()

28.3.2 查询构建

python
from sqlalchemy import select, update, delete, func, and_, or_, not_
from sqlalchemy.orm import joinedload, selectinload, subqueryload
from typing import TypeVar, Generic, Type, Any

ModelType = TypeVar("ModelType", bound=Base)


class BaseRepository(Generic[ModelType]):
    def __init__(self, model: Type[ModelType], session: Session):
        self.model = model
        self.session = session

    def create(self, **kwargs) -> ModelType:
        instance = self.model(**kwargs)
        self.session.add(instance)
        self.session.commit()
        self.session.refresh(instance)
        return instance

    def find_by_id(self, id: int) -> Optional[ModelType]:
        return self.session.get(self.model, id)

    def find_all(self, skip: int = 0, limit: int = 100) -> List[ModelType]:
        stmt = select(self.model).offset(skip).limit(limit)
        return list(self.session.scalars(stmt))

    def update(self, id: int, **kwargs) -> Optional[ModelType]:
        instance = self.find_by_id(id)
        if instance:
            for key, value in kwargs.items():
                setattr(instance, key, value)
            self.session.commit()
            self.session.refresh(instance)
        return instance

    def delete(self, id: int) -> bool:
        instance = self.find_by_id(id)
        if instance:
            self.session.delete(instance)
            self.session.commit()
            return True
        return False

    def count(self) -> int:
        stmt = select(func.count()).select_from(self.model)
        return self.session.scalar(stmt) or 0

    def exists(self, id: int) -> bool:
        return self.find_by_id(id) is not None


class UserRepositorySQLAlchemy(BaseRepository[User]):
    def __init__(self, session: Session):
        super().__init__(User, session)

    def find_by_username(self, username: str) -> Optional[User]:
        stmt = select(User).where(User.username == username)
        return self.session.scalar(stmt)

    def find_by_email(self, email: str) -> Optional[User]:
        stmt = select(User).where(User.email == email)
        return self.session.scalar(stmt)

    def search_users(
        self,
        keyword: Optional[str] = None,
        is_active: Optional[bool] = None,
        is_admin: Optional[bool] = None,
        skip: int = 0,
        limit: int = 100
    ) -> List[User]:
        stmt = select(User)
        conditions = []

        if keyword:
            conditions.append(
                or_(
                    User.username.ilike(f"%{keyword}%"),
                    User.email.ilike(f"%{keyword}%")
                )
            )
        if is_active is not None:
            conditions.append(User.is_active == is_active)
        if is_admin is not None:
            conditions.append(User.is_admin == is_admin)

        if conditions:
            stmt = stmt.where(and_(*conditions))

        stmt = stmt.offset(skip).limit(limit).order_by(User.created_at.desc())
        return list(self.session.scalars(stmt))

    def get_user_with_orders(self, user_id: int) -> Optional[User]:
        stmt = (
            select(User)
            .options(selectinload(User.orders))
            .where(User.id == user_id)
        )
        return self.session.scalar(stmt)


class ProductRepositorySQLAlchemy(BaseRepository[Product]):
    def __init__(self, session: Session):
        super().__init__(Product, session)

    def find_by_category(self, category: str) -> List[Product]:
        stmt = select(Product).where(Product.category == category).order_by(Product.name)
        return list(self.session.scalars(stmt))

    def find_available_products(self, skip: int = 0, limit: int = 100) -> List[Product]:
        stmt = (
            select(Product)
            .where(and_(Product.is_available == True, Product.stock > 0))
            .offset(skip)
            .limit(limit)
            .order_by(Product.created_at.desc())
        )
        return list(self.session.scalars(stmt))

    def search_products(
        self,
        keyword: str,
        category: Optional[str] = None,
        min_price: Optional[float] = None,
        max_price: Optional[float] = None
    ) -> List[Product]:
        stmt = select(Product)
        conditions = [
            or_(
                Product.name.ilike(f"%{keyword}%"),
                Product.description.ilike(f"%{keyword}%")
            )
        ]

        if category:
            conditions.append(Product.category == category)
        if min_price is not None:
            conditions.append(Product.price >= min_price)
        if max_price is not None:
            conditions.append(Product.price <= max_price)

        stmt = stmt.where(and_(*conditions)).order_by(Product.name)
        return list(self.session.scalars(stmt))

    def get_categories(self) -> List[str]:
        stmt = select(Product.category).distinct().order_by(Product.category)
        return list(self.session.scalars(stmt))

    def get_stock_statistics(self) -> dict:
        stmt = select(
            Product.category,
            func.count(Product.id).label("product_count"),
            func.sum(Product.stock).label("total_stock"),
            func.avg(Product.price).label("avg_price")
        ).group_by(Product.category)

        results = self.session.execute(stmt).all()
        return [
            {
                "category": row.category,
                "product_count": row.product_count,
                "total_stock": row.total_stock or 0,
                "avg_price": round(row.avg_price, 2) if row.avg_price else 0
            }
            for row in results
        ]


class OrderRepositorySQLAlchemy(BaseRepository[Order]):
    def __init__(self, session: Session):
        super().__init__(Order, session)

    def find_by_user(self, user_id: int, skip: int = 0, limit: int = 100) -> List[Order]:
        stmt = (
            select(Order)
            .where(Order.user_id == user_id)
            .options(selectinload(Order.items))
            .offset(skip)
            .limit(limit)
            .order_by(Order.order_date.desc())
        )
        return list(self.session.scalars(stmt))

    def find_by_status(self, status: str) -> List[Order]:
        stmt = (
            select(Order)
            .where(Order.status == status)
            .options(joinedload(Order.user))
            .order_by(Order.order_date)
        )
        return list(self.session.scalars(stmt))

    def get_order_with_details(self, order_id: int) -> Optional[Order]:
        stmt = (
            select(Order)
            .options(
                joinedload(Order.user),
                selectinload(Order.items).joinedload(OrderItem.product)
            )
            .where(Order.id == order_id)
        )
        return self.session.scalar(stmt)

    def get_daily_sales(self, start_date: datetime, end_date: datetime) -> List[dict]:
        stmt = (
            select(
                func.date(Order.order_date).label("date"),
                func.count(Order.id).label("order_count"),
                func.sum(Order.total_amount).label("total_revenue")
            )
            .where(
                and_(
                    Order.order_date >= start_date,
                    Order.order_date <= end_date,
                    Order.status != "cancelled"
                )
            )
            .group_by(func.date(Order.order_date))
            .order_by(func.date(Order.order_date))
        )

        results = self.session.execute(stmt).all()
        return [
            {
                "date": row.date,
                "order_count": row.order_count,
                "total_revenue": row.total_revenue or 0
            }
            for row in results
        ]

    def get_top_customers(self, limit: int = 10) -> List[dict]:
        stmt = (
            select(
                User.id,
                User.username,
                User.email,
                func.count(Order.id).label("order_count"),
                func.sum(Order.total_amount).label("total_spent")
            )
            .join(Order, User.id == Order.user_id)
            .where(Order.status != "cancelled")
            .group_by(User.id)
            .order_by(func.sum(Order.total_amount).desc())
            .limit(limit)
        )

        results = self.session.execute(stmt).all()
        return [
            {
                "user_id": row.id,
                "username": row.username,
                "email": row.email,
                "order_count": row.order_count,
                "total_spent": row.total_spent or 0
            }
            for row in results
        ]

28.3.3 高级查询技术

python
from sqlalchemy import case, literal_column, over
from sqlalchemy.sql import functions as sql_func


class AdvancedQueries:
    def __init__(self, session: Session):
        self.session = session

    def complex_join_query(self) -> List[dict]:
        stmt = (
            select(
                Order.id,
                Order.order_date,
                User.username,
                User.email,
                func.count(OrderItem.id).label("item_count"),
                func.sum(OrderItem.quantity * OrderItem.unit_price).label("total")
            )
            .join(User, Order.user_id == User.id)
            .join(OrderItem, Order.id == OrderItem.order_id)
            .where(Order.status == "delivered")
            .group_by(Order.id, User.id)
            .having(func.sum(OrderItem.quantity * OrderItem.unit_price) > 100)
            .order_by(func.sum(OrderItem.quantity * OrderItem.unit_price).desc())
        )

        results = self.session.execute(stmt).all()
        return [dict(row._mapping) for row in results]

    def subquery_example(self) -> List[dict]:
        avg_price_subquery = (
            select(
                Product.category,
                func.avg(Product.price).label("avg_price")
            )
            .group_by(Product.category)
            .subquery()
        )

        stmt = (
            select(
                Product.name,
                Product.category,
                Product.price,
                avg_price_subquery.c.avg_price
            )
            .join(
                avg_price_subquery,
                Product.category == avg_price_subquery.c.category
            )
            .where(Product.price > avg_price_subquery.c.avg_price)
            .order_by(Product.category, Product.price.desc())
        )

        results = self.session.execute(stmt).all()
        return [dict(row._mapping) for row in results]

    def window_function_example(self) -> List[dict]:
        rank_window = over(
            func.rank(),
            partition_by=Product.category,
            order_by=Product.price.desc()
        )

        stmt = (
            select(
                Product.name,
                Product.category,
                Product.price,
                rank_window.label("price_rank")
            )
            .select_from(Product)
        )

        results = self.session.execute(stmt).all()
        return [
            {
                "name": row.name,
                "category": row.category,
                "price": row.price,
                "price_rank": row.price_rank
            }
            for row in results
        ]

    def case_expression_example(self) -> List[dict]:
        price_category = case(
            (Product.price < 50, "Budget"),
            (Product.price < 100, "Mid-range"),
            (Product.price < 500, "Premium"),
            else_="Luxury"
        )

        stmt = (
            select(
                Product.name,
                Product.price,
                Product.category,
                price_category.label("price_tier")
            )
            .order_by(Product.price)
        )

        results = self.session.execute(stmt).all()
        return [dict(row._mapping) for row in results]

    def cte_example(self) -> List[dict]:
        monthly_sales = (
            select(
                func.strftime("%Y-%m", Order.order_date).label("month"),
                func.sum(Order.total_amount).label("monthly_total")
            )
            .where(Order.status == "delivered")
            .group_by(func.strftime("%Y-%m", Order.order_date))
            .cte("monthly_sales")
        )

        stmt = (
            select(
                monthly_sales.c.month,
                monthly_sales.c.monthly_total,
                sql_func.lag(monthly_sales.c.monthly_total).over(
                    order_by=monthly_sales.c.month
                ).label("previous_month"),
                (
                    monthly_sales.c.monthly_total - 
                    sql_func.lag(monthly_sales.c.monthly_total).over(
                        order_by=monthly_sales.c.month
                    )
                ).label("growth")
            )
            .order_by(monthly_sales.c.month)
        )

        results = self.session.execute(stmt).all()
        return [dict(row._mapping) for row in results]

28.4 数据库连接池

28.4.1 连接池配置

python
from sqlalchemy import create_engine, event
from sqlalchemy.pool import QueuePool, NullPool
import logging

logger = logging.getLogger(__name__)


class ConnectionPoolManager:
    def __init__(
        self,
        database_url: str,
        pool_size: int = 5,
        max_overflow: int = 10,
        pool_timeout: int = 30,
        pool_recycle: int = 3600,
        echo: bool = False
    ):
        self.database_url = database_url
        self.pool_size = pool_size
        self.max_overflow = max_overflow
        self.pool_timeout = pool_timeout
        self.pool_recycle = pool_recycle
        self.echo = echo
        self._engine = None

    @property
    def engine(self):
        if self._engine is None:
            self._engine = create_engine(
                self.database_url,
                poolclass=QueuePool,
                pool_size=self.pool_size,
                max_overflow=self.max_overflow,
                pool_timeout=self.pool_timeout,
                pool_recycle=self.pool_recycle,
                echo=self.echo,
                echo_pool=True,
                pre_ping=True
            )
            self._setup_events()
        return self._engine

    def _setup_events(self):
        @event.listens_for(self._engine, "connect")
        def on_connect(dbapi_conn, connection_record):
            logger.debug(f"Connection created: {connection_record.info}")

        @event.listens_for(self._engine, "checkout")
        def on_checkout(dbapi_conn, connection_record, connection_proxy):
            logger.debug(f"Connection checked out from pool")

        @event.listens_for(self._engine, "checkin")
        def on_checkin(dbapi_conn, connection_record):
            logger.debug(f"Connection returned to pool")

        @event.listens_for(self._engine, "checkout")
        def on_checkout_error(dbapi_conn, connection_record, connection_proxy):
            logger.warning(f"Connection checkout failed, attempting reconnect")

    def get_pool_status(self) -> dict:
        pool = self._engine.pool
        return {
            "pool_size": pool.size(),
            "checked_in": pool.checkedin(),
            "checked_out": pool.checkedout(),
            "overflow": pool.overflow(),
            "invalid": pool.invalidatedcount() if hasattr(pool, "invalidatedcount") else 0
        }

    def dispose(self):
        if self._engine:
            self._engine.dispose()
            logger.info("Connection pool disposed")


class DatabaseConfig:
    SQLITE_MEMORY = "sqlite:///:memory:"
    SQLITE_FILE = "sqlite:///app.db"
    
    @staticmethod
    def postgresql(
        host: str = "localhost",
        port: int = 5432,
        database: str = "app",
        user: str = "postgres",
        password: str = ""
    ) -> str:
        return f"postgresql://{user}:{password}@{host}:{port}/{database}"

    @staticmethod
    def mysql(
        host: str = "localhost",
        port: int = 3306,
        database: str = "app",
        user: str = "root",
        password: str = ""
    ) -> str:
        return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"


class DatabaseManager:
    _instance = None
    _pool_manager = None

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    @classmethod
    def initialize(
        cls,
        database_url: str,
        pool_size: int = 5,
        max_overflow: int = 10
    ) -> None:
        cls._pool_manager = ConnectionPoolManager(
            database_url=database_url,
            pool_size=pool_size,
            max_overflow=max_overflow
        )

    @classmethod
    def get_engine(cls):
        if cls._pool_manager is None:
            raise RuntimeError("DatabaseManager not initialized")
        return cls._pool_manager.engine

    @classmethod
    def get_session(cls) -> Session:
        engine = cls.get_engine()
        SessionLocal = sessionmaker(bind=engine)
        return SessionLocal()

    @classmethod
    def get_pool_status(cls) -> dict:
        if cls._pool_manager is None:
            return {"error": "Pool not initialized"}
        return cls._pool_manager.get_pool_status()

    @classmethod
    def dispose(cls) -> None:
        if cls._pool_manager:
            cls._pool_manager.dispose()
            cls._pool_manager = None

28.4.2 异步数据库操作

python
from sqlalchemy.ext.asyncio import (
    create_async_engine,
    AsyncSession,
    async_sessionmaker
)
from sqlalchemy import select
import asyncio


class AsyncDatabase:
    def __init__(self, database_url: str = "sqlite+aiosqlite:///app.db"):
        self.engine = create_async_engine(database_url, echo=False)
        self.async_session = async_sessionmaker(
            self.engine,
            class_=AsyncSession,
            expire_on_commit=False
        )

    async def create_tables(self):
        async with self.engine.begin() as conn:
            await conn.run_sync(Base.metadata.create_all)

    async def get_session(self) -> AsyncSession:
        async with self.async_session() as session:
            yield session


class AsyncUserRepository:
    def __init__(self, session: AsyncSession):
        self.session = session

    async def create(self, username: str, email: str, password_hash: str) -> User:
        user = User(
            username=username,
            email=email,
            password_hash=password_hash
        )
        self.session.add(user)
        await self.session.commit()
        await self.session.refresh(user)
        return user

    async def find_by_id(self, user_id: int) -> Optional[User]:
        result = await self.session.execute(
            select(User).where(User.id == user_id)
        )
        return result.scalar_one_or_none()

    async def find_by_username(self, username: str) -> Optional[User]:
        result = await self.session.execute(
            select(User).where(User.username == username)
        )
        return result.scalar_one_or_none()

    async def find_all(self, skip: int = 0, limit: int = 100) -> List[User]:
        result = await self.session.execute(
            select(User).offset(skip).limit(limit)
        )
        return list(result.scalars().all())

    async def update(self, user_id: int, **kwargs) -> Optional[User]:
        user = await self.find_by_id(user_id)
        if user:
            for key, value in kwargs.items():
                setattr(user, key, value)
            await self.session.commit()
            await self.session.refresh(user)
        return user

    async def delete(self, user_id: int) -> bool:
        user = await self.find_by_id(user_id)
        if user:
            await self.session.delete(user)
            await self.session.commit()
            return True
        return False


async def async_main():
    db = AsyncDatabase()
    await db.create_tables()

    async with db.async_session() as session:
        user_repo = AsyncUserRepository(session)

        user = await user_repo.create(
            username="testuser",
            email="test@example.com",
            password_hash="hashed_password"
        )
        print(f"Created user: {user}")

        found = await user_repo.find_by_username("testuser")
        print(f"Found user: {found}")

        users = await user_repo.find_all()
        print(f"All users: {users}")


if __name__ == "__main__":
    asyncio.run(async_main())

28.5 数据库迁移

28.5.1 Alembic配置

python
from alembic.config import Config
from alembic import command
from pathlib import Path


class MigrationManager:
    def __init__(self, database_url: str, migrations_dir: str = "migrations"):
        self.database_url = database_url
        self.migrations_dir = Path(migrations_dir)
        self.config = self._create_config()

    def _create_config(self) -> Config:
        config = Config()
        config.set_main_option("sqlalchemy.url", self.database_url)
        config.set_main_option("script_location", str(self.migrations_dir))
        return config

    def init_migrations(self) -> None:
        if not self.migrations_dir.exists():
            command.init(self.config, str(self.migrations_dir))
            print(f"Initialized migrations directory: {self.migrations_dir}")
        else:
            print(f"Migrations directory already exists: {self.migrations_dir}")

    def create_migration(self, message: str = "auto migration") -> None:
        command.revision(self.config, autogenerate=True, message=message)
        print(f"Created migration: {message}")

    def upgrade(self, revision: str = "head") -> None:
        command.upgrade(self.config, revision)
        print(f"Upgraded database to revision: {revision}")

    def downgrade(self, revision: str = "-1") -> None:
        command.downgrade(self.config, revision)
        print(f"Downgraded database to revision: {revision}")

    def current(self) -> None:
        command.current(self.config)

    def history(self) -> None:
        command.history(self.config)


def alembic_env_py_template():
    return '''
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).resolve().parent.parent))

from models import Base

config = context.config

if config.config_file_name is not None:
    fileConfig(config.config_file_name)

target_metadata = Base.metadata

def run_migrations_offline() -> None:
    url = config.get_main_option("sqlalchemy.url")
    context.configure(
        url=url,
        target_metadata=target_metadata,
        literal_binds=True,
        dialect_opts={"paramstyle": "named"},
    )

    with context.begin_transaction():
        context.run_migrations()


def run_migrations_online() -> None:
    connectable = engine_from_config(
        config.get_section(config.config_ini_section, {}),
        prefix="sqlalchemy.",
        poolclass=pool.NullPool,
    )

    with connectable.connect() as connection:
        context.configure(
            connection=connection,
            target_metadata=target_metadata
        )

        with context.begin_transaction():
            context.run_migrations()


if context.is_offline_mode():
    run_migrations_offline()
else:
    run_migrations_online()
'''

28.5.2 迁移脚本示例

python
from alembic import op
import sqlalchemy as sa
from datetime import datetime


def upgrade():
    op.create_table(
        'users',
        sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
        sa.Column('username', sa.String(length=50), nullable=False),
        sa.Column('email', sa.String(length=100), nullable=False),
        sa.Column('password_hash', sa.String(length=128), nullable=False),
        sa.Column('is_active', sa.Boolean(), nullable=True, default=True),
        sa.Column('is_admin', sa.Boolean(), nullable=True, default=False),
        sa.Column('created_at', sa.DateTime(), nullable=True, default=datetime.utcnow),
        sa.Column('updated_at', sa.DateTime(), nullable=True, default=datetime.utcnow),
        sa.PrimaryKeyConstraint('id'),
        sa.UniqueConstraint('username'),
        sa.UniqueConstraint('email')
    )
    op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=False)
    op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)

    op.create_table(
        'products',
        sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
        sa.Column('name', sa.String(length=100), nullable=False),
        sa.Column('description', sa.Text(), nullable=True),
        sa.Column('price', sa.Float(), nullable=False),
        sa.Column('category', sa.String(length=50), nullable=False),
        sa.Column('stock', sa.Integer(), nullable=True, default=0),
        sa.Column('is_available', sa.Boolean(), nullable=True, default=True),
        sa.Column('created_at', sa.DateTime(), nullable=True, default=datetime.utcnow),
        sa.PrimaryKeyConstraint('id')
    )
    op.create_index(op.f('ix_products_category'), 'products', ['category'], unique=False)

    op.create_table(
        'orders',
        sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
        sa.Column('user_id', sa.Integer(), nullable=False),
        sa.Column('order_date', sa.DateTime(), nullable=True, default=datetime.utcnow),
        sa.Column('status', sa.String(length=20), nullable=True, default='pending'),
        sa.Column('total_amount', sa.Float(), nullable=True, default=0.0),
        sa.Column('shipping_address', sa.Text(), nullable=True),
        sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
        sa.PrimaryKeyConstraint('id')
    )

    op.create_table(
        'order_items',
        sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
        sa.Column('order_id', sa.Integer(), nullable=False),
        sa.Column('product_id', sa.Integer(), nullable=False),
        sa.Column('quantity', sa.Integer(), nullable=False),
        sa.Column('unit_price', sa.Float(), nullable=False),
        sa.ForeignKeyConstraint(['order_id'], ['orders.id'], ondelete='CASCADE'),
        sa.ForeignKeyConstraint(['product_id'], ['products.id'], ondelete='RESTRICT'),
        sa.PrimaryKeyConstraint('id')
    )


def downgrade():
    op.drop_table('order_items')
    op.drop_table('orders')
    op.drop_index(op.f('ix_products_category'), table_name='products')
    op.drop_table('products')
    op.drop_index(op.f('ix_users_username'), table_name='users')
    op.drop_index(op.f('ix_users_email'), table_name='users')
    op.drop_table('users')

28.6 数据库安全实践

28.6.1 SQL注入防护

python
import re
from typing import Any


class SQLInjectionProtector:
    DANGEROUS_PATTERNS = [
        r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|UNION|ALTER|CREATE|TRUNCATE)\b)",
        r"(--|#|/\*|\*/)",
        r"(;|\||`)",
        r"(\b(OR|AND)\b\s+\d+\s*=\s*\d+)",
        r"(CONCAT|CHAR|ASCII|HEX|UNHEX)",
    ]

    @classmethod
    def sanitize_input(cls, value: str) -> str:
        if not isinstance(value, str):
            return value

        for pattern in cls.DANGEROUS_PATTERNS:
            if re.search(pattern, value, re.IGNORECASE):
                raise ValueError(f"Potentially dangerous input detected: {value}")

        return value.strip()

    @classmethod
    def validate_identifier(cls, identifier: str) -> str:
        if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier):
            raise ValueError(f"Invalid identifier: {identifier}")
        return identifier


class SecureQueryBuilder:
    def __init__(self, table: str):
        self.table = SQLInjectionProtector.validate_identifier(table)
        self._conditions = []
        self._params = []
        self._order_by = None
        self._limit = None
        self._offset = None

    def where(self, column: str, operator: str, value: Any) -> "SecureQueryBuilder":
        column = SQLInjectionProtector.validate_identifier(column)
        valid_operators = {"=", "!=", "<", ">", "<=", ">=", "LIKE", "IN", "IS"}
        operator = operator.upper()
        if operator not in valid_operators:
            raise ValueError(f"Invalid operator: {operator}")

        if operator == "IN":
            if not isinstance(value, (list, tuple)):
                raise ValueError("IN operator requires a list or tuple")
            placeholders = ", ".join(["?"] * len(value))
            self._conditions.append(f"{column} IN ({placeholders})")
            self._params.extend(value)
        elif operator == "IS":
            self._conditions.append(f"{column} IS {value}")
        else:
            self._conditions.append(f"{column} {operator} ?")
            self._params.append(value)

        return self

    def order_by(self, column: str, direction: str = "ASC") -> "SecureQueryBuilder":
        column = SQLInjectionProtector.validate_identifier(column)
        direction = direction.upper()
        if direction not in ("ASC", "DESC"):
            raise ValueError(f"Invalid direction: {direction}")
        self._order_by = f"ORDER BY {column} {direction}"
        return self

    def limit(self, limit: int) -> "SecureQueryBuilder":
        if not isinstance(limit, int) or limit < 0:
            raise ValueError("Limit must be a non-negative integer")
        self._limit = f"LIMIT {limit}"
        return self

    def offset(self, offset: int) -> "SecureQueryBuilder":
        if not isinstance(offset, int) or offset < 0:
            raise ValueError("Offset must be a non-negative integer")
        self._offset = f"OFFSET {offset}"
        return self

    def build_select(self, columns: list[str] = None) -> tuple[str, list]:
        if columns:
            columns = [SQLInjectionProtector.validate_identifier(c) for c in columns]
            select_clause = ", ".join(columns)
        else:
            select_clause = "*"

        query = f"SELECT {select_clause} FROM {self.table}"

        if self._conditions:
            query += " WHERE " + " AND ".join(self._conditions)

        if self._order_by:
            query += f" {self._order_by}"

        if self._limit:
            query += f" {self._limit}"

        if self._offset:
            query += f" {self._offset}"

        return query, self._params

    def build_count(self) -> tuple[str, list]:
        query = f"SELECT COUNT(*) FROM {self.table}"

        if self._conditions:
            query += " WHERE " + " AND ".join(self._conditions)

        return query, self._params


def demonstrate_secure_queries():
    builder = SecureQueryBuilder("users")
    
    query, params = (
        builder
        .where("is_active", "=", True)
        .where("created_at", ">", "2024-01-01")
        .order_by("created_at", "DESC")
        .limit(10)
        .build_select(["id", "username", "email"])
    )
    
    print(f"Query: {query}")
    print(f"Params: {params}")

28.6.2 数据加密

python
import os
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from typing import Optional


class FieldEncryption:
    def __init__(self, encryption_key: Optional[bytes] = None):
        if encryption_key:
            self.key = encryption_key
        else:
            self.key = Fernet.generate_key()
        self.cipher = Fernet(self.key)

    @classmethod
    def derive_key_from_password(cls, password: str, salt: Optional[bytes] = None) -> tuple["FieldEncryption", bytes]:
        if salt is None:
            salt = os.urandom(16)

        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=480000,
        )
        key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
        return cls(key), salt

    def encrypt(self, plaintext: str) -> str:
        if not plaintext:
            return plaintext
        encrypted = self.cipher.encrypt(plaintext.encode())
        return base64.urlsafe_b64encode(encrypted).decode()

    def decrypt(self, ciphertext: str) -> str:
        if not ciphertext:
            return ciphertext
        encrypted = base64.urlsafe_b64decode(ciphertext.encode())
        decrypted = self.cipher.decrypt(encrypted)
        return decrypted.decode()


class EncryptedField:
    def __init__(self, encryption: FieldEncryption):
        self.encryption = encryption

    def __get__(self, obj, objtype=None):
        if obj is None:
            return self
        
        encrypted_value = obj.__dict__.get(self.name + "_encrypted")
        if encrypted_value:
            return self.encryption.decrypt(encrypted_value)
        return None

    def __set__(self, obj, value):
        if value:
            encrypted = self.encryption.encrypt(value)
            obj.__dict__[self.name + "_encrypted"] = encrypted
        else:
            obj.__dict__[self.name + "_encrypted"] = None

    def __set_name__(self, owner, name):
        self.name = name


class SecureUser:
    encryption = FieldEncryption()
    
    id: int
    username: str
    email: str
    credit_card: EncryptedField = EncryptedField(encryption)
    ssn: EncryptedField = EncryptedField(encryption)

    def __init__(self, username: str, email: str):
        self.username = username
        self.email = email

28.7 Repository模式与数据访问层

28.7.1 Repository模式实现

python
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Optional, List, Protocol
from dataclasses import dataclass

T = TypeVar("T")
ID = TypeVar("ID")


class Repository(Protocol[T, ID]):
    def find_by_id(self, id: ID) -> Optional[T]: ...
    def find_all(self) -> List[T]: ...
    def save(self, entity: T) -> T: ...
    def delete(self, entity: T) -> None: ...


@dataclass
class Specification:
    conditions: List[str]
    params: dict
    order_by: Optional[str] = None
    limit: Optional[int] = None
    offset: Optional[int] = None

    @classmethod
    def where(cls, condition: str, **params) -> "Specification":
        return cls(conditions=[condition], params=params)

    def and_where(self, condition: str, **params) -> "Specification":
        self.conditions.append(condition)
        self.params.update(params)
        return self

    def order_by_clause(self, clause: str) -> "Specification":
        self.order_by = clause
        return self

    def limit_rows(self, limit: int) -> "Specification":
        self.limit = limit
        return self

    def offset_rows(self, offset: int) -> "Specification":
        self.offset = offset
        return self


class UnitOfWork:
    def __init__(self, session: Session):
        self.session = session
        self._committed = False

    def commit(self) -> None:
        self.session.commit()
        self._committed = True

    def rollback(self) -> None:
        self.session.rollback()

    def __enter__(self) -> "UnitOfWork":
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        if exc_type:
            self.rollback()
        elif not self._committed:
            self.commit()


class ServiceLayer:
    def __init__(self, session: Session):
        self.session = session
        self.uow = UnitOfWork(session)

    @property
    def users(self) -> UserRepositorySQLAlchemy:
        return UserRepositorySQLAlchemy(self.session)

    @property
    def products(self) -> ProductRepositorySQLAlchemy:
        return ProductRepositorySQLAlchemy(self.session)

    @property
    def orders(self) -> OrderRepositorySQLAlchemy:
        return OrderRepositorySQLAlchemy(self.session)

28.7.2 完整应用示例

python
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, EmailStr
from typing import List, Optional
from contextlib import contextmanager
import uvicorn

app = FastAPI(title="E-Commerce API")
security = HTTPBearer()


class UserCreate(BaseModel):
    username: str
    email: EmailStr
    password: str


class UserResponse(BaseModel):
    id: int
    username: str
    email: str
    is_active: bool
    created_at: datetime

    class Config:
        from_attributes = True


class ProductCreate(BaseModel):
    name: str
    description: Optional[str] = None
    price: float
    category: str
    stock: int = 0


class ProductResponse(BaseModel):
    id: int
    name: str
    description: Optional[str]
    price: float
    category: str
    stock: int

    class Config:
        from_attributes = True


class OrderItemCreate(BaseModel):
    product_id: int
    quantity: int


class OrderCreate(BaseModel):
    items: List[OrderItemCreate]
    shipping_address: str


class OrderResponse(BaseModel):
    id: int
    user_id: int
    order_date: datetime
    status: str
    total_amount: float

    class Config:
        from_attributes = True


def get_db():
    db = DatabaseManager.get_session()
    try:
        yield db
    finally:
        db.close()


def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(security),
    db: Session = Depends(get_db)
) -> User:
    token = credentials.credentials
    user_repo = UserRepositorySQLAlchemy(db)
    user = user_repo.find_by_id(int(token))
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid authentication credentials"
        )
    return user


@app.post("/users", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
def create_user(user_data: UserCreate, db: Session = Depends(get_db)):
    user_repo = UserRepositorySQLAlchemy(db)
    
    if user_repo.find_by_username(user_data.username):
        raise HTTPException(status_code=400, detail="Username already exists")
    
    if user_repo.find_by_email(user_data.email):
        raise HTTPException(status_code=400, detail="Email already exists")
    
    password_hash = UserRepository._hash_password(user_data.password)
    user = user_repo.create(
        username=user_data.username,
        email=user_data.email,
        password_hash=password_hash
    )
    return user


@app.get("/users/me", response_model=UserResponse)
def get_current_user_info(current_user: User = Depends(get_current_user)):
    return current_user


@app.get("/products", response_model=List[ProductResponse])
def list_products(
    category: Optional[str] = None,
    min_price: Optional[float] = None,
    max_price: Optional[float] = None,
    skip: int = 0,
    limit: int = 100,
    db: Session = Depends(get_db)
):
    product_repo = ProductRepositorySQLAlchemy(db)
    products = product_repo.search_products(
        keyword="",
        category=category,
        min_price=min_price,
        max_price=max_price
    )
    return products[skip:skip+limit]


@app.post("/orders", response_model=OrderResponse, status_code=status.HTTP_201_CREATED)
def create_order(
    order_data: OrderCreate,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    product_repo = ProductRepositorySQLAlchemy(db)
    order_repo = OrderRepositorySQLAlchemy(db)
    
    total_amount = 0.0
    order_items = []

    for item in order_data.items:
        product = product_repo.find_by_id(item.product_id)
        if not product:
            raise HTTPException(status_code=404, detail=f"Product {item.product_id} not found")
        
        if product.stock < item.quantity:
            raise HTTPException(status_code=400, detail=f"Insufficient stock for {product.name}")
        
        subtotal = product.price * item.quantity
        total_amount += subtotal
        order_items.append({
            "product_id": item.product_id,
            "quantity": item.quantity,
            "unit_price": product.price
        })

    order = order_repo.create(
        user_id=current_user.id,
        total_amount=total_amount,
        shipping_address=order_data.shipping_address,
        status="pending"
    )

    for item_data in order_items:
        product = product_repo.find_by_id(item_data["product_id"])
        product.stock -= item_data["quantity"]

    db.commit()
    return order


@app.get("/orders/{order_id}", response_model=OrderResponse)
def get_order(
    order_id: int,
    current_user: User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    order_repo = OrderRepositorySQLAlchemy(db)
    order = order_repo.get_order_with_details(order_id)
    
    if not order:
        raise HTTPException(status_code=404, detail="Order not found")
    
    if order.user_id != current_user.id and not current_user.is_admin:
        raise HTTPException(status_code=403, detail="Access denied")
    
    return order


if __name__ == "__main__":
    DatabaseManager.initialize("sqlite:///ecommerce.db")
    DatabaseManager.get_engine().create_tables(Base.metadata)
    uvicorn.run(app, host="0.0.0.0", port=8000)

28.8 本章小结

本章详细介绍了Python数据库编程的核心概念和实践:

  1. 数据库基础:理解关系型数据库原理、SQL语言和设计范式
  2. SQLite操作:使用sqlite3模块进行轻量级数据库操作
  3. SQLAlchemy ORM:模型定义、关系映射、查询构建
  4. 连接池管理:配置、监控和优化数据库连接
  5. 异步数据库:使用asyncio和SQLAlchemy进行异步操作
  6. 数据库迁移:使用Alembic管理数据库版本
  7. 安全实践:SQL注入防护、数据加密、权限管理
  8. 架构模式:Repository模式、Unit of Work、服务层设计

练习题

  1. 创建一个图书管理系统,使用SQLite存储图书、作者和借阅记录
  2. 使用SQLAlchemy实现一个博客系统的数据模型
  3. 实现一个支持分页、排序和过滤的通用查询构建器
  4. 使用异步SQLAlchemy实现一个高并发的API服务
  5. 设计一个支持多租户的数据库架构

扩展阅读

Python技术丛书 - 江苏省宿城中等专业学校