Thread Pool
发布于 1 个月前·
3 人看过
·约 4 分钟
·103 行代码#pragma once
#include <atomic>
#include <functional>
#include <future>
#include <iostream>
#include <thread>
#include <tuple>
#include <vector>
#include "SafeQueue.hpp"
// 任务类型:最底层的无参 void 函数
using Task = std::function<void()>;
class ThreadPool {
private:
SafeQueue<Task> m_queue; // 任务队列
std::vector<std::thread> m_threads; // 线程组
std::atomic<bool> m_running; // 运行标志
public:
// 构造函数
// max_queue_size: 允许传递队列最大容量,默认 1000
ThreadPool(int num_threads, int max_queue_size = 1000) : m_running(true), m_queue(max_queue_size) {
for (int i = 0; i < num_threads; ++i) {
m_threads.emplace_back([this, i] {
// std::cout << "🧵 线程 " << i << " 启动..." << std::endl;
while (m_running) {
try {
Task task;
// 这里会阻塞,直到有任务或者收到退出信号
m_queue.wait_pop(task);
if (task) {
task(); // 执行任务
}
} catch (const std::exception &e) {
std::cerr << "❌ 线程异常: " << e.what() << std::endl;
}
}
});
}
}
// 禁止拷贝和移动 (Mutex不可移动)
ThreadPool(const ThreadPool &) = delete;
ThreadPool &operator=(const ThreadPool &) = delete;
ThreadPool(ThreadPool &&) = delete;
ThreadPool &operator=(ThreadPool &&) = delete;
// 析构函数:确保程序退出时回收所有线程
~ThreadPool() { stop(); }
// 支持任意函数、任意参数、支持获取返回值
template <typename F, typename... Args>
auto submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> {
// 1. 推导返回值类型
using ReturnType = decltype(f(args...));
// 2. 将函数和参数打包
// 使用 lambda + tuple 捕获所有参数,替代 std::bind
// 这是一个 shared_ptr,因为 packaged_task 不可拷贝,只能移动,但在 lambda 里我们需要共享它
auto task = std::make_shared<std::packaged_task<ReturnType()>>(
[func = std::forward<F>(f), args = std::make_tuple(std::forward<Args>(args)...)]() mutable {
// std::apply 会把 tuple 展开作为参数传给 func
return std::apply(std::move(func), std::move(args));
});
// 3. 获取 future (给主线程用的“提货单”)
std::future<ReturnType> result = task->get_future();
// 4. 封装成 void() 扔进队列
// 只有当队列不满时,这里才会返回;否则主线程会阻塞在这里(实现背压)
m_queue.push([task]() { (*task)(); });
return result;
}
// 停止线程池
void stop() {
// 使用 call_once 或者 atomic 标志位保证只执行一次
bool expected = true;
if (!m_running.compare_exchange_strong(expected, false)) {
return; // 已经停止过了
}
// 唤醒所有等待中的线程(推送空任务作为毒药丸)
// 推送数量 = 线程数量,确保每个线程都能拿一个毒药丸吃掉然后下班
for (size_t i = 0; i < m_threads.size(); ++i) {
// 这里我们推送一个空的 lambda,worker 线程里 if(task) 会判断为假,然后循环判断 m_running 为假,退出
// 但为了保险,worker 里最好还是正常执行,只是我们这里 Task() 构造的是空的 std::function
// 为了更稳健,我们可以推送一个空操作
m_queue.push([]() {});
}
// 等待所有工人下班
for (std::thread &thread : m_threads) {
if (thread.joinable()) {
thread.join();
}
}
}
};