1 /*
2  * Copyright (c) 2021-2022, Arm Limited. All rights reserved.
3  *
4  * SPDX-License-Identifier: BSD-3-Clause
5  *
6  */
7 
8 #include "crypto.h"
9 
10 #include <stdint.h>
11 #include <string.h>
12 
13 #include "region_defs.h"
14 #include "dx_reg_base_host.h"
15 #include "otp.h"
16 #include "fih.h"
17 #include "cc3xx_aes.h"
18 #include "cc3xx_hash.h"
19 #include "cmsis_compiler.h"
20 
21 #define KEY_DERIVATION_MAX_BUF_SIZE 128
22 
bl1_sha256_init(void)23 fih_int bl1_sha256_init(void)
24 {
25     fih_int fih_rc = FIH_FAILURE;
26 
27     fih_rc = fih_int_encode_zero_equality(cc3xx_hash_sha256_init());
28     if(fih_not_eq(fih_rc, FIH_SUCCESS)) {
29         FIH_RET(FIH_FAILURE);
30     }
31 
32     return FIH_SUCCESS;
33 }
34 
bl1_sha256_finish(uint8_t * hash)35 fih_int bl1_sha256_finish(uint8_t *hash)
36 {
37     fih_int fih_rc = FIH_FAILURE;
38 
39     fih_rc = fih_int_encode_zero_equality(cc3xx_hash_sha256_finish(hash, 32));
40     if(fih_not_eq(fih_rc, FIH_SUCCESS)) {
41         FIH_RET(FIH_FAILURE);
42     }
43 
44     return FIH_SUCCESS;
45 }
46 
bl1_sha256_update(uint8_t * data,size_t data_length)47 fih_int bl1_sha256_update(uint8_t *data, size_t data_length)
48 {
49     size_t idx;
50     fih_int fih_rc = FIH_FAILURE;
51 
52     for (idx = 0; idx + 0x8000 < data_length; idx += 0x8000) {
53         fih_rc = FIH_FAILURE;
54         fih_rc = fih_int_encode_zero_equality(cc3xx_hash_sha256_update(data + idx,
55                                                                        0x8000));
56         if(fih_not_eq(fih_rc, FIH_SUCCESS)) {
57             FIH_RET(FIH_FAILURE);
58         }
59     }
60     if (idx != (data_length - (data_length % 0x8000))) {
61         FIH_RET(FIH_FAILURE);
62     }
63 
64     fih_rc = fih_int_encode_zero_equality(cc3xx_hash_sha256_update(data + idx,
65                                                                    data_length - idx));
66     if(fih_not_eq(fih_rc, FIH_SUCCESS)) {
67         FIH_RET(FIH_FAILURE);
68     }
69 
70     return FIH_SUCCESS;
71 }
72 
bl1_sha256_compute(const uint8_t * data,size_t data_length,uint8_t * hash)73 fih_int bl1_sha256_compute(const uint8_t *data,
74                            size_t data_length,
75                            uint8_t *hash)
76 {
77     fih_int fih_rc = FIH_FAILURE;
78     size_t idx = 0;
79 
80     if (data == NULL || hash == NULL) {
81         FIH_RET(FIH_FAILURE);
82     }
83 
84     fih_rc = fih_int_encode_zero_equality(cc3xx_hash_sha256_init());
85     if(fih_not_eq(fih_rc, FIH_SUCCESS)) {
86         FIH_RET(FIH_FAILURE);
87     }
88 
89     for (idx = 0; idx + 0x8000 < data_length; idx += 0x8000) {
90         fih_rc = FIH_FAILURE;
91         fih_rc = fih_int_encode_zero_equality(cc3xx_hash_sha256_update(data + idx,
92                                                                        0x8000));
93         if(fih_not_eq(fih_rc, FIH_SUCCESS)) {
94             FIH_RET(FIH_FAILURE);
95         }
96     }
97     if (idx != (data_length - (data_length % 0x8000))) {
98         FIH_RET(FIH_FAILURE);
99     }
100 
101     fih_rc = fih_int_encode_zero_equality(cc3xx_hash_sha256_update(data + idx,
102                                                                    data_length - idx));
103     if(fih_not_eq(fih_rc, FIH_SUCCESS)) {
104         FIH_RET(FIH_FAILURE);
105     }
106     fih_rc = fih_int_encode_zero_equality(cc3xx_hash_sha256_finish(hash, 32));
107     if(fih_not_eq(fih_rc, FIH_SUCCESS)) {
108         FIH_RET(FIH_FAILURE);
109     }
110 
111     FIH_RET(FIH_SUCCESS);
112 }
113 
bl1_key_to_cc3xx_key(enum tfm_bl1_key_id_t key_id,cc3xx_aes_key_id_t * cc3xx_key_type,uint8_t * key_buf,size_t key_buf_size)114 static int32_t bl1_key_to_cc3xx_key(enum tfm_bl1_key_id_t key_id,
115                                     cc3xx_aes_key_id_t *cc3xx_key_type,
116                                     uint8_t *key_buf, size_t key_buf_size)
117 {
118     int32_t rc;
119 
120     switch(key_id) {
121     case TFM_BL1_KEY_HUK:
122         *cc3xx_key_type = CC3XX_AES_KEY_ID_HUK;
123         break;
124     case TFM_BL1_KEY_GUK:
125         *cc3xx_key_type = CC3XX_AES_KEY_ID_GUK;
126         break;
127     default:
128         *cc3xx_key_type = CC3XX_AES_KEY_ID_USER_KEY;
129         rc = bl1_otp_read_key(key_id, key_buf);
130         if (rc) {
131             memset(key_buf, 0, key_buf_size);
132             return rc;
133         }
134         break;
135     }
136 
137     return 0;
138 }
139 
bl1_aes_256_ctr_decrypt(enum tfm_bl1_key_id_t key_id,const uint8_t * key_material,uint8_t * counter,const uint8_t * ciphertext,size_t ciphertext_length,uint8_t * plaintext)140 int32_t bl1_aes_256_ctr_decrypt(enum tfm_bl1_key_id_t key_id,
141                                 const uint8_t *key_material,
142                                 uint8_t *counter,
143                                 const uint8_t *ciphertext,
144                                 size_t ciphertext_length,
145                                 uint8_t *plaintext)
146 {
147     cc3xx_aes_key_id_t cc3xx_key_type;
148     uint8_t  __ALIGNED(4) key_buf[32];
149     int32_t rc = 0;
150     size_t idx = 0;
151     const uint8_t *input_key = key_buf;
152 
153     if (ciphertext_length == 0) {
154         return 0;
155     }
156 
157     if (counter == NULL || ciphertext == NULL || plaintext == NULL) {
158         return -1;
159     }
160 
161     if (key_material == NULL) {
162         rc = bl1_key_to_cc3xx_key(key_id, &cc3xx_key_type, key_buf,
163                                   sizeof(key_buf));
164         if (rc) {
165             return rc;
166         }
167     } else {
168         cc3xx_key_type = CC3XX_AES_KEY_ID_USER_KEY;
169         input_key = key_material;
170     }
171 
172     for (idx = 0; idx + 0x8000 < ciphertext_length; idx += 0x8000) {
173         rc = cc3xx_aes(cc3xx_key_type, input_key, CC3XX_AES_KEYSIZE_256,
174                        ciphertext + idx, 0x8000, counter, plaintext + idx,
175                        CC3XX_AES_DIRECTION_ENCRYPT, CC3XX_AES_MODE_CTR);
176         if (rc != CC3XX_ERR_SUCCESS) {
177             return rc;
178         }
179     }
180 
181     /* Under CTR mode encryption and decryption are the same operation */
182     return cc3xx_aes(cc3xx_key_type, input_key, CC3XX_AES_KEYSIZE_256,
183                      ciphertext + idx, ciphertext_length - idx, counter,
184                      plaintext + idx, CC3XX_AES_DIRECTION_ENCRYPT,
185                      CC3XX_AES_MODE_CTR);
186 }
187 
aes_256_ecb_encrypt(enum tfm_bl1_key_id_t key_id,const uint8_t * plaintext,size_t ciphertext_length,uint8_t * ciphertext)188 static int32_t aes_256_ecb_encrypt(enum tfm_bl1_key_id_t key_id,
189                                    const uint8_t *plaintext,
190                                    size_t ciphertext_length,
191                                    uint8_t *ciphertext)
192 {
193     cc3xx_aes_key_id_t cc3xx_key_type;
194     uint8_t __ALIGNED(4) key_buf[32];
195     int32_t rc = 0;
196 
197     if (ciphertext_length == 0) {
198         return 0;
199     }
200 
201     if (ciphertext == NULL || plaintext == NULL) {
202         return -1;
203     }
204 
205     rc = bl1_key_to_cc3xx_key(key_id, &cc3xx_key_type, key_buf, sizeof(key_buf));
206     if (rc) {
207         return rc;
208     }
209 
210     return cc3xx_aes(cc3xx_key_type, key_buf, CC3XX_AES_KEYSIZE_256, plaintext,
211                      ciphertext_length, NULL, ciphertext,
212                      CC3XX_AES_DIRECTION_ENCRYPT, CC3XX_AES_MODE_ECB);
213 }
214 
215 /* This is a counter-mode KDF complying with NIST SP800-108 where the PRF is a
216  * combined sha256 hash and an ECB-mode AES encryption. ECB is acceptable here
217  * since the input to the PRF is a hash, and the hash input is different every
218  * time because of the counter being part of the input.
219  */
bl1_derive_key(enum tfm_bl1_key_id_t input_key,const uint8_t * label,size_t label_length,const uint8_t * context,size_t context_length,uint8_t * output_key,size_t output_length)220 int32_t bl1_derive_key(enum tfm_bl1_key_id_t input_key, const uint8_t *label,
221                        size_t label_length, const uint8_t *context,
222                        size_t context_length, uint8_t *output_key,
223                        size_t output_length)
224 {
225     uint8_t state[KEY_DERIVATION_MAX_BUF_SIZE];
226     uint8_t state_size = label_length + context_length + sizeof(uint8_t)
227                          + 2 * sizeof(uint32_t);
228     uint8_t state_hash[32];
229     uint32_t L = output_length;
230     uint32_t n = (output_length + sizeof(state_hash) - 1) / sizeof(state_hash);
231     uint32_t i = 1;
232     size_t output_idx = 0;
233     cc3xx_err_t rc;
234 
235     if (output_length == 0) {
236         return 0;
237     }
238 
239     if (label == NULL || label_length == 0 ||
240         context == NULL || context_length == 0 ||
241         output_key == NULL) {
242         return -1;
243     }
244 
245     if (state_size > KEY_DERIVATION_MAX_BUF_SIZE) {
246         return -1;
247     }
248 
249     memcpy(state + sizeof(uint32_t), label, label_length);
250     memset(state + sizeof(uint32_t) + label_length, 0, sizeof(uint8_t));
251     memcpy(state + sizeof(uint32_t) + label_length + sizeof(uint8_t),
252            context, context_length);
253     memcpy(state + sizeof(uint32_t) + label_length + sizeof(uint8_t) + context_length,
254            &L, sizeof(uint32_t));
255 
256     for (i = 1; i < n; i++) {
257         memcpy(state, &i, sizeof(uint32_t));
258 
259         /* Hash the state to make it a constant size */
260         rc = bl1_sha256_compute(state, state_size, state_hash);
261         if (rc != CC3XX_ERR_SUCCESS) {
262             goto err;
263         }
264 
265         /* Encrypt using ECB, which is fine because the state is different every
266          * time and we're hashing it.
267          */
268         rc = aes_256_ecb_encrypt(input_key, state_hash, sizeof(state_hash),
269                                  output_key + output_idx);
270         if (rc != CC3XX_ERR_SUCCESS) {
271             goto err;
272         }
273 
274         output_idx += sizeof(state_hash);
275     }
276 
277     /* For the last block, encrypt into the state buf and then memcpy out how
278      * much we need
279      */
280     memcpy(state, &i, sizeof(uint32_t));
281 
282     rc = bl1_sha256_compute(state, state_size, state_hash);
283     if (rc != CC3XX_ERR_SUCCESS) {
284         goto err;
285     }
286 
287     /* This relies on us being able to have overlapping input and output
288      * pointers.
289      */
290     rc = aes_256_ecb_encrypt(input_key, state_hash, sizeof(state_hash),
291                              state_hash);
292     if (rc != CC3XX_ERR_SUCCESS) {
293         goto err;
294     }
295 
296     memcpy(output_key + output_idx, state_hash, output_length - output_idx);
297     memset(state, 0, sizeof(state));
298     memset(state_hash, 0, sizeof(state_hash));
299 
300     return 0;
301 
302 err:
303     memset(output_key, 0, output_length);
304     memset(state, 0, sizeof(state));
305     memset(state_hash, 0, sizeof(state_hash));
306     return rc;
307 }
308