[COMMIT seastar master] Merge 'rpc: add APIs for connection dropping' from Kamil Braun

2 views
Skip to first unread message

Commit Bot

<bot@cloudius-systems.com>
unread,
May 25, 2023, 6:29:39 AM5/25/23
to seastar-dev@googlegroups.com, Pavel Emelyanov
From: Pavel Emelyanov <xe...@scylladb.com>
Committer: Pavel Emelyanov <xe...@scylladb.com>
Branch: master

Merge 'rpc: add APIs for connection dropping' from Kamil Braun

Introduce some new functionality into Seastar RPC that allows storing references to connections from RPC handlers and dropping connections using those references.

The connection information is accessible from the `rpc::client_info` struct that handlers have access to.
In order to drop the connection, one can use `rpc::server::abort_connection`.

Those APIs will be used in Scylla to ban nodes that were removed from the cluster: https://github.com/scylladb/scylladb/pull/13850

Closes #1663

* https://github.com/scylladb/seastar:
rpc: rpc_types: make `connection_id` a class
tests: rpc_test: simple test for connection aborting
rpc: introduce `server::abort_connection(connection_id)`
rpc: remove `connection::_server` field
rpc: add `server&` and `connection_id` to `client_info`
rpc: rpc_types: move `connection_id` definition before `client_info`

---
diff --git a/include/seastar/rpc/rpc.hh b/include/seastar/rpc/rpc.hh
--- a/include/seastar/rpc/rpc.hh
+++ b/include/seastar/rpc/rpc.hh
@@ -550,7 +550,6 @@ private:

public:
class connection : public rpc::connection, public enable_shared_from_this<connection> {
- server& _server;
client_info _info;
connection_id _parent_id = invalid_connection_id;
std::optional<isolation_config> _isolation_config;
@@ -581,19 +580,22 @@ public:
// Resources will be released when this goes out of scope
future<resource_permit> wait_for_resources(size_t memory_consumed, std::optional<rpc_clock_type::time_point> timeout) {
if (timeout) {
- return get_units(_server._resources_available, memory_consumed, *timeout);
+ return get_units(get_server()._resources_available, memory_consumed, *timeout);
} else {
- return get_units(_server._resources_available, memory_consumed);
+ return get_units(get_server()._resources_available, memory_consumed);
}
}
size_t estimate_request_size(size_t serialized_size) {
- return rpc::estimate_request_size(_server._limits, serialized_size);
+ return rpc::estimate_request_size(get_server()._limits, serialized_size);
}
size_t max_request_size() const {
- return _server._limits.max_memory;
+ return get_server()._limits.max_memory;
}
server& get_server() {
- return _server;
+ return _info.server;
+ }
+ const server& get_server() const {
+ return _info.server;
}
future<> deregister_this_stream();
};
@@ -621,6 +623,14 @@ public:
f(*c.second);
}
}
+ /**
+ * Abort the given connection, causing it to stop receiving any further messages.
+ * It's safe to abort a connection from an RPC handler running on that connection.
+ * Does nothing if there is no connection with the given ID on this server.
+ *
+ * @param id the ID of the connection to abort.
+ */
+ void abort_connection(connection_id id);
gate& reply_gate() {
return _reply_gate;
}
diff --git a/include/seastar/rpc/rpc_impl.hh b/include/seastar/rpc/rpc_impl.hh
--- a/include/seastar/rpc/rpc_impl.hh
+++ b/include/seastar/rpc/rpc_impl.hh
@@ -176,15 +176,15 @@ maybe_add_time_point(do_want_time_point, opt_time_point& otp, std::tuple<In...>&
inline sstring serialize_connection_id(const connection_id& id) {
sstring p = uninitialized_string(sizeof(id));
auto c = p.data();
- write_le(c, id.id);
+ write_le(c, id.id());
return p;
}

inline connection_id deserialize_connection_id(const sstring& s) {
- connection_id id;
+ using id_type = decltype(connection_id{0}.id());
auto p = s.c_str();
- id.id = read_le<decltype(id.id)>(p);
- return id;
+ auto id = read_le<id_type>(p);
+ return connection_id{id};
}

template <bool IsSmartPtr>
diff --git a/include/seastar/rpc/rpc_types.hh b/include/seastar/rpc/rpc_types.hh
--- a/include/seastar/rpc/rpc_types.hh
+++ b/include/seastar/rpc/rpc_types.hh
@@ -55,9 +55,41 @@ struct stats {
counter_type timeout = 0;
};

+class connection_id {
+ uint64_t _id;
+
+public:
+ uint64_t id() const {
+ return _id;
+ }
+ bool operator==(const connection_id& o) const {
+ return _id == o._id;
+ }
+ explicit operator bool() const {
+ return shard() != 0xffff;
+ }
+ size_t shard() const {
+ return size_t(_id & 0xffff);
+ }
+ constexpr static connection_id make_invalid_id(uint64_t _id = 0) {
+ return make_id(_id, 0xffff);
+ }
+ constexpr static connection_id make_id(uint64_t _id, uint16_t shard) {
+ return {_id << 16 | shard};
+ }
+ constexpr connection_id(uint64_t id) : _id(id) {}
+};
+
+constexpr connection_id invalid_connection_id = connection_id::make_invalid_id();
+
+std::ostream& operator<<(std::ostream&, const connection_id&);
+
+class server;

struct client_info {
socket_address addr;
+ rpc::server& server;
+ connection_id conn_id;
std::unordered_map<sstring, boost::any> user_data;
template <typename T>
void attach_auxiliary(const sstring& key, T&& object) {
@@ -258,29 +290,6 @@ public:

class connection;

-struct connection_id {
- uint64_t id;
- bool operator==(const connection_id& o) const {
- return id == o.id;
- }
- explicit operator bool() const {
- return shard() != 0xffff;
- }
- size_t shard() const {
- return size_t(id & 0xffff);
- }
- constexpr static connection_id make_invalid_id(uint64_t id = 0) {
- return make_id(id, 0xffff);
- }
- constexpr static connection_id make_id(uint64_t id, uint16_t shard) {
- return {id << 16 | shard};
- }
-};
-
-constexpr connection_id invalid_connection_id = connection_id::make_invalid_id();
-
-std::ostream& operator<<(std::ostream&, const connection_id&);
-
using xshard_connection_ptr = lw_shared_ptr<foreign_ptr<shared_ptr<connection>>>;
constexpr size_t max_queued_stream_buffers = 50;
constexpr size_t max_stream_buffers_memory = 100 * 1024;
@@ -390,7 +399,7 @@ template<>
struct hash<seastar::rpc::connection_id> {
size_t operator()(const seastar::rpc::connection_id& id) const {
size_t h = 0;
- boost::hash_combine(h, std::hash<uint64_t>{}(id.id));
+ boost::hash_combine(h, std::hash<uint64_t>{}(id.id()));
return h;
}
};
diff --git a/src/rpc/rpc.cc b/src/rpc/rpc.cc
--- a/src/rpc/rpc.cc
+++ b/src/rpc/rpc.cc
@@ -861,8 +861,8 @@ namespace rpc {
switch (id) {
// supported features go here
case protocol_features::COMPRESS: {
- if (_server._options.compressor_factory) {
- _compressor = _server._options.compressor_factory->negotiate(e.second, true);
+ if (get_server()._options.compressor_factory) {
+ _compressor = get_server()._options.compressor_factory->negotiate(e.second, true);
if (_compressor) {
ret[protocol_features::COMPRESS] = _compressor->name();
}
@@ -874,20 +874,20 @@ namespace rpc {
ret[protocol_features::TIMEOUT] = "";
break;
case protocol_features::STREAM_PARENT: {
- if (!_server._options.streaming_domain) {
+ if (!get_server()._options.streaming_domain) {
f = f.then([] {
return make_exception_future<>(std::runtime_error("streaming is not configured for the server"));
});
} else {
_parent_id = deserialize_connection_id(e.second);
_is_stream = true;
// remove stream connection from rpc connection list
- _server._conns.erase(get_connection_id());
+ get_server()._conns.erase(get_connection_id());
f = f.then([this, c = shared_from_this()] () mutable {
return smp::submit_to(_parent_id.shard(), [this, c = make_foreign(static_pointer_cast<rpc::connection>(c))] () mutable {
- auto sit = _servers.find(*_server._options.streaming_domain);
+ auto sit = _servers.find(*get_server()._options.streaming_domain);
if (sit == _servers.end()) {
- throw std::logic_error(format("Shard {:d} does not have server with streaming domain {}", this_shard_id(), *_server._options.streaming_domain).c_str());
+ throw std::logic_error(format("Shard {:d} does not have server with streaming domain {}", this_shard_id(), *get_server()._options.streaming_domain).c_str());
}
auto s = sit->second;
auto it = s->_conns.find(_parent_id);
@@ -918,7 +918,7 @@ namespace rpc {

auto visitor = isolation_function_visitor(isolation_cookie);
f = f.then([visitor = std::move(visitor), this] () mutable {
- return std::visit(visitor, _server._limits.isolate_connection).then([this] (isolation_config conf) {
+ return std::visit(visitor, get_server()._limits.isolate_connection).then([this] (isolation_config conf) {
_isolation_config = conf;
});
});
@@ -930,7 +930,7 @@ namespace rpc {
;
}
}
- if (_server._options.streaming_domain) {
+ if (get_server()._options.streaming_domain) {
ret[protocol_features::CONNECTION_ID] = serialize_connection_id(_id);
}
return f.then([ret = std::move(ret)] {
@@ -1017,7 +1017,7 @@ future<> server::connection::send_unknown_verb_reply(std::optional<rpc_clock_typ
try {
// Send asynchronously.
// This is safe since connection::stop() will wait for background work.
- (void)with_gate(_server._reply_gate, [this, timeout, msg_id, data = std::move(data), permit = std::move(permit)] () mutable {
+ (void)with_gate(get_server()._reply_gate, [this, timeout, msg_id, data = std::move(data), permit = std::move(permit)] () mutable {
// workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=83268
auto c = shared_from_this();
return respond(-msg_id, std::move(data), timeout).then([c = std::move(c), permit = std::move(permit)] {});
@@ -1048,7 +1048,7 @@ future<> server::connection::send_unknown_verb_reply(std::optional<rpc_clock_typ
if (expire && *expire) {
timeout = relative_timeout_to_absolute(std::chrono::milliseconds(*expire));
}
- auto h = _server._proto->get_handler(type);
+ auto h = get_server()._proto->get_handler(type);
if (!h) {
return send_unknown_verb_reply(timeout, msg_id, type);
}
@@ -1059,7 +1059,7 @@ future<> server::connection::send_unknown_verb_reply(std::optional<rpc_clock_typ
return with_scheduling_group(sg, [this, timeout, msg_id, h, data = std::move(data.value())] () mutable {
return h->func(shared_from_this(), timeout, msg_id, std::move(data)).finally([this, h] {
// If anything between get_handler() and here throws, we leak put_handler
- _server._proto->put_handler(h);
+ get_server()._proto->put_handler(h);
});
});
}
@@ -1078,7 +1078,7 @@ future<> server::connection::send_unknown_verb_reply(std::optional<rpc_clock_typ
_stream_queue.abort(std::make_exception_ptr(stream_closed()));
return stop_send_loop(ep).then_wrapped([this] (future<> f) {
f.ignore_ready_future();
- _server._conns.erase(get_connection_id());
+ get_server()._conns.erase(get_connection_id());
if (is_stream()) {
return deregister_this_stream();
} else {
@@ -1093,16 +1093,16 @@ future<> server::connection::send_unknown_verb_reply(std::optional<rpc_clock_typ
}

server::connection::connection(server& s, connected_socket&& fd, socket_address&& addr, const logger& l, void* serializer, connection_id id)
- : rpc::connection(std::move(fd), l, serializer, id), _server(s) {
- _info.addr = std::move(addr);
+ : rpc::connection(std::move(fd), l, serializer, id)
+ , _info{.addr{std::move(addr)}, .server{s}, .conn_id{id}} {
}

future<> server::connection::deregister_this_stream() {
- if (!_server._options.streaming_domain) {
+ if (!get_server()._options.streaming_domain) {
return make_ready_future<>();
}
return smp::submit_to(_parent_id.shard(), [this] () mutable {
- auto sit = server::_servers.find(*_server._options.streaming_domain);
+ auto sit = server::_servers.find(*get_server()._options.streaming_domain);
if (sit != server::_servers.end()) {
auto s = sit->second;
auto it = s->_conns.find(_parent_id);
@@ -1184,8 +1184,21 @@ future<> server::connection::send_unknown_verb_reply(std::optional<rpc_clock_typ
).discard_result();
}

+ void server::abort_connection(connection_id id) {
+ auto it = _conns.find(id);
+ if (it == _conns.end()) {
+ return;
+ }
+ try {
+ it->second->abort();
+ } catch (...) {
+ log_exception(*it->second, log_level::error,
+ "fail to shutdown connection on user request", std::current_exception());
+ }
+ }
+
std::ostream& operator<<(std::ostream& os, const connection_id& id) {
- fmt::print(os, "{:x}", id.id);
+ fmt::print(os, "{:x}", id.id());
return os;
}

diff --git a/tests/unit/rpc_test.cc b/tests/unit/rpc_test.cc
--- a/tests/unit/rpc_test.cc
+++ b/tests/unit/rpc_test.cc
@@ -1444,21 +1444,23 @@ SEASTAR_TEST_CASE(test_connection_id_format) {
static_assert(std::is_same_v<decltype(rpc::tuple(1U, 1L)), rpc::tuple<unsigned, long>>, "rpc::tuple deduction guid not working");

SEASTAR_TEST_CASE(test_client_info) {
- rpc::client_info info;
- const rpc::client_info& const_info = *const_cast<rpc::client_info*>(&info);
+ return rpc_test_env<>::do_with(rpc_test_config(), [] (rpc_test_env<>& env) {
+ rpc::client_info info{.server{env.server()}, .conn_id{0}};
+ const rpc::client_info& const_info = *const_cast<rpc::client_info*>(&info);

- info.attach_auxiliary("key", 0);
- BOOST_REQUIRE_EQUAL(const_info.retrieve_auxiliary<int>("key"), 0);
- info.retrieve_auxiliary<int>("key") = 1;
- BOOST_REQUIRE_EQUAL(const_info.retrieve_auxiliary<int>("key"), 1);
+ info.attach_auxiliary("key", 0);
+ BOOST_REQUIRE_EQUAL(const_info.retrieve_auxiliary<int>("key"), 0);
+ info.retrieve_auxiliary<int>("key") = 1;
+ BOOST_REQUIRE_EQUAL(const_info.retrieve_auxiliary<int>("key"), 1);

- BOOST_REQUIRE_EQUAL(info.retrieve_auxiliary_opt<int>("key"), &info.retrieve_auxiliary<int>("key"));
- BOOST_REQUIRE_EQUAL(const_info.retrieve_auxiliary_opt<int>("key"), &const_info.retrieve_auxiliary<int>("key"));
+ BOOST_REQUIRE_EQUAL(info.retrieve_auxiliary_opt<int>("key"), &info.retrieve_auxiliary<int>("key"));
+ BOOST_REQUIRE_EQUAL(const_info.retrieve_auxiliary_opt<int>("key"), &const_info.retrieve_auxiliary<int>("key"));

- BOOST_REQUIRE_EQUAL(info.retrieve_auxiliary_opt<int>("missing"), nullptr);
- BOOST_REQUIRE_EQUAL(const_info.retrieve_auxiliary_opt<int>("missing"), nullptr);
+ BOOST_REQUIRE_EQUAL(info.retrieve_auxiliary_opt<int>("missing"), nullptr);
+ BOOST_REQUIRE_EQUAL(const_info.retrieve_auxiliary_opt<int>("missing"), nullptr);

- return make_ready_future<>();
+ return make_ready_future<>();
+ });
}

void send_messages_and_check_timeouts(rpc_test_env<>& env, test_rpc_proto::client& cln) {
@@ -1498,3 +1500,25 @@ SEASTAR_TEST_CASE(test_rpc_send_timeout_on_connect) {
send_messages_and_check_timeouts(env, cln);
});
}
+
+SEASTAR_TEST_CASE(test_rpc_abort_connection) {
+ return rpc_test_env<>::do_with_thread(rpc_test_config(), [] (rpc_test_env<>& env) {
+ test_rpc_proto::client c1(env.proto(), {}, env.make_socket(), ipv4_addr());
+ int arrived = 0;
+ env.register_handler(1, [&arrived] (rpc::client_info& cinfo, int x) {
+ BOOST_REQUIRE_EQUAL(x, arrived++);
+ if (arrived == 2) {
+ cinfo.server.abort_connection(cinfo.conn_id);
+ }
+ // The third message won't arrive because we abort the connection.
+
+ return 0;
+ }).get();
+ auto f = env.proto().make_client<int (int)>(1);
+ BOOST_REQUIRE_EQUAL(f(c1, 0).get0(), 0);
+ BOOST_REQUIRE_THROW(f(c1, 1).get0(), rpc::closed_error);
+ BOOST_REQUIRE_THROW(f(c1, 2).get0(), rpc::closed_error);
+ BOOST_REQUIRE_EQUAL(arrived, 2);
+ c1.stop().get0();
+ });
+}
Reply all
Reply to author
Forward
0 new messages