Line data Source code
1 : #include "tcp_socket_server.hpp"
2 : #include "workers.hpp"
3 : #include "../util/logger.hpp"
4 : #include "../util/types.hpp"
5 : #include <boost/asio/ssl.hpp>
6 :
7 : namespace thinger::asio {
8 :
9 : // Constructor with io_context providers
10 601 : tcp_socket_server::tcp_socket_server(std::string host,
11 : std::string port,
12 : io_context_provider acceptor_context_provider,
13 : io_context_provider connection_context_provider,
14 : std::set<std::string> allowed_remotes,
15 601 : std::set<std::string> forbidden_remotes)
16 601 : : socket_server_base(std::move(acceptor_context_provider),
17 601 : std::move(connection_context_provider),
18 601 : std::move(allowed_remotes),
19 601 : std::move(forbidden_remotes))
20 601 : , host_(std::move(host))
21 3005 : , port_(std::move(port))
22 : {
23 601 : }
24 :
25 : // Legacy constructor for backward compatibility
26 12 : tcp_socket_server::tcp_socket_server(std::string host,
27 : std::string port,
28 : std::set<std::string> allowed_remotes,
29 12 : std::set<std::string> forbidden_remotes)
30 : : tcp_socket_server(host, port,
31 12 : []() -> boost::asio::io_context& { return get_workers().get_thread_io_context(); },
32 12 : []() -> boost::asio::io_context& { return get_workers().get_next_io_context(); },
33 12 : std::move(allowed_remotes),
34 24 : std::move(forbidden_remotes))
35 : {
36 12 : }
37 :
38 1172 : tcp_socket_server::~tcp_socket_server() {
39 601 : close_acceptor();
40 1172 : }
41 :
42 595 : bool tcp_socket_server::stop() {
43 : // First call base class to set running_ = false
44 595 : socket_server_base::stop();
45 :
46 : // Now close the acceptor
47 595 : close_acceptor();
48 :
49 595 : return true;
50 : }
51 :
52 1196 : void tcp_socket_server::close_acceptor() {
53 1196 : if (acceptor_) {
54 595 : boost::system::error_code ec;
55 595 : acceptor_->close(ec);
56 595 : if (ec) {
57 0 : LOG_WARNING("Error closing TCP acceptor: {}", ec.message());
58 : }
59 595 : acceptor_.reset();
60 : }
61 1196 : }
62 :
63 0 : void tcp_socket_server::set_tcp_no_delay(bool tcp_no_delay) {
64 0 : tcp_no_delay_ = tcp_no_delay;
65 0 : }
66 :
67 52 : void tcp_socket_server::enable_ssl(bool ssl, bool client_certificate) {
68 52 : ssl_enabled_ = ssl;
69 52 : client_certificate_ = client_certificate;
70 52 : }
71 :
72 52 : void tcp_socket_server::set_ssl_context(std::shared_ptr<boost::asio::ssl::context> context) {
73 52 : ssl_context_ = std::move(context);
74 52 : }
75 :
76 52 : void tcp_socket_server::set_sni_callback(sni_callback_type callback) {
77 52 : if (ssl_context_) {
78 52 : SSL_CTX_set_tlsext_servername_callback(ssl_context_->native_handle(), callback);
79 : }
80 52 : }
81 :
82 0 : std::string tcp_socket_server::get_service_name() const {
83 0 : return (ssl_enabled_ ? "ssl_server@" : "tcp_server@") + host_ + ":" + port_;
84 : }
85 :
86 573 : uint16_t tcp_socket_server::local_port() const {
87 573 : return acceptor_ ? acceptor_->local_endpoint().port() : 0;
88 : }
89 :
90 601 : bool tcp_socket_server::create_acceptor() {
91 601 : int num_attempts = 0;
92 :
93 : // Get io_context from provider
94 601 : boost::asio::io_context& io_context = acceptor_context_provider_();
95 :
96 : // Resolve endpoint
97 601 : boost::asio::ip::tcp::endpoint endpoint;
98 : try {
99 601 : boost::asio::ip::tcp::resolver resolver(io_context);
100 601 : auto results = resolver.resolve(host_, port_);
101 598 : if (results.begin() == results.end()) {
102 0 : LOG_ERROR("no endpoints found for {}:{}", host_, port_);
103 0 : return false;
104 : }
105 :
106 598 : auto entry = *results.begin();
107 598 : endpoint = entry.endpoint();
108 604 : } catch (const boost::system::system_error& e) {
109 3 : LOG_ERROR("failed to resolve {}:{} - {}", host_, port_, e.code().message());
110 3 : return false;
111 3 : }
112 :
113 598 : bool success = false;
114 : do {
115 598 : LOG_DEBUG("starting TCP socket acceptor on {}:{}", host_, port_);
116 598 : if (num_attempts > 0) {
117 0 : std::this_thread::sleep_for(std::chrono::seconds(5));
118 : }
119 :
120 598 : acceptor_ = std::make_unique<boost::asio::ip::tcp::acceptor>(io_context);
121 598 : acceptor_->open(endpoint.protocol());
122 598 : acceptor_->set_option(boost::asio::ip::tcp::acceptor::reuse_address(true));
123 :
124 : try {
125 598 : LOG_DEBUG("binding and listening to endpoint: {}:{}",
126 : endpoint.address().to_string(), endpoint.port());
127 598 : acceptor_->bind(endpoint);
128 595 : acceptor_->listen();
129 595 : success = true;
130 3 : } catch (boost::system::system_error& error) {
131 3 : LOG_ERROR("cannot start listening on {}:{}: {}",
132 : host_, port_, error.code().message());
133 : // Reset acceptor if binding failed to avoid inconsistent state
134 3 : acceptor_.reset();
135 3 : if (max_listening_attempts_ >= 0 && num_attempts >= max_listening_attempts_) {
136 0 : return false;
137 : }
138 3 : }
139 598 : num_attempts++;
140 598 : } while (!success && (max_listening_attempts_ < 0 || num_attempts < max_listening_attempts_));
141 :
142 598 : if (success) {
143 595 : LOG_INFO("TCP server is now listening on {}:{}", host_, port_);
144 : }
145 :
146 598 : return success;
147 : }
148 :
149 1417 : void tcp_socket_server::accept_connection() {
150 : // Get next io_context from provider
151 1417 : boost::asio::io_context& io_context = connection_context_provider_();
152 :
153 : // Create socket based on SSL configuration
154 1417 : std::shared_ptr<tcp_socket> sock;
155 1417 : if (ssl_enabled_) {
156 104 : if (!ssl_context_) {
157 0 : LOG_ERROR("SSL enabled but no SSL context configured");
158 0 : return;
159 : }
160 104 : sock = std::make_shared<ssl_socket>("ssl_socket_server", io_context, ssl_context_);
161 : } else {
162 1313 : sock = std::make_shared<tcp_socket>("tcp_socket_server", io_context);
163 : }
164 :
165 1417 : auto& socket = sock->get_socket();
166 :
167 : // Start accepting a connection
168 1417 : acceptor_->async_accept(socket, [sock = std::move(sock), this](const boost::system::error_code& e) mutable {
169 1215 : if (!e) {
170 : // Get remote socket ip
171 828 : auto remote_ip = sock->get_remote_ip();
172 :
173 : // Check if IP is allowed
174 828 : if (!is_remote_allowed(remote_ip)) {
175 0 : sock->close();
176 0 : LOG_WARNING("rejecting connection from: ip: {}, port: {}, secure: {}",
177 : remote_ip, sock->get_local_port(), sock->is_secure());
178 0 : if (running_) accept_connection();
179 0 : return;
180 : }
181 :
182 828 : LOG_INFO("received connection from: ip: {}, port: {}, secure: {}",
183 : remote_ip, sock->get_local_port(), sock->is_secure());
184 :
185 828 : if (tcp_no_delay_) {
186 828 : sock->enable_tcp_no_delay();
187 : }
188 :
189 828 : if (sock->requires_handshake()) {
190 : // Use co_spawn to run the coroutine-based handshake
191 52 : co_spawn(sock->get_io_context(),
192 156 : [this, sock]() -> awaitable<void> {
193 : auto ec = co_await sock->handshake();
194 : if (ec) {
195 : LOG_ERROR("error while handling SSL handshake: {}, remote ip: {}",
196 : ec.message(), sock->get_remote_ip());
197 : co_return;
198 : }
199 : if (handler_) handler_(sock);
200 104 : },
201 : detached);
202 : } else {
203 776 : if (handler_) handler_(std::move(sock));
204 : }
205 :
206 : // Continue accepting connections
207 828 : if (running_) accept_connection();
208 828 : } else {
209 387 : if (e != boost::asio::error::operation_aborted) {
210 0 : LOG_ERROR("cannot accept more connections: {}", e.message());
211 0 : if (running_) {
212 : // Retry after a delay to avoid tight loop on persistent errors
213 : auto timer = std::make_shared<boost::asio::steady_timer>(
214 0 : acceptor_context_provider_(),
215 0 : std::chrono::seconds(1)
216 0 : );
217 0 : timer->async_wait([this, timer](const boost::system::error_code& e) {
218 0 : if (e != boost::asio::error::operation_aborted) {
219 0 : accept_connection();
220 : }
221 0 : });
222 0 : }
223 : } else {
224 387 : LOG_INFO("stop accepting connections");
225 : }
226 : }
227 : });
228 1417 : }
229 :
230 : } // namespace thinger::asio
|