coro/coroasyncdebug.cpp

The following code example is taken from the book
C++20 - The Complete Guide by Nicolai M. Josuttis, Leanpub, 2021
The code is licensed under a Creative Commons Attribution 4.0 International License. Creative Commons License

// raw code

#include <iostream>
#include <list>
#include <utility>     // for std::exchange()
#include <functional>  // for std::function
#include <coroutine>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <semaphore>
#include <functional>
#include <cassert>
#include <syncstream> // for std::osyncstream
#include <chrono>
using namespace std::literals;


// true: enable debug output
// false: disable debug output
constexpr bool syncOutDebug = true;

inline auto syncOut(std::ostream& strm = std::cout) {
  return std::osyncstream{strm};
}

inline auto coutDebug() {
  if constexpr (syncOutDebug) {
    return std::osyncstream{std::cout};
  }
  else {
    struct devnullbuf : public std::streambuf {
      // basic output primitive without any print statement:
      int_type overflow (int_type c) {
        return c;
      }
    };
    static devnullbuf devnull;
    return std::ostream{&devnull};
  }
}


class CoroPool;

class [[nodiscard]] CoroPoolTask
{
  friend class CoroPool;
public:
  struct promise_type;
  using CoroHdl = std::coroutine_handle<promise_type>;
private:
  CoroHdl hdl;
  int coroId = -1;
public:
  struct promise_type {
    inline static std::atomic<int> maxId = 0;
    int id = ++maxId;
    CoroPool* poolPtr = nullptr;   // if not null, lifetime is controlled by pool
    CoroHdl contHdl = nullptr;     // coro that awaits this coro
  
    CoroPoolTask get_return_object() noexcept {
      return CoroPoolTask{CoroHdl::from_promise(*this)};
    }
    auto initial_suspend() const noexcept { return std::suspend_always{}; }
    void unhandled_exception() noexcept { std::exit(1); }
    void return_void() noexcept {}

    auto final_suspend() const noexcept {
      struct FinalAwaiter {
        bool await_ready() const noexcept { return false; }
        std::coroutine_handle<> await_suspend(CoroHdl h) noexcept {
          if (h.promise().contHdl) {
            coutDebug() << "= coro" << h.promise().id << " finally suspended"
                        << " => continue with coro" << h.promise().contHdl.promise().id << std::endl;
            return h.promise().contHdl;    // resume continuation
          }
          else {
            coutDebug() << "= coro" << h.promise().id << " finally suspended" << std::endl;
            return std::noop_coroutine();  // no continuation
          }  
        }
        void await_resume() noexcept {
          coutDebug() << "= await_resume()" << std::endl;
        }
      };
      return FinalAwaiter{};  // AFTER suspended, resume continuation if there is one
    }
  };

  explicit CoroPoolTask(CoroHdl handle)
   : hdl{handle}, coroId{hdl.promise().id} {
    assert(hdl);
    coutDebug() << "= CoroPoolTask() for coro" << hdl.promise().id << std::endl;
  }
  ~CoroPoolTask() {
    if (hdl && !hdl.promise().poolPtr) {
      // task was not passed to pool:
      coutDebug() << "= DESTROY coro" << hdl.promise().id << " by task (not by pool)"
                  << " on thread " << std::this_thread::get_id() << std::endl;
      hdl.destroy();
    }
  }
  CoroPoolTask(const CoroPoolTask&) = delete;
  CoroPoolTask& operator= (const CoroPoolTask&) = delete;
  CoroPoolTask(CoroPoolTask&& t)
   : hdl{t.hdl} {
      coutDebug() << "= MOVE coro" << hdl.promise().id << std::endl;
      t.hdl = nullptr;
  }
  CoroPoolTask& operator= (CoroPoolTask&&) = delete;

  // Awaiter for: co_await task()
  // - queues the new coro in the pool
  // - sets the calling coro as continuation
  struct CoAwaitAwaiter {
    CoroHdl newHdl;
    bool await_ready() const noexcept { return false; }
    void await_suspend(CoroHdl awaitingHdl) noexcept;  // see below
    void await_resume() noexcept {}
  };
  auto operator co_await() noexcept {
    coutDebug() << "= OPERATOR co_await() for coro" << hdl.promise().id << std::endl;
    coutDebug() << "= POOL takes ownership of coro" << hdl.promise().id << std::endl;
    return CoAwaitAwaiter{std::exchange(hdl, nullptr)}; // pool takes ownership of hdl
  }
};

class CoroPool
{
private:
  std::list<std::jthread> threads;         // list of threads
  std::list<CoroPoolTask::CoroHdl> coros;  // queue of scheduled coros
  std::mutex corosMx;
  std::condition_variable_any corosCV;
  std::atomic<int> numCoros = 0;           // counter for all coros owned by the pool

public:
  explicit CoroPool(int num) {
    // start pool with num threads:
    for (int i = 0; i < num; ++i) {
      std::jthread worker_thread{[this](std::stop_token st) {
                                   threadLoop(st);
                                 }};
      threads.push_back(std::move(worker_thread));
    }
  }

  ~CoroPool() {
    for (auto& t : threads) {  // request stop for all threads
      t.request_stop();
    }
    for (auto& t : threads) {  // wait for end of all threads
      t.join();
    }
    for (auto& c : coros) {    // destroy remaining coros
      coutDebug() << "= DESTROY coro" << c.promise().id << " by POOL destructor" << std::endl;
      c.destroy();
    }
  }

  CoroPool(CoroPool&) = delete;
  CoroPool& operator=(CoroPool&) = delete;

  void runTask(CoroPoolTask&& coroTask) noexcept {
    coutDebug() << "= RUN task coro" << coroTask.hdl.promise().id
                << " by pool" << std::endl;
    coutDebug() << "= POOL takes ownership of coro" << coroTask.hdl.promise().id << std::endl;
    runCoro(std::exchange(coroTask.hdl, nullptr)); // pool takes ownership of hdl
  }

  // runCoro(): let pool run (and control lifetime of) coroutine
  // called from:
  // - pool.runTask(CoroPoolTask)
  // - co_await task()
  void runCoro(CoroPoolTask::CoroHdl coro) noexcept {
    coutDebug() << "== runCoro() for coro" << coro.promise().id
                << " = INCR numCoros" << std::endl;
    ++numCoros;
    coro.promise().poolPtr = this;  // disables destroy in CoroPoolTask
    {
      std::scoped_lock lock(corosMx);
      coros.push_front(coro);       // queue coro
      corosCV.notify_one();         // and let one thread resume it
    }
  }

  void threadLoop(std::stop_token st) {
    while (!st.stop_requested()) {
      // get next coro task from the queue:
      CoroPoolTask::CoroHdl coro;
      {
        coutDebug() << "      --->>> WAIT by " << std::this_thread::get_id() << std::endl;
        std::unique_lock lock(corosMx);
        if (!corosCV.wait(lock, st, [&] {
                                      return !coros.empty();
                                    })) {
          coutDebug() << "      --->>> STOP for " << std::this_thread::get_id() << std::endl;
          return;  // stop requested
        }
        coro = coros.back();
        coros.pop_back();
      }

      // resume it:
      auto id = coro.promise().id;
      coutDebug() << "      ***>>> RESUME coro" << id
                  << " by " << std::this_thread::get_id() << std::endl;
      coro.resume();  // RESUME
      if (coro.done()) {
        coutDebug() << "      ***<<< SUSPEND coro" << id
                    << " DONE by " << std::this_thread::get_id() << std::endl;
      }
      else {
        coutDebug() << "      ***<<< SUSPEND coro" << id
                    << " by " << std::this_thread::get_id() << std::endl;
      }

      // NOTE: The coro initially resumed in this thread might NOT be the coro finally called.
      // If a coroMain co_await's a coroSub, then
      // the thread that finally resumed coroSub resumes coroMain as its continuation.
      // => After this resumption this coro and SOME continuations MIGHT be done
      std::function<void(CoroPoolTask::CoroHdl)> destroyDone;
      destroyDone = [&destroyDone, this](auto hdl) {
                      // coutDebug() << "= destroyDone() for coro" << hdl.promise().id << std::endl;
                      if (hdl && hdl.done()) {
                        auto nextHdl = hdl.promise().contHdl;
                        coutDebug() << "= DESTROY coro" << hdl.promise().id << " by POOL" << std::endl;
                        hdl.destroy();         // destroy done handle
                        --numCoros;            // adjust total number of coros
                        destroyDone(nextHdl);  // do it for all done continuations
                      }
                    };
      destroyDone(coro);      // recursively destroy done coros
      numCoros.notify_all();  // wake-up any waiting waitUntilNoCoros()

      // sleep a bit to force that another thread is used next:
      std::this_thread::sleep_for(std::chrono::milliseconds{100});
    }
  }

  void waitUntilNoCoros() {
    int num = numCoros.load();
    coutDebug() << "\nwaitUntilNoCoros(): " << num << " left" << std::endl;
    while (num > 0) {
      numCoros.wait(num);  // wait for notification that numCoros changed the value
      num = numCoros.load();
      coutDebug() << "\nwaitUntilNoCoros(): " << num << " left" << std::endl;
    }
  }

  void syncWait(CoroPoolTask&& task) {
    std::binary_semaphore taskDone{0};
    auto makeWaitingTask = [&]() -> CoroPoolTask {
      coutDebug() << "  waitingTask(): co_await task" << std::endl;
      co_await task;
      coutDebug() << "  waitingTask(): co_await SignalDone{}" << std::endl;
      struct SignalDone {
        std::binary_semaphore& taskDoneRef;
        bool await_ready() { return false; }
        bool await_suspend(std::coroutine_handle<>) {
          taskDoneRef.release();  // signal task is done
          return false;           // do not suspend at all
        }
        void await_resume() { }
      };
      co_await SignalDone{taskDone};
      coutDebug() << "  waitingTask(): done" << std::endl;
    };
    coutDebug() << "\nsyncWait(): makeWaitingTask()" << std::endl;
    auto t = makeWaitingTask();
    coutDebug() << "\nsyncWait(): runTask()" << std::endl;
    runTask(std::move(t));
    taskDone.acquire();
    coutDebug() << "\nsyncWait(): done" << std::endl;
    corosCV.notify_one();       // let one thread resume it
  }
};

// CoroPoolTask awaiter for: co_await task()
// - queues the new coro in the pool
// - sets the calling coro as continuation
void CoroPoolTask::CoAwaitAwaiter::await_suspend(CoroHdl awaitingHdl) noexcept
{
  assert(newHdl);
  coutDebug() << "= RUN coro" << newHdl.promise().id << " via co_await "
              << "(suspend coro" << awaitingHdl.promise().id << " for it)" << std::endl;
  assert(awaitingHdl.promise().poolPtr);
  assert(!newHdl.promise().contHdl);
  newHdl.promise().contHdl = awaitingHdl;
  awaitingHdl.promise().poolPtr->runCoro(newHdl);
}


CoroPoolTask print(std::string id, std::string msg)
{
  syncOut() << "    > " << id <<  " print: " << msg << "   on thread: " << std::this_thread::get_id() << std::endl;
  co_return;
}

CoroPoolTask runAsync(std::string id)
{
  syncOut() << "===== " << id << " start        on thread: " << std::this_thread::get_id() << std::endl;

/*
co_await print(id + "a", "start");
syncOut() << "===== " << id << " resume on thread " << std::this_thread::get_id() << std::endl;

co_await print(id + "b", "end ");
syncOut() << "===== " << id << " resume on thread " << std::this_thread::get_id() << std::endl;
*/

  syncOut() << "===== " << id << " done" << std::endl;
  co_return;
}

int main()
{
  // init pool of coroutine threads:
  syncOut() << "**** main() on thread " << std::this_thread::get_id() << std::endl;
  CoroPool pool{4};

  // start main coroutine and run it in coroutine pool:
  syncOut() << "runTask(runAsync(0))" << std::endl;
  CoroPoolTask t0 = runAsync("0");
  pool.runTask(std::move(t0));

  syncOut() << "\n**** t1 = runAsync(1)" << std::endl;
  auto t1 = runAsync("1");
  syncOut() << "\n**** syncWait(t1)" << std::endl;
  pool.syncWait(std::move(t1));

  // start other coroutines:
  for (int i = 0; i < 3; ++i) {
    syncOut() << "runTask(runAsync(" << i + 2 << "))" << std::endl;
    CoroPoolTask t = runAsync(std::to_string(i + 2));
    pool.runTask(std::move(t));
  }

  // wait until all coroutines are done:
  syncOut() << "\n**** waitUntilNoCoros()" << std::endl;
  pool.waitUntilNoCoros();

  syncOut() << "\n**** main() done" << std::endl;
}