It is possible to have multiple connections with the
invalid_connection_id id. When that happens, the server no longer has
a complete list of connections
Without a complete list, the future returned by server::stop might
complete while a connection is still running and
_server._conns.erase(get_connection_id());
in server::connection::process might access free memory.
Signed-off-by: Rafael Ávila de Espíndola <
espi...@scylladb.com>
---
include/seastar/rpc/rpc.hh | 12 +++++++++---
src/rpc/rpc.cc | 26 ++++++++++++++++++--------
2 files changed, 27 insertions(+), 11 deletions(-)
diff --git a/include/seastar/rpc/rpc.hh b/include/seastar/rpc/rpc.hh
index 77471670..606ce9a1 100644
--- a/include/seastar/rpc/rpc.hh
+++ b/include/seastar/rpc/rpc.hh
@@ -515,7 +515,13 @@ class server {
server_socket _ss;
resource_limits _limits;
rpc_semaphore _resources_available;
- std::unordered_map<connection_id, shared_ptr<connection>> _conns;
+
+ // There can be many connections with invalid_connection_id, so keep both a map and a set.
+ std::unordered_map<connection_id, shared_ptr<connection>> _conns_map;
+ std::unordered_set<shared_ptr<connection>> _all_conns;
+ void add_connection(shared_ptr<connection>);
+ void remove_connection(shared_ptr<connection>);
+
promise<> _ss_stopped;
gate _reply_gate;
server_options _options;
@@ -530,8 +536,8 @@ class server {
future<> stop();
template<typename Func>
void foreach_connection(Func&& f) {
- for (auto c : _conns) {
- f(*c.second);
+ for (auto c : _all_conns) {
+ f(*c);
}
}
gate& reply_gate() {
diff --git a/src/rpc/rpc.cc b/src/rpc/rpc.cc
index 575c8361..0926bfb4 100644
--- a/src/rpc/rpc.cc
+++ b/src/rpc/rpc.cc
@@ -725,6 +725,16 @@ namespace rpc {
{}
+ void server::add_connection(shared_ptr<connection> c) {
+ _conns_map.emplace(c->get_connection_id(), c);
+ _all_conns.insert(c);
+ }
+
+ void server::remove_connection(shared_ptr<connection> c) {
+ _conns_map.erase(c->get_connection_id());
+ _all_conns.erase(c);
+ }
+
future<feature_map>
server::connection::negotiate(feature_map requested) {
feature_map ret;
@@ -751,15 +761,15 @@ namespace rpc {
_parent_id = deserialize_connection_id(e.second);
_is_stream = true;
// remove stream connection from rpc connection list
- _server._conns.erase(get_connection_id());
+ _server.remove_connection(shared_from_this());
f = smp::submit_to(_parent_id.shard(), [this, c = make_foreign(static_pointer_cast<rpc::connection>(shared_from_this()))] () mutable {
auto sit = _servers.find(*_server._options.streaming_domain);
if (sit == _servers.end()) {
throw std::logic_error(format("Shard {:d} does not have server with streaming domain {:x}", engine().cpu_id(), *_server._options.streaming_domain).c_str());
}
auto s = sit->second;
- auto it = s->_conns.find(_parent_id);
- if (it == s->_conns.end()) {
+ auto it = s->_conns_map.find(_parent_id);
+ if (it == s->_conns_map.end()) {
throw std::logic_error(format("Unknown parent connection {:d} on shard {:d}", _parent_id, engine().cpu_id()).c_str());
}
auto id = c->get_connection_id();
@@ -908,7 +918,7 @@ namespace rpc {
_stream_queue.abort(std::make_exception_ptr(stream_closed()));
return stop_send_loop().then_wrapped([this] (future<> f) {
f.ignore_ready_future();
- _server._conns.erase(get_connection_id());
+ _server.remove_connection(shared_from_this());
if (is_stream()) {
return deregister_this_stream();
} else {
@@ -935,8 +945,8 @@ namespace rpc {
auto sit = server::_servers.find(*_server._options.streaming_domain);
if (sit != server::_servers.end()) {
auto s = sit->second;
- auto it = s->_conns.find(_parent_id);
- if (it != s->_conns.end()) {
+ auto it = s->_conns_map.find(_parent_id);
+ if (it != s->_conns_map.end()) {
it->second->_streams.erase(get_connection_id());
}
}
@@ -975,7 +985,7 @@ namespace rpc {
id = {_next_client_id++ << 16 | uint16_t(engine().cpu_id())};
}
auto conn = _proto->make_server_connection(*this, std::move(fd), std::move(addr), id);
- _conns.emplace(id, conn);
+ add_connection(conn);
conn->process();
});
}).then_wrapped([this] (future<>&& f){
@@ -995,7 +1005,7 @@ namespace rpc {
_servers.erase(*_options.streaming_domain);
}
return when_all(_ss_stopped.get_future(),
- parallel_for_each(_conns | boost::adaptors::map_values, [] (shared_ptr<connection> conn) {
+ parallel_for_each(_all_conns, [this] (shared_ptr<connection> conn) {
return conn->stop();
}),
_reply_gate.close()
--
2.20.1