This patch adds support for coroutine::parallel_for_each
that can be used in coroutines instead of
co_await seastar::parallel_for_each.
This implementation minimizes memory allocation
by deferring the allocation of a vector of futures
to wait on, till any of the futures is unavailable,
similar to the legacy parallel_for_each implementation.
Also it waits on the unavailable futures from back to front,
similar to to legacy implementation, to minimize
the number of callbacks required.
The advantage over the legacy implementation is
that the parallel_for_each object itself is a
seastar task that is used to wait on unavailable
futures with no need to allocate a parallel_for_each_state.
This is possible since coroutine::parallel_for_each
is defined as [[nodiscard]] and the caller must co_await
for it, making sure it remains valid until all the futures
it's waiting on are resolved.
Test: coroutines_test(debug, release)
w/ clang++ 12.0.1, g++ 11.2.1, c++20
Signed-off-by: Benny Halevy <
bha...@scylladb.com>
---
demos/coroutines_demo.cc | 11 +-
doc/tutorial.md | 17 ++
.../seastar/coroutine/parallel_for_each.hh | 165 ++++++++++++++++++
tests/unit/coroutines_test.cc | 97 ++++++++++
4 files changed, 289 insertions(+), 1 deletion(-)
create mode 100644 include/seastar/coroutine/parallel_for_each.hh
In v2:
- Improved commitlog message
- Use std::ranges::range
- Use parallel_for_each object as continuation
- Allocate _futures only if there's an unavialable future
- Wait for _futures in reverse order
- Tested with both g++ and clang++
- Added unit test for empty range, subrange, and exception throwing.
index 00000000..6e811a83
--- /dev/null
+++ b/include/seastar/coroutine/parallel_for_each.hh
@@ -0,0 +1,165 @@
+/*
+ * This file is open source software, licensed to you under the terms
+ * of the Apache License, Version 2.0 (the "License"). See the NOTICE file
+ * distributed with this work for additional information regarding copyright
+ * ownership. You may not use this file except in compliance with the License.
+ *
+ * You may obtain a copy of the License at
+ *
+ *
http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * Copyright (C) 2022-present ScyllaDB
+ */
+
+#pragma once
+
+#include <ranges>
+
+template <typename Func>
+class [[nodiscard("must co_await an parallel_for_each() object")]] parallel_for_each final : public continuation_base<> {
+ using coroutine_handle_t = SEASTAR_INTERNAL_COROUTINE_NAMESPACE::coroutine_handle<void>;
+
+ Func _func;
+ std::vector<future<>> _futures;
+ std::exception_ptr _ex;
+ coroutine_handle_t _when_ready;
+ task* _waiting_task = nullptr;
+
+ // Consume futures in reverse order.
+ // Since futures at the front are expected
+ // to become ready before futures at the back,
+ // therefore it is less likely we will have
+ // to wait on them, after the back futures
+ // become available.
+ //
+ // Return true iff all futures were consumed.
+ bool consume_next() noexcept {
+ while (!_futures.empty()) {
+ auto& fut = _futures.back();
+ if (!fut.available()) {
+ return false;
+ }
+ if (fut.failed()) {
+ _ex = fut.get_exception();
+ }
+ _futures.pop_back();
+ }
+ return true;
+ }
+
+ void set_callback() noexcept {
+ auto fut = std::move(_futures.back());
+ _futures.pop_back();
+ // To reuse `this` as continuation_base<>
+ // we must reset _state, to allow setting
+ // it again.
+ this->_state.~future_state();
+ new (&this->_state) future_state();
+ seastar::internal::set_callback(fut, this);
+ }
+
+ void resume_or_set_callback() noexcept {
+ if (__builtin_expect(consume_next(), false)) {
+ _when_ready.resume();
+ } else {
+ set_callback();
+ }
+ }
+
+public:
+ // clang 13.0.1 doesn't support subrange
+ // so provide also a Iterator/Sentinel based constructor.
+ // See
https://github.com/llvm/llvm-project/issues/46091
+ template <typename Iterator, typename Sentinel>
+ requires (std::same_as<Sentinel, Iterator> || std::sentinel_for<Sentinel, Iterator>)
+ explicit parallel_for_each(Iterator begin, Sentinel end, Func&& func)
+ : _func(std::move(func))
+ {
+ for (auto it = begin; it != end; ++it) {
+ auto fut = futurize_invoke(_func, *it);
+ if (fut.available()) {
+ if (fut.failed()) {
+ _ex = fut.get_exception();
+ }
+ continue;
+ }
+ if (_futures.empty()) {
+ _futures.reserve(std::distance(it, end));
+ }
+ _futures.push_back(std::move(fut));
+ }
+ }
+
+ template <std::ranges::range Range>
+ requires std::invocable<Func, std::ranges::range_value_t<Range>>
+ explicit parallel_for_each(const Range& range, Func&& func)
+ : parallel_for_each(std::ranges::begin(range), std::ranges::end(range), std::forward<Func>(func))
+ { }
+
+ bool await_ready() const {
+ if (_futures.empty()) {
+ await_resume();
+ return true;
+ }
+ return false;
+ }
+
+ template<typename T>
+ void await_suspend(SEASTAR_INTERNAL_COROUTINE_NAMESPACE::coroutine_handle<T> h) {
+ _when_ready = h;
+ _waiting_task = &h.promise();
+ resume_or_set_callback();
+ }
+
+ void await_resume() const {
+ if (_ex) {
+ std::rethrow_exception(_ex);
+ }
+ }
+
+ virtual void run_and_dispose() noexcept override {
+ if (__builtin_expect(this->_state.failed(), false)) {
+ _ex = std::move(this->_state).get_exception();
+ }
+ resume_or_set_callback();
+ }
+
+ virtual task* waiting_task() noexcept override {
+ return _waiting_task;
+ }
+};
+
+}
\ No newline at end of file
diff --git a/tests/unit/coroutines_test.cc b/tests/unit/coroutines_test.cc
index 23de9ac4..5fb4e3df 100644
--- a/tests/unit/coroutines_test.cc
+++ b/tests/unit/coroutines_test.cc
@@ -20,11 +20,13 @@
*/
#include <exception>
+#include <numeric>
#include <seastar/core/future-util.hh>
#include <seastar/testing/test_case.hh>
#include <seastar/core/sleep.hh>
#include <seastar/util/later.hh>
+#include <seastar/testing/random.hh>
using namespace seastar;
using namespace std::chrono_literals;
@@ -41,6 +43,7 @@ SEASTAR_TEST_CASE(test_coroutines_not_compiled_in) {
#include <seastar/coroutine/all.hh>
#include <seastar/coroutine/maybe_yield.hh>
#include <seastar/coroutine/switch_to.hh>
+#include <seastar/coroutine/parallel_for_each.hh>
namespace {
@@ -414,4 +417,98 @@ SEASTAR_TEST_CASE(generator)
#endif
+SEASTAR_TEST_CASE(test_parallel_for_each_empty) {
+ std::vector<int> values;
+ int count = 0;
+
+ co_await coroutine::parallel_for_each(values, [&] (int x) {
+ ++count;
+ });
+ BOOST_REQUIRE_EQUAL(count, 0); // the test will hang if it doesn't work.
+}
+
+SEASTAR_TEST_CASE(test_parallel_for_each_exception) {
+ std::array<int, 5> values = { 10, 2, 1, 4, 8 };
+ int count = 0;
+ auto& eng = testing::local_random_engine;
+ auto dist = std::uniform_int_distribution<unsigned>();
+ int throw_at = dist(eng) % values.size();
+
+ BOOST_TEST_MESSAGE(fmt::format("Will throw at value #{}/{}", throw_at, values.size()));
+
+ auto f0 = coroutine::parallel_for_each(values, [&] (int x) {
+ if (count++ == throw_at) {
+ BOOST_TEST_MESSAGE("throw");
+ throw std::runtime_error("test");
+ }
+ });
+ // An exception thrown by the functor must be propagated
+ BOOST_REQUIRE_THROW(co_await std::move(f0), std::runtime_error);
+ // Functor must be called on all values, even if there's an exception
+ BOOST_REQUIRE_EQUAL(count, values.size());
+
+ count = 0;
+ throw_at = dist(eng) % values.size();
+ BOOST_TEST_MESSAGE(fmt::format("Will throw at value #{}/{}", throw_at, values.size()));
+
+ auto f1 = coroutine::parallel_for_each(values, [&] (int x) -> future<> {
+ co_await sleep(std::chrono::milliseconds(x));
+ if (count++ == throw_at) {
+ throw std::runtime_error("test");
+ }
+ });
+ BOOST_REQUIRE_THROW(co_await std::move(f1), std::runtime_error);
+ BOOST_REQUIRE_EQUAL(count, values.size());
+}
+
+SEASTAR_TEST_CASE(test_parallel_for_each) {
+ std::vector<int> values = { 3, 1, 4 };
+ int sum_of_squares = 0;
+
+ int expected = std::accumulate(values.begin(), values.end(), 0, [] (int sum, int x) {
+ return sum + x * x;
+ });
+
+ // Test all-ready futures
+ co_await coroutine::parallel_for_each(values, [&sum_of_squares] (int x) {
+ sum_of_squares += x * x;
+ });
+ BOOST_REQUIRE_EQUAL(sum_of_squares, expected);
+
+ // Test non-ready futures
+ sum_of_squares = 0;
+ co_await coroutine::parallel_for_each(values, [&sum_of_squares] (int x) -> future<> {
+ if (x > 1) {
+ co_await sleep(std::chrono::milliseconds(x));
+ }
+ sum_of_squares += x * x;
+ });
+ BOOST_REQUIRE_EQUAL(sum_of_squares, expected);
+
+ // Test legacy subrange
+ sum_of_squares = 0;
+ co_await coroutine::parallel_for_each(values.begin(), values.end() - 1, [&sum_of_squares] (int x) -> future<> {
+ if (x > 1) {
+ co_await sleep(std::chrono::milliseconds(x));
+ }
+ sum_of_squares += x * x;
+ });
+ BOOST_REQUIRE_EQUAL(sum_of_squares, 10);
+
+ // clang 13.0.1 doesn't support subrange
+ // so provide also a Iterator/Sentinel based constructor.
+ // See
https://github.com/llvm/llvm-project/issues/46091
+#ifndef __clang__
+ // Test std::ranges::subrange
+ sum_of_squares = 0;
+ co_await coroutine::parallel_for_each(std::ranges::subrange(values.begin(), values.end() - 1), [&sum_of_squares] (int x) -> future<> {
+ if (x > 1) {
+ co_await sleep(std::chrono::milliseconds(x));
+ }
+ sum_of_squares += x * x;
+ });
+ BOOST_REQUIRE_EQUAL(sum_of_squares, 10);
+#endif