1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 #ifndef _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_ 21 #define _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_ 1 22 23 #include <cstdlib> 24 #include <iostream> 25 #include <sstream> 26 27 #include <openssl/sha.h> 28 29 #include <thrift/config.h> 30 #include <thrift/protocol/TProtocol.h> 31 #include <thrift/transport/TSocket.h> 32 #include <thrift/transport/THttpServer.h> 33 #if defined(_MSC_VER) || defined(__MINGW32__) 34 #include <Shlwapi.h> 35 #define THRIFT_strncasecmp(str1, str2, len) _strnicmp(str1, str2, len) 36 #define THRIFT_strcasestr(haystack, needle) StrStrIA(haystack, needle) 37 #else 38 #define THRIFT_strncasecmp(str1, str2, len) strncasecmp(str1, str2, len) 39 #define THRIFT_strcasestr(haystack, needle) strcasestr(haystack, needle) 40 #endif 41 #if defined(__CYGWIN__) 42 #include <alloca.h> 43 #endif 44 45 using std::string; 46 47 namespace apache { 48 namespace thrift { 49 namespace transport { 50 51 std::string base64Encode(unsigned char* data, int length); 52 53 template <bool binary> 54 class TWebSocketServer : public THttpServer { 55 public: 56 TWebSocketServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr) THttpServer(transport,config)57 : THttpServer(transport, config) { 58 resetHandshake(); 59 } 60 61 ~TWebSocketServer() override = default; 62 readAll_virt(uint8_t * buf,uint32_t len)63 uint32_t readAll_virt(uint8_t* buf, uint32_t len) override { 64 // If we do not have a good handshake, the client will attempt one. 65 if (!handshakeComplete()) { 66 resetHandshake(); 67 THttpServer::read(buf, len); 68 // If we did not get everything we expected, the handshake failed 69 // and we need to send a 400 response back. 70 if (!handshakeComplete()) { 71 sendBadRequest(); 72 return 0; 73 } 74 // Otherwise, send back the 101 response. 75 THttpServer::flush(); 76 } 77 78 uint32_t want = len; 79 auto have = readBuffer_.available_read(); 80 81 // If we have some data in the buffer, copy it out and return it. 82 // We have to return it without attempting to read more, since we aren't 83 // guaranteed that the underlying transport actually has more data, so 84 // attempting to read from it could block. 85 if (have > 0 && have >= want) { 86 return readBuffer_.read(buf, want); 87 } 88 89 // Read another frame. 90 if (!readFrame()) { 91 // EOF. No frame available. 92 return 0; 93 } 94 95 // Hand over whatever we have. 96 uint32_t give = (std::min)(want, readBuffer_.available_read()); 97 return readBuffer_.read(buf, give); 98 } 99 flush()100 void flush() override { 101 resetConsumedMessageSize(); 102 writeFrameHeader(); 103 uint8_t* buffer; 104 uint32_t length; 105 writeBuffer_.getBuffer(&buffer, &length); 106 transport_->write(buffer, length); 107 transport_->flush(); 108 writeBuffer_.resetBuffer(); 109 } 110 111 protected: getHeader(uint32_t len)112 std::string getHeader(uint32_t len) override { 113 THRIFT_UNUSED_VARIABLE(len); 114 std::ostringstream h; 115 h << "HTTP/1.1 101 Switching Protocols" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF 116 << "Upgrade: websocket" << CRLF << "Connection: Upgrade" << CRLF 117 << "Sec-WebSocket-Accept: " << acceptKey_ << CRLF << CRLF; 118 return h.str(); 119 } 120 parseHeader(char * header)121 void parseHeader(char* header) override { 122 char* colon = strchr(header, ':'); 123 if (colon == nullptr) { 124 return; 125 } 126 size_t sz = colon - header; 127 char* value = colon + 1; 128 129 if (THRIFT_strncasecmp(header, "Upgrade", sz) == 0) { 130 if (THRIFT_strcasestr(value, "websocket") != nullptr) { 131 upgrade_ = true; 132 } 133 } else if (THRIFT_strncasecmp(header, "Connection", sz) == 0) { 134 if (THRIFT_strcasestr(value, "Upgrade") != nullptr) { 135 connection_ = true; 136 } 137 } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Key", sz) == 0) { 138 std::string toHash = value + 1; 139 toHash += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 140 unsigned char hash[20]; 141 SHA1((const unsigned char*)toHash.c_str(), toHash.length(), hash); 142 acceptKey_ = base64Encode(hash, 20); 143 secWebSocketKey_ = true; 144 } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Version", sz) == 0) { 145 if (THRIFT_strcasestr(value, "13") != nullptr) { 146 secWebSocketVersion_ = true; 147 } 148 } 149 } 150 parseStatusLine(char * status)151 bool parseStatusLine(char* status) override { 152 char* method = status; 153 154 char* path = strchr(method, ' '); 155 if (path == nullptr) { 156 throw TTransportException(string("Bad Status: ") + status); 157 } 158 159 *path = '\0'; 160 while (*(++path) == ' ') { 161 }; 162 163 char* http = strchr(path, ' '); 164 if (http == nullptr) { 165 throw TTransportException(string("Bad Status: ") + status); 166 } 167 *http = '\0'; 168 169 if (strcmp(method, "GET") == 0) { 170 // GET method ok, looking for content. 171 return true; 172 } 173 throw TTransportException(string("Bad Status (unsupported method): ") + status); 174 } 175 176 private: 177 enum class CloseCode : uint16_t { 178 NormalClosure = 1000, 179 GoingAway = 1001, 180 ProtocolError = 1002, 181 UnsupportedDataType = 1003, 182 NoStatusCode = 1005, 183 AbnormalClosure = 1006, 184 InvalidData = 1007, 185 PolicyViolation = 1008, 186 MessageTooBig = 1009, 187 ExtensionExpected = 1010, 188 UnexpectedError = 1011, 189 NotSecure = 1015 190 }; 191 192 enum class Opcode : uint8_t { 193 Continuation = 0x0, 194 Text = 0x1, 195 Binary = 0x2, 196 Close = 0x8, 197 Ping = 0x9, 198 Pong = 0xA 199 }; 200 failConnection(CloseCode reason)201 void failConnection(CloseCode reason) { 202 writeFrameHeader(Opcode::Close); 203 auto buffer = htons(static_cast<uint16_t>(reason)); 204 transport_->write(reinterpret_cast<const uint8_t*>(&buffer), 2); 205 transport_->flush(); 206 transport_->close(); 207 } 208 handshakeComplete()209 bool handshakeComplete() { 210 return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_; 211 } 212 pong()213 void pong() { 214 writeFrameHeader(Opcode::Pong); 215 uint8_t* buffer; 216 uint32_t size; 217 readBuffer_.getBuffer(&buffer, &size); 218 transport_->write(buffer, size); 219 transport_->flush(); 220 } 221 readFrame()222 bool readFrame() { 223 uint8_t headerBuffer[8]; 224 225 auto read = transport_->read(headerBuffer, 2); 226 if (read < 2) { 227 return false; 228 } 229 // Since Thrift has its own message end marker and we read frame by frame, 230 // it doesn't really matter if the frame is marked as FIN. 231 // Capture it only for debugging only. 232 auto fin = (headerBuffer[0] & 0x80) != 0; 233 THRIFT_UNUSED_VARIABLE(fin); 234 235 // RSV1, RSV2, RSV3 236 if ((headerBuffer[0] & 0x70) != 0) { 237 failConnection(CloseCode::ProtocolError); 238 throw TTransportException(TTransportException::CORRUPTED_DATA, 239 "Reserved bits must be zeroes"); 240 } 241 242 auto opcode = (Opcode)(headerBuffer[0] & 0x0F); 243 244 // Mask 245 if ((headerBuffer[1] & 0x80) == 0) { 246 failConnection(CloseCode::ProtocolError); 247 throw TTransportException(TTransportException::CORRUPTED_DATA, 248 "Messages from the client must be masked"); 249 } 250 251 // Read the length 252 uint64_t payloadLength = headerBuffer[1] & 0x7F; 253 if (payloadLength == 126) { 254 read = transport_->read(headerBuffer, 2); 255 if (read < 2) { 256 return false; 257 } 258 payloadLength = ntohs(*reinterpret_cast<uint16_t*>(headerBuffer)); 259 } else if (payloadLength == 127) { 260 read = transport_->read(headerBuffer, 8); 261 if (read < 8) { 262 return false; 263 } 264 payloadLength = THRIFT_ntohll(*reinterpret_cast<uint64_t*>(headerBuffer)); 265 if ((payloadLength & 0x8000000000000000) != 0) { 266 failConnection(CloseCode::ProtocolError); 267 throw TTransportException( 268 TTransportException::CORRUPTED_DATA, 269 "The most significant bit of the payload length must be zero"); 270 } 271 } 272 273 // size_t is smaller than a ulong on a 32-bit system 274 if (payloadLength > UINT32_MAX) { 275 failConnection(CloseCode::MessageTooBig); 276 return false; 277 } 278 279 auto length = static_cast<uint32_t>(payloadLength); 280 281 if (length > 0) { 282 // Read the masking key 283 read = transport_->read(headerBuffer, 4); 284 if (read < 4) { 285 return false; 286 } 287 288 readBuffer_.resetBuffer(length); 289 uint8_t* buffer = readBuffer_.getWritePtr(length); 290 read = transport_->read(buffer, length); 291 readBuffer_.wroteBytes(read); 292 if (read < length) { 293 return false; 294 } 295 296 // Unmask the data 297 for (size_t i = 0; i < length; i++) { 298 buffer[i] ^= headerBuffer[i % 4]; 299 } 300 301 T_DEBUG("FIN=%d, Opcode=%X, length=%d, payload=%s", fin, opcode, length, 302 binary ? readBuffer_.toHexString() : cast(string) readBuffer_); 303 } 304 305 switch (opcode) { 306 case Opcode::Close: 307 if (length >= 2) { 308 uint8_t buffer[2]; 309 readBuffer_.read(buffer, 2); 310 CloseCode closeCode = static_cast<CloseCode>(ntohs(*reinterpret_cast<uint16_t*>(buffer))); 311 THRIFT_UNUSED_VARIABLE(closeCode); 312 string closeReason = readBuffer_.readAsString(length - 2); 313 T_DEBUG("Connection closed: %d %s", closeCode, closeReason); 314 } 315 transport_->close(); 316 return false; 317 case Opcode::Ping: 318 pong(); 319 return readFrame(); 320 default: 321 return true; 322 } 323 } 324 resetHandshake()325 void resetHandshake() { 326 connection_ = false; 327 secWebSocketKey_ = false; 328 secWebSocketVersion_ = false; 329 upgrade_ = false; 330 } 331 sendBadRequest()332 void sendBadRequest() { 333 std::ostringstream h; 334 h << "HTTP/1.1 400 Bad Request" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF << CRLF; 335 std::string header = h.str(); 336 transport_->write(reinterpret_cast<const uint8_t*>(header.data()), static_cast<uint32_t>(header.length())); 337 transport_->flush(); 338 transport_->close(); 339 } 340 341 void writeFrameHeader(Opcode opcode = Opcode::Continuation) { 342 uint32_t headerSize = 1; 343 uint32_t length = writeBuffer_.available_read(); 344 if (length < 126) { 345 ++headerSize; 346 } else if (length < 65536) { 347 headerSize += 3; 348 } else { 349 headerSize += 9; 350 } 351 // The server does not mask the response 352 353 uint8_t* header = static_cast<uint8_t*>(alloca(headerSize)); 354 if (opcode == Opcode::Continuation) { 355 opcode = binary ? Opcode::Binary : Opcode::Text; 356 } 357 header[0] = static_cast<uint8_t>(opcode) | 0x80; 358 if (length < 126) { 359 header[1] = static_cast<uint8_t>(length); 360 } else if (length < 65536) { 361 header[1] = 126; 362 *reinterpret_cast<uint16_t*>(header + 2) = htons(length); 363 } else { 364 header[1] = 127; 365 *reinterpret_cast<uint64_t*>(header + 2) = THRIFT_htonll(length); 366 } 367 368 transport_->write(header, headerSize); 369 } 370 371 // Add constant here to avoid a linker error on Windows 372 constexpr static const char* CRLF = "\r\n"; 373 std::string acceptKey_; 374 bool connection_; 375 bool secWebSocketKey_; 376 bool secWebSocketVersion_; 377 bool upgrade_; 378 }; 379 380 /** 381 * Wraps a transport into binary WebSocket protocol 382 */ 383 class TBinaryWebSocketServerTransportFactory : public TTransportFactory { 384 public: 385 TBinaryWebSocketServerTransportFactory() = default; 386 387 ~TBinaryWebSocketServerTransportFactory() override = default; 388 389 /** 390 * Wraps the transport into a buffered one. 391 */ getTransport(std::shared_ptr<TTransport> trans)392 std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override { 393 return std::shared_ptr<TTransport>(new TWebSocketServer<true>(trans)); 394 } 395 }; 396 397 /** 398 * Wraps a transport into text WebSocket protocol 399 */ 400 class TTextWebSocketServerTransportFactory : public TTransportFactory { 401 public: 402 TTextWebSocketServerTransportFactory() = default; 403 404 ~TTextWebSocketServerTransportFactory() override = default; 405 406 /** 407 * Wraps the transport into a buffered one. 408 */ getTransport(std::shared_ptr<TTransport> trans)409 std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override { 410 return std::shared_ptr<TTransport>(new TWebSocketServer<false>(trans)); 411 } 412 }; 413 } // namespace transport 414 } // namespace thrift 415 } // namespace apache 416 #endif 417