1 /*
2  * Copyright (c) 2023-2024 Arm Limited. All rights reserved.
3  *
4  * Licensed under the Apache License Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing software
11  * distributed under the License is distributed on an "AS IS" BASIS
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "rse_handshake.h"
18 
19 #include "device_definition.h"
20 #include "mhu.h"
21 #include "tfm_plat_otp.h"
22 #include "rse_key_derivation.h"
23 #include "rse_kmu_slot_ids.h"
24 #include "crypto.h"
25 #include "cc3xx_aes.h"
26 #include "cc3xx_rng.h"
27 #include "log.h"
28 #include "cmsis.h"
29 #include "dpa_hardened_word_copy.h"
30 
31 #include <string.h>
32 
33 #define RSE_SERVER_ID            0
34 #define SESSION_KEY_IV_SIZE      32
35 #define SESSION_KEY_IV_WORD_SIZE 8
36 #define VHUK_SEED_SIZE           32
37 #define VHUK_SEED_WORD_SIZE      8
38 
39 uint32_t sending_mhu[RSE_AMOUNT];
40 uint32_t receiving_mhu[RSE_AMOUNT];
41 
42 enum rse_handshake_msg_type {
43     RSE_HANDSHAKE_SESSION_KEY_MSG,
44     RSE_HANDSHAKE_SESSION_KEY_REPLY,
45     RSE_HANDSHAKE_VHUK_MSG,
46     RSE_HANDSHAKE_VHUK_REPLY,
47     RSE_HANDSHAKE_MAX_MSG = UINT32_MAX,
48 };
49 
50 enum rse_handshake_crypt_type {
51     RSE_HANDSHAKE_ENCRYPT_MESSAGE,
52     RSE_HANDSHAKE_DONT_ENCRYPT_MESSAGE,
53 };
54 
55 __PACKED_STRUCT rse_handshake_header {
56     enum rse_handshake_msg_type type;
57     uint32_t rse_id;
58     uint32_t ccm_iv[3];
59 };
60 
61 __PACKED_STRUCT rse_handshake_trailer {
62     uint32_t ccm_tag[4];
63 };
64 
65 struct __attribute__((__packed__)) rse_handshake_msg {
66     struct rse_handshake_header header;
67     __PACKED_UNION {
68         __PACKED_STRUCT {
69             uint32_t session_key_iv[SESSION_KEY_IV_WORD_SIZE];
70         } session_key_msg;
71         __PACKED_STRUCT {
72             uint32_t session_key_ivs[SESSION_KEY_IV_WORD_SIZE * RSE_AMOUNT];
73         } session_key_reply;
74         __PACKED_STRUCT {
75             uint32_t vhuk_contribution[VHUK_SEED_WORD_SIZE];
76         } vhuk_msg;
77         __PACKED_STRUCT {
78             uint32_t vhuk_contributions[VHUK_SEED_WORD_SIZE * RSE_AMOUNT];
79         } vhuk_reply;
80     } body;
81     struct rse_handshake_trailer trailer;
82 };
83 
header_init(struct rse_handshake_msg * msg,enum rse_handshake_msg_type type)84 static int32_t header_init(struct rse_handshake_msg *msg,
85                            enum rse_handshake_msg_type type)
86 {
87     int32_t err;
88 
89     msg->header.type = type;
90 
91     err = tfm_plat_otp_read(PLAT_OTP_ID_RSE_ID,
92                             sizeof(msg->header.rse_id),
93                             (uint8_t*)&msg->header.rse_id);
94     if (err != 0)
95         return err;
96 
97     err = cc3xx_lowlevel_rng_get_random((uint8_t *)&msg->header.ccm_iv,
98                                         sizeof(msg->header.ccm_iv));
99     if (err != 0) {
100         return err;
101     }
102 
103     return 0;
104 }
105 
construct_session_key_msg(struct rse_handshake_msg * msg,uint32_t * session_key_iv)106 static int32_t construct_session_key_msg(struct rse_handshake_msg *msg,
107                                          uint32_t *session_key_iv)
108 {
109     int32_t err;
110 
111     err = header_init(msg, RSE_HANDSHAKE_SESSION_KEY_MSG);
112     if (err) {
113         return err;
114     }
115 
116     dpa_hardened_word_copy(msg->body.session_key_msg.session_key_iv,
117                            session_key_iv, SESSION_KEY_IV_WORD_SIZE);
118 
119     return 0;
120 }
121 
construct_vhuk_msg(struct rse_handshake_msg * msg,uint32_t * vhuk_seed)122 static int32_t construct_vhuk_msg(struct rse_handshake_msg *msg,
123                                   uint32_t *vhuk_seed)
124 {
125     int32_t err;
126 
127     err = header_init(msg, RSE_HANDSHAKE_VHUK_MSG);
128     if (err) {
129         return err;
130     }
131 
132     dpa_hardened_word_copy(msg->body.vhuk_msg.vhuk_contribution,
133                            vhuk_seed, VHUK_SEED_WORD_SIZE);
134 
135     return 0;
136 }
137 
construct_session_key_reply(struct rse_handshake_msg * msg,uint32_t * session_key_ivs)138 static int32_t construct_session_key_reply(struct rse_handshake_msg *msg,
139                                            uint32_t *session_key_ivs)
140 {
141     int32_t err;
142 
143     err = header_init(msg, RSE_HANDSHAKE_SESSION_KEY_REPLY);
144     if (err) {
145         return err;
146     }
147 
148     dpa_hardened_word_copy(msg->body.session_key_reply.session_key_ivs,
149                            session_key_ivs, SESSION_KEY_IV_WORD_SIZE * RSE_AMOUNT);
150 
151     return 0;
152 }
153 
construct_vhuk_reply(struct rse_handshake_msg * msg,uint32_t * vhuk_seeds)154 static int32_t construct_vhuk_reply(struct rse_handshake_msg *msg,
155                                     uint32_t *vhuk_seeds)
156 {
157     int32_t err;
158 
159     err = header_init(msg, RSE_HANDSHAKE_VHUK_REPLY);
160     if (err) {
161         return err;
162     }
163 
164     dpa_hardened_word_copy(msg->body.vhuk_reply.vhuk_contributions,
165                            vhuk_seeds, VHUK_SEED_WORD_SIZE * RSE_AMOUNT);
166 
167     return 0;
168 }
169 
rse_handshake_msg_crypt(cc3xx_aes_direction_t direction,struct rse_handshake_msg * msg)170 static int32_t rse_handshake_msg_crypt(cc3xx_aes_direction_t direction,
171                                        struct rse_handshake_msg *msg)
172 {
173     int32_t err;
174 
175     err = cc3xx_lowlevel_aes_init(direction, CC3XX_AES_MODE_CCM, RSE_KMU_SLOT_SESSION_KEY_0,
176                                   NULL, CC3XX_AES_KEYSIZE_256,
177                                   (uint32_t *)msg->header.ccm_iv, sizeof(msg->header.ccm_iv));
178     if (err != 0) {
179         return err;
180     }
181 
182     cc3xx_lowlevel_aes_set_tag_len(sizeof(msg->trailer.ccm_tag));
183     cc3xx_lowlevel_aes_set_data_len(sizeof(msg->body),
184                                     sizeof(msg->header));
185 
186     cc3xx_lowlevel_aes_update_authed_data((uint8_t *)msg,
187                                           sizeof(msg->header));
188 
189     cc3xx_lowlevel_aes_set_output_buffer((uint8_t*)&msg->body,
190                                          sizeof(msg->body));
191 
192     err = cc3xx_lowlevel_aes_update((uint8_t*)&msg->body,
193                                     sizeof(msg->body));
194     if (err != 0) {
195         return err;
196     }
197 
198     err = cc3xx_lowlevel_aes_finish((uint32_t *)&msg->trailer.ccm_tag, NULL);
199     if (err != 0) {
200         return err;
201     }
202 
203     return 0;
204 }
205 
rse_handshake_msg_send(void * mhu_sender_dev,struct rse_handshake_msg * msg,enum rse_handshake_crypt_type crypt)206 static int32_t rse_handshake_msg_send(void *mhu_sender_dev,
207                                       struct rse_handshake_msg *msg,
208                                       enum rse_handshake_crypt_type crypt)
209 {
210     int32_t err;
211 
212     err = mhu_init_sender(mhu_sender_dev);
213     if (err != MHU_ERR_NONE && err != MHU_ERR_ALREADY_INIT) {
214         return err;
215     }
216 
217     if (crypt == RSE_HANDSHAKE_ENCRYPT_MESSAGE) {
218         err = rse_handshake_msg_crypt(CC3XX_AES_DIRECTION_ENCRYPT, msg);
219         if (err != 0) {
220             return err;
221         }
222     }
223 
224     err = mhu_send_data(mhu_sender_dev,
225                         (uint8_t *)msg,
226                         sizeof(struct rse_handshake_msg));
227     if (err != 0) {
228         return err;
229     }
230 
231     return 0;
232 }
233 
rse_handshake_msg_receive(void * mhu_receiver_dev,struct rse_handshake_msg * msg,enum rse_handshake_crypt_type crypt)234 static int32_t rse_handshake_msg_receive(void *mhu_receiver_dev,
235                                          struct rse_handshake_msg *msg,
236                                          enum rse_handshake_crypt_type crypt)
237 {
238     int32_t err;
239     size_t size;
240 
241     err = mhu_init_receiver(mhu_receiver_dev);
242     if (err != MHU_ERR_NONE && err != MHU_ERR_ALREADY_INIT) {
243         return err;
244     }
245 
246     err = mhu_wait_data(mhu_receiver_dev);
247     if (err != 0) {
248         return err;
249     }
250 
251     size = sizeof(struct rse_handshake_msg);
252     err = mhu_receive_data(mhu_receiver_dev, (uint8_t*)msg, &size);
253     if (err != 0) {
254         return err;
255     }
256 
257     if (crypt == RSE_HANDSHAKE_ENCRYPT_MESSAGE) {
258         err = rse_handshake_msg_crypt(CC3XX_AES_DIRECTION_DECRYPT, msg);
259         if (err != 0) {
260             return err;
261         }
262     }
263 
264     return 0;
265 }
266 
calculate_session_key_client(uint32_t rse_id)267 static int32_t calculate_session_key_client(uint32_t rse_id)
268 {
269     int32_t err;
270     uint32_t session_key_iv[SESSION_KEY_IV_WORD_SIZE];
271     struct rse_handshake_msg msg;
272 
273     /* Calculate our session key */
274     err = cc3xx_lowlevel_rng_get_random((uint8_t *)session_key_iv, SESSION_KEY_IV_SIZE);
275     if (err) {
276         return err;
277     }
278 
279     /* Send our session key IV to the server */
280     err = construct_session_key_msg(&msg, session_key_iv);
281     if (err) {
282         return err;
283     }
284     err = rse_handshake_msg_send(&MHU_RSE_TO_RSE_SENDER_DEVS[sending_mhu[RSE_SERVER_ID]],
285                                  &msg, RSE_HANDSHAKE_DONT_ENCRYPT_MESSAGE);
286     if (err) {
287         return err;
288     }
289 
290     /* Receive back the session key IVs */
291     err = rse_handshake_msg_receive(&MHU_RSE_TO_RSE_RECEIVER_DEVS[receiving_mhu[RSE_SERVER_ID]],
292                                     &msg, RSE_HANDSHAKE_DONT_ENCRYPT_MESSAGE);
293     if (err) {
294         return err;
295     }
296 
297     if (msg.header.type != RSE_HANDSHAKE_SESSION_KEY_REPLY) {
298         return 1;
299     }
300 
301     /* Finally construct the session key */
302     err = rse_derive_session_key((uint8_t *)&msg.body.session_key_reply.session_key_ivs,
303                                  SESSION_KEY_IV_SIZE * RSE_AMOUNT,
304                                  RSE_KMU_SLOT_SESSION_KEY_0);
305 
306     return 0;
307 }
308 
exchange_vhuk_seeds_client(uint32_t rse_id,uint32_t * vhuk_seeds_buf)309 static int32_t exchange_vhuk_seeds_client(uint32_t rse_id, uint32_t *vhuk_seeds_buf)
310 {
311     int32_t err;
312     uint32_t vhuk_seed[VHUK_SEED_WORD_SIZE];
313     struct rse_handshake_msg msg;
314 
315     /* Calculate our VHUK contribution key */
316     err = cc3xx_lowlevel_rng_get_random((uint8_t *)vhuk_seed, VHUK_SEED_SIZE);
317     if (err) {
318         return err;
319     }
320 
321     /* Send our VHUK contribution to the server */
322     err = construct_vhuk_msg(&msg, vhuk_seed);
323     if (err) {
324         return err;
325     }
326     err = rse_handshake_msg_send(&MHU_RSE_TO_RSE_SENDER_DEVS[sending_mhu[RSE_SERVER_ID]],
327                                  &msg, RSE_HANDSHAKE_ENCRYPT_MESSAGE);
328     if (err) {
329         return err;
330     }
331 
332     /* Receive back the VHUK contributions */
333     err = rse_handshake_msg_receive(&MHU_RSE_TO_RSE_RECEIVER_DEVS[receiving_mhu[RSE_SERVER_ID]],
334                                     &msg, RSE_HANDSHAKE_ENCRYPT_MESSAGE);
335     if (err) {
336         return err;
337     }
338 
339     if (msg.header.type != RSE_HANDSHAKE_VHUK_REPLY) {
340         return 1;
341     }
342 
343     dpa_hardened_word_copy(vhuk_seeds_buf, msg.body.vhuk_reply.vhuk_contributions,
344                            VHUK_SEED_WORD_SIZE * RSE_AMOUNT);
345     /* Overwrite our VHUK contribution in the array, in case the sender has
346      * overwritten it.
347      */
348     dpa_hardened_word_copy(vhuk_seeds_buf + VHUK_SEED_WORD_SIZE * rse_id,
349                            vhuk_seed, VHUK_SEED_WORD_SIZE);
350 
351     return 0;
352 }
353 
rse_handshake_client(uint32_t rse_id,uint32_t * vhuk_seeds_buf)354 static int32_t rse_handshake_client(uint32_t rse_id, uint32_t *vhuk_seeds_buf)
355 {
356     int err;
357 
358     err = calculate_session_key_client(rse_id);
359     if (err) {
360         return err;
361     }
362 
363     err = exchange_vhuk_seeds_client(rse_id, vhuk_seeds_buf);
364     if (err) {
365         return err;
366     }
367 }
368 
calculate_session_key_server()369 static int32_t calculate_session_key_server()
370 {
371     uint32_t idx;
372     int32_t err;
373     uint32_t session_key_ivs[SESSION_KEY_IV_WORD_SIZE * RSE_AMOUNT];
374     struct rse_handshake_msg msg;
375 
376     /* Calculate the session key for RSE 0 */
377     err = cc3xx_lowlevel_rng_get_random((uint8_t *)session_key_ivs, SESSION_KEY_IV_SIZE);
378     if (err) {
379         return err;
380     }
381 
382     /* Receive all the other session keys */
383     for (idx = 0; idx < RSE_AMOUNT; idx++) {
384         if (idx == RSE_SERVER_ID) {
385             continue;
386         }
387 
388         memset(&msg, 0, sizeof(msg));
389         err = rse_handshake_msg_receive(&MHU_RSE_TO_RSE_RECEIVER_DEVS[receiving_mhu[idx]],
390                                         &msg, RSE_HANDSHAKE_DONT_ENCRYPT_MESSAGE);
391         if (err != 0) {
392             return err;
393         }
394 
395         if (msg.header.type != RSE_HANDSHAKE_SESSION_KEY_MSG) {
396             return 1;
397         }
398 
399         dpa_hardened_word_copy(session_key_ivs + (SESSION_KEY_IV_WORD_SIZE * idx),
400                                msg.body.session_key_msg.session_key_iv,
401                                SESSION_KEY_IV_WORD_SIZE);
402     }
403 
404     /* Construct the reply */
405     memset(&msg, 0, sizeof(msg));
406     err = construct_session_key_reply(&msg, session_key_ivs);
407     if (err != 0) {
408         return err;
409     }
410 
411     /* Send the session key reply to all other RSEes */
412     for (idx = 0; idx < RSE_AMOUNT; idx++) {
413         if (idx == RSE_SERVER_ID) {
414             continue;
415         }
416 
417         err = rse_handshake_msg_send(&MHU_RSE_TO_RSE_SENDER_DEVS[sending_mhu[idx]],
418                                      &msg, RSE_HANDSHAKE_DONT_ENCRYPT_MESSAGE);
419         if (err != 0) {
420             return err;
421         }
422     }
423 
424     /* Finally derive our own key */
425     err = rse_derive_session_key((uint8_t *)session_key_ivs, sizeof(session_key_ivs),
426                                  RSE_KMU_SLOT_SESSION_KEY_0);
427     return err;
428 }
429 
exchange_vhuk_seeds_server(uint32_t * vhuk_seeds_buf)430 static int32_t exchange_vhuk_seeds_server(uint32_t *vhuk_seeds_buf)
431 {
432     uint32_t idx;
433     int32_t err;
434     struct rse_handshake_msg msg;
435 
436     /* Receive all the other vhuk seeds */
437     for (idx = 0; idx < RSE_AMOUNT; idx++) {
438         if (idx == RSE_SERVER_ID) {
439             continue;
440         }
441 
442         memset(&msg, 0, sizeof(msg));
443         err = rse_handshake_msg_receive(&MHU_RSE_TO_RSE_RECEIVER_DEVS[receiving_mhu[idx]],
444                                         &msg, RSE_HANDSHAKE_ENCRYPT_MESSAGE);
445         if (err != 0) {
446             return err;
447         }
448 
449         if (msg.header.type != RSE_HANDSHAKE_VHUK_MSG) {
450             return 1;
451         }
452 
453         dpa_hardened_word_copy(vhuk_seeds_buf + (SESSION_KEY_IV_WORD_SIZE * idx),
454                                msg.body.vhuk_msg.vhuk_contribution,
455                                VHUK_SEED_WORD_SIZE);
456     }
457 
458     /* Construct the reply */
459     memset(&msg, 0, sizeof(msg));
460     err = construct_vhuk_reply(&msg, vhuk_seeds_buf);
461     if (err != 0) {
462         return err;
463     }
464 
465     /* Send the VUHK reply to all other RSEes */
466     for (idx = 0; idx < RSE_AMOUNT; idx++) {
467         if (idx == RSE_SERVER_ID) {
468             continue;
469         }
470 
471         err = rse_handshake_msg_send(&MHU_RSE_TO_RSE_SENDER_DEVS[sending_mhu[idx]],
472                                      &msg, RSE_HANDSHAKE_ENCRYPT_MESSAGE);
473         if (err != 0) {
474             return err;
475         }
476     }
477 
478     return 0;
479 }
480 
rse_handshake_server(uint32_t * vhuk_seeds_buf)481 static int32_t rse_handshake_server(uint32_t *vhuk_seeds_buf)
482 {
483     int32_t err;
484 
485     err = calculate_session_key_server();
486     if (err) {
487         return err;
488     }
489 
490     err = exchange_vhuk_seeds_server(vhuk_seeds_buf);
491     if (err) {
492         return err;
493     }
494 
495     return 0;
496 }
497 
rse_handshake(uint32_t * vhuk_seeds_buf)498 int32_t rse_handshake(uint32_t *vhuk_seeds_buf)
499 {
500     uint32_t rse_id;
501     enum tfm_plat_err_t plat_err;
502 
503     plat_err = tfm_plat_otp_read(PLAT_OTP_ID_RSE_TO_RSE_SENDER_ROUTING_TABLE,
504                                  sizeof(sending_mhu), (uint8_t *)sending_mhu);
505     if (plat_err != TFM_PLAT_ERR_SUCCESS) {
506         return 1;
507     }
508 
509     plat_err = tfm_plat_otp_read(PLAT_OTP_ID_RSE_TO_RSE_RECEIVER_ROUTING_TABLE,
510                                  sizeof(receiving_mhu), (uint8_t *)receiving_mhu);
511     if (plat_err != TFM_PLAT_ERR_SUCCESS) {
512         return 1;
513     }
514 
515     plat_err = tfm_plat_otp_read(PLAT_OTP_ID_RSE_ID, sizeof(rse_id),
516                                  (uint8_t*)&rse_id);
517     if (plat_err != TFM_PLAT_ERR_SUCCESS) {
518         return 1;
519     }
520 
521     if (rse_id == RSE_SERVER_ID) {
522 #if RSE_SERVER_ID != 0
523         dpa_hardened_word_copy(vhuk_seeds_buf + VHUK_SEED_WORD_SIZE * rse_id,
524                                vhuk_seeds_buf, VHUK_SEED_WORD_SIZE);
525 #endif /* RSE_SERVER_ID != 0 */
526 
527         return rse_handshake_server(vhuk_seeds_buf);
528     } else {
529         return rse_handshake_client(rse_id, vhuk_seeds_buf);
530     }
531 }
532