[COMMIT seastar master] websocket: add a basic implementation of websocket frame handling

1 view
Skip to first unread message

Commit Bot

<bot@cloudius-systems.com>
unread,
Jun 23, 2022, 8:25:46 AM6/23/22
to seastar-dev@googlegroups.com, Andrzej Stalke
From: Andrzej Stalke <38095405...@users.noreply.github.com>
Committer: Piotr Sarna <sa...@scylladb.com>
Branch: master

websocket: add a basic implementation of websocket frame handling

This commit adds a basic implementation of websocket data frame
handling. Implementation is not completed and is subject to change.
Related document:
- https://datatracker.ietf.org/doc/html/rfc6455

---
diff --git a/demos/websocket_demo.cc b/demos/websocket_demo.cc
--- a/demos/websocket_demo.cc
+++ b/demos/websocket_demo.cc
@@ -37,6 +37,9 @@ int main(int argc, char** argv) {
app.run(argc, argv, [] () -> seastar::future<> {
return async([] {
websocket::server ws;
+ ws.register_handler("echo", [] (temporary_buffer<char> &&buf, output_stream<char> &write_buf) {
+ return write_buf.write(std::move(buf));
+ });
auto d = defer([&ws] () noexcept {
ws.stop().get();
});
diff --git a/include/seastar/websocket/server.hh b/include/seastar/websocket/server.hh
--- a/include/seastar/websocket/server.hh
+++ b/include/seastar/websocket/server.hh
@@ -21,6 +21,9 @@

#pragma once

+#include <map>
+#include <functional>
+
#include <seastar/http/request_parser.hh>
#include <seastar/core/seastar.hh>
#include <seastar/core/sstring.hh>
@@ -31,6 +34,8 @@

namespace seastar::experimental::websocket {

+using handler_t = std::function<future<>(input_stream<char>&, output_stream<char>&)>;
+
class server;
struct reply {
//TODO: implement
@@ -39,19 +44,162 @@ struct reply {
/*!
* \brief an error in handling a WebSocket connection
*/
-class exception : std::exception {
+class exception : public std::exception {
std::string _msg;
public:
exception(std::string_view msg) : _msg(msg) {}
- const char* what() const noexcept {
+ virtual const char* what() const noexcept {
return _msg.c_str();
}
};

+struct frame_header {
+ static constexpr uint8_t FIN = 7;
+ static constexpr uint8_t RSV1 = 6;
+ static constexpr uint8_t RSV2 = 5;
+ static constexpr uint8_t RSV3 = 4;
+ static constexpr uint8_t MASKED = 7;
+
+ uint8_t fin : 1;
+ uint8_t rsv1 : 1;
+ uint8_t rsv2 : 1;
+ uint8_t rsv3 : 1;
+ uint8_t opcode : 4;
+ uint8_t masked : 1;
+ uint8_t length : 7;
+ frame_header(const char* input) {
+ this->fin = (input[0] >> FIN) & 1;
+ this->rsv1 = (input[0] >> RSV1) & 1;
+ this->rsv2 = (input[0] >> RSV2) & 1;
+ this->rsv3 = (input[0] >> RSV3) & 1;
+ this->opcode = input[0] & 0b1111;
+ this->masked = (input[1] >> MASKED) & 1;
+ this->length = (input[1] & 0b1111111);
+ }
+ // Returns length of the rest of the header.
+ uint64_t get_rest_of_header_length() {
+ size_t next_read_length = sizeof(uint32_t); // Masking key
+ if (length == 126) {
+ next_read_length += sizeof(uint16_t);
+ } else if (length == 127) {
+ next_read_length += sizeof(uint64_t);
+ }
+ return next_read_length;
+ }
+ uint8_t get_fin() {return fin;}
+ uint8_t get_rsv1() {return rsv1;}
+ uint8_t get_rsv2() {return rsv2;}
+ uint8_t get_rsv3() {return rsv3;}
+ uint8_t get_opcode() {return opcode;}
+ uint8_t get_masked() {return masked;}
+ uint8_t get_length() {return length;}
+
+ bool is_opcode_known() {
+ //https://datatracker.ietf.org/doc/html/rfc6455#section-5.1
+ return opcode < 0xA && !(opcode < 0x8 && opcode > 0x2);
+ }
+};
+
+
+
+class websocket_parser {
+ enum class parsing_state : uint8_t {
+ flags_and_payload_data,
+ payload_length_and_mask,
+ payload
+ };
+ enum class connection_state : uint8_t {
+ valid,
+ closed,
+ error
+ };
+ using consumption_result_t = consumption_result<char>;
+ using buff_t = temporary_buffer<char>;
+ // What parser is currently doing.
+ parsing_state _state;
+ // State of connection - can be valid, closed or should be closed
+ // due to error.
+ connection_state _cstate;
+ sstring _buffer;
+ std::unique_ptr<frame_header> _header;
+ uint64_t _payload_length;
+ uint32_t _masking_key;
+ buff_t _result;
+
+ static future<consumption_result_t> dont_stop() {
+ return make_ready_future<consumption_result_t>(continue_consuming{});
+ }
+ static future<consumption_result_t> stop(buff_t data) {
+ return make_ready_future<consumption_result_t>(stop_consuming(std::move(data)));
+ }
+
+ // Removes mask from payload given in p.
+ void remove_mask(buff_t& p, size_t n) {
+ char *payload = p.get_write();
+ for (uint64_t i = 0, j = 0; i < n; ++i, j = (j + 1) % 4) {
+ payload[i] ^= static_cast<char>(((_masking_key << (j * 8)) >> 24));
+ }
+ }
+public:
+ websocket_parser() : _state(parsing_state::flags_and_payload_data),
+ _cstate(connection_state::valid),
+ _payload_length(0),
+ _masking_key(0) {}
+ future<consumption_result_t> operator()(temporary_buffer<char> data);
+ bool is_valid() { return _cstate == connection_state::valid; }
+ bool eof() { return _cstate == connection_state::closed; }
+ buff_t result() { return std::move(_result); }
+};
+
/*!
* \brief a WebSocket connection
*/
class connection : public boost::intrusive::list_base_hook<> {
+ using buff_t = temporary_buffer<char>;
+
+ /*!
+ * \brief Implementation of connection's data source.
+ */
+ class connection_source_impl final : public data_source_impl {
+ queue<buff_t>* data;
+
+ public:
+ connection_source_impl(queue<buff_t>* data) : data(data) {}
+
+ virtual future<buff_t> get() override {
+ return data->pop_eventually();
+ }
+
+ virtual future<> close() override {
+ data->abort(std::make_exception_ptr(exception("Connection closed")));
+ return make_ready_future<>();
+ }
+ };
+
+ /*!
+ * \brief Implementation of connection's data sink.
+ */
+ class connection_sink_impl final : public data_sink_impl {
+ queue<buff_t>* data;
+ public:
+ connection_sink_impl(queue<buff_t>* data) : data(data) {}
+
+ virtual future<> put(net::packet d) override {
+ net::fragment f = d.frag(0);
+ return data->push_eventually(temporary_buffer<char>{std::move(f.base), f.size});
+ }
+
+ size_t buffer_size() const noexcept override {
+ return data->max_size();
+ }
+
+ virtual future<> close() override {
+ data->abort(std::make_exception_ptr(exception("Connection closed")));
+ return make_ready_future<>();
+ }
+ };
+
+ static const size_t PIPE_SIZE = 512;
server& _server;
connected_socket _fd;
input_stream<char> _read_buf;
@@ -60,6 +208,15 @@ class connection : public boost::intrusive::list_base_hook<> {
std::unique_ptr<reply> _resp;
queue<std::unique_ptr<reply>> _replies{10};
bool _done = false;
+
+ websocket_parser _websocket_parser;
+ queue <temporary_buffer<char>> _input_buffer;
+ input_stream<char> _input;
+
+ queue <temporary_buffer<char>> _output_buffer;
+ output_stream<char> _output;
+ sstring _subprotocol;
+ handler_t _handler;
public:
/*!
* \param server owning \ref server
@@ -70,7 +227,13 @@ public:
, _fd(std::move(fd))
, _read_buf(_fd.input())
, _write_buf(_fd.output())
+ , _input_buffer{PIPE_SIZE}
+ , _output_buffer{PIPE_SIZE}
{
+ _input = input_stream<char>{data_source{
+ std::make_unique<connection_source_impl>(&_input_buffer)}};
+ _output = output_stream<char>{data_sink{
+ std::make_unique<connection_sink_impl>(&_output_buffer)}};
on_new_connection();
}
~connection();
@@ -83,13 +246,18 @@ public:
* \brief close the socket
*/
void shutdown();
-
+
protected:
future<> read_loop();
future<> read_one();
future<> read_http_upgrade_request();
future<> response_loop();
void on_new_connection();
+ /*!
+ * \brief Packs buff in websocket data frame and sends it to the client.
+ */
+ future<> send_data(temporary_buffer<char>&& buff);
+
};

/*!
@@ -102,6 +270,7 @@ class server {
std::vector<server_socket> _listeners;
gate _task_gate;
boost::intrusive::list<connection> _connections;
+ std::map<std::string, handler_t> _handlers;
public:
/*!
* \brief listen for a WebSocket connection on given address
@@ -120,6 +289,10 @@ public:
*/
future<> stop();

+ bool is_handler_registered(std::string const& name);
+
+ void register_handler(std::string&& name, handler_t handler);
+
friend class connection;
protected:
void do_accepts(int which);
diff --git a/src/websocket/server.cc b/src/websocket/server.cc
--- a/src/websocket/server.cc
+++ b/src/websocket/server.cc
@@ -24,6 +24,7 @@
#include <cryptopp/sha.h>
#include <cryptopp/filters.h>
#include <cryptopp/base64.h>
+#include <seastar/core/scattered_message.hh>

#ifndef CRYPTOPP_NO_GLOBAL_BYTE
namespace CryptoPP {
@@ -110,6 +111,7 @@ future<> connection::process() {
} catch (...) {
wlogger.debug("Read exception encountered: {}", std::current_exception());
}
+
try {
std::get<1>(joined).get();
} catch (...) {
@@ -139,13 +141,27 @@ future<> connection::read_http_upgrade_request() {
}
std::unique_ptr<httpd::request> req = _http_parser.get_parsed_request();
if (_http_parser.failed()) {
+ return make_exception_future<>(websocket::exception("Incorrect upgrade request"));
throw websocket::exception("Incorrect upgrade request");
}

sstring upgrade_header = req->get_header("Upgrade");
if (upgrade_header != "websocket") {
- throw websocket::exception("Upgrade header missing");
+ return make_exception_future<>("Upgrade header missing");
}
+
+ sstring subprotocol = req->get_header("Sec-WebSocket-Protocol");
+ if (subprotocol.empty()) {
+ return make_exception_future<>("Subprotocol header missing.");
+ }
+
+ if (!_server.is_handler_registered(subprotocol)) {
+ return make_exception_future<>("Subprotocol not supported.");
+ }
+ this->_handler = this->_server._handlers[subprotocol];
+ this->_subprotocol = subprotocol;
+ wlogger.debug("Sec-WebSocket-Protocol: {}", subprotocol);
+
sstring sec_key = req->get_header("Sec-Websocket-Key");
sstring sec_version = req->get_header("Sec-Websocket-Version");

@@ -166,40 +182,175 @@ future<> connection::read_http_upgrade_request() {
});
}

+future<websocket_parser::consumption_result_t> websocket_parser::operator()(
+ temporary_buffer<char> data) {
+ if (data.size() == 0) {
+ // EOF
+ _cstate = connection_state::closed;
+ return websocket_parser::stop(std::move(data));
+ }
+ if (_state == parsing_state::flags_and_payload_data) {
+ if (_buffer.length() + data.size() >= 2) {
+ if (_buffer.length() < 2) {
+ size_t hlen = _buffer.length();
+ _buffer.append(data.get(), 2 - hlen);
+ data.trim_front(2 - hlen);
+ _header = std::make_unique<frame_header>(_buffer.data());
+ _buffer = {};
+
+ // https://datatracker.ietf.org/doc/html/rfc6455#section-5.1
+ // We must close the connection if data isn't masked.
+ if ((!_header->masked) ||
+ // RSVX must be 0
+ (_header->rsv1 | _header->rsv2 | _header->rsv3) ||
+ // Opcode must be known.
+ (!_header->is_opcode_known())) {
+ _cstate = connection_state::error;
+ return websocket_parser::stop(std::move(data));
+ }
+ }
+ _state = parsing_state::payload_length_and_mask;
+ } else {
+ _buffer.append(data.get(), data.size());
+ return websocket_parser::dont_stop();
+ }
+ }
+ if (_state == parsing_state::payload_length_and_mask) {
+ size_t const required_bytes = _header->get_rest_of_header_length();
+ if (_buffer.length() + data.size() >= required_bytes) {
+ if (_buffer.length() < required_bytes) {
+ size_t hlen = _buffer.length();
+ _buffer.append(data.get(), required_bytes - hlen);
+ data.trim_front(required_bytes - hlen);
+
+ _payload_length = _header->length;
+ size_t offset = 0;
+ char const *input = _buffer.data();
+ if (_header->length == 126) {
+ _payload_length = be16toh(*(uint16_t const *)(input + offset));
+ offset += sizeof(uint16_t);
+ } else if (_header->length == 127) {
+ _payload_length = be64toh(*(uint64_t const *)(input + offset));
+ offset += sizeof(uint64_t);
+ }
+ _masking_key = be32toh(*(uint32_t const *)(input + offset));
+ _buffer = {};
+ }
+ _state = parsing_state::payload;
+ } else {
+ _buffer.append(data.get(), data.size());
+ return websocket_parser::dont_stop();
+ }
+ }
+ if (_state == parsing_state::payload) {
+ if (_payload_length > data.size()) {
+ _payload_length -= data.size();
+ remove_mask(data, data.size());
+ _result = std::move(data);
+ return websocket_parser::stop(buff_t(0));
+ } else {
+ _result = data.clone();
+ remove_mask(_result, _payload_length);
+ data.trim_front(_payload_length);
+ _payload_length = 0;
+ _state = parsing_state::flags_and_payload_data;
+ return websocket_parser::stop(std::move(data));
+ }
+ }
+ _cstate = connection_state::error;
+ return websocket_parser::stop(std::move(data));
+}
+
future<> connection::read_one() {
- return _read_buf.read().then([this] (temporary_buffer<char> buf) {
- if (buf.empty()) {
+ return _read_buf.consume(_websocket_parser).then([this] () mutable {
+ if (_websocket_parser.is_valid()) {
+ // FIXME: implement error handling
+ return _input_buffer.push_eventually(_websocket_parser.result());
+ } else if (_websocket_parser.eof()) {
_done = true;
+ return make_ready_future<>();
}
- //FIXME: implement
- wlogger.info("Received: {}", buf.get());
+ wlogger.debug("Reading from socket has failed.");
+ _done = true;
+ return make_ready_future<>();
});
}

future<> connection::read_loop() {
return read_http_upgrade_request().then([this] {
- return do_until([this] {return _done;}, [this] {
- return read_one();
+ return when_all(
+ _handler(_input, _output),
+ do_until([this] {return _done;}, [this] {return read_one();})
+ ).then([] (std::tuple<future<>, future<>> joined) {
+ try {
+ std::get<0>(joined).get();
+ } catch (...) {
+ wlogger.debug("Handler exception encountered: {}",
+ std::current_exception());
+ }
+ try {
+ std::get<1>(joined).get();
+ } catch (...) {
+ wlogger.debug("Read exception encountered: {}",
+ std::current_exception());
+ }
+ // FIXME
+ return _replies.push_eventually({});
+ }).finally([this] {
+ return _read_buf.close();
});
- }).then_wrapped([this] (future<> f) {
- if (f.failed()) {
- wlogger.error("Read failed: {}", f.get_exception());
- }
- return _replies.push_eventually({});
- }).finally([this] {
- return _read_buf.close();
});
}

+future<> connection::send_data(temporary_buffer<char>&& buff) {
+ sstring data;
+ data.append("\x81", 1);
+ if ((126 <= buff.size()) && (buff.size() <= std::numeric_limits<uint16_t>::max())) {
+ uint16_t length = buff.size();
+ length = htobe16(length);
+ data.append("\x7e", 1);
+ data.append(reinterpret_cast<char*>(&length), sizeof(uint16_t));
+ } else if (std::numeric_limits<uint16_t>::max() < buff.size()) {
+ uint64_t length = buff.size();
+ length = htobe64(length);
+ data.append("\x7f", 1);
+ data.append(reinterpret_cast<char*>(&length), sizeof(uint64_t));
+ } else {
+ uint8_t length = buff.size() & 0x7F;
+ data.append(reinterpret_cast<char*>(&length), sizeof(uint8_t));
+ }
+
+ scattered_message<char> msg;
+ msg.append(std::move(data));
+ msg.append(std::move(buff));
+ return _write_buf.write(std::move(msg)).then([this] {
+ return _write_buf.flush();
+ });
+}
+
future<> connection::response_loop() {
- // FIXME: implement
- return make_ready_future<>();
+ return do_until([this] {return _done;}, [this] {
+ // FIXME: implement error handling
+ return _output_buffer.pop_eventually().then([this] (
+ temporary_buffer<char> buf) {
+ return send_data(std::move(buf));
+ });
+ });
}

void connection::shutdown() {
wlogger.debug("Shutting down");
_fd.shutdown_input();
_fd.shutdown_output();
+ when_all(_input.close(), _output.close()).discard_result().get();
+}
+
+bool server::is_handler_registered(std::string const& name) {
+ return _handlers.find(name) != _handlers.end();
+}
+
+void server::register_handler(std::string&& name, handler_t handler) {
+ _handlers[name] = handler;
}

}
diff --git a/tests/unit/websocket_test.cc b/tests/unit/websocket_test.cc
--- a/tests/unit/websocket_test.cc
+++ b/tests/unit/websocket_test.cc
@@ -20,6 +20,7 @@ SEASTAR_TEST_CASE(test_websocket_handshake) {
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
"Sec-WebSocket-Version: 13\r\n"
+ "Sec-WebSocket-Protocol: echo\r\n"
"\r\n";
loopback_connection_factory factory;
loopback_socket_impl lsi(factory);
@@ -31,6 +32,10 @@ SEASTAR_TEST_CASE(test_websocket_handshake) {
auto output = sock.output();

websocket::server dummy;
+ dummy.register_handler("echo", [] (input_stream<char>& in,
+ output_stream<char>& out) {
+ return make_ready_future<>();
+ });
websocket::connection conn(dummy, acceptor.get0().connection);
future<> serve = conn.process();

Reply all
Reply to author
Forward
0 new messages