muduo 库解析之十四:ThreadPool

源码

ThreadPool.h

#pragma once

#include <deque>
#include <vector>
#include <algorithm>
#include <memory>

#include "Thread.h"
#include "Mutex.h"
#include "Condition.h"
#include "NonCopyable.h"

namespace muduo
{
    class ThreadPool : public NonCopyable
    {
    public:
        typedef std::function<void(void)> Task;

        explicit ThreadPool(const std::string &name = std::string("ThreadPool"));
        ~ThreadPool();

        //@ call before start()
        void set_max_queue_size(int max_size) { max_queue_size_ = max_size; }
        void set_thread_init_callback(const Task &cb) { thread_init_callback_ = cb; }

        void start(int num_threads);
        void stop();

        const std::string &name() const
        {
            return name_;
        }

        size_t queue_size() const;

        void run(Task f);

    private:
        bool is_full() const;
        void run_in_thread();
        Task take();

    private:
        mutable MutexLock mutex_;
        Condition not_empty_;
        Condition not_full_;
        std::string name_;
        Task thread_init_callback_;
        std::vector<std::unique_ptr<Thread>> threads_;
        std::deque<Task> queue_;
        size_t max_queue_size_;
        bool running_;
    };
}

ThreadPool.cc

#include "ThreadPool.h"

#include "Exception.h"

namespace muduo
{
    ThreadPool::ThreadPool(const std::string &name) : mutex_(),
                                                      not_empty_(mutex_),
                                                      not_full_(mutex_),
                                                      name_(name),
                                                      max_queue_size_(0),
                                                      running_(false)
    {
    }

    ThreadPool::~ThreadPool()
    {
        if (running_)
        {
            stop();
        }
    }

    void ThreadPool::start(int num_threads)
    {
        assert(threads_.empty());
        running_ = true;
        threads_.reserve(num_threads);
        for (int i = 0; i < num_threads; ++i)
        {
            char id[32];
            snprintf(id, sizeof(id), "%d", i + 1);
            threads_.emplace_back(new Thread(std::bind(&ThreadPool::run_in_thread, this), name_ + id));
            threads_[i]->start();
        }

        if (num_threads == 0 && thread_init_callback_)
        {
            thread_init_callback_();
        }
    }

    void ThreadPool::stop()
    {
        {
            MutexLockGuard lock(mutex_);
            running_ = false;
            not_empty_.notify_all();
            not_full_.notify_all();
        }
        for (auto &thr : threads_)
        {
            thr->join();
        }
    }

    size_t ThreadPool::queue_size() const
    {
        MutexLockGuard lock(mutex_);
        return queue_.size();
    }

    void ThreadPool::run(Task task)
    {
        if (threads_.empty())
        {
            task();
        }
        else
        {
            MutexLockGuard lock(mutex_);
            while (is_full() && running_)
            {
                not_full_.wait();
            }
            if (!running_)
                return;
            assert(!is_full());

            queue_.push_back(std::move(task));
            not_empty_.notify();
        }
    }

    ThreadPool::Task ThreadPool::take()
    {
        MutexLockGuard lock(mutex_);
        while (queue_.empty() && running_)
        {
            not_empty_.wait();
        }

        Task task;
        if (!queue_.empty())
        {
            task = queue_.front();
            queue_.pop_front();
            if (max_queue_size_ > 0)
            {
                not_full_.notify();
            }
        }
        return task;
    }

    bool ThreadPool::is_full() const
    {
        mutex_.assert_locked();
        return max_queue_size_ > 0 && queue_.size() >= max_queue_size_;
    }

    void ThreadPool::run_in_thread()
    {
        try
        {
            if (thread_init_callback_)
            {
                thread_init_callback_();
            }
            while (running_)
            {
                Task task(take());
                if (task)
                {
                    task();
                }
            }
        }
        catch (const Exception &ex)
        {
            fprintf(stderr, "exception caught in ThreadPool %s 
", name_.c_str());
            fprintf(stderr, "exception reason: %s 
", ex.what());
            fprintf(stderr, "exception stack trace: %s 
", ex.stack());
            abort();
        }
        catch (const std::exception &ex)
        {
            fprintf(stderr, "exception caught in ThreadPool %s 
", name_.c_str());
            fprintf(stderr, "exception reason: %s 
", ex.what());
        }
        catch (...)
        {
            fprintf(stderr, "unknown exception caught in ThreadPool %s 
", name_.c_str());
            throw;
        }
    }
}
原文地址:https://www.cnblogs.com/xiaojianliu/p/14707651.html