1 /* Run the boot image. */
2 
3 #include <assert.h>
4 #include <setjmp.h>
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
8 #include <bootutil/bootutil.h>
9 #include <bootutil/image.h>
10 
11 #include <flash_map_backend/flash_map_backend.h>
12 
13 #include "../../../boot/bootutil/src/bootutil_priv.h"
14 #include "bootsim.h"
15 
16 #ifdef MCUBOOT_ENCRYPT_RSA
17 #include "mbedtls/rsa.h"
18 #include "mbedtls/asn1.h"
19 #endif
20 
21 #ifdef MCUBOOT_ENCRYPT_KW
22 #include "mbedtls/nist_kw.h"
23 #endif
24 
25 #define BOOT_LOG_LEVEL BOOT_LOG_LEVEL_ERROR
26 #include <bootutil/bootutil_log.h>
27 #include "bootutil/crypto/common.h"
28 
29 #define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
30 
31 struct area_desc;
32 extern struct area_desc *sim_get_flash_areas(void);
33 extern void sim_set_flash_areas(struct area_desc *areas);
34 extern void sim_reset_flash_areas(void);
35 
36 struct sim_context;
37 extern struct sim_context *sim_get_context(void);
38 extern void sim_set_context(struct sim_context *ctx);
39 extern void sim_reset_context(void);
40 
41 extern int sim_flash_erase(uint8_t flash_id, uint32_t offset, uint32_t size);
42 extern int sim_flash_read(uint8_t flash_id, uint32_t offset, uint8_t *dest,
43         uint32_t size);
44 extern int sim_flash_write(uint8_t flash_id, uint32_t offset, const uint8_t *src,
45         uint32_t size);
46 extern uint32_t sim_flash_align(uint8_t flash_id);
47 extern uint8_t sim_flash_erased_val(uint8_t flash_id);
48 
49 struct sim_context {
50     int flash_counter;
51     int jumped;
52     uint8_t c_asserts;
53     uint8_t c_catch_asserts;
54     jmp_buf boot_jmpbuf;
55 };
56 
57 #ifdef MCUBOOT_ENCRYPT_RSA
58 static int
parse_pubkey(mbedtls_rsa_context * ctx,uint8_t ** p,uint8_t * end)59 parse_pubkey(mbedtls_rsa_context *ctx, uint8_t **p, uint8_t *end)
60 {
61     int rc;
62     size_t len;
63 
64     if ((rc = mbedtls_asn1_get_tag(p, end, &len,
65                     MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE)) != 0) {
66         return -1;
67     }
68 
69     if (*p + len != end) {
70         return -2;
71     }
72 
73     if ((rc = mbedtls_asn1_get_tag(p, end, &len,
74                     MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE)) != 0) {
75         return -3;
76     }
77 
78     *p += len;
79 
80     if ((rc = mbedtls_asn1_get_tag(p, end, &len, MBEDTLS_ASN1_BIT_STRING)) != 0) {
81         return -4;
82     }
83 
84     if (**p != MBEDTLS_ASN1_PRIMITIVE) {
85         return -5;
86     }
87 
88     *p += 1;
89 
90     if ((rc = mbedtls_asn1_get_tag(p, end, &len,
91                     MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE)) != 0) {
92         return -6;
93     }
94 
95     if (mbedtls_asn1_get_mpi(p, end, &ctx->MBEDTLS_CONTEXT_MEMBER(N)) != 0) {
96         return -7;
97     }
98 
99     if (mbedtls_asn1_get_mpi(p, end, &ctx->MBEDTLS_CONTEXT_MEMBER(E)) != 0) {
100         return -8;
101     }
102 
103     ctx->MBEDTLS_CONTEXT_MEMBER(len) = mbedtls_mpi_size(&ctx->MBEDTLS_CONTEXT_MEMBER(N));
104 
105     if (*p != end) {
106         return -9;
107     }
108 
109     if (mbedtls_rsa_check_pubkey(ctx) != 0) {
110         return -10;
111     }
112 
113     return 0;
114 }
115 
116 static int
fake_rng(void * p_rng,unsigned char * output,size_t len)117 fake_rng(void *p_rng, unsigned char *output, size_t len)
118 {
119     size_t i;
120 
121     (void)p_rng;
122     for (i = 0; i < len; i++) {
123         output[i] = (char)i;
124     }
125 
126     return 0;
127 }
128 #endif
129 
130 int mbedtls_platform_set_calloc_free(void * (*calloc_func)(size_t, size_t),
131                                      void (*free_func)(void *));
132 
rsa_oaep_encrypt_(const uint8_t * pubkey,unsigned pubkey_len,const uint8_t * seckey,unsigned seckey_len,uint8_t * encbuf)133 int rsa_oaep_encrypt_(const uint8_t *pubkey, unsigned pubkey_len,
134                       const uint8_t *seckey, unsigned seckey_len,
135                       uint8_t *encbuf)
136 {
137 #ifdef MCUBOOT_ENCRYPT_RSA
138     mbedtls_rsa_context ctx;
139     uint8_t *cp;
140     uint8_t *cpend;
141     int rc;
142 
143     mbedtls_platform_set_calloc_free(calloc, free);
144 
145 #if MBEDTLS_VERSION_NUMBER >= 0x03000000
146     mbedtls_rsa_init(&ctx);
147     mbedtls_rsa_set_padding(&ctx, MBEDTLS_RSA_PKCS_V21, MBEDTLS_MD_SHA256);
148 #else
149     mbedtls_rsa_init(&ctx, MBEDTLS_RSA_PKCS_V21, MBEDTLS_MD_SHA256);
150 #endif
151 
152     cp = (uint8_t *)pubkey;
153     cpend = cp + pubkey_len;
154 
155     rc = parse_pubkey(&ctx, &cp, cpend);
156     if (rc) {
157         goto done;
158     }
159 
160 #if MBEDTLS_VERSION_NUMBER >= 0x03000000
161     rc = mbedtls_rsa_rsaes_oaep_encrypt(&ctx, fake_rng, NULL,
162             NULL, 0, seckey_len, seckey, encbuf);
163 #else
164     rc = mbedtls_rsa_rsaes_oaep_encrypt(&ctx, fake_rng, NULL, MBEDTLS_RSA_PUBLIC,
165             NULL, 0, seckey_len, seckey, encbuf);
166 #endif
167     if (rc) {
168         goto done;
169     }
170 
171 done:
172     mbedtls_rsa_free(&ctx);
173     return rc;
174 
175 #else
176     (void)pubkey;
177     (void)pubkey_len;
178     (void)seckey;
179     (void)seckey_len;
180     (void)encbuf;
181     return 0;
182 #endif
183 }
184 
kw_encrypt_(const uint8_t * kek,const uint8_t * seckey,uint8_t * encbuf)185 int kw_encrypt_(const uint8_t *kek, const uint8_t *seckey, uint8_t *encbuf)
186 {
187 #ifdef MCUBOOT_ENCRYPT_KW
188 #ifdef MCUBOOT_AES_256
189     int key_len = 256;
190     int out_size = 40;
191     int in_len = 32;
192 #else
193     int key_len = 128;
194     int out_size = 24;
195     int in_len = 16;
196 #endif
197     mbedtls_nist_kw_context kw;
198     size_t olen;
199     int rc;
200 
201     mbedtls_platform_set_calloc_free(calloc, free);
202 
203     mbedtls_nist_kw_init(&kw);
204 
205     rc = mbedtls_nist_kw_setkey(&kw, MBEDTLS_CIPHER_ID_AES, kek, key_len, 1);
206     if (rc) {
207         goto done;
208     }
209 
210     rc = mbedtls_nist_kw_wrap(&kw, MBEDTLS_KW_MODE_KW, seckey, in_len, encbuf,
211             &olen, out_size);
212 
213 done:
214     mbedtls_nist_kw_free(&kw);
215     return rc;
216 
217 #else
218     (void)kek;
219     (void)seckey;
220     (void)encbuf;
221     return 0;
222 #endif
223 }
224 
flash_area_align(const struct flash_area * area)225 uint32_t flash_area_align(const struct flash_area *area)
226 {
227     return sim_flash_align(area->fa_device_id);
228 }
229 
flash_area_erased_val(const struct flash_area * area)230 uint8_t flash_area_erased_val(const struct flash_area *area)
231 {
232     return sim_flash_erased_val(area->fa_device_id);
233 }
234 
235 struct area {
236     struct flash_area whole;
237     struct flash_area *areas;
238     uint32_t num_areas;
239     uint8_t id;
240 };
241 
242 struct area_desc {
243     struct area slots[16];
244     uint32_t num_slots;
245 };
246 
invoke_boot_go(struct sim_context * ctx,struct area_desc * adesc,struct boot_rsp * rsp,int image_id)247 int invoke_boot_go(struct sim_context *ctx, struct area_desc *adesc,
248                    struct boot_rsp *rsp, int image_id)
249 {
250     int res;
251     struct boot_loader_state *state;
252 
253 #if defined(MCUBOOT_SIGN_RSA) || \
254     (defined(MCUBOOT_SIGN_EC256) && defined(MCUBOOT_USE_MBED_TLS)) ||\
255     (defined(MCUBOOT_ENCRYPT_EC256) && defined(MCUBOOT_USE_MBED_TLS)) ||\
256     (defined(MCUBOOT_ENCRYPT_X25519) && defined(MCUBOOT_USE_MBED_TLS))
257     mbedtls_platform_set_calloc_free(calloc, free);
258 #endif
259 
260     state = malloc(sizeof(struct boot_loader_state));
261 
262     sim_set_flash_areas(adesc);
263     sim_set_context(ctx);
264 
265     if (setjmp(ctx->boot_jmpbuf) == 0) {
266         boot_state_clear(state);
267 
268 #if BOOT_IMAGE_NUMBER > 1
269         if (image_id >= 0) {
270             memset(state->img_mask, 1, sizeof(state->img_mask));
271             state->img_mask[image_id] = 0;
272         }
273 #else
274         (void) image_id;
275 #endif /* BOOT_IMAGE_NUMBER > 1 */
276 
277         res = context_boot_go(state, rsp);
278         sim_reset_flash_areas();
279         sim_reset_context();
280         free(state);
281         /* printf("boot_go off: %d (0x%08x)\n", res, rsp.br_image_off); */
282         return res;
283     } else {
284         sim_reset_flash_areas();
285         sim_reset_context();
286         free(state);
287         return -0x13579;
288     }
289 }
290 
os_malloc(size_t size)291 void *os_malloc(size_t size)
292 {
293     // printf("os_malloc 0x%x bytes\n", size);
294     return malloc(size);
295 }
296 
flash_area_id_from_multi_image_slot(int image_index,int slot)297 int flash_area_id_from_multi_image_slot(int image_index, int slot)
298 {
299     switch (slot) {
300     case 0: return FLASH_AREA_IMAGE_PRIMARY(image_index);
301     case 1: return FLASH_AREA_IMAGE_SECONDARY(image_index);
302     case 2: return FLASH_AREA_IMAGE_SCRATCH;
303     }
304 
305     printf("Image flash area ID not found\n");
306     return -1; /* flash_area_open will fail on that */
307 }
308 
flash_area_open(uint8_t id,const struct flash_area ** area)309 int flash_area_open(uint8_t id, const struct flash_area **area)
310 {
311     uint32_t i;
312     struct area_desc *flash_areas;
313 
314     flash_areas = sim_get_flash_areas();
315     for (i = 0; i < flash_areas->num_slots; i++) {
316         if (flash_areas->slots[i].id == id)
317             break;
318     }
319     if (i == flash_areas->num_slots) {
320         printf("Unsupported area\n");
321         abort();
322     }
323 
324     /* Unsure if this is right, just returning the first area. */
325     *area = &flash_areas->slots[i].whole;
326     return 0;
327 }
328 
flash_area_close(const struct flash_area * area)329 void flash_area_close(const struct flash_area *area)
330 {
331     (void)area;
332 }
333 
334 /*
335  * Read/write/erase. Offset is relative from beginning of flash area.
336  */
flash_area_read(const struct flash_area * area,uint32_t off,void * dst,uint32_t len)337 int flash_area_read(const struct flash_area *area, uint32_t off, void *dst,
338                     uint32_t len)
339 {
340     BOOT_LOG_SIM("%s: area=%d, off=%x, len=%x",
341                  __func__, area->fa_id, off, len);
342     return sim_flash_read(area->fa_device_id, area->fa_off + off, dst, len);
343 }
344 
flash_area_write(const struct flash_area * area,uint32_t off,const void * src,uint32_t len)345 int flash_area_write(const struct flash_area *area, uint32_t off, const void *src,
346                      uint32_t len)
347 {
348     BOOT_LOG_SIM("%s: area=%d, off=%x, len=%x", __func__,
349                  area->fa_id, off, len);
350     struct sim_context *ctx = sim_get_context();
351     if (--(ctx->flash_counter) == 0) {
352         ctx->jumped++;
353         longjmp(ctx->boot_jmpbuf, 1);
354     }
355     return sim_flash_write(area->fa_device_id, area->fa_off + off, src, len);
356 }
357 
flash_area_erase(const struct flash_area * area,uint32_t off,uint32_t len)358 int flash_area_erase(const struct flash_area *area, uint32_t off, uint32_t len)
359 {
360     BOOT_LOG_SIM("%s: area=%d, off=%x, len=%x", __func__,
361                  area->fa_id, off, len);
362     struct sim_context *ctx = sim_get_context();
363     if (--(ctx->flash_counter) == 0) {
364         ctx->jumped++;
365         longjmp(ctx->boot_jmpbuf, 1);
366     }
367     return sim_flash_erase(area->fa_device_id, area->fa_off + off, len);
368 }
369 
flash_area_to_sectors(int idx,int * cnt,struct flash_area * ret)370 int flash_area_to_sectors(int idx, int *cnt, struct flash_area *ret)
371 {
372     uint32_t i;
373     struct area *slot;
374     struct area_desc *flash_areas;
375 
376     flash_areas = sim_get_flash_areas();
377     for (i = 0; i < flash_areas->num_slots; i++) {
378         if (flash_areas->slots[i].id == idx)
379             break;
380     }
381     if (i == flash_areas->num_slots) {
382         printf("Unsupported area\n");
383         abort();
384     }
385 
386     slot = &flash_areas->slots[i];
387 
388     if (slot->num_areas > (uint32_t)*cnt) {
389         printf("Too many areas in slot\n");
390         abort();
391     }
392 
393     *cnt = slot->num_areas;
394     memcpy(ret, slot->areas, slot->num_areas * sizeof(struct flash_area));
395 
396     return 0;
397 }
398 
flash_area_get_sectors(int fa_id,uint32_t * count,struct flash_sector * sectors)399 int flash_area_get_sectors(int fa_id, uint32_t *count,
400                            struct flash_sector *sectors)
401 {
402     uint32_t i;
403     struct area *slot;
404     struct area_desc *flash_areas;
405 
406     flash_areas = sim_get_flash_areas();
407     for (i = 0; i < flash_areas->num_slots; i++) {
408         if (flash_areas->slots[i].id == fa_id)
409             break;
410     }
411     if (i == flash_areas->num_slots) {
412         printf("Unsupported area\n");
413         abort();
414     }
415 
416     slot = &flash_areas->slots[i];
417 
418     if (slot->num_areas > *count) {
419         printf("Too many areas in slot\n");
420         abort();
421     }
422 
423     for (i = 0; i < slot->num_areas; i++) {
424         sectors[i].fs_off = slot->areas[i].fa_off -
425             slot->whole.fa_off;
426         sectors[i].fs_size = slot->areas[i].fa_size;
427     }
428     *count = slot->num_areas;
429 
430     return 0;
431 }
432 
flash_area_id_to_multi_image_slot(int image_index,int area_id)433 int flash_area_id_to_multi_image_slot(int image_index, int area_id)
434 {
435     if (area_id == FLASH_AREA_IMAGE_PRIMARY(image_index)) {
436         return 0;
437     }
438     if (area_id == FLASH_AREA_IMAGE_SECONDARY(image_index)) {
439         return 1;
440     }
441 
442     printf("Unsupported image area ID\n");
443     abort();
444 }
445 
flash_area_id_from_image_slot(int slot)446 int flash_area_id_from_image_slot(int slot) {
447     /* For single image cases, just use the first image. */
448     return flash_area_id_from_multi_image_slot(0, slot);
449 }
450 
flash_area_sector_from_off(uint32_t off,struct flash_sector * sector)451 int flash_area_sector_from_off(uint32_t off, struct flash_sector *sector)
452 {
453     uint32_t i, sec_off, sec_size;
454     struct area *slot;
455     struct area_desc *flash_areas;
456 
457     flash_areas = sim_get_flash_areas();
458     for (i = 0; i < flash_areas->num_slots; i++) {
459         if (flash_areas->slots[i].id == FLASH_AREA_ID(image_0))
460             break;
461     }
462 
463     if (i == flash_areas->num_slots) {
464         printf("Unsupported area\n");
465         abort();
466     }
467 
468     slot = &flash_areas->slots[i];
469 
470     for (i = 0; i < slot->num_areas; i++) {
471         sec_off = slot->areas[i].fa_off - slot->whole.fa_off;
472         sec_size = slot->areas[i].fa_size;
473 
474         if (off >= sec_off && off < (sec_off + sec_size)) {
475             sector->fs_off = sec_off;
476             sector->fs_size = sec_size;
477             break;
478         }
479     }
480 
481     return (i < slot->num_areas) ? 0 : -1;
482 }
483 
flash_area_get_sector(const struct flash_area * fa,uint32_t off,struct flash_sector * sector)484 int flash_area_get_sector(const struct flash_area *fa, uint32_t off,
485                           struct flash_sector *sector)
486 {
487     uint32_t i, sec_off, sec_size;
488     struct area *slot;
489     struct area_desc *flash_areas;
490 
491     flash_areas = sim_get_flash_areas();
492     for (i = 0; i < flash_areas->num_slots; i++) {
493         if (&flash_areas->slots[i].whole == fa)
494             break;
495     }
496 
497     if (i == flash_areas->num_slots) {
498         printf("Unsupported area\n");
499         abort();
500     }
501 
502     slot = &flash_areas->slots[i];
503 
504     for (i = 0; i < slot->num_areas; i++) {
505         sec_off = slot->areas[i].fa_off - slot->whole.fa_off;
506         sec_size = slot->areas[i].fa_size;
507 
508         if (off >= sec_off && off < (sec_off + sec_size)) {
509             sector->fs_off = sec_off;
510             sector->fs_size = sec_size;
511             break;
512         }
513     }
514 
515     return (i < slot->num_areas) ? 0 : -1;
516 }
517 
sim_assert(int x,const char * assertion,const char * file,unsigned int line,const char * function)518 void sim_assert(int x, const char *assertion, const char *file, unsigned int line, const char *function)
519 {
520     if (!(x)) {
521         struct sim_context *ctx = sim_get_context();
522         if (ctx->c_catch_asserts) {
523             ctx->c_asserts++;
524         } else {
525             BOOT_LOG_ERR("%s:%d: %s: Assertion `%s' failed.", file, line, function, assertion);
526 
527             /* NOTE: if the assert below is triggered, the place where it was originally
528              * asserted is printed by the message above...
529              */
530             assert(x);
531         }
532     }
533 }
534 
boot_max_align(void)535 uint32_t boot_max_align(void)
536 {
537     return BOOT_MAX_ALIGN;
538 }
539 
boot_magic_sz(void)540 uint32_t boot_magic_sz(void)
541 {
542     return BOOT_MAGIC_ALIGN_SIZE;
543 }
544