Skip to content

第38章 任务队列与调度

学习目标

完成本章学习后,你将能够:

  1. 理解任务队列原理:生产者-消费者模式、消息传递、异步处理
  2. 使用Celery:任务定义、队列配置、结果存储、任务监控
  3. 使用RQ(Redis Queue):简单任务队列、任务优先级、失败重试
  4. 实现定时任务:Celery Beat、APScheduler、Cron表达式
  5. 处理任务结果:结果存储、回调函数、链式任务
  6. 实现任务监控:Flower监控、日志记录、性能指标
  7. 处理任务失败:重试机制、死信队列、错误处理
  8. 构建分布式系统:多Worker、任务路由、负载均衡

38.1 任务队列原理

38.1.1 基本概念

python
from typing import Any, Callable, Dict, List, Optional, TypeVar
from dataclasses import dataclass, field
from enum import Enum
from datetime import datetime, timedelta
import queue
import threading
import time
import json
import uuid


T = TypeVar("T")


class TaskStatus(Enum):
    PENDING = "pending"
    QUEUED = "queued"
    RUNNING = "running"
    SUCCESS = "success"
    FAILURE = "failure"
    RETRY = "retry"
    REVOKED = "revoked"


@dataclass
class Task:
    id: str
    name: str
    func: Callable
    args: tuple = ()
    kwargs: Dict = field(default_factory=dict)
    status: TaskStatus = TaskStatus.PENDING
    result: Any = None
    error: Optional[str] = None
    retries: int = 0
    max_retries: int = 3
    created_at: datetime = field(default_factory=datetime.now)
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    priority: int = 5

    @property
    def duration(self) -> Optional[float]:
        if self.started_at and self.completed_at:
            return (self.completed_at - self.started_at).total_seconds()
        return None

    def to_dict(self) -> Dict:
        return {
            "id": self.id,
            "name": self.name,
            "status": self.status.value,
            "result": str(self.result) if self.result else None,
            "error": self.error,
            "retries": self.retries,
            "created_at": self.created_at.isoformat() if self.created_at else None,
            "started_at": self.started_at.isoformat() if self.started_at else None,
            "completed_at": self.completed_at.isoformat() if self.completed_at else None,
            "duration": self.duration
        }


class SimpleTaskQueue:
    def __init__(self, max_workers: int = 4):
        self._queue: queue.PriorityQueue = queue.PriorityQueue()
        self._tasks: Dict[str, Task] = {}
        self._workers: List[threading.Thread] = []
        self._max_workers = max_workers
        self._running = False
        self._lock = threading.Lock()

    def enqueue(
        self,
        func: Callable,
        *args,
        name: str = None,
        priority: int = 5,
        max_retries: int = 3,
        **kwargs
    ) -> str:
        task_id = str(uuid.uuid4())
        task = Task(
            id=task_id,
            name=name or func.__name__,
            func=func,
            args=args,
            kwargs=kwargs,
            priority=priority,
            max_retries=max_retries
        )

        with self._lock:
            self._tasks[task_id] = task

        self._queue.put((priority, task_id, task))
        task.status = TaskStatus.QUEUED

        return task_id

    def get_task(self, task_id: str) -> Optional[Task]:
        return self._tasks.get(task_id)

    def get_result(self, task_id: str) -> Any:
        task = self.get_task(task_id)
        if task and task.status == TaskStatus.SUCCESS:
            return task.result
        return None

    def start(self) -> None:
        self._running = True
        for i in range(self._max_workers):
            worker = threading.Thread(target=self._worker_loop, args=(i,))
            worker.daemon = True
            worker.start()
            self._workers.append(worker)

    def stop(self) -> None:
        self._running = False
        for worker in self._workers:
            worker.join(timeout=1)

    def _worker_loop(self, worker_id: int) -> None:
        while self._running:
            try:
                priority, task_id, task = self._queue.get(timeout=1)

                task.status = TaskStatus.RUNNING
                task.started_at = datetime.now()

                try:
                    result = task.func(*task.args, **task.kwargs)
                    task.result = result
                    task.status = TaskStatus.SUCCESS
                except Exception as e:
                    task.error = str(e)
                    task.retries += 1

                    if task.retries <= task.max_retries:
                        task.status = TaskStatus.RETRY
                        self._queue.put((priority + 1, task_id, task))
                    else:
                        task.status = TaskStatus.FAILURE

                task.completed_at = datetime.now()

            except queue.Empty:
                continue

    def __enter__(self) -> "SimpleTaskQueue":
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.stop()


class TaskResult:
    def __init__(self, task_id: str, queue: SimpleTaskQueue):
        self._task_id = task_id
        self._queue = queue

    @property
    def id(self) -> str:
        return self._task_id

    @property
    def status(self) -> TaskStatus:
        task = self._queue.get_task(self._task_id)
        return task.status if task else TaskStatus.PENDING

    @property
    def result(self) -> Any:
        return self._queue.get_result(self._task_id)

    @property
    def ready(self) -> bool:
        return self.status in (TaskStatus.SUCCESS, TaskStatus.FAILURE)

    @property
    def successful(self) -> bool:
        return self.status == TaskStatus.SUCCESS

    def get(self, timeout: float = None) -> Any:
        start_time = time.time()
        while not self.ready:
            if timeout and (time.time() - start_time) > timeout:
                raise TimeoutError(f"Task {self._task_id} did not complete in time")
            time.sleep(0.1)
        return self.result


def simple_queue_example():
    def process_data(data: str) -> str:
        time.sleep(1)
        return f"Processed: {data}"

    with SimpleTaskQueue(max_workers=2) as tq:
        task_id = tq.enqueue(process_data, "Hello World", priority=1)
        print(f"Task ID: {task_id}")

        result = TaskResult(task_id, tq)
        print(f"Status: {result.status}")

        time.sleep(2)
        print(f"Result: {result.result}")

38.1.2 消息队列基础

python
class Message:
    def __init__(
        self,
        body: Any,
        headers: Dict = None,
        message_id: str = None
    ):
        self.body = body
        self.headers = headers or {}
        self.message_id = message_id or str(uuid.uuid4())
        self.timestamp = datetime.now()

    def to_json(self) -> str:
        return json.dumps({
            "body": self.body,
            "headers": self.headers,
            "message_id": self.message_id,
            "timestamp": self.timestamp.isoformat()
        })

    @classmethod
    def from_json(cls, data: str) -> "Message":
        obj = json.loads(data)
        msg = cls(
            body=obj["body"],
            headers=obj["headers"],
            message_id=obj["message_id"]
        )
        msg.timestamp = datetime.fromisoformat(obj["timestamp"])
        return msg


class MessageQueue:
    def __init__(self, name: str):
        self.name = name
        self._queue: queue.Queue = queue.Queue()
        self._consumers: List[Callable] = []

    def publish(self, message: Message) -> None:
        self._queue.put(message)

    def consume(self, callback: Callable[[Message], None]) -> None:
        self._consumers.append(callback)

    def start_consuming(self) -> None:
        while True:
            try:
                message = self._queue.get(timeout=1)
                for consumer in self._consumers:
                    consumer(message)
            except queue.Empty:
                continue

    def get_message_count(self) -> int:
        return self._queue.qsize()


class Exchange:
    DIRECT = "direct"
    FANOUT = "fanout"
    TOPIC = "topic"

    def __init__(self, name: str, exchange_type: str = DIRECT):
        self.name = name
        self.exchange_type = exchange_type
        self._bindings: Dict[str, List[MessageQueue]] = {}

    def bind(self, queue: MessageQueue, routing_key: str = "") -> None:
        if routing_key not in self._bindings:
            self._bindings[routing_key] = []
        self._bindings[routing_key].append(queue)

    def publish(self, message: Message, routing_key: str = "") -> None:
        if self.exchange_type == self.FANOUT:
            for queues in self._bindings.values():
                for q in queues:
                    q.publish(message)
        elif self.exchange_type == self.TOPIC:
            for pattern, queues in self._bindings.items():
                if self._match_pattern(pattern, routing_key):
                    for q in queues:
                        q.publish(message)
        else:
            if routing_key in self._bindings:
                for q in self._bindings[routing_key]:
                    q.publish(message)

    def _match_pattern(self, pattern: str, routing_key: str) -> bool:
        pattern_parts = pattern.split(".")
        key_parts = routing_key.split(".")

        for i, p in enumerate(pattern_parts):
            if p == "#":
                return True
            if i >= len(key_parts):
                return False
            if p != "*" and p != key_parts[i]:
                return False

        return len(pattern_parts) == len(key_parts)


class MessageBroker:
    def __init__(self):
        self._queues: Dict[str, MessageQueue] = {}
        self._exchanges: Dict[str, Exchange] = {}

    def declare_queue(self, name: str) -> MessageQueue:
        if name not in self._queues:
            self._queues[name] = MessageQueue(name)
        return self._queues[name]

    def declare_exchange(self, name: str, exchange_type: str = Exchange.DIRECT) -> Exchange:
        if name not in self._exchanges:
            self._exchanges[name] = Exchange(name, exchange_type)
        return self._exchanges[name]

    def bind_queue(self, queue_name: str, exchange_name: str, routing_key: str = "") -> None:
        queue = self._queues.get(queue_name)
        exchange = self._exchanges.get(exchange_name)
        if queue and exchange:
            exchange.bind(queue, routing_key)

    def publish(self, exchange_name: str, routing_key: str, message: Message) -> None:
        exchange = self._exchanges.get(exchange_name)
        if exchange:
            exchange.publish(message, routing_key)

    def consume(self, queue_name: str, callback: Callable[[Message], None]) -> None:
        queue = self._queues.get(queue_name)
        if queue:
            queue.consume(callback)

38.2 Celery任务队列

38.2.1 Celery配置

python
from celery import Celery, Task
from celery.result import AsyncResult
from celery.schedules import crontab
from typing import Any, Dict, List, Optional
import time


class CeleryConfig:
    broker_url = "redis://localhost:6379/0"
    result_backend = "redis://localhost:6379/1"

    task_serializer = "json"
    result_serializer = "json"
    accept_content = ["json"]

    timezone = "Asia/Shanghai"
    enable_utc = True

    task_track_started = True
    task_time_limit = 30 * 60
    task_soft_time_limit = 25 * 60

    worker_prefetch_multiplier = 4
    worker_max_tasks_per_child = 1000

    task_default_queue = "default"
    task_queues = {
        "default": {
            "exchange": "default",
            "routing_key": "default"
        },
        "high_priority": {
            "exchange": "high_priority",
            "routing_key": "high_priority"
        },
        "low_priority": {
            "exchange": "low_priority",
            "routing_key": "low_priority"
        },
        "compute": {
            "exchange": "compute",
            "routing_key": "compute"
        },
        "io": {
            "exchange": "io",
            "routing_key": "io"
        }
    }

    task_routes = {
        "tasks.compute.*": {"queue": "compute"},
        "tasks.io.*": {"queue": "io"},
        "tasks.email.*": {"queue": "high_priority"}
    }

    task_annotations = {
        "tasks.compute.heavy_computation": {
            "rate_limit": "10/m"
        }
    }

    beat_schedule = {
        "cleanup-every-hour": {
            "task": "tasks.cleanup.cleanup_expired_data",
            "schedule": crontab(minute=0),
        },
        "send-daily-report": {
            "task": "tasks.reports.send_daily_report",
            "schedule": crontab(hour=8, minute=0),
        },
        "check-health-every-minute": {
            "task": "tasks.monitoring.check_system_health",
            "schedule": 60.0,
        }
    }


app = Celery("myapp")
app.config_from_object(CeleryConfig)


class BaseTask(Task):
    def on_success(self, retval: Any, task_id: str, args: tuple, kwargs: dict) -> None:
        print(f"Task {task_id} succeeded with result: {retval}")

    def on_failure(self, exc: Exception, task_id: str, args: tuple, kwargs: dict, einfo) -> None:
        print(f"Task {task_id} failed with error: {exc}")

    def on_retry(self, exc: Exception, task_id: str, args: tuple, kwargs: dict, einfo) -> None:
        print(f"Task {task_id} retrying due to: {exc}")


@app.task(base=BaseTask, bind=True, max_retries=3, default_retry_delay=60)
def process_data(self, data: Dict) -> Dict:
    try:
        result = {"processed": data, "timestamp": time.time()}
        return result
    except Exception as exc:
        raise self.retry(exc=exc)


@app.task(base=BaseTask, bind=True)
def send_email(self, to: str, subject: str, body: str) -> bool:
    print(f"Sending email to {to}: {subject}")
    time.sleep(2)
    return True


@app.task(base=BaseTask, bind=True)
def heavy_computation(self, n: int) -> int:
    result = sum(i * i for i in range(n))
    return result


class CeleryManager:
    def __init__(self, app: Celery):
        self.app = app

    def submit_task(
        self,
        task_name: str,
        *args,
        queue: str = None,
        priority: int = None,
        countdown: float = None,
        eta: datetime = None,
        expires: float = None,
        **kwargs
    ) -> AsyncResult:
        task = self.app.send_task(
            task_name,
            args=args,
            kwargs=kwargs,
            queue=queue,
            priority=priority,
            countdown=countdown,
            eta=eta,
            expires=expires
        )
        return task

    def get_task_result(self, task_id: str) -> AsyncResult:
        return AsyncResult(task_id, app=self.app)

    def revoke_task(self, task_id: str, terminate: bool = False) -> None:
        self.app.control.revoke(task_id, terminate=terminate)

    def get_active_tasks(self) -> List[Dict]:
        inspect = self.app.control.inspect()
        return inspect.active()

    def get_scheduled_tasks(self) -> List[Dict]:
        inspect = self.app.control.inspect()
        return inspect.scheduled()

    def get_worker_stats(self) -> Dict:
        inspect = self.app.control.inspect()
        return inspect.stats()

38.2.2 任务链与工作流

python
from celery import chain, group, chord, signature


class TaskWorkflow:
    def __init__(self, app: Celery):
        self.app = app

    def create_chain(self, *tasks) -> Any:
        task_chain = chain(*tasks)
        return task_chain

    def create_group(self, *tasks) -> Any:
        task_group = group(*tasks)
        return task_group

    def create_chord(self, header_tasks: list, body_task) -> Any:
        task_chord = chord(header_tasks)(body_task)
        return task_chord

    def execute_pipeline(self, data: Dict) -> AsyncResult:
        pipeline = chain(
            process_data.s(data),
            validate_data.s(),
            save_data.s()
        )
        return pipeline()

    def execute_parallel(self, items: List[Dict]) -> AsyncResult:
        task_group = group(
            process_data.s(item) for item in items
        )
        return task_group()

    def execute_map_reduce(self, items: List[Dict]) -> AsyncResult:
        map_tasks = group(
            process_data.s(item) for item in items
        )
        reduce_task = aggregate_results.s()

        return chord(map_tasks)(reduce_task)


@app.task
def validate_data(data: Dict) -> Dict:
    if not data.get("processed"):
        raise ValueError("Data not processed")
    return data


@app.task
def save_data(data: Dict) -> str:
    print(f"Saving data: {data}")
    return f"saved_{id(data)}"


@app.task
def aggregate_results(results: List[Dict]) -> Dict:
    return {
        "total": len(results),
        "results": results
    }


class TaskSignature:
    def __init__(self, app: Celery):
        self.app = app

    def create_signature(
        self,
        task_name: str,
        args: tuple = (),
        kwargs: dict = None,
        immutable: bool = False
    ) -> signature:
        return signature(
            task_name,
            args=args,
            kwargs=kwargs or {},
            immutable=immutable,
            app=self.app
        )

    def create_partial(self, task_name: str, *args, **kwargs) -> signature:
        return self.app.signature(task_name, args=args, kwargs=kwargs)

    def apply_async_with_options(
        self,
        sig: signature,
        queue: str = None,
        priority: int = None,
        retry: bool = True,
        retry_policy: dict = None
    ) -> AsyncResult:
        return sig.apply_async(
            queue=queue,
            priority=priority,
            retry=retry,
            retry_policy=retry_policy
        )

38.3 RQ任务队列

38.3.1 RQ基础

python
from redis import Redis
from rq import Queue, Worker, job
from rq.job import Job
from rq.registry import FailedJobRegistry, FinishedJobRegistry, StartedJobRegistry
from typing import Any, Callable, Dict, List, Optional
import time


class RQManager:
    def __init__(self, redis_url: str = "redis://localhost:6379/0"):
        self.redis = Redis.from_url(redis_url)
        self._queues: Dict[str, Queue] = {}

    def get_queue(self, name: str = "default") -> Queue:
        if name not in self._queues:
            self._queues[name] = Queue(name, connection=self.redis)
        return self._queues[name]

    def enqueue(
        self,
        func: Callable,
        *args,
        queue: str = "default",
        timeout: int = None,
        result_ttl: int = 500,
        failure_ttl: int = None,
        description: str = None,
        **kwargs
    ) -> Job:
        q = self.get_queue(queue)
        return q.enqueue(
            func,
            *args,
            timeout=timeout,
            result_ttl=result_ttl,
            failure_ttl=failure_ttl,
            description=description,
            **kwargs
        )

    def enqueue_at(
        self,
        datetime_obj: datetime,
        func: Callable,
        *args,
        queue: str = "default",
        **kwargs
    ) -> Job:
        from rq.scheduler import enqueue_at
        return enqueue_at(
            self.get_queue(queue),
            datetime_obj,
            func,
            *args,
            **kwargs
        )

    def enqueue_in(
        self,
        time_delta: timedelta,
        func: Callable,
        *args,
        queue: str = "default",
        **kwargs
    ) -> Job:
        from rq.scheduler import enqueue_in
        return enqueue_in(
            self.get_queue(queue),
            time_delta,
            func,
            *args,
            **kwargs
        )

    def get_job(self, job_id: str) -> Optional[Job]:
        try:
            return Job.fetch(job_id, connection=self.redis)
        except:
            return None

    def get_job_status(self, job_id: str) -> str:
        job = self.get_job(job_id)
        if job:
            return job.get_status()
        return "unknown"

    def get_job_result(self, job_id: str) -> Any:
        job = self.get_job(job_id)
        if job and job.is_finished:
            return job.result
        return None

    def cancel_job(self, job_id: str) -> None:
        job = self.get_job(job_id)
        if job:
            job.cancel()

    def requeue_job(self, job_id: str) -> Job:
        job = self.get_job(job_id)
        if job:
            return job.requeue()
        return None

    def get_failed_jobs(self, queue: str = "default") -> List[Job]:
        q = self.get_queue(queue)
        registry = FailedJobRegistry(queue=q)
        return [Job.fetch(job_id, connection=self.redis) for job_id in registry.get_job_ids()]

    def get_finished_jobs(self, queue: str = "default") -> List[Job]:
        q = self.get_queue(queue)
        registry = FinishedJobRegistry(queue=q)
        return [Job.fetch(job_id, connection=self.redis) for job_id in registry.get_job_ids()]

    def get_queue_length(self, queue: str = "default") -> int:
        return self.get_queue(queue).count

    def get_workers(self) -> List[Worker]:
        return Worker.all(connection=self.redis)

    def get_worker_stats(self, worker_name: str) -> Dict:
        workers = Worker.all(connection=self.redis)
        for worker in workers:
            if worker.name == worker_name:
                return {
                    "name": worker.name,
                    "state": worker.get_state(),
                    "current_job": worker.get_current_job(),
                    "successful_jobs": worker.successful_job_count,
                    "failed_jobs": worker.failed_job_count,
                    "queues": [q.name for q in worker.queues]
                }
        return {}


@job("default", connection=Redis())
def process_file(file_path: str) -> Dict:
    time.sleep(2)
    return {"file": file_path, "status": "processed"}


@job("high_priority", connection=Redis())
def send_notification(user_id: int, message: str) -> bool:
    print(f"Sending notification to user {user_id}: {message}")
    return True


class RQWorkerManager:
    def __init__(self, redis_url: str = "redis://localhost:6379/0"):
        self.redis = Redis.from_url(redis_url)

    def create_worker(self, queues: List[str] = None, name: str = None) -> Worker:
        queues = queues or ["default"]
        queue_objects = [Queue(q, connection=self.redis) for q in queues]
        return Worker(
            queue_objects,
            connection=self.redis,
            name=name
        )

    def start_worker(self, queues: List[str] = None, name: str = None) -> None:
        worker = self.create_worker(queues, name)
        worker.work()

    def get_worker_count(self) -> int:
        return len(Worker.all(connection=self.redis))

    def stop_all_workers(self) -> None:
        workers = Worker.all(connection=self.redis)
        for worker in workers:
            worker.request_stop()

38.3.2 RQ任务依赖

python
class RQTaskPipeline:
    def __init__(self, redis_url: str = "redis://localhost:6379/0"):
        self.redis = Redis.from_url(redis_url)
        self.queue = Queue(connection=self.redis)

    def execute_pipeline(self, tasks: List[Callable], *args, **kwargs) -> Job:
        previous_job = None

        for task in tasks:
            if previous_job is None:
                job = self.queue.enqueue(task, *args, **kwargs)
            else:
                job = self.queue.enqueue(task, depends_on=previous_job)
            previous_job = job

        return previous_job

    def execute_parallel(self, tasks: List[tuple]) -> List[Job]:
        jobs = []
        for task_info in tasks:
            if isinstance(task_info, tuple):
                func, args, kwargs = task_info[0], task_info[1:2] or (), task_info[2] if len(task_info) > 2 else {}
            else:
                func, args, kwargs = task_info, (), {}

            job = self.queue.enqueue(func, *args, **kwargs)
            jobs.append(job)

        return jobs

    def execute_with_callback(
        self,
        main_task: Callable,
        callback_task: Callable,
        *args,
        **kwargs
    ) -> Job:
        main_job = self.queue.enqueue(main_task, *args, **kwargs)
        callback_job = self.queue.enqueue(callback_task, depends_on=main_job)
        return callback_job


class RQRetryHandler:
    def __init__(self, max_retries: int = 3, retry_intervals: List[int] = None):
        self.max_retries = max_retries
        self.retry_intervals = retry_intervals or [60, 300, 900]

    def should_retry(self, job: Job, exception: Exception) -> bool:
        retry_count = job.meta.get("retry_count", 0)
        return retry_count < self.max_retries

    def get_retry_interval(self, retry_count: int) -> int:
        if retry_count < len(self.retry_intervals):
            return self.retry_intervals[retry_count]
        return self.retry_intervals[-1]

    def handle_failure(self, job: Job, exception: Exception) -> None:
        retry_count = job.meta.get("retry_count", 0)

        if self.should_retry(job, exception):
            retry_count += 1
            job.meta["retry_count"] = retry_count
            job.save()

            interval = self.get_retry_interval(retry_count - 1)
            job.requeue()
        else:
            print(f"Job {job.id} failed after {retry_count} retries")

38.4 定时任务调度

38.4.1 APScheduler

python
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from apscheduler.triggers.date import DateTrigger
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.jobstores.redis import RedisJobStore
from apscheduler.executors.pool import ThreadPoolExecutor, ProcessPoolExecutor
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, List, Optional


class SchedulerConfig:
    jobstores = {
        "default": MemoryJobStore(),
        "redis": RedisJobStore(jobs_key="apscheduler.jobs", run_times_key="apscheduler.run_times", host="localhost", port=6379)
    }

    executors = {
        "default": ThreadPoolExecutor(20),
        "processpool": ProcessPoolExecutor(5)
    }

    job_defaults = {
        "coalesce": True,
        "max_instances": 3,
        "misfire_grace_time": 60
    }


class TaskScheduler:
    def __init__(self, config: Dict = None):
        self.scheduler = BackgroundScheduler()
        self._configure(config or SchedulerConfig().__dict__)
        self._jobs: Dict[str, Any] = {}

    def _configure(self, config: Dict) -> None:
        self.scheduler.configure(**config)

    def start(self) -> None:
        self.scheduler.start()

    def stop(self, wait: bool = True) -> None:
        self.scheduler.shutdown(wait=wait)

    def add_interval_job(
        self,
        func: Callable,
        job_id: str = None,
        seconds: int = None,
        minutes: int = None,
        hours: int = None,
        start_date: datetime = None,
        end_date: datetime = None,
        **kwargs
    ) -> str:
        trigger = IntervalTrigger(
            seconds=seconds,
            minutes=minutes,
            hours=hours,
            start_date=start_date,
            end_date=end_date
        )

        job = self.scheduler.add_job(
            func,
            trigger=trigger,
            id=job_id,
            **kwargs
        )

        self._jobs[job.id] = job
        return job.id

    def add_cron_job(
        self,
        func: Callable,
        job_id: str = None,
        year: str = None,
        month: str = None,
        day: str = None,
        week: str = None,
        day_of_week: str = None,
        hour: str = None,
        minute: str = None,
        second: str = None,
        start_date: datetime = None,
        end_date: datetime = None,
        **kwargs
    ) -> str:
        trigger = CronTrigger(
            year=year,
            month=month,
            day=day,
            week=week,
            day_of_week=day_of_week,
            hour=hour,
            minute=minute,
            second=second,
            start_date=start_date,
            end_date=end_date
        )

        job = self.scheduler.add_job(
            func,
            trigger=trigger,
            id=job_id,
            **kwargs
        )

        self._jobs[job.id] = job
        return job.id

    def add_date_job(
        self,
        func: Callable,
        run_date: datetime,
        job_id: str = None,
        **kwargs
    ) -> str:
        trigger = DateTrigger(run_date)

        job = self.scheduler.add_job(
            func,
            trigger=trigger,
            id=job_id,
            **kwargs
        )

        self._jobs[job.id] = job
        return job.id

    def remove_job(self, job_id: str) -> None:
        if job_id in self._jobs:
            self.scheduler.remove_job(job_id)
            del self._jobs[job_id]

    def pause_job(self, job_id: str) -> None:
        self.scheduler.pause_job(job_id)

    def resume_job(self, job_id: str) -> None:
        self.scheduler.resume_job(job_id)

    def get_job(self, job_id: str) -> Optional[Any]:
        return self.scheduler.get_job(job_id)

    def get_jobs(self) -> List[Any]:
        return self.scheduler.get_jobs()

    def get_next_run_time(self, job_id: str) -> Optional[datetime]:
        job = self.get_job(job_id)
        return job.next_run_time if job else None

    def modify_job(self, job_id: str, **changes) -> None:
        self.scheduler.modify_job(job_id, **changes)

    def print_jobs(self) -> None:
        for job in self.get_jobs():
            print(f"Job ID: {job.id}")
            print(f"  Next run: {job.next_run_time}")
            print(f"  Trigger: {job.trigger}")


class CronExpressionParser:
    @staticmethod
    def parse(expression: str) -> Dict[str, str]:
        parts = expression.split()
        if len(parts) != 5:
            raise ValueError("Invalid cron expression")

        return {
            "minute": parts[0],
            "hour": parts[1],
            "day": parts[2],
            "month": parts[3],
            "day_of_week": parts[4]
        }

    @staticmethod
    def to_trigger(expression: str) -> CronTrigger:
        parsed = CronExpressionParser.parse(expression)
        return CronTrigger(**parsed)

    @staticmethod
    def describe(expression: str) -> str:
        descriptions = {
            "* * * * *": "Every minute",
            "0 * * * *": "Every hour",
            "0 0 * * *": "Every day at midnight",
            "0 9 * * 1-5": "Every weekday at 9 AM",
            "0 0 1 * *": "On the first day of every month",
            "0 0 1 1 *": "Once a year on January 1st"
        }
        return descriptions.get(expression, f"Custom schedule: {expression}")


def scheduled_task_example():
    scheduler = TaskScheduler()
    scheduler.start()

    def cleanup_task():
        print(f"Running cleanup at {datetime.now()}")

    def report_task():
        print(f"Generating report at {datetime.now()}")

    scheduler.add_interval_job(cleanup_task, seconds=30, job_id="cleanup")

    scheduler.add_cron_job(report_task, hour=8, minute=0, job_id="daily_report")

    scheduler.add_date_job(
        lambda: print("One-time task executed!"),
        datetime.now() + timedelta(minutes=5),
        job_id="one_time"
    )

    scheduler.print_jobs()

38.4.2 任务监控

python
class TaskMonitor:
    def __init__(self, scheduler: TaskScheduler):
        self.scheduler = scheduler
        self._stats: Dict[str, Dict] = {}

    def record_execution(self, job_id: str, success: bool, duration: float) -> None:
        if job_id not in self._stats:
            self._stats[job_id] = {
                "total_runs": 0,
                "successful_runs": 0,
                "failed_runs": 0,
                "total_duration": 0,
                "last_run": None,
                "last_success": None,
                "last_failure": None
            }

        stats = self._stats[job_id]
        stats["total_runs"] += 1
        stats["total_duration"] += duration
        stats["last_run"] = datetime.now()

        if success:
            stats["successful_runs"] += 1
            stats["last_success"] = datetime.now()
        else:
            stats["failed_runs"] += 1
            stats["last_failure"] = datetime.now()

    def get_stats(self, job_id: str = None) -> Dict:
        if job_id:
            return self._stats.get(job_id, {})
        return self._stats.copy()

    def get_success_rate(self, job_id: str) -> float:
        stats = self._stats.get(job_id)
        if not stats or stats["total_runs"] == 0:
            return 0.0
        return stats["successful_runs"] / stats["total_runs"]

    def get_average_duration(self, job_id: str) -> float:
        stats = self._stats.get(job_id)
        if not stats or stats["total_runs"] == 0:
            return 0.0
        return stats["total_duration"] / stats["total_runs"]

    def get_health_status(self, job_id: str) -> str:
        success_rate = self.get_success_rate(job_id)

        if success_rate >= 0.95:
            return "healthy"
        elif success_rate >= 0.8:
            return "warning"
        else:
            return "critical"


class TaskAlertManager:
    def __init__(self, monitor: TaskMonitor):
        self.monitor = monitor
        self._alerts: List[Dict] = []
        self._handlers: List[Callable] = []

    def add_alert_handler(self, handler: Callable) -> None:
        self._handlers.append(handler)

    def check_alerts(self) -> List[Dict]:
        new_alerts = []

        for job_id, stats in self.monitor._stats.items():
            health = self.monitor.get_health_status(job_id)

            if health == "critical":
                alert = {
                    "job_id": job_id,
                    "level": "critical",
                    "message": f"Job {job_id} has low success rate: {self.monitor.get_success_rate(job_id):.2%}",
                    "timestamp": datetime.now()
                }
                new_alerts.append(alert)

            elif health == "warning":
                alert = {
                    "job_id": job_id,
                    "level": "warning",
                    "message": f"Job {job_id} success rate is declining: {self.monitor.get_success_rate(job_id):.2%}",
                    "timestamp": datetime.now()
                }
                new_alerts.append(alert)

        for alert in new_alerts:
            self._alerts.append(alert)
            for handler in self._handlers:
                handler(alert)

        return new_alerts

    def get_alerts(self, level: str = None) -> List[Dict]:
        if level:
            return [a for a in self._alerts if a["level"] == level]
        return self._alerts.copy()

    def clear_alerts(self) -> None:
        self._alerts.clear()

38.5 知识图谱

38.5.1 任务队列架构体系

┌─────────────────────────────────────────────────────────────────────┐
│                      任务队列系统架构                                 │
├─────────────────────────────────────────────────────────────────────┤
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      生产者层 (Producer)                      │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │ Web应用  │ │ API服务  │ │ 定时任务  │ │ 手动触发  │       │   │
│  │  │ Flask    │ │ FastAPI  │ │ CeleryBeat│ │ CLI命令  │       │   │
│  │  └────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘       │   │
│  └───────┼────────────┼────────────┼────────────┼──────────────┘   │
│          │            │            │            │                   │
│          └────────────┴────────────┴────────────┘                   │
│                                │                                    │
│                                ▼                                    │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      消息代理层 (Broker)                      │   │
│  │  ┌──────────────────────────────────────────────────────┐   │   │
│  │  │                    Redis / RabbitMQ                    │   │   │
│  │  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │   │   │
│  │  │  │ 队列管理  │ │ 消息路由  │ │ 持久化   │ │ 集群支持  │ │   │   │
│  │  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘ │   │   │
│  │  └──────────────────────────────────────────────────────┘   │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                │                                    │
│          ┌─────────────────────┼─────────────────────┐             │
│          ▼                     ▼                     ▼             │
│  ┌──────────────┐      ┌──────────────┐      ┌──────────────┐     │
│  │   Worker 1   │      │   Worker 2   │      │   Worker N   │     │
│  │  ┌────────┐  │      │  ┌────────┐  │      │  ┌────────┐  │     │
│  │  │ 任务执行│  │      │  │ 任务执行│  │      │  │ 任务执行│  │     │
│  │  │ 进程池  │  │      │  │ 协程池  │  │      │  │ 线程池  │  │     │
│  │  └────────┘  │      │  └────────┘  │      │  └────────┘  │     │
│  └───────┬──────┘      └───────┬──────┘      └───────┬──────┘     │
│          │                     │                     │             │
│          └─────────────────────┼─────────────────────┘             │
│                                ▼                                    │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                      结果存储层 (Backend)                     │   │
│  │  ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐       │   │
│  │  │  Redis   │ │  RPC     │ │ Database │ │  文件    │       │   │
│  │  │ 结果缓存  │ │ 远程调用 │ │ 持久存储  │ │ 日志存储 │       │   │
│  │  └──────────┘ └──────────┘ └──────────┘ └──────────┘       │   │
│  └─────────────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────────────┘

38.5.2 任务生命周期

┌─────────────────────────────────────────────────────────────────────┐
│                        任务生命周期流程                               │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│   ┌──────────┐                                                      │
│   │ 创建任务  │                                                      │
│   │ PENDING  │                                                      │
│   └────┬─────┘                                                      │
│        │                                                            │
│        ▼                                                            │
│   ┌──────────┐    ┌──────────┐    ┌──────────┐                    │
│   │ 发送队列  │───▶│ 等待执行  │───▶│ 被Worker │                    │
│   │ QUEUED   │    │ SCHEDULED│    │ 获取     │                    │
│   └──────────┘    └──────────┘    └────┬─────┘                    │
│                                        │                           │
│                          ┌─────────────┴─────────────┐             │
│                          ▼                           ▼             │
│                    ┌──────────┐               ┌──────────┐        │
│                    │ 开始执行  │               │ 任务撤销  │        │
│                    │ RUNNING  │               │ REVOKED  │        │
│                    └────┬─────┘               └──────────┘        │
│                         │                                         │
│              ┌──────────┼──────────┐                              │
│              ▼          ▼          ▼                              │
│        ┌──────────┐ ┌──────────┐ ┌──────────┐                    │
│        │ 执行成功  │ │ 执行失败  │ │ 执行超时  │                    │
│        │ SUCCESS  │ │ FAILURE  │ │ TIMEOUT  │                    │
│        └────┬─────┘ └────┬─────┘ └────┬─────┘                    │
│             │            │            │                           │
│             │            ▼            │                           │
│             │      ┌──────────┐      │                           │
│             │      │ 重试判断  │      │                           │
│             │      └────┬─────┘      │                           │
│             │           │            │                           │
│             │    ┌──────┴──────┐     │                           │
│             │    ▼             ▼     │                           │
│             │ ┌────────┐ ┌────────┐ │                           │
│             │ │ 重试   │ │ 放弃   │ │                           │
│             │ │ RETRY  │ │ FAILURE│ │                           │
│             │ └───┬────┘ └────────┘ │                           │
│             │     │                 │                           │
│             │     └─────────────────┘                           │
│             │                                                     │
│             ▼                                                     │
│      ┌──────────────┐                                            │
│      │   结果存储    │                                            │
│      │   Backend    │                                            │
│      └──────────────┘                                            │
│                                                                   │
└─────────────────────────────────────────────────────────────────────┘

38.6 技术选型指南

38.6.1 任务队列框架选型

框架适用场景复杂度性能功能丰富度推荐指数
Celery企业级分布式任务★★★★★★★★★★
RQ简单任务队列★★★☆☆★★★★☆
Dramatiq现代任务队列★★★★☆★★★★☆
Huey轻量级任务队列★★★☆☆★★★☆☆
Dask计算密集型任务极高★★★★★★★★★☆

38.6.2 消息代理选型

消息代理吞吐量持久化协议支持运维复杂度推荐场景
Redis极高可选自定义高性能、简单场景
RabbitMQ原生AMQP企业级、复杂路由
Kafka极高原生自定义大数据、流处理
SQS原生AWS极低云原生应用

38.6.3 定时任务方案选型

方案精度分布式支持持久化动态调度适用场景
Celery Beat秒级分布式定时任务
APScheduler秒级单机定时任务
Cron分钟级系统级定时任务
Systemd Timer秒级Linux系统服务

38.7 常见问题与解决方案

38.7.1 任务重试策略

python
from celery import Celery
from celery.exceptions import Retry
import time

app = Celery('tasks')

class RetryStrategy:
    """任务重试策略"""
    
    @staticmethod
    @app.task(bind=True, max_retries=5, default_retry_delay=60)
    def task_with_fixed_delay(self, data):
        """固定延迟重试"""
        try:
            return process_data(data)
        except Exception as exc:
            raise self.retry(exc=exc)
    
    @staticmethod
    @app.task(bind=True, max_retries=5)
    def task_with_exponential_backoff(self, data):
        """指数退避重试"""
        try:
            return process_data(data)
        except Exception as exc:
            retry_count = self.request.retries
            backoff = min(2 ** retry_count, 300)  # 最大300秒
            raise self.retry(exc=exc, countdown=backoff)
    
    @staticmethod
    @app.task(bind=True, autoretry_for=(ConnectionError, TimeoutError),
              retry_kwargs={'max_retries': 5, 'countdown': 5})
    def task_with_autoretry(self, url):
        """自动重试特定异常"""
        return fetch_url(url)
    
    @staticmethod
    @app.task(bind=True, max_retries=3)
    def task_with_custom_logic(self, data):
        """自定义重试逻辑"""
        try:
            result = process_data(data)
            return result
        except ValueError as exc:
            raise exc  # 不重试业务错误
        except ConnectionError as exc:
            if self.request.retries < 3:
                raise self.retry(exc=exc, countdown=60)
            raise  # 超过重试次数,抛出异常


class CircuitBreakerTask:
    """熔断器任务"""
    
    def __init__(self, failure_threshold=5, recovery_timeout=60):
        self.failure_count = 0
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.last_failure_time = None
        self.state = 'closed'  # closed, open, half_open
    
    def can_execute(self):
        """检查是否可以执行"""
        if self.state == 'closed':
            return True
        elif self.state == 'open':
            if time.time() - self.last_failure_time > self.recovery_timeout:
                self.state = 'half_open'
                return True
            return False
        else:  # half_open
            return True
    
    def record_success(self):
        """记录成功"""
        self.failure_count = 0
        self.state = 'closed'
    
    def record_failure(self):
        """记录失败"""
        self.failure_count += 1
        self.last_failure_time = time.time()
        if self.failure_count >= self.failure_threshold:
            self.state = 'open'


def process_data(data):
    return data

def fetch_url(url):
    return url

38.7.2 任务超时处理

python
import signal
from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded

class TimeoutHandler:
    """任务超时处理"""
    
    @staticmethod
    @app.task(bind=True, time_limit=30, soft_time_limit=25)
    def task_with_timeout(self, data):
        """带超时限制的任务"""
        try:
            return long_running_process(data)
        except SoftTimeLimitExceeded:
            return handle_soft_timeout(data)
        except TimeLimitExceeded:
            return handle_hard_timeout(data)
    
    @staticmethod
    @app.task(bind=True)
    def task_with_manual_timeout(self, data, timeout=30):
        """手动超时控制"""
        import threading
        
        result = [None]
        exception = [None]
        
        def worker():
            try:
                result[0] = long_running_process(data)
            except Exception as e:
                exception[0] = e
        
        thread = threading.Thread(target=worker)
        thread.start()
        thread.join(timeout=timeout)
        
        if thread.is_alive():
            return {'status': 'timeout', 'message': f'Task exceeded {timeout}s'}
        
        if exception[0]:
            raise exception[0]
        
        return result[0]
    
    @staticmethod
    @app.task(bind=True)
    def task_with_signal_timeout(self, data, timeout=30):
        """使用信号的超时控制(仅Unix)"""
        def timeout_handler(signum, frame):
            raise TimeoutError(f'Task timed out after {timeout} seconds')
        
        signal.signal(signal.SIGALRM, timeout_handler)
        signal.alarm(timeout)
        
        try:
            result = long_running_process(data)
        finally:
            signal.alarm(0)
        
        return result


def long_running_process(data):
    return data

def handle_soft_timeout(data):
    return {'status': 'soft_timeout', 'data': data}

def handle_hard_timeout(data):
    return {'status': 'hard_timeout', 'data': data}

38.7.3 任务优先级与路由

python
from celery import Celery

app = Celery('tasks')

class TaskRouter:
    """任务路由配置"""
    
    ROUTING_CONFIG = {
        'tasks.compute.*': {'queue': 'compute'},
        'tasks.io.*': {'queue': 'io'},
        'tasks.email.*': {'queue': 'email'},
        'tasks.report.*': {'queue': 'report'},
    }
    
    @staticmethod
    def route_task(name, args, kwargs):
        """自定义路由逻辑"""
        for pattern, options in TaskRouter.ROUTING_CONFIG.items():
            if name.startswith(pattern.replace('*', '')):
                return options
        return {'queue': 'default'}


class PriorityTaskManager:
    """优先级任务管理"""
    
    PRIORITY_LEVELS = {
        'critical': 0,
        'high': 3,
        'normal': 5,
        'low': 7,
        'background': 9
    }
    
    @staticmethod
    @app.task(bind=True, priority=5)
    def normal_priority_task(self, data):
        """普通优先级任务"""
        return process_data(data)
    
    @staticmethod
    def submit_high_priority(task_func, *args, **kwargs):
        """提交高优先级任务"""
        return task_func.apply_async(
            args=args,
            kwargs=kwargs,
            priority=PriorityTaskManager.PRIORITY_LEVELS['high']
        )
    
    @staticmethod
    def submit_low_priority(task_func, *args, **kwargs):
        """提交低优先级任务"""
        return task_func.apply_async(
            args=args,
            kwargs=kwargs,
            priority=PriorityTaskManager.PRIORITY_LEVELS['low']
        )


class TaskDependencyManager:
    """任务依赖管理"""
    
    @staticmethod
    def chain_tasks(*tasks):
        """链式任务"""
        from celery import chain
        workflow = chain(*tasks)
        return workflow.apply_async()
    
    @staticmethod
    def parallel_tasks(*tasks):
        """并行任务"""
        from celery import group
        workflow = group(*tasks)
        return workflow.apply_async()
    
    @staticmethod
    def chord_tasks(header_tasks, callback):
        """和弦任务(Map-Reduce)"""
        from celery import chord
        workflow = chord(header_tasks)(callback)
        return workflow


def process_data(data):
    return data

38.7.4 任务监控与告警

python
from celery import Celery
from celery.events import EventReceiver
from celery.signals import task_success, task_failure, task_retry
import logging

app = Celery('tasks')

class TaskMonitor:
    """任务监控器"""
    
    def __init__(self):
        self.stats = {
            'total': 0,
            'success': 0,
            'failure': 0,
            'retry': 0
        }
        self.slow_tasks = []
        self.failed_tasks = []
        self.logger = logging.getLogger(__name__)
    
    def record_success(self, task_id, result, runtime):
        """记录成功任务"""
        self.stats['total'] += 1
        self.stats['success'] += 1
        
        if runtime > 60:  # 超过60秒为慢任务
            self.slow_tasks.append({
                'task_id': task_id,
                'runtime': runtime,
                'timestamp': time.time()
            })
            self.logger.warning(f'Slow task detected: {task_id} took {runtime}s')
    
    def record_failure(self, task_id, exception, traceback):
        """记录失败任务"""
        self.stats['total'] += 1
        self.stats['failure'] += 1
        
        self.failed_tasks.append({
            'task_id': task_id,
            'exception': str(exception),
            'timestamp': time.time()
        })
        self.logger.error(f'Task failed: {task_id} - {exception}')
    
    def record_retry(self, task_id, reason):
        """记录重试任务"""
        self.stats['retry'] += 1
        self.logger.info(f'Task retry: {task_id} - {reason}')
    
    def get_health_status(self):
        """获取健康状态"""
        if self.stats['total'] == 0:
            return 'unknown'
        
        failure_rate = self.stats['failure'] / self.stats['total']
        
        if failure_rate < 0.01:
            return 'healthy'
        elif failure_rate < 0.1:
            return 'warning'
        else:
            return 'critical'


monitor = TaskMonitor()

@task_success.connect
def on_task_success(sender=None, result=None, **kwargs):
    """任务成功信号处理"""
    task_id = sender.request.id if sender else None
    runtime = sender.request.runtime if sender else 0
    monitor.record_success(task_id, result, runtime)

@task_failure.connect
def on_task_failure(sender=None, exception=None, traceback=None, **kwargs):
    """任务失败信号处理"""
    task_id = sender.request.id if sender else None
    monitor.record_failure(task_id, exception, traceback)

@task_retry.connect
def on_task_retry(sender=None, reason=None, **kwargs):
    """任务重试信号处理"""
    task_id = sender.request.id if sender else None
    monitor.record_retry(task_id, reason)


class AlertManager:
    """告警管理器"""
    
    def __init__(self, monitor: TaskMonitor):
        self.monitor = monitor
        self.alert_handlers = []
    
    def add_handler(self, handler):
        """添加告警处理器"""
        self.alert_handlers.append(handler)
    
    def check_and_alert(self):
        """检查并发送告警"""
        status = self.monitor.get_health_status()
        
        if status == 'critical':
            alert = {
                'level': 'critical',
                'message': f"Task failure rate is critical: {self.monitor.stats}",
                'timestamp': time.time()
            }
            self._send_alert(alert)
        
        elif status == 'warning':
            alert = {
                'level': 'warning',
                'message': f"Task failure rate is elevated: {self.monitor.stats}",
                'timestamp': time.time()
            }
            self._send_alert(alert)
    
    def _send_alert(self, alert):
        """发送告警"""
        for handler in self.alert_handlers:
            try:
                handler(alert)
            except Exception as e:
                logging.error(f"Alert handler failed: {e}")

import time

38.8 本章小结

本章详细介绍了Python任务队列与调度的核心概念和实践:

  1. 任务队列原理:生产者-消费者模式、消息传递
  2. Celery:任务定义、队列配置、任务链、工作流
  3. RQ:简单任务队列、任务依赖、失败处理
  4. 定时任务:APScheduler、Cron表达式、任务调度
  5. 任务监控:执行统计、健康检查、告警管理
  6. 分布式处理:多Worker、任务路由、负载均衡

练习题

  1. 实现一个支持优先级的任务队列系统
  2. 开发一个任务重试机制,支持指数退避
  3. 实现一个任务依赖图执行引擎
  4. 开发一个定时任务管理面板,支持动态添加/删除任务
  5. 实现一个分布式任务调度系统,支持任务分片

扩展阅读

Python技术丛书 - 江苏省宿城中等专业学校