Thread Pool

发布于 1 个月前·
3 人看过
·
约 4 分钟
·103 行代码
#pragma once
#include <atomic>
#include <functional>
#include <future>
#include <iostream>
#include <thread>
#include <tuple>
#include <vector>

#include "SafeQueue.hpp"

// 任务类型:最底层的无参 void 函数
using Task = std::function<void()>;

class ThreadPool {
private:
    SafeQueue<Task> m_queue;             // 任务队列
    std::vector<std::thread> m_threads;  // 线程组
    std::atomic<bool> m_running;         // 运行标志

public:
    // 构造函数
    // max_queue_size: 允许传递队列最大容量,默认 1000
    ThreadPool(int num_threads, int max_queue_size = 1000) : m_running(true), m_queue(max_queue_size) {
        for (int i = 0; i < num_threads; ++i) {
            m_threads.emplace_back([this, i] {
                // std::cout << "🧵 线程 " << i << " 启动..." << std::endl;
                while (m_running) {
                    try {
                        Task task;
                        // 这里会阻塞,直到有任务或者收到退出信号
                        m_queue.wait_pop(task);

                        if (task) {
                            task();  // 执行任务
                        }
                    } catch (const std::exception &e) {
                        std::cerr << "❌ 线程异常: " << e.what() << std::endl;
                    }
                }
            });
        }
    }

    // 禁止拷贝和移动 (Mutex不可移动)
    ThreadPool(const ThreadPool &) = delete;
    ThreadPool &operator=(const ThreadPool &) = delete;
    ThreadPool(ThreadPool &&) = delete;
    ThreadPool &operator=(ThreadPool &&) = delete;

    // 析构函数:确保程序退出时回收所有线程
    ~ThreadPool() { stop(); }

    // 支持任意函数、任意参数、支持获取返回值
    template <typename F, typename... Args>
    auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> {
        // 1. 推导返回值类型
        using ReturnType = decltype(f(args...));

        // 2. 将函数和参数打包
        // 使用 lambda + tuple 捕获所有参数,替代 std::bind
        // 这是一个 shared_ptr,因为 packaged_task 不可拷贝,只能移动,但在 lambda 里我们需要共享它
        auto task = std::make_shared<std::packaged_task<ReturnType()>>(
            [func = std::forward<F>(f), args = std::make_tuple(std::forward<Args>(args)...)]() mutable {
                // std::apply 会把 tuple 展开作为参数传给 func
                return std::apply(std::move(func), std::move(args));
            });

        // 3. 获取 future (给主线程用的“提货单”)
        std::future<ReturnType> result = task->get_future();

        // 4. 封装成 void() 扔进队列
        // 只有当队列不满时,这里才会返回;否则主线程会阻塞在这里(实现背压)
        m_queue.push([task]() { (*task)(); });

        return result;
    }

    // 停止线程池
    void stop() {
        // 使用 call_once 或者 atomic 标志位保证只执行一次
        bool expected = true;
        if (!m_running.compare_exchange_strong(expected, false)) {
            return;  // 已经停止过了
        }

        // 唤醒所有等待中的线程(推送空任务作为毒药丸)
        // 推送数量 = 线程数量,确保每个线程都能拿一个毒药丸吃掉然后下班
        for (size_t i = 0; i < m_threads.size(); ++i) {
            // 这里我们推送一个空的 lambda,worker 线程里 if(task) 会判断为假,然后循环判断 m_running 为假,退出
            // 但为了保险,worker 里最好还是正常执行,只是我们这里 Task() 构造的是空的 std::function
            // 为了更稳健,我们可以推送一个空操作
            m_queue.push([]() {});
        }

        // 等待所有工人下班
        for (std::thread &thread : m_threads) {
            if (thread.joinable()) {
                thread.join();
            }
        }
    }
};
snippet
cpp
$ cd ..