1 /*
2  * Multi-precision integer library
3  * ESP-IDF hardware accelerated parts based on mbedTLS implementation
4  *
5  * SPDX-FileCopyrightText: The Mbed TLS Contributors
6  *
7  * SPDX-License-Identifier: Apache-2.0
8  *
9  * SPDX-FileContributor: 2016-2023 Espressif Systems (Shanghai) CO LTD
10  */
11 #include <stdio.h>
12 #include <string.h>
13 #include <malloc.h>
14 #include <limits.h>
15 #include <assert.h>
16 #include <stdlib.h>
17 #include <sys/param.h>
18 
19 #include "esp_system.h"
20 #include "esp_log.h"
21 #include "esp_attr.h"
22 #include "esp_intr_alloc.h"
23 #if CONFIG_PM_ENABLE
24 #include "esp_pm.h"
25 #endif
26 
27 #include "freertos/FreeRTOS.h"
28 #include "freertos/semphr.h"
29 
30 #include "soc/hwcrypto_periph.h"
31 #include "soc/periph_defs.h"
32 #include "soc/soc_caps.h"
33 
34 #include "bignum_impl.h"
35 
36 #include <mbedtls/bignum.h>
37 
38 
39 /* Some implementation notes:
40  *
41  * - Naming convention x_words, y_words, z_words for number of words (limbs) used in a particular
42  *   bignum. This number may be less than the size of the bignum
43  *
44  * - Naming convention hw_words for the hardware length of the operation. This number maybe be rounded up
45  *   for targets that requres this (e.g. ESP32), and may be larger than any of the numbers
46  *   involved in the calculation.
47  *
48  * - Timing behaviour of these functions will depend on the length of the inputs. This is fundamentally
49  *   the same constraint as the software mbedTLS implementations, and relies on the same
50  *   countermeasures (exponent blinding, etc) which are used in mbedTLS.
51  */
52 
53 static const __attribute__((unused)) char *TAG = "bignum";
54 
55 #define ciL    (sizeof(mbedtls_mpi_uint))         /* chars in limb  */
56 #define biL    (ciL << 3)                         /* bits  in limb  */
57 
58 #if defined(CONFIG_MBEDTLS_MPI_USE_INTERRUPT)
59 static SemaphoreHandle_t op_complete_sem;
60 #if defined(CONFIG_PM_ENABLE)
61 static esp_pm_lock_handle_t s_pm_cpu_lock;
62 static esp_pm_lock_handle_t s_pm_sleep_lock;
63 #endif
64 
esp_mpi_complete_isr(void * arg)65 static IRAM_ATTR void esp_mpi_complete_isr(void *arg)
66 {
67     BaseType_t higher_woken;
68     esp_mpi_interrupt_clear();
69 
70     xSemaphoreGiveFromISR(op_complete_sem, &higher_woken);
71     if (higher_woken) {
72         portYIELD_FROM_ISR();
73     }
74 }
75 
76 
esp_mpi_isr_initialise(void)77 static esp_err_t esp_mpi_isr_initialise(void)
78 {
79     esp_mpi_interrupt_clear();
80     esp_mpi_interrupt_enable(true);
81     if (op_complete_sem == NULL) {
82         static StaticSemaphore_t op_sem_buf;
83         op_complete_sem = xSemaphoreCreateBinaryStatic(&op_sem_buf);
84         if (op_complete_sem == NULL) {
85             ESP_LOGE(TAG, "Failed to create intr semaphore");
86             return ESP_FAIL;
87         }
88 
89         esp_err_t ret;
90         ret = esp_intr_alloc(ETS_RSA_INTR_SOURCE, 0, esp_mpi_complete_isr, NULL, NULL);
91         if (ret != ESP_OK) {
92             ESP_LOGE(TAG, "Failed to allocate RSA interrupt %d", ret);
93 
94             // This should be treated as fatal error as this API would mostly
95             // be invoked within mbedTLS interface. There is no way for the system
96             // to proceed if the MPI interrupt allocation fails here.
97             abort();
98         }
99     }
100 
101     /* MPI is clocked proportionally to CPU clock, take power management lock */
102 #ifdef CONFIG_PM_ENABLE
103     if (s_pm_cpu_lock == NULL) {
104         if (esp_pm_lock_create(ESP_PM_NO_LIGHT_SLEEP, 0, "mpi_sleep", &s_pm_sleep_lock) != ESP_OK) {
105             ESP_LOGE(TAG, "Failed to create PM sleep lock");
106             return ESP_FAIL;
107         }
108         if (esp_pm_lock_create(ESP_PM_CPU_FREQ_MAX, 0, "mpi_cpu", &s_pm_cpu_lock) != ESP_OK) {
109             ESP_LOGE(TAG, "Failed to create PM CPU lock");
110             return ESP_FAIL;
111         }
112     }
113     esp_pm_lock_acquire(s_pm_cpu_lock);
114     esp_pm_lock_acquire(s_pm_sleep_lock);
115 #endif
116 
117     return ESP_OK;
118 }
119 
esp_mpi_wait_intr(void)120 static int esp_mpi_wait_intr(void)
121 {
122     if (!xSemaphoreTake(op_complete_sem, 2000 / portTICK_PERIOD_MS)) {
123         ESP_LOGE("MPI", "Timed out waiting for completion of MPI Interrupt");
124         return -1;
125     }
126 
127 #ifdef CONFIG_PM_ENABLE
128     esp_pm_lock_release(s_pm_cpu_lock);
129     esp_pm_lock_release(s_pm_sleep_lock);
130 #endif  // CONFIG_PM_ENABLE
131 
132     esp_mpi_interrupt_enable(false);
133 
134     return 0;
135 }
136 
137 #endif // CONFIG_MBEDTLS_MPI_USE_INTERRUPT
138 
139 /* Convert bit count to word count
140  */
bits_to_words(size_t bits)141 static inline size_t bits_to_words(size_t bits)
142 {
143     return (bits + 31) / 32;
144 }
145 
146 /* Return the number of words actually used to represent an mpi
147    number.
148 */
149 #if defined(MBEDTLS_MPI_EXP_MOD_ALT) || defined(MBEDTLS_MPI_EXP_MOD_ALT_FALLBACK)
mpi_words(const mbedtls_mpi * mpi)150 static size_t mpi_words(const mbedtls_mpi *mpi)
151 {
152     for (size_t i = mpi->MBEDTLS_PRIVATE(n); i > 0; i--) {
153         if (mpi->MBEDTLS_PRIVATE(p[i - 1]) != 0) {
154             return i;
155         }
156     }
157     return 0;
158 }
159 
160 #endif //(MBEDTLS_MPI_EXP_MOD_ALT || MBEDTLS_MPI_EXP_MOD_ALT_FALLBACK)
161 
162 /**
163  *
164  * There is a need for the value of integer N' such that B^-1(B-1)-N^-1N'=1,
165  * where B^-1(B-1) mod N=1. Actually, only the least significant part of
166  * N' is needed, hence the definition N0'=N' mod b. We reproduce below the
167  * simple algorithm from an article by Dusse and Kaliski to efficiently
168  * find N0' from N0 and b
169  */
modular_inverse(const mbedtls_mpi * M)170 static mbedtls_mpi_uint modular_inverse(const mbedtls_mpi *M)
171 {
172     int i;
173     uint64_t t = 1;
174     uint64_t two_2_i_minus_1 = 2;   /* 2^(i-1) */
175     uint64_t two_2_i = 4;           /* 2^i */
176     uint64_t N = M->MBEDTLS_PRIVATE(p[0]);
177 
178     for (i = 2; i <= 32; i++) {
179         if ((mbedtls_mpi_uint) N * t % two_2_i >= two_2_i_minus_1) {
180             t += two_2_i_minus_1;
181         }
182 
183         two_2_i_minus_1 <<= 1;
184         two_2_i <<= 1;
185     }
186 
187     return (mbedtls_mpi_uint)(UINT32_MAX - t + 1);
188 }
189 
190 /* Calculate Rinv = RR^2 mod M, where:
191  *
192  *  R = b^n where b = 2^32, n=num_words,
193  *  R = 2^N (where N=num_bits)
194  *  RR = R^2 = 2^(2*N) (where N=num_bits=num_words*32)
195  *
196  * This calculation is computationally expensive (mbedtls_mpi_mod_mpi)
197  * so caller should cache the result where possible.
198  *
199  * DO NOT call this function while holding esp_mpi_enable_hardware_hw_op().
200  *
201  */
calculate_rinv(mbedtls_mpi * Rinv,const mbedtls_mpi * M,int num_words)202 static int calculate_rinv(mbedtls_mpi *Rinv, const mbedtls_mpi *M, int num_words)
203 {
204     int ret;
205     size_t num_bits = num_words * 32;
206     mbedtls_mpi RR;
207     mbedtls_mpi_init(&RR);
208     MBEDTLS_MPI_CHK(mbedtls_mpi_set_bit(&RR, num_bits * 2, 1));
209     MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(Rinv, &RR, M));
210 
211 cleanup:
212     mbedtls_mpi_free(&RR);
213 
214     return ret;
215 }
216 
217 
218 
219 
220 
221 
222 /* Z = (X * Y) mod M
223 
224    Not an mbedTLS function
225 */
esp_mpi_mul_mpi_mod(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,const mbedtls_mpi * M)226 int esp_mpi_mul_mpi_mod(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M)
227 {
228     int ret = 0;
229 
230     size_t x_bits = mbedtls_mpi_bitlen(X);
231     size_t y_bits = mbedtls_mpi_bitlen(Y);
232     size_t m_bits = mbedtls_mpi_bitlen(M);
233     size_t z_bits = MIN(m_bits, x_bits + y_bits);
234     size_t x_words = bits_to_words(x_bits);
235     size_t y_words = bits_to_words(y_bits);
236     size_t m_words = bits_to_words(m_bits);
237     size_t z_words = bits_to_words(z_bits);
238     size_t hw_words = esp_mpi_hardware_words(MAX(x_words, MAX(y_words, m_words))); /* longest operand */
239     mbedtls_mpi Rinv;
240     mbedtls_mpi_uint Mprime;
241 
242     /* Calculate and load the first stage montgomery multiplication */
243     mbedtls_mpi_init(&Rinv);
244     MBEDTLS_MPI_CHK(calculate_rinv(&Rinv, M, hw_words));
245     Mprime = modular_inverse(M);
246 
247     esp_mpi_enable_hardware_hw_op();
248     /* Load and start a (X * Y) mod M calculation */
249     esp_mpi_mul_mpi_mod_hw_op(X, Y, M, &Rinv, Mprime, hw_words);
250 
251     MBEDTLS_MPI_CHK(mbedtls_mpi_grow(Z, z_words));
252 
253     esp_mpi_read_result_hw_op(Z, z_words);
254     Z->MBEDTLS_PRIVATE(s) = X->MBEDTLS_PRIVATE(s) * Y->MBEDTLS_PRIVATE(s);
255 
256 cleanup:
257     mbedtls_mpi_free(&Rinv);
258     esp_mpi_disable_hardware_hw_op();
259 
260     return ret;
261 }
262 
263 #if defined(MBEDTLS_MPI_EXP_MOD_ALT) || defined(MBEDTLS_MPI_EXP_MOD_ALT_FALLBACK)
264 
265 #ifdef ESP_MPI_USE_MONT_EXP
266 /*
267  * Return the most significant one-bit.
268  */
mbedtls_mpi_msb(const mbedtls_mpi * X)269 static size_t mbedtls_mpi_msb( const mbedtls_mpi *X )
270 {
271     int i, j;
272     if (X != NULL && X->MBEDTLS_PRIVATE(n) != 0) {
273         for (i = X->MBEDTLS_PRIVATE(n) - 1; i >= 0; i--) {
274             if (X->MBEDTLS_PRIVATE(p[i]) != 0) {
275                 for (j = biL - 1; j >= 0; j--) {
276                     if ((X->MBEDTLS_PRIVATE(p[i]) & (1 << j)) != 0) {
277                         return (i * biL) + j;
278                     }
279                 }
280             }
281         }
282     }
283     return 0;
284 }
285 
286 /*
287  * Montgomery exponentiation: Z = X ^ Y mod M  (HAC 14.94)
288  */
mpi_montgomery_exp_calc(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,const mbedtls_mpi * M,mbedtls_mpi * Rinv,size_t hw_words,mbedtls_mpi_uint Mprime)289 static int mpi_montgomery_exp_calc( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M,
290                                     mbedtls_mpi *Rinv,
291                                     size_t hw_words,
292                                     mbedtls_mpi_uint Mprime )
293 {
294     int ret = 0;
295     mbedtls_mpi X_, one;
296 
297     mbedtls_mpi_init(&X_);
298     mbedtls_mpi_init(&one);
299     if ( ( ( ret = mbedtls_mpi_grow(&one, hw_words) ) != 0 ) ||
300             ( ( ret = mbedtls_mpi_set_bit(&one, 0, 1) )  != 0 ) ) {
301         goto cleanup2;
302     }
303 
304     // Algorithm from HAC 14.94
305     {
306         // 0 determine t (highest bit set in y)
307         int t = mbedtls_mpi_msb(Y);
308 
309         esp_mpi_enable_hardware_hw_op();
310 
311         // 1.1 x_ = mont(x, R^2 mod m)
312         //        = mont(x, rb)
313         MBEDTLS_MPI_CHK( esp_mont_hw_op(&X_, X, Rinv, M, Mprime, hw_words, false) );
314 
315         // 1.2 z = R mod m
316         // now z = R mod m = Mont (R^2 mod m, 1) mod M (as Mont(x) = X&R^-1 mod M)
317         MBEDTLS_MPI_CHK( esp_mont_hw_op(Z, Rinv, &one, M, Mprime, hw_words, true) );
318 
319         // 2 for i from t down to 0
320         for (int i = t; i >= 0; i--) {
321             // 2.1 z = mont(z,z)
322             if (i != t) { // skip on the first iteration as is still unity
323                 MBEDTLS_MPI_CHK( esp_mont_hw_op(Z, Z, Z, M, Mprime, hw_words, true) );
324             }
325 
326             // 2.2 if y[i] = 1 then z = mont(A, x_)
327             if (mbedtls_mpi_get_bit(Y, i)) {
328                 MBEDTLS_MPI_CHK( esp_mont_hw_op(Z, Z, &X_, M, Mprime, hw_words, true) );
329             }
330         }
331 
332         // 3 z = Mont(z, 1)
333         MBEDTLS_MPI_CHK( esp_mont_hw_op(Z, Z, &one, M, Mprime, hw_words, true) );
334     }
335 
336 cleanup:
337     esp_mpi_disable_hardware_hw_op();
338 
339 cleanup2:
340     mbedtls_mpi_free(&X_);
341     mbedtls_mpi_free(&one);
342     return ret;
343 }
344 
345 #endif //USE_MONT_EXPONENATIATION
346 
347 /*
348  * Z = X ^ Y mod M
349  *
350  * _Rinv is optional pre-calculated version of Rinv (via calculate_rinv()).
351  *
352  * (See RSA Accelerator section in Technical Reference for more about Mprime, Rinv)
353  *
354  */
esp_mpi_exp_mod(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,const mbedtls_mpi * M,mbedtls_mpi * _Rinv)355 static int esp_mpi_exp_mod( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M, mbedtls_mpi *_Rinv )
356 {
357     int ret = 0;
358 
359     mbedtls_mpi Rinv_new; /* used if _Rinv == NULL */
360     mbedtls_mpi *Rinv;    /* points to _Rinv (if not NULL) othwerwise &RR_new */
361     mbedtls_mpi_uint Mprime;
362 
363     size_t x_words = mpi_words(X);
364     size_t y_words = mpi_words(Y);
365     size_t m_words = mpi_words(M);
366 
367     /* "all numbers must be the same length", so choose longest number
368        as cardinal length of operation...
369     */
370     size_t num_words = esp_mpi_hardware_words(MAX(m_words, MAX(x_words, y_words)));
371 
372     if (num_words * 32 > SOC_RSA_MAX_BIT_LEN) {
373         return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
374     }
375 
376     if (mbedtls_mpi_cmp_int(M, 0) <= 0 || (M->MBEDTLS_PRIVATE(p[0]) & 1) == 0) {
377         return MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
378     }
379 
380     if (mbedtls_mpi_cmp_int(Y, 0) < 0) {
381         return MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
382     }
383 
384     if (mbedtls_mpi_cmp_int(Y, 0) == 0) {
385         return mbedtls_mpi_lset(Z, 1);
386     }
387 
388     /* Determine RR pointer, either _RR for cached value
389        or local RR_new */
390     if (_Rinv == NULL) {
391         mbedtls_mpi_init(&Rinv_new);
392         Rinv = &Rinv_new;
393     } else {
394         Rinv = _Rinv;
395     }
396     if (Rinv->MBEDTLS_PRIVATE(p) == NULL) {
397         MBEDTLS_MPI_CHK(calculate_rinv(Rinv, M, num_words));
398     }
399 
400     Mprime = modular_inverse(M);
401 
402     // Montgomery exponentiation: Z = X ^ Y mod M  (HAC 14.94)
403 #ifdef ESP_MPI_USE_MONT_EXP
404     ret = mpi_montgomery_exp_calc(Z, X, Y, M, Rinv, num_words, Mprime) ;
405     MBEDTLS_MPI_CHK(ret);
406 #else
407     esp_mpi_enable_hardware_hw_op();
408 
409 #if defined (CONFIG_MBEDTLS_MPI_USE_INTERRUPT)
410     if (esp_mpi_isr_initialise() != ESP_OK) {
411         ret = -1;
412         esp_mpi_disable_hardware_hw_op();
413         goto cleanup;
414     }
415 #endif
416 
417     esp_mpi_exp_mpi_mod_hw_op(X, Y, M, Rinv, Mprime, num_words);
418     ret = mbedtls_mpi_grow(Z, m_words);
419     if (ret != 0) {
420         esp_mpi_disable_hardware_hw_op();
421         goto cleanup;
422     }
423 
424 #if defined(CONFIG_MBEDTLS_MPI_USE_INTERRUPT)
425     ret = esp_mpi_wait_intr();
426     if (ret != 0) {
427         esp_mpi_disable_hardware_hw_op();
428         goto cleanup;
429     }
430 #endif //CONFIG_MBEDTLS_MPI_USE_INTERRUPT
431 
432     esp_mpi_read_result_hw_op(Z, m_words);
433     esp_mpi_disable_hardware_hw_op();
434 #endif
435 
436     // Compensate for negative X
437     if (X->MBEDTLS_PRIVATE(s) == -1 && (Y->MBEDTLS_PRIVATE(p[0]) & 1) != 0) {
438         Z->MBEDTLS_PRIVATE(s) = -1;
439         MBEDTLS_MPI_CHK(mbedtls_mpi_add_mpi(Z, M, Z));
440     } else {
441         Z->MBEDTLS_PRIVATE(s) = 1;
442     }
443 
444 cleanup:
445     if (_Rinv == NULL) {
446         mbedtls_mpi_free(&Rinv_new);
447     }
448     return ret;
449 }
450 
451 #endif /* (MBEDTLS_MPI_EXP_MOD_ALT || MBEDTLS_MPI_EXP_MOD_ALT_FALLBACK) */
452 
453 /*
454  * Sliding-window exponentiation: X = A^E mod N  (HAC 14.85)
455  */
mbedtls_mpi_exp_mod(mbedtls_mpi * X,const mbedtls_mpi * A,const mbedtls_mpi * E,const mbedtls_mpi * N,mbedtls_mpi * _RR)456 int mbedtls_mpi_exp_mod( mbedtls_mpi *X, const mbedtls_mpi *A,
457                          const mbedtls_mpi *E, const mbedtls_mpi *N,
458                          mbedtls_mpi *_RR )
459 {
460     int ret;
461 #if defined(MBEDTLS_MPI_EXP_MOD_ALT_FALLBACK)
462     /* Try hardware API first and then fallback to software */
463     ret = esp_mpi_exp_mod( X, A, E, N, _RR );
464     if( ret == MBEDTLS_ERR_MPI_NOT_ACCEPTABLE ) {
465         ret = mbedtls_mpi_exp_mod_soft( X, A, E, N, _RR );
466     }
467 #else
468     /* Hardware approach */
469     ret = esp_mpi_exp_mod( X, A, E, N, _RR );
470 #endif
471     /* Note: For software only approach, it gets handled in mbedTLS library.
472     This file is not part of build objects for that case */
473 
474     return ret;
475 }
476 
477 #if defined(MBEDTLS_MPI_MUL_MPI_ALT) /* MBEDTLS_MPI_MUL_MPI_ALT */
478 
479 static int mpi_mult_mpi_failover_mod_mult( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t z_words);
480 static int mpi_mult_mpi_overlong(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t y_words, size_t z_words);
481 
482 /* Z = X * Y */
mbedtls_mpi_mul_mpi(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y)483 int mbedtls_mpi_mul_mpi( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y )
484 {
485     int ret = 0;
486     size_t x_bits = mbedtls_mpi_bitlen(X);
487     size_t y_bits = mbedtls_mpi_bitlen(Y);
488     size_t x_words = bits_to_words(x_bits);
489     size_t y_words = bits_to_words(y_bits);
490     size_t z_words = bits_to_words(x_bits + y_bits);
491     size_t hw_words = esp_mpi_hardware_words(MAX(x_words, y_words)); // length of one operand in hardware
492 
493     /* Short-circuit eval if either argument is 0 or 1.
494 
495        This is needed as the mpi modular division
496        argument will sometimes call in here when one
497        argument is too large for the hardware unit, but the other
498        argument is zero or one.
499     */
500     if (x_bits == 0 || y_bits == 0) {
501         mbedtls_mpi_lset(Z, 0);
502         return 0;
503     }
504     if (x_bits == 1) {
505         ret = mbedtls_mpi_copy(Z, Y);
506         Z->MBEDTLS_PRIVATE(s) *= X->MBEDTLS_PRIVATE(s);
507         return ret;
508     }
509     if (y_bits == 1) {
510         ret = mbedtls_mpi_copy(Z, X);
511         Z->MBEDTLS_PRIVATE(s) *= Y->MBEDTLS_PRIVATE(s);
512         return ret;
513     }
514 
515     /* Grow Z to result size early, avoid interim allocations */
516     MBEDTLS_MPI_CHK( mbedtls_mpi_grow(Z, z_words) );
517 
518     /* If either factor is over 2048 bits, we can't use the standard hardware multiplier
519        (it assumes result is double longest factor, and result is max 4096 bits.)
520 
521        However, we can fail over to mod_mult for up to 4096 bits of result (modulo
522        multiplication doesn't have the same restriction, so result is simply the
523        number of bits in X plus number of bits in in Y.)
524     */
525     if (hw_words * 32 > SOC_RSA_MAX_BIT_LEN/2) {
526         if (z_words * 32 <= SOC_RSA_MAX_BIT_LEN) {
527             /* Note: it's possible to use mpi_mult_mpi_overlong
528                for this case as well, but it's very slightly
529                slower and requires a memory allocation.
530             */
531             return mpi_mult_mpi_failover_mod_mult(Z, X, Y, z_words);
532         } else {
533             /* Still too long for the hardware unit... */
534             if (y_words > x_words) {
535                 return mpi_mult_mpi_overlong(Z, X, Y, y_words, z_words);
536             } else {
537                 return mpi_mult_mpi_overlong(Z, Y, X, x_words, z_words);
538             }
539         }
540     }
541 
542     /* Otherwise, we can use the (faster) multiply hardware unit */
543     esp_mpi_enable_hardware_hw_op();
544 
545     esp_mpi_mul_mpi_hw_op(X, Y, hw_words);
546     esp_mpi_read_result_hw_op(Z, z_words);
547 
548     esp_mpi_disable_hardware_hw_op();
549 
550     Z->MBEDTLS_PRIVATE(s) = X->MBEDTLS_PRIVATE(s) * Y->MBEDTLS_PRIVATE(s);
551 
552 cleanup:
553     return ret;
554 }
555 
mbedtls_mpi_mul_int(mbedtls_mpi * X,const mbedtls_mpi * A,mbedtls_mpi_uint b)556 int mbedtls_mpi_mul_int( mbedtls_mpi *X, const mbedtls_mpi *A, mbedtls_mpi_uint b )
557 {
558     mbedtls_mpi _B;
559     mbedtls_mpi_uint p[1];
560 
561     _B.MBEDTLS_PRIVATE(s) = 1;
562     _B.MBEDTLS_PRIVATE(n) = 1;
563     _B.MBEDTLS_PRIVATE(p) = p;
564     p[0] = b;
565 
566     return( mbedtls_mpi_mul_mpi( X, A, &_B ) );
567 }
568 
569 /* Deal with the case when X & Y are too long for the hardware unit, by splitting one operand
570    into two halves.
571 
572    Y must be the longer operand
573 
574    Slice Y into Yp, Ypp such that:
575    Yp = lower 'b' bits of Y
576    Ypp = upper 'b' bits of Y (right shifted)
577 
578    Such that
579    Z = X * Y
580    Z = X * (Yp + Ypp<<b)
581    Z = (X * Yp) + (X * Ypp<<b)
582 
583    Note that this function may recurse multiple times, if both X & Y
584    are too long for the hardware multiplication unit.
585 */
mpi_mult_mpi_overlong(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,size_t y_words,size_t z_words)586 static int mpi_mult_mpi_overlong(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t y_words, size_t z_words)
587 {
588     int ret = 0;
589     mbedtls_mpi Ztemp;
590     /* Rather than slicing in two on bits we slice on limbs (32 bit words) */
591     const size_t words_slice = y_words / 2;
592     /* Yp holds lower bits of Y (declared to reuse Y's array contents to save on copying) */
593     const mbedtls_mpi Yp = {
594         .MBEDTLS_PRIVATE(p) = Y->MBEDTLS_PRIVATE(p),
595         .MBEDTLS_PRIVATE(n) = words_slice,
596         .MBEDTLS_PRIVATE(s) = Y->MBEDTLS_PRIVATE(s)
597     };
598     /* Ypp holds upper bits of Y, right shifted (also reuses Y's array contents) */
599     const mbedtls_mpi Ypp = {
600         .MBEDTLS_PRIVATE(p) = Y->MBEDTLS_PRIVATE(p) + words_slice,
601         .MBEDTLS_PRIVATE(n) = y_words - words_slice,
602         .MBEDTLS_PRIVATE(s) = Y->MBEDTLS_PRIVATE(s)
603     };
604     mbedtls_mpi_init(&Ztemp);
605 
606     /* Get result Ztemp = Yp * X (need temporary variable Ztemp) */
607     MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi(&Ztemp, X, &Yp) );
608 
609     /* Z = Ypp * Y */
610     MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi(Z, X, &Ypp) );
611 
612     /* Z = Z << b */
613     MBEDTLS_MPI_CHK( mbedtls_mpi_shift_l(Z, words_slice * 32) );
614 
615     /* Z += Ztemp */
616     MBEDTLS_MPI_CHK( mbedtls_mpi_add_mpi(Z, Z, &Ztemp) );
617 
618 cleanup:
619     mbedtls_mpi_free(&Ztemp);
620 
621     return ret;
622 }
623 
624 /* Special-case of mbedtls_mpi_mult_mpi(), where we use hardware montgomery mod
625    multiplication to calculate an mbedtls_mpi_mult_mpi result where either
626    A or B are >2048 bits so can't use the standard multiplication method.
627 
628    Result (number of words, based on A bits + B bits) must still be less than 4096 bits.
629 
630    This case is simpler than the general case modulo multiply of
631    esp_mpi_mul_mpi_mod() because we can control the other arguments:
632 
633    * Modulus is chosen with M=(2^num_bits - 1) (ie M=R-1), so output
634    * Mprime and Rinv are therefore predictable as follows:
635    isn't actually modulo anything.
636    Mprime 1
637    Rinv 1
638 
639    (See RSA Accelerator section in Technical Reference for more about Mprime, Rinv)
640 */
641 
mpi_mult_mpi_failover_mod_mult(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,size_t z_words)642 static int mpi_mult_mpi_failover_mod_mult( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t z_words)
643 {
644     int ret;
645     size_t hw_words = esp_mpi_hardware_words(z_words);
646 
647     esp_mpi_enable_hardware_hw_op();
648 
649     esp_mpi_mult_mpi_failover_mod_mult_hw_op(X, Y, hw_words );
650     MBEDTLS_MPI_CHK( mbedtls_mpi_grow(Z, hw_words) );
651     esp_mpi_read_result_hw_op(Z, hw_words);
652 
653     Z->MBEDTLS_PRIVATE(s) = X->MBEDTLS_PRIVATE(s) * Y->MBEDTLS_PRIVATE(s);
654     /*
655      * Relevant: https://github.com/espressif/esp-idf/issues/11850
656      * If the first condition fails then most likely hardware peripheral
657      * has produced an incorrect result for MPI operation. This can
658      * happen if data fed to the peripheral register was incorrect.
659      *
660      * z_words is calculated as the worst-case possible size of the result
661      * MPI Z. The difference between z_words and the actual words taken by
662      * the MPI result (mpi_words(Z)) can be a maximum of 1 word.
663      * The value z_bits (actual bits taken by the MPI result) is calculated
664      * as x_bits + y_bits bits, however, in some cases, z_bits can be
665      * x_bits + y_bits - 1 bits (see example below).
666      * 0b1111 * 0b1111 = 0b11100001 -> 8 bits
667      * 0b1000 * 0b1000 = 0b01000000 -> 7 bits.
668      * The code rounds up to the nearest word size, so the maximum difference
669      * could be of only 1 word. The second condition handles this.
670      */
671     assert((z_words >= mpi_words(Z)) && (z_words - mpi_words(Z) <= (size_t)1));
672 cleanup:
673     esp_mpi_disable_hardware_hw_op();
674     return ret;
675 }
676 
677 #endif /* MBEDTLS_MPI_MUL_MPI_ALT */
678