发布于 

C++ Coroutine: 通用异步任务 Task

本文使用协程实现了一个通用的异步任务执行类 Task,支持设置回调函数并将在 Task 完成后执行回调。

Task、TaskPromise 和 TaskAwaiter 覆盖到了大部分的协程执行过程,把这几个类的实现理解了那基本上就可以说已经理解了 c++ 协程的工作方式。

阅读下面这段代码的方式建议通过 main 函数开始,对照运行结果一点一点来看。

Source Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
#include <iostream>
#include <coroutine>
#include <functional>
#include <exception>
#include <optional>
#include <atomic>
#include <utility>
#include <chrono>
#include <thread>
#include <list>
#include <condition_variable>
#include <algorithm>
#include <atomic>

#define debug(info) std::cout << __LINE__ << " " << __func__ << ": " << info << std::endl;

template<typename Task>
struct Result
{
// 初始化为默认值
explicit Result() = default;

// 当 Task 正常返回时用结果初始化 Result
explicit Result(Task &&value) : _value(value)
{
debug("Construct from value");
}

// 当 Task 抛异常时用异常初始化 Result
explicit Result(std::exception_ptr &&exception_ptr) : _exception_ptr(exception_ptr)
{
debug("Construct from exception_ptr");
}

// 读取结果,有异常则抛出异常
Task get_or_throw()
{
if (_exception_ptr)
{
std::rethrow_exception(_exception_ptr);
}
return _value;
}

private:
Task _value{};
std::exception_ptr _exception_ptr;
};

template<template<typename> class Task, typename R>
struct TaskAwaiter {
explicit TaskAwaiter(Task<R> &&task) noexcept
: task(std::move(task))
{
debug("Construct from task");
}

TaskAwaiter(TaskAwaiter &&completion) noexcept
: task(std::exchange(completion.task, {}))
{
debug("Construct from completion");
}

TaskAwaiter(TaskAwaiter &) = delete;

TaskAwaiter &operator=(TaskAwaiter &) = delete;

constexpr bool await_ready() const noexcept
{
debug("");
return false;
}

void await_suspend(std::coroutine_handle<> handle) noexcept
{
debug("");
std::cout << handle.address() << std::endl;
// 当 task 执行完之后调用 resume
task.finally([handle]() {
handle.resume();
});
}

// 协程恢复执行时,被等待的 Task 已经执行完,调用 get_result 来获取结果
R await_resume() noexcept
{
debug("");
return task.get_result();
}

private:
Task<R> task;
};

template<template<typename> class Task, typename ResultType>
struct TaskPromise
{
auto initial_suspend() noexcept
{
debug("");
return std::suspend_never{};
}
auto final_suspend() noexcept
{
debug("");
return std::suspend_always{};
}

Task<ResultType> get_return_object()
{
return Task{std::coroutine_handle<TaskPromise>::from_promise(*this)};
}

void unhandled_exception()
{
debug("");
std::lock_guard lock(completion_lock);
result = Result<ResultType>(std::current_exception());
completion.notify_all();
// 调用回调
notify_callbacks();
}

void return_value(ResultType value)
{
debug("");
std::lock_guard lock(completion_lock);
result = Result<ResultType>(std::move(value));
completion.notify_all();
// 调用回调
notify_callbacks();
}

ResultType get_result()
{
debug("from TaskPromise");
// 如果 result 没有值,说明协程还没有运行完,等待值被写入再返回
std::unique_lock lock(completion_lock);
if (!result.has_value()) {
debug("hasn't value now");
// 等待写入值之后调用 notify_all
completion.wait(lock);
} else {
debug("already has value now");
}
// 如果有值,则直接返回(或者抛出异常)
return result->get_or_throw();
}

void on_completed(std::function<void(Result<ResultType>)> &&func)
{
debug("");
std::unique_lock lock(completion_lock);
// 加锁判断 result
if (result.has_value()) {
debug("already has value");
// result 已经有值
auto value = result.value();
// 解锁之后再调用 func
lock.unlock();
func(value);
} else {
debug("waiting for execution");
// 否则添加回调函数,等待调用
completion_callbacks.push_back(func);
}
}

// 注意这里的模板参数
template<typename _ResultType>
auto await_transform(Task<_ResultType> &&task)
{
debug("");
return TaskAwaiter<Task, _ResultType>{std::move(task)};
}

private:

// 回调列表,我们允许对同一个 Task 添加多个回调
std::list<std::function<void(Result<ResultType>)>> completion_callbacks;

void notify_callbacks()
{
debug("");
auto value = result.value();
for (auto &callback : completion_callbacks) {
debug("call callback function from completion_callbacks");
callback(value);
}
// 调用完成,清空回调
completion_callbacks.clear();
}

// 使用 std::optional 可以区分协程是否执行完成
std::optional<Result<ResultType>> result;

std::mutex completion_lock;
std::condition_variable completion;
};

template<typename ResultType>
struct Task
{

// 声明 promise_type 为 TaskPromise 类型
using promise_type = TaskPromise<Task, ResultType>;

ResultType get_result()
{
debug("from Task");
return handle.promise().get_result();
}

Task &then(std::function<void(ResultType)> &&func)
{
debug("task id = " + std::to_string(task_id));
std::cout << handle.address() << std::endl;
handle.promise().on_completed([func](auto result) {
try {
func(result.get_or_throw());
} catch (std::exception &e) {
// 忽略异常
}
});
return *this;
}

Task &catching(std::function<void(std::exception &)> &&func)
{
debug("task id = " + std::to_string(task_id));
handle.promise().on_completed([func](auto result) {
try {
// 忽略返回值
result.get_or_throw();
} catch (std::exception &e) {
func(e);
}
});
return *this;
}

Task &finally(std::function<void()> &&func)
{
debug("task id = " + std::to_string(task_id));
std::cout << handle.address() << std::endl;
handle.promise().on_completed([func](auto result) { func(); });
return *this;
}

explicit Task(std::coroutine_handle<promise_type> handle) noexcept: handle(handle), task_id(cnt)
{
++cnt;
debug("Construct from handle, task id = " + std::to_string(task_id));
std::cout << handle.address() << std::endl;
}

Task(Task &&task) noexcept: handle(std::exchange(task.handle, {})), task_id(cnt)
{
++cnt;
debug("Move Construct from task, task id = " + std::to_string(task.task_id) << "->" << std::to_string(task_id));
std::cout << handle.address() << std::endl;
}

Task(Task &) = delete;

Task &operator=(Task &) = delete;

~Task() {
if (handle) handle.destroy();
}

private:
std::coroutine_handle<promise_type> handle;
static std::atomic<int> cnt;
const int task_id;
};

template <typename T>
std::atomic<int> Task<T>::cnt = 0;

Task<int> simple_task2() {
debug("task 2 start ...");
using namespace std::chrono_literals;
std::this_thread::sleep_for(1s);
debug("task 2 returns after 1s.");
co_return 2;
}

Task<int> simple_task3() {
debug("in task 3 start ...");
using namespace std::chrono_literals;
std::this_thread::sleep_for(2s);
debug("task 3 returns after 2s.");
co_return 3;
}

Task<int> simple_task() {
debug("task start ...");
auto result2 = co_await simple_task2();
debug("returns from task2: " + std::to_string(result2));
auto result3 = co_await simple_task3();
debug("returns from task3: " + std::to_string(result3));
co_return 1 + result2 + result3;
}

int main() {
auto simpleTask = simple_task();
std::cout << "======================" << std::endl;
simpleTask.then([](int i) {
debug("simple task end: " + std::to_string(i));
}).catching([](std::exception &e) {
debug("error occurred" + std::string{e.what()});
});
std::cout << "======================" << std::endl;
try {
auto i = simpleTask.get_result();
debug("simple task end from get: " + std::to_string(i));
} catch (std::exception &e) {
debug("error: " + std::string{e.what()});
}
return 0;
}

运行结果分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
ASM generation compiler returned: 0
Execution build compiler returned: 0
Program returned: 0

// 这里通过 auto simpleTask = simple_task(); 构建第一个 Task1 0x5562ea17feb0
250 Task: Construct from handle, task id = 0
0x5562ea17feb0
// 构建协程类第一步 initial_suspend
100 initial_suspend:
// 这里因为 suspend_never 所以继往下走
295 simple_task: task start ...
// 通过 auto result2 = co_await simple_task2(); 构建第二个 Task2 0x5562ea180fe0
250 Task: Construct from handle, task id = 1
0x5562ea180fe0
100 initial_suspend:
279 simple_task2: task 2 start ...
282 simple_task2: task 2 returns after 1s.
// 一直走到 co_return 2; 进入 TaskPromise::return_value
126 return_value:
// 通过 result = Result<ResultType>(std::move(value)); 用 value(2) 构建了一个 Result
26 Result: Construct from value
// 中间调了 completion.notify_all(); 但是这里没有 wait
// 继续在 TaskPromise::return_value 中 notify_callbacks();
181 notify_callbacks:
// co_return 之后这个协程 0x5562ea180fe0 就结束了 (还没释放,因为 finial 是 suspend_always)
105 final_suspend:
// 这里把 co_return 的结果给到 Task1 的 TaskPromise
170 await_transform:
// await_transform 返回一个 TaskAwaiter, 把刚刚的 Task2 存起来了(移动拷贝)
257 Task: Move Construct from task, task id = 1->2
0x5562ea180fe0
55 TaskAwaiter: Construct from task
// 这里先判断 ready, 直接返回 false
70 await_ready:
// false 的情况下需要 suspend (注意这里传入的参数 handle 是 Task1 的 handle)
76 await_suspend:
0x5562ea17feb0
// 这里 task.finally([handle]() { 就是刚刚存起来的 Task2
// 传了一个回调去 handle.resume();
241 finally: task id = 2
0x5562ea180fe0
// handle.promise().on_completed([func](auto result) { func(); });
149 on_completed:
// 这里判断了一下 result 有没有值, 那么这里因为是 Task2 的 Promise
// 所以 result 在刚刚 return_value 填充上了
153 on_completed: already has value
// 这里是通过刚刚的回调调到的 Task1 Promise 的 handle.resume()
87 await_resume:
// 这里 return task.get_result(); 这里的 task 是刚刚传到 TaskAwaiter 的 Task2
207 get_result: from Task
136 get_result: from TaskPromise
// 这里还是一样的 result 是 value(2)
144 get_result: already has value now
// 这里因为 TaskAwaiter 在 finally 中 resume 了 Task1 的 handle 所以继续进行
297 simple_task: returns from task2: 2
// 通过 auto result3 = co_await simple_task3(); 构建 Task3 0x5562ea180fe0
250 Task: Construct from handle, task id = 3
0x5562ea180fe0
// 后面 Task3 执行和返回的过程大部分都一样的
100 initial_suspend:
287 simple_task3: in task 3 start ...
290 simple_task3: task 3 returns after 2s.
126 return_value:
26 Result: Construct from value
181 notify_callbacks:
105 final_suspend:
170 await_transform:
257 Task: Move Construct from task, task id = 3->4
0x5562ea180fe0
55 TaskAwaiter: Construct from task
70 await_ready:
76 await_suspend:
0x5562ea17feb0
241 finally: task id = 4
0x5562ea180fe0
149 on_completed:
153 on_completed: already has value
87 await_resume:
207 get_result: from Task
136 get_result: from TaskPromise
144 get_result: already has value now
// 一直到这里同样 Task1 通过 Task3 的回调被 resume 继续执行
299 simple_task: returns from task3: 3
// 这里 co_return 1 + result2 + result3; 调到 return_value
126 return_value:
// 这里和前面一样同样构建了 Result
26 Result: Construct from value
181 notify_callbacks:
// 这里 Task 也走完了, suspend
105 final_suspend:
======================
// 刚刚所有的过程只在声明了 simpleTask 之后就执行完了
// 接着调用 simpleTask.then()
213 then: task id = 0
0x5562ea17feb0
// 那这里同样之前已经有值了
149 on_completed:
153 on_completed: already has value
// 这里通过 on_completed 调回到 lambda 回调 debug("simple task end: " + std::to_string(i));
306 operator(): simple task end: 6
// 这里 catch 了一下异常,但是因为没有异常所以没有触发 debug("error: " + std::string{e.what()});
227 catching: task id = 0
149 on_completed:
153 on_completed: already has value
======================
// try-catch 里的逻辑也类似
207 get_result: from Task
136 get_result: from TaskPromise
144 get_result: already has value now
313 main: simple task end from get: 6

References