1# Copyright 2018 Espressif Systems (Shanghai) PTE LTD
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15
16# APIs for interpreting and creating protobuf packets for
17# protocomm endpoint with security type protocomm_security1
18
19from __future__ import print_function
20
21import proto
22import session_pb2
23import utils
24from cryptography.hazmat.backends import default_backend
25from cryptography.hazmat.primitives import hashes, serialization
26from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
27from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
28from future.utils import tobytes
29
30from .security import Security
31
32
33# Enum for state of protocomm_security1 FSM
34class security_state:
35    REQUEST1 = 0
36    RESPONSE1_REQUEST2 = 1
37    RESPONSE2 = 2
38    FINISHED = 3
39
40
41def xor(a, b):
42    # XOR two inputs of type `bytes`
43    ret = bytearray()
44    # Decode the input bytes to strings
45    a = a.decode('latin-1')
46    b = b.decode('latin-1')
47    for i in range(max(len(a), len(b))):
48        # Convert the characters to corresponding 8-bit ASCII codes
49        # then XOR them and store in bytearray
50        ret.append(([0, ord(a[i])][i < len(a)]) ^ ([0, ord(b[i])][i < len(b)]))
51    # Convert bytearray to bytes
52    return bytes(ret)
53
54
55class Security1(Security):
56    def __init__(self, pop, verbose):
57        # Initialize state of the security1 FSM
58        self.session_state = security_state.REQUEST1
59        self.pop = tobytes(pop)
60        self.verbose = verbose
61        Security.__init__(self, self.security1_session)
62
63    def security1_session(self, response_data):
64        # protocomm security1 FSM which interprets/forms
65        # protobuf packets according to present state of session
66        if (self.session_state == security_state.REQUEST1):
67            self.session_state = security_state.RESPONSE1_REQUEST2
68            return self.setup0_request()
69        if (self.session_state == security_state.RESPONSE1_REQUEST2):
70            self.session_state = security_state.RESPONSE2
71            self.setup0_response(response_data)
72            return self.setup1_request()
73        if (self.session_state == security_state.RESPONSE2):
74            self.session_state = security_state.FINISHED
75            self.setup1_response(response_data)
76            return None
77        else:
78            print('Unexpected state')
79            return None
80
81    def __generate_key(self):
82        # Generate private and public key pair for client
83        self.client_private_key = X25519PrivateKey.generate()
84        try:
85            self.client_public_key = self.client_private_key.public_key().public_bytes(
86                encoding=serialization.Encoding.Raw,
87                format=serialization.PublicFormat.Raw)
88        except TypeError:
89            # backward compatible call for older cryptography library
90            self.client_public_key  = self.client_private_key.public_key().public_bytes()
91
92    def _print_verbose(self, data):
93        if (self.verbose):
94            print('++++ ' + data + ' ++++')
95
96    def setup0_request(self):
97        # Form SessionCmd0 request packet using client public key
98        setup_req = session_pb2.SessionData()
99        setup_req.sec_ver = session_pb2.SecScheme1
100        self.__generate_key()
101        setup_req.sec1.sc0.client_pubkey = self.client_public_key
102        self._print_verbose('Client Public Key:\t' + utils.str_to_hexstr(self.client_public_key.decode('latin-1')))
103        return setup_req.SerializeToString().decode('latin-1')
104
105    def setup0_response(self, response_data):
106        # Interpret SessionResp0 response packet
107        setup_resp = proto.session_pb2.SessionData()
108        setup_resp.ParseFromString(tobytes(response_data))
109        self._print_verbose('Security version:\t' + str(setup_resp.sec_ver))
110        if setup_resp.sec_ver != session_pb2.SecScheme1:
111            print('Incorrect sec scheme')
112            exit(1)
113        self.device_public_key = setup_resp.sec1.sr0.device_pubkey
114        # Device random is the initialization vector
115        device_random = setup_resp.sec1.sr0.device_random
116        self._print_verbose('Device Public Key:\t' + utils.str_to_hexstr(self.device_public_key.decode('latin-1')))
117        self._print_verbose('Device Random:\t' + utils.str_to_hexstr(device_random.decode('latin-1')))
118
119        # Calculate Curve25519 shared key using Client private key and Device public key
120        sharedK = self.client_private_key.exchange(X25519PublicKey.from_public_bytes(self.device_public_key))
121        self._print_verbose('Shared Key:\t' + utils.str_to_hexstr(sharedK.decode('latin-1')))
122
123        # If PoP is provided, XOR SHA256 of PoP with the previously
124        # calculated Shared Key to form the actual Shared Key
125        if len(self.pop) > 0:
126            # Calculate SHA256 of PoP
127            h = hashes.Hash(hashes.SHA256(), backend=default_backend())
128            h.update(self.pop)
129            digest = h.finalize()
130            # XOR with and update Shared Key
131            sharedK = xor(sharedK, digest)
132            self._print_verbose('New Shared Key XORed with PoP:\t' + utils.str_to_hexstr(sharedK.decode('latin-1')))
133        # Initialize the encryption engine with Shared Key and initialization vector
134        cipher = Cipher(algorithms.AES(sharedK), modes.CTR(device_random), backend=default_backend())
135        self.cipher = cipher.encryptor()
136
137    def setup1_request(self):
138        # Form SessionCmd1 request packet using encrypted device public key
139        setup_req = proto.session_pb2.SessionData()
140        setup_req.sec_ver = session_pb2.SecScheme1
141        setup_req.sec1.msg = proto.sec1_pb2.Session_Command1
142        # Encrypt device public key and attach to the request packet
143        client_verify = self.cipher.update(self.device_public_key)
144        self._print_verbose('Client Verify:\t' + utils.str_to_hexstr(client_verify.decode('latin-1')))
145        setup_req.sec1.sc1.client_verify_data = client_verify
146        return setup_req.SerializeToString().decode('latin-1')
147
148    def setup1_response(self, response_data):
149        # Interpret SessionResp1 response packet
150        setup_resp = proto.session_pb2.SessionData()
151        setup_resp.ParseFromString(tobytes(response_data))
152        # Ensure security scheme matches
153        if setup_resp.sec_ver == session_pb2.SecScheme1:
154            # Read encrypyed device verify string
155            device_verify = setup_resp.sec1.sr1.device_verify_data
156            self._print_verbose('Device verify:\t' + utils.str_to_hexstr(device_verify.decode('latin-1')))
157            # Decrypt the device verify string
158            enc_client_pubkey = self.cipher.update(setup_resp.sec1.sr1.device_verify_data)
159            self._print_verbose('Enc client pubkey:\t ' + utils.str_to_hexstr(enc_client_pubkey.decode('latin-1')))
160            # Match decryped string with client public key
161            if enc_client_pubkey != self.client_public_key:
162                print('Mismatch in device verify')
163                return -2
164        else:
165            print('Unsupported security protocol')
166            return -1
167
168    def encrypt_data(self, data):
169        return self.cipher.update(tobytes(data))
170
171    def decrypt_data(self, data):
172        return self.cipher.update(tobytes(data))
173