1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_
21 #define _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_ 1
22 
23 #include <cstdlib>
24 #include <iostream>
25 #include <sstream>
26 
27 #include <openssl/sha.h>
28 
29 #include <thrift/config.h>
30 #include <thrift/protocol/TProtocol.h>
31 #include <thrift/transport/TSocket.h>
32 #include <thrift/transport/THttpServer.h>
33 #if defined(_MSC_VER) || defined(__MINGW32__)
34 #include <Shlwapi.h>
35 #define THRIFT_strncasecmp(str1, str2, len) _strnicmp(str1, str2, len)
36 #define THRIFT_strcasestr(haystack, needle) StrStrIA(haystack, needle)
37 #else
38 #define THRIFT_strncasecmp(str1, str2, len) strncasecmp(str1, str2, len)
39 #define THRIFT_strcasestr(haystack, needle) strcasestr(haystack, needle)
40 #endif
41 #if defined(__CYGWIN__)
42 #include <alloca.h>
43 #endif
44 
45 using std::string;
46 
47 namespace apache {
48 namespace thrift {
49 namespace transport {
50 
51 std::string base64Encode(unsigned char* data, int length);
52 
53 template <bool binary>
54 class TWebSocketServer : public THttpServer {
55 public:
56   TWebSocketServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
THttpServer(transport,config)57     : THttpServer(transport, config) {
58       resetHandshake();
59   }
60 
61   ~TWebSocketServer() override = default;
62 
readAll_virt(uint8_t * buf,uint32_t len)63   uint32_t readAll_virt(uint8_t* buf, uint32_t len) override {
64     // If we do not have a good handshake, the client will attempt one.
65     if (!handshakeComplete()) {
66       resetHandshake();
67       THttpServer::read(buf, len);
68       // If we did not get everything we expected, the handshake failed
69       // and we need to send a 400 response back.
70       if (!handshakeComplete()) {
71         sendBadRequest();
72         return 0;
73       }
74       // Otherwise, send back the 101 response.
75       THttpServer::flush();
76     }
77 
78     uint32_t want = len;
79     auto have = readBuffer_.available_read();
80 
81     // If we have some data in the buffer, copy it out and return it.
82     // We have to return it without attempting to read more, since we aren't
83     // guaranteed that the underlying transport actually has more data, so
84     // attempting to read from it could block.
85     if (have > 0 && have >= want) {
86       return readBuffer_.read(buf, want);
87     }
88 
89     // Read another frame.
90     if (!readFrame()) {
91       // EOF.  No frame available.
92       return 0;
93     }
94 
95     // Hand over whatever we have.
96     uint32_t give = (std::min)(want, readBuffer_.available_read());
97     return readBuffer_.read(buf, give);
98   }
99 
flush()100   void flush() override {
101     resetConsumedMessageSize();
102     writeFrameHeader();
103     uint8_t* buffer;
104     uint32_t length;
105     writeBuffer_.getBuffer(&buffer, &length);
106     transport_->write(buffer, length);
107     transport_->flush();
108     writeBuffer_.resetBuffer();
109   }
110 
111 protected:
getHeader(uint32_t len)112   std::string getHeader(uint32_t len) override {
113     THRIFT_UNUSED_VARIABLE(len);
114     std::ostringstream h;
115     h << "HTTP/1.1 101 Switching Protocols" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF
116       << "Upgrade: websocket" << CRLF << "Connection: Upgrade" << CRLF
117       << "Sec-WebSocket-Accept: " << acceptKey_ << CRLF << CRLF;
118     return h.str();
119   }
120 
parseHeader(char * header)121   void parseHeader(char* header) override {
122     char* colon = strchr(header, ':');
123     if (colon == nullptr) {
124       return;
125     }
126     size_t sz = colon - header;
127     char* value = colon + 1;
128 
129     if (THRIFT_strncasecmp(header, "Upgrade", sz) == 0) {
130       if (THRIFT_strcasestr(value, "websocket") != nullptr) {
131         upgrade_ = true;
132       }
133     } else if (THRIFT_strncasecmp(header, "Connection", sz) == 0) {
134       if (THRIFT_strcasestr(value, "Upgrade") != nullptr) {
135         connection_ = true;
136       }
137     } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Key", sz) == 0) {
138       std::string toHash = value + 1;
139       toHash += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
140       unsigned char hash[20];
141       SHA1((const unsigned char*)toHash.c_str(), toHash.length(), hash);
142       acceptKey_ = base64Encode(hash, 20);
143       secWebSocketKey_ = true;
144     } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Version", sz) == 0) {
145       if (THRIFT_strcasestr(value, "13") != nullptr) {
146         secWebSocketVersion_ = true;
147       }
148     }
149   }
150 
parseStatusLine(char * status)151   bool parseStatusLine(char* status) override {
152     char* method = status;
153 
154     char* path = strchr(method, ' ');
155     if (path == nullptr) {
156       throw TTransportException(string("Bad Status: ") + status);
157     }
158 
159     *path = '\0';
160     while (*(++path) == ' ') {
161     };
162 
163     char* http = strchr(path, ' ');
164     if (http == nullptr) {
165       throw TTransportException(string("Bad Status: ") + status);
166     }
167     *http = '\0';
168 
169     if (strcmp(method, "GET") == 0) {
170       // GET method ok, looking for content.
171       return true;
172     }
173     throw TTransportException(string("Bad Status (unsupported method): ") + status);
174   }
175 
176 private:
177   enum class CloseCode : uint16_t {
178     NormalClosure = 1000,
179     GoingAway = 1001,
180     ProtocolError = 1002,
181     UnsupportedDataType = 1003,
182     NoStatusCode = 1005,
183     AbnormalClosure = 1006,
184     InvalidData = 1007,
185     PolicyViolation = 1008,
186     MessageTooBig = 1009,
187     ExtensionExpected = 1010,
188     UnexpectedError = 1011,
189     NotSecure = 1015
190   };
191 
192   enum class Opcode : uint8_t {
193     Continuation = 0x0,
194     Text = 0x1,
195     Binary = 0x2,
196     Close = 0x8,
197     Ping = 0x9,
198     Pong = 0xA
199   };
200 
failConnection(CloseCode reason)201   void failConnection(CloseCode reason) {
202     writeFrameHeader(Opcode::Close);
203     auto buffer = htons(static_cast<uint16_t>(reason));
204     transport_->write(reinterpret_cast<const uint8_t*>(&buffer), 2);
205     transport_->flush();
206     transport_->close();
207   }
208 
handshakeComplete()209   bool handshakeComplete() {
210     return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_;
211   }
212 
pong()213   void pong() {
214     writeFrameHeader(Opcode::Pong);
215     uint8_t* buffer;
216     uint32_t size;
217     readBuffer_.getBuffer(&buffer, &size);
218     transport_->write(buffer, size);
219     transport_->flush();
220   }
221 
readFrame()222   bool readFrame() {
223     uint8_t headerBuffer[8];
224 
225     auto read = transport_->read(headerBuffer, 2);
226     if (read < 2) {
227       return false;
228     }
229     // Since Thrift has its own message end marker and we read frame by frame,
230     // it doesn't really matter if the frame is marked as FIN.
231     // Capture it only for debugging only.
232     auto fin = (headerBuffer[0] & 0x80) != 0;
233     THRIFT_UNUSED_VARIABLE(fin);
234 
235     // RSV1, RSV2, RSV3
236     if ((headerBuffer[0] & 0x70) != 0) {
237       failConnection(CloseCode::ProtocolError);
238       throw TTransportException(TTransportException::CORRUPTED_DATA,
239                                 "Reserved bits must be zeroes");
240     }
241 
242     auto opcode = (Opcode)(headerBuffer[0] & 0x0F);
243 
244     // Mask
245     if ((headerBuffer[1] & 0x80) == 0) {
246       failConnection(CloseCode::ProtocolError);
247       throw TTransportException(TTransportException::CORRUPTED_DATA,
248                                 "Messages from the client must be masked");
249     }
250 
251     // Read the length
252     uint64_t payloadLength = headerBuffer[1] & 0x7F;
253     if (payloadLength == 126) {
254       read = transport_->read(headerBuffer, 2);
255       if (read < 2) {
256         return false;
257       }
258       payloadLength = ntohs(*reinterpret_cast<uint16_t*>(headerBuffer));
259     } else if (payloadLength == 127) {
260       read = transport_->read(headerBuffer, 8);
261       if (read < 8) {
262         return false;
263       }
264       payloadLength = THRIFT_ntohll(*reinterpret_cast<uint64_t*>(headerBuffer));
265       if ((payloadLength & 0x8000000000000000) != 0) {
266         failConnection(CloseCode::ProtocolError);
267         throw TTransportException(
268             TTransportException::CORRUPTED_DATA,
269             "The most significant bit of the payload length must be zero");
270       }
271     }
272 
273     // size_t is smaller than a ulong on a 32-bit system
274     if (payloadLength > UINT32_MAX) {
275       failConnection(CloseCode::MessageTooBig);
276       return false;
277     }
278 
279     auto length = static_cast<uint32_t>(payloadLength);
280 
281     if (length > 0) {
282       // Read the masking key
283       read = transport_->read(headerBuffer, 4);
284       if (read < 4) {
285         return false;
286       }
287 
288       readBuffer_.resetBuffer(length);
289       uint8_t* buffer = readBuffer_.getWritePtr(length);
290       read = transport_->read(buffer, length);
291       readBuffer_.wroteBytes(read);
292       if (read < length) {
293         return false;
294       }
295 
296       // Unmask the data
297       for (size_t i = 0; i < length; i++) {
298         buffer[i] ^= headerBuffer[i % 4];
299       }
300 
301       T_DEBUG("FIN=%d, Opcode=%X, length=%d, payload=%s", fin, opcode, length,
302               binary ? readBuffer_.toHexString() : cast(string) readBuffer_);
303     }
304 
305     switch (opcode) {
306     case Opcode::Close:
307       if (length >= 2) {
308         uint8_t buffer[2];
309         readBuffer_.read(buffer, 2);
310         CloseCode closeCode = static_cast<CloseCode>(ntohs(*reinterpret_cast<uint16_t*>(buffer)));
311         THRIFT_UNUSED_VARIABLE(closeCode);
312         string closeReason = readBuffer_.readAsString(length - 2);
313         T_DEBUG("Connection closed: %d %s", closeCode, closeReason);
314       }
315       transport_->close();
316       return false;
317     case Opcode::Ping:
318       pong();
319       return readFrame();
320     default:
321       return true;
322     }
323   }
324 
resetHandshake()325   void resetHandshake() {
326     connection_ = false;
327     secWebSocketKey_ = false;
328     secWebSocketVersion_ = false;
329     upgrade_ = false;
330   }
331 
sendBadRequest()332   void sendBadRequest() {
333     std::ostringstream h;
334     h << "HTTP/1.1 400 Bad Request" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF << CRLF;
335     std::string header = h.str();
336     transport_->write(reinterpret_cast<const uint8_t*>(header.data()), static_cast<uint32_t>(header.length()));
337     transport_->flush();
338     transport_->close();
339   }
340 
341   void writeFrameHeader(Opcode opcode = Opcode::Continuation) {
342     uint32_t headerSize = 1;
343     uint32_t length = writeBuffer_.available_read();
344     if (length < 126) {
345       ++headerSize;
346     } else if (length < 65536) {
347       headerSize += 3;
348     } else {
349       headerSize += 9;
350     }
351     // The server does not mask the response
352 
353     uint8_t* header = static_cast<uint8_t*>(alloca(headerSize));
354     if (opcode == Opcode::Continuation) {
355       opcode = binary ? Opcode::Binary : Opcode::Text;
356     }
357     header[0] = static_cast<uint8_t>(opcode) | 0x80;
358     if (length < 126) {
359       header[1] = static_cast<uint8_t>(length);
360     } else if (length < 65536) {
361       header[1] = 126;
362       *reinterpret_cast<uint16_t*>(header + 2) = htons(length);
363     } else {
364       header[1] = 127;
365       *reinterpret_cast<uint64_t*>(header + 2) = THRIFT_htonll(length);
366     }
367 
368     transport_->write(header, headerSize);
369   }
370 
371   // Add constant here to avoid a linker error on Windows
372   constexpr static const char* CRLF = "\r\n";
373   std::string acceptKey_;
374   bool connection_;
375   bool secWebSocketKey_;
376   bool secWebSocketVersion_;
377   bool upgrade_;
378 };
379 
380 /**
381  * Wraps a transport into binary WebSocket protocol
382  */
383 class TBinaryWebSocketServerTransportFactory : public TTransportFactory {
384 public:
385   TBinaryWebSocketServerTransportFactory() = default;
386 
387   ~TBinaryWebSocketServerTransportFactory() override = default;
388 
389   /**
390    * Wraps the transport into a buffered one.
391    */
getTransport(std::shared_ptr<TTransport> trans)392   std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override {
393     return std::shared_ptr<TTransport>(new TWebSocketServer<true>(trans));
394   }
395 };
396 
397 /**
398  * Wraps a transport into text WebSocket protocol
399  */
400 class TTextWebSocketServerTransportFactory : public TTransportFactory {
401 public:
402   TTextWebSocketServerTransportFactory() = default;
403 
404   ~TTextWebSocketServerTransportFactory() override = default;
405 
406   /**
407    * Wraps the transport into a buffered one.
408    */
getTransport(std::shared_ptr<TTransport> trans)409   std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override {
410     return std::shared_ptr<TTransport>(new TWebSocketServer<false>(trans));
411   }
412 };
413 } // namespace transport
414 } // namespace thrift
415 } // namespace apache
416 #endif
417