1# Author: Johan Hanssen Seferidis
2# License: MIT
3
4import sys
5import struct
6from base64 import b64encode
7from hashlib import sha1
8import logging
9from socket import error as SocketError
10import errno
11
12if sys.version_info[0] < 3:
13    from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler
14else:
15    from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler
16
17logger = logging.getLogger(__name__)
18logging.basicConfig()
19
20'''
21+-+-+-+-+-------+-+-------------+-------------------------------+
22 0                   1                   2                   3
23 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
24+-+-+-+-+-------+-+-------------+-------------------------------+
25|F|R|R|R| opcode|M| Payload len |    Extended payload length    |
26|I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
27|N|V|V|V|       |S|             |   (if payload len==126/127)   |
28| |1|2|3|       |K|             |                               |
29+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
30|     Extended payload length continued, if payload len == 127  |
31+ - - - - - - - - - - - - - - - +-------------------------------+
32|                     Payload Data continued ...                |
33+---------------------------------------------------------------+
34'''
35
36FIN    = 0x80
37OPCODE = 0x0f
38MASKED = 0x80
39PAYLOAD_LEN = 0x7f
40PAYLOAD_LEN_EXT16 = 0x7e
41PAYLOAD_LEN_EXT64 = 0x7f
42
43OPCODE_CONTINUATION = 0x0
44OPCODE_TEXT         = 0x1
45OPCODE_BINARY       = 0x2
46OPCODE_CLOSE_CONN   = 0x8
47OPCODE_PING         = 0x9
48OPCODE_PONG         = 0xA
49
50
51# -------------------------------- API ---------------------------------
52
53class API():
54
55    def run_forever(self):
56        try:
57            logger.info("Listening on port %d for clients.." % self.port)
58            self.serve_forever()
59        except KeyboardInterrupt:
60            self.server_close()
61            logger.info("Server terminated.")
62        except Exception as e:
63            logger.error(str(e), exc_info=True)
64            exit(1)
65
66    def new_client(self, client, server):
67        pass
68
69    def client_left(self, client, server):
70        pass
71
72    def message_received(self, client, server, message):
73        pass
74
75    def set_fn_new_client(self, fn):
76        self.new_client = fn
77
78    def set_fn_client_left(self, fn):
79        self.client_left = fn
80
81    def set_fn_message_received(self, fn):
82        self.message_received = fn
83
84    def send_message(self, client, msg):
85        self._unicast_(client, msg)
86
87    def send_message_to_all(self, msg):
88        self._multicast_(msg)
89
90
91# ------------------------- Implementation -----------------------------
92
93class WebsocketServer(ThreadingMixIn, TCPServer, API):
94    """
95	A websocket server waiting for clients to connect.
96
97    Args:
98        port(int): Port to bind to
99        host(str): Hostname or IP to listen for connections. By default 127.0.0.1
100            is being used. To accept connections from any client, you should use
101            0.0.0.0.
102        loglevel: Logging level from logging module to use for logging. By default
103            warnings and errors are being logged.
104
105    Properties:
106        clients(list): A list of connected clients. A client is a dictionary
107            like below.
108                {
109                 'id'      : id,
110                 'handler' : handler,
111                 'address' : (addr, port)
112                }
113    """
114
115    allow_reuse_address = True
116    daemon_threads = True  # comment to keep threads alive until finished
117
118    clients = []
119    id_counter = 0
120
121    def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING):
122        logger.setLevel(loglevel)
123        TCPServer.__init__(self, (host, port), WebSocketHandler)
124        self.port = self.socket.getsockname()[1]
125
126    def _message_received_(self, handler, msg):
127        self.message_received(self.handler_to_client(handler), self, msg)
128
129    def _ping_received_(self, handler, msg):
130        handler.send_pong(msg)
131
132    def _pong_received_(self, handler, msg):
133        pass
134
135    def _new_client_(self, handler):
136        self.id_counter += 1
137        client = {
138            'id': self.id_counter,
139            'handler': handler,
140            'address': handler.client_address
141        }
142        self.clients.append(client)
143        self.new_client(client, self)
144
145    def _client_left_(self, handler):
146        client = self.handler_to_client(handler)
147        self.client_left(client, self)
148        if client in self.clients:
149            self.clients.remove(client)
150
151    def _unicast_(self, to_client, msg):
152        to_client['handler'].send_message(msg)
153
154    def _multicast_(self, msg):
155        for client in self.clients:
156            self._unicast_(client, msg)
157
158    def handler_to_client(self, handler):
159        for client in self.clients:
160            if client['handler'] == handler:
161                return client
162
163
164class WebSocketHandler(StreamRequestHandler):
165
166    def __init__(self, socket, addr, server):
167        self.server = server
168        StreamRequestHandler.__init__(self, socket, addr, server)
169
170    def setup(self):
171        StreamRequestHandler.setup(self)
172        self.keep_alive = True
173        self.handshake_done = False
174        self.valid_client = False
175
176    def handle(self):
177        while self.keep_alive:
178            if not self.handshake_done:
179                self.handshake()
180            elif self.valid_client:
181                self.read_next_message()
182
183    def read_bytes(self, num):
184        # python3 gives ordinal of byte directly
185        bytes = self.rfile.read(num)
186        if sys.version_info[0] < 3:
187            return map(ord, bytes)
188        else:
189            return bytes
190
191    def read_next_message(self):
192        try:
193            b1, b2 = self.read_bytes(2)
194        except SocketError as e:  # to be replaced with ConnectionResetError for py3
195            if e.errno == errno.ECONNRESET:
196                logger.info("Client closed connection.")
197                self.keep_alive = 0
198                return
199            b1, b2 = 0, 0
200        except ValueError as e:
201            b1, b2 = 0, 0
202
203        fin    = b1 & FIN
204        opcode = b1 & OPCODE
205        masked = b2 & MASKED
206        payload_length = b2 & PAYLOAD_LEN
207
208        if opcode == OPCODE_CLOSE_CONN:
209            logger.info("Client asked to close connection.")
210            self.keep_alive = 0
211            return
212        if not masked:
213            logger.warn("Client must always be masked.")
214            self.keep_alive = 0
215            return
216        if opcode == OPCODE_CONTINUATION:
217            logger.warn("Continuation frames are not supported.")
218            return
219        elif opcode == OPCODE_BINARY:
220            logger.warn("Binary frames are not supported.")
221            return
222        elif opcode == OPCODE_TEXT:
223            opcode_handler = self.server._message_received_
224        elif opcode == OPCODE_PING:
225            opcode_handler = self.server._ping_received_
226        elif opcode == OPCODE_PONG:
227            opcode_handler = self.server._pong_received_
228        else:
229            logger.warn("Unknown opcode %#x." % opcode)
230            self.keep_alive = 0
231            return
232
233        if payload_length == 126:
234            payload_length = struct.unpack(">H", self.rfile.read(2))[0]
235        elif payload_length == 127:
236            payload_length = struct.unpack(">Q", self.rfile.read(8))[0]
237
238        masks = self.read_bytes(4)
239        message_bytes = bytearray()
240        for message_byte in self.read_bytes(payload_length):
241            message_byte ^= masks[len(message_bytes) % 4]
242            message_bytes.append(message_byte)
243        opcode_handler(self, message_bytes.decode('utf8'))
244
245    def send_message(self, message):
246        self.send_text(message)
247
248    def send_pong(self, message):
249        self.send_text(message, OPCODE_PONG)
250
251    def send_text(self, message, opcode=OPCODE_TEXT):
252        """
253        Important: Fragmented(=continuation) messages are not supported since
254        their usage cases are limited - when we don't know the payload length.
255        """
256
257        # Validate message
258        if isinstance(message, bytes):
259            message = try_decode_UTF8(message)  # this is slower but ensures we have UTF-8
260            if not message:
261                logger.warning("Can\'t send message, message is not valid UTF-8")
262                return False
263        elif sys.version_info < (3,0) and (isinstance(message, str) or isinstance(message, unicode)):
264            pass
265        elif isinstance(message, str):
266            pass
267        else:
268            logger.warning('Can\'t send message, message has to be a string or bytes. Given type is %s' % type(message))
269            return False
270
271        header  = bytearray()
272        payload = encode_to_UTF8(message)
273        payload_length = len(payload)
274
275        # Normal payload
276        if payload_length <= 125:
277            header.append(FIN | opcode)
278            header.append(payload_length)
279
280        # Extended payload
281        elif payload_length >= 126 and payload_length <= 65535:
282            header.append(FIN | opcode)
283            header.append(PAYLOAD_LEN_EXT16)
284            header.extend(struct.pack(">H", payload_length))
285
286        # Huge extended payload
287        elif payload_length < 18446744073709551616:
288            header.append(FIN | opcode)
289            header.append(PAYLOAD_LEN_EXT64)
290            header.extend(struct.pack(">Q", payload_length))
291
292        else:
293            raise Exception("Message is too big. Consider breaking it into chunks.")
294            return
295
296        self.request.send(header + payload)
297
298    def read_http_headers(self):
299        headers = {}
300        # first line should be HTTP GET
301        http_get = self.rfile.readline().decode().strip()
302        assert http_get.upper().startswith('GET')
303        # remaining should be headers
304        while True:
305            header = self.rfile.readline().decode().strip()
306            if not header:
307                break
308            head, value = header.split(':', 1)
309            headers[head.lower().strip()] = value.strip()
310        return headers
311
312    def handshake(self):
313        headers = self.read_http_headers()
314
315        try:
316            assert headers['upgrade'].lower() == 'websocket'
317        except AssertionError:
318            self.keep_alive = False
319            return
320
321        try:
322            key = headers['sec-websocket-key']
323        except KeyError:
324            logger.warning("Client tried to connect but was missing a key")
325            self.keep_alive = False
326            return
327
328        response = self.make_handshake_response(key)
329        self.handshake_done = self.request.send(response.encode())
330        self.valid_client = True
331        self.server._new_client_(self)
332
333    @classmethod
334    def make_handshake_response(cls, key):
335        return \
336          'HTTP/1.1 101 Switching Protocols\r\n'\
337          'Upgrade: websocket\r\n'              \
338          'Connection: Upgrade\r\n'             \
339          'Sec-WebSocket-Accept: %s\r\n'        \
340          '\r\n' % cls.calculate_response_key(key)
341
342    @classmethod
343    def calculate_response_key(cls, key):
344        GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
345        hash = sha1(key.encode() + GUID.encode())
346        response_key = b64encode(hash.digest()).strip()
347        return response_key.decode('ASCII')
348
349    def finish(self):
350        self.server._client_left_(self)
351
352
353def encode_to_UTF8(data):
354    try:
355        return data.encode('UTF-8')
356    except UnicodeEncodeError as e:
357        logger.error("Could not encode data to UTF-8 -- %s" % e)
358        return False
359    except Exception as e:
360        raise(e)
361        return False
362
363
364def try_decode_UTF8(data):
365    try:
366        return data.decode('utf-8')
367    except UnicodeDecodeError:
368        return False
369    except Exception as e:
370        raise(e)
371