1# Author: Johan Hanssen Seferidis 2# License: MIT 3 4import sys 5import struct 6from base64 import b64encode 7from hashlib import sha1 8import logging 9from socket import error as SocketError 10import errno 11 12if sys.version_info[0] < 3: 13 from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler 14else: 15 from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler 16 17logger = logging.getLogger(__name__) 18logging.basicConfig() 19 20''' 21+-+-+-+-+-------+-+-------------+-------------------------------+ 22 0 1 2 3 23 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 24+-+-+-+-+-------+-+-------------+-------------------------------+ 25|F|R|R|R| opcode|M| Payload len | Extended payload length | 26|I|S|S|S| (4) |A| (7) | (16/64) | 27|N|V|V|V| |S| | (if payload len==126/127) | 28| |1|2|3| |K| | | 29+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + 30| Extended payload length continued, if payload len == 127 | 31+ - - - - - - - - - - - - - - - +-------------------------------+ 32| Payload Data continued ... | 33+---------------------------------------------------------------+ 34''' 35 36FIN = 0x80 37OPCODE = 0x0f 38MASKED = 0x80 39PAYLOAD_LEN = 0x7f 40PAYLOAD_LEN_EXT16 = 0x7e 41PAYLOAD_LEN_EXT64 = 0x7f 42 43OPCODE_CONTINUATION = 0x0 44OPCODE_TEXT = 0x1 45OPCODE_BINARY = 0x2 46OPCODE_CLOSE_CONN = 0x8 47OPCODE_PING = 0x9 48OPCODE_PONG = 0xA 49 50 51# -------------------------------- API --------------------------------- 52 53class API(): 54 55 def run_forever(self): 56 try: 57 logger.info("Listening on port %d for clients.." % self.port) 58 self.serve_forever() 59 except KeyboardInterrupt: 60 self.server_close() 61 logger.info("Server terminated.") 62 except Exception as e: 63 logger.error(str(e), exc_info=True) 64 exit(1) 65 66 def new_client(self, client, server): 67 pass 68 69 def client_left(self, client, server): 70 pass 71 72 def message_received(self, client, server, message): 73 pass 74 75 def set_fn_new_client(self, fn): 76 self.new_client = fn 77 78 def set_fn_client_left(self, fn): 79 self.client_left = fn 80 81 def set_fn_message_received(self, fn): 82 self.message_received = fn 83 84 def send_message(self, client, msg): 85 self._unicast_(client, msg) 86 87 def send_message_to_all(self, msg): 88 self._multicast_(msg) 89 90 91# ------------------------- Implementation ----------------------------- 92 93class WebsocketServer(ThreadingMixIn, TCPServer, API): 94 """ 95 A websocket server waiting for clients to connect. 96 97 Args: 98 port(int): Port to bind to 99 host(str): Hostname or IP to listen for connections. By default 127.0.0.1 100 is being used. To accept connections from any client, you should use 101 0.0.0.0. 102 loglevel: Logging level from logging module to use for logging. By default 103 warnings and errors are being logged. 104 105 Properties: 106 clients(list): A list of connected clients. A client is a dictionary 107 like below. 108 { 109 'id' : id, 110 'handler' : handler, 111 'address' : (addr, port) 112 } 113 """ 114 115 allow_reuse_address = True 116 daemon_threads = True # comment to keep threads alive until finished 117 118 clients = [] 119 id_counter = 0 120 121 def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING): 122 logger.setLevel(loglevel) 123 TCPServer.__init__(self, (host, port), WebSocketHandler) 124 self.port = self.socket.getsockname()[1] 125 126 def _message_received_(self, handler, msg): 127 self.message_received(self.handler_to_client(handler), self, msg) 128 129 def _ping_received_(self, handler, msg): 130 handler.send_pong(msg) 131 132 def _pong_received_(self, handler, msg): 133 pass 134 135 def _new_client_(self, handler): 136 self.id_counter += 1 137 client = { 138 'id': self.id_counter, 139 'handler': handler, 140 'address': handler.client_address 141 } 142 self.clients.append(client) 143 self.new_client(client, self) 144 145 def _client_left_(self, handler): 146 client = self.handler_to_client(handler) 147 self.client_left(client, self) 148 if client in self.clients: 149 self.clients.remove(client) 150 151 def _unicast_(self, to_client, msg): 152 to_client['handler'].send_message(msg) 153 154 def _multicast_(self, msg): 155 for client in self.clients: 156 self._unicast_(client, msg) 157 158 def handler_to_client(self, handler): 159 for client in self.clients: 160 if client['handler'] == handler: 161 return client 162 163 164class WebSocketHandler(StreamRequestHandler): 165 166 def __init__(self, socket, addr, server): 167 self.server = server 168 StreamRequestHandler.__init__(self, socket, addr, server) 169 170 def setup(self): 171 StreamRequestHandler.setup(self) 172 self.keep_alive = True 173 self.handshake_done = False 174 self.valid_client = False 175 176 def handle(self): 177 while self.keep_alive: 178 if not self.handshake_done: 179 self.handshake() 180 elif self.valid_client: 181 self.read_next_message() 182 183 def read_bytes(self, num): 184 # python3 gives ordinal of byte directly 185 bytes = self.rfile.read(num) 186 if sys.version_info[0] < 3: 187 return map(ord, bytes) 188 else: 189 return bytes 190 191 def read_next_message(self): 192 try: 193 b1, b2 = self.read_bytes(2) 194 except SocketError as e: # to be replaced with ConnectionResetError for py3 195 if e.errno == errno.ECONNRESET: 196 logger.info("Client closed connection.") 197 self.keep_alive = 0 198 return 199 b1, b2 = 0, 0 200 except ValueError as e: 201 b1, b2 = 0, 0 202 203 fin = b1 & FIN 204 opcode = b1 & OPCODE 205 masked = b2 & MASKED 206 payload_length = b2 & PAYLOAD_LEN 207 208 if opcode == OPCODE_CLOSE_CONN: 209 logger.info("Client asked to close connection.") 210 self.keep_alive = 0 211 return 212 if not masked: 213 logger.warn("Client must always be masked.") 214 self.keep_alive = 0 215 return 216 if opcode == OPCODE_CONTINUATION: 217 logger.warn("Continuation frames are not supported.") 218 return 219 elif opcode == OPCODE_BINARY: 220 logger.warn("Binary frames are not supported.") 221 return 222 elif opcode == OPCODE_TEXT: 223 opcode_handler = self.server._message_received_ 224 elif opcode == OPCODE_PING: 225 opcode_handler = self.server._ping_received_ 226 elif opcode == OPCODE_PONG: 227 opcode_handler = self.server._pong_received_ 228 else: 229 logger.warn("Unknown opcode %#x." % opcode) 230 self.keep_alive = 0 231 return 232 233 if payload_length == 126: 234 payload_length = struct.unpack(">H", self.rfile.read(2))[0] 235 elif payload_length == 127: 236 payload_length = struct.unpack(">Q", self.rfile.read(8))[0] 237 238 masks = self.read_bytes(4) 239 message_bytes = bytearray() 240 for message_byte in self.read_bytes(payload_length): 241 message_byte ^= masks[len(message_bytes) % 4] 242 message_bytes.append(message_byte) 243 opcode_handler(self, message_bytes.decode('utf8')) 244 245 def send_message(self, message): 246 self.send_text(message) 247 248 def send_pong(self, message): 249 self.send_text(message, OPCODE_PONG) 250 251 def send_text(self, message, opcode=OPCODE_TEXT): 252 """ 253 Important: Fragmented(=continuation) messages are not supported since 254 their usage cases are limited - when we don't know the payload length. 255 """ 256 257 # Validate message 258 if isinstance(message, bytes): 259 message = try_decode_UTF8(message) # this is slower but ensures we have UTF-8 260 if not message: 261 logger.warning("Can\'t send message, message is not valid UTF-8") 262 return False 263 elif sys.version_info < (3,0) and (isinstance(message, str) or isinstance(message, unicode)): 264 pass 265 elif isinstance(message, str): 266 pass 267 else: 268 logger.warning('Can\'t send message, message has to be a string or bytes. Given type is %s' % type(message)) 269 return False 270 271 header = bytearray() 272 payload = encode_to_UTF8(message) 273 payload_length = len(payload) 274 275 # Normal payload 276 if payload_length <= 125: 277 header.append(FIN | opcode) 278 header.append(payload_length) 279 280 # Extended payload 281 elif payload_length >= 126 and payload_length <= 65535: 282 header.append(FIN | opcode) 283 header.append(PAYLOAD_LEN_EXT16) 284 header.extend(struct.pack(">H", payload_length)) 285 286 # Huge extended payload 287 elif payload_length < 18446744073709551616: 288 header.append(FIN | opcode) 289 header.append(PAYLOAD_LEN_EXT64) 290 header.extend(struct.pack(">Q", payload_length)) 291 292 else: 293 raise Exception("Message is too big. Consider breaking it into chunks.") 294 return 295 296 self.request.send(header + payload) 297 298 def read_http_headers(self): 299 headers = {} 300 # first line should be HTTP GET 301 http_get = self.rfile.readline().decode().strip() 302 assert http_get.upper().startswith('GET') 303 # remaining should be headers 304 while True: 305 header = self.rfile.readline().decode().strip() 306 if not header: 307 break 308 head, value = header.split(':', 1) 309 headers[head.lower().strip()] = value.strip() 310 return headers 311 312 def handshake(self): 313 headers = self.read_http_headers() 314 315 try: 316 assert headers['upgrade'].lower() == 'websocket' 317 except AssertionError: 318 self.keep_alive = False 319 return 320 321 try: 322 key = headers['sec-websocket-key'] 323 except KeyError: 324 logger.warning("Client tried to connect but was missing a key") 325 self.keep_alive = False 326 return 327 328 response = self.make_handshake_response(key) 329 self.handshake_done = self.request.send(response.encode()) 330 self.valid_client = True 331 self.server._new_client_(self) 332 333 @classmethod 334 def make_handshake_response(cls, key): 335 return \ 336 'HTTP/1.1 101 Switching Protocols\r\n'\ 337 'Upgrade: websocket\r\n' \ 338 'Connection: Upgrade\r\n' \ 339 'Sec-WebSocket-Accept: %s\r\n' \ 340 '\r\n' % cls.calculate_response_key(key) 341 342 @classmethod 343 def calculate_response_key(cls, key): 344 GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' 345 hash = sha1(key.encode() + GUID.encode()) 346 response_key = b64encode(hash.digest()).strip() 347 return response_key.decode('ASCII') 348 349 def finish(self): 350 self.server._client_left_(self) 351 352 353def encode_to_UTF8(data): 354 try: 355 return data.encode('UTF-8') 356 except UnicodeEncodeError as e: 357 logger.error("Could not encode data to UTF-8 -- %s" % e) 358 return False 359 except Exception as e: 360 raise(e) 361 return False 362 363 364def try_decode_UTF8(data): 365 try: 366 return data.decode('utf-8') 367 except UnicodeDecodeError: 368 return False 369 except Exception as e: 370 raise(e) 371