1 /**
2  * \brief  Multi-precision integer library, ESP32 hardware accelerated parts
3  *
4  *  based on mbedTLS implementation
5  *
6  *  Copyright (C) 2006-2015, ARM Limited, All Rights Reserved
7  *  Additions Copyright (C) 2016, Espressif Systems (Shanghai) PTE Ltd
8  *  SPDX-License-Identifier: Apache-2.0
9  *
10  *  Licensed under the Apache License, Version 2.0 (the "License"); you may
11  *  not use this file except in compliance with the License.
12  *  You may obtain a copy of the License at
13  *
14  *  http://www.apache.org/licenses/LICENSE-2.0
15  *
16  *  Unless required by applicable law or agreed to in writing, software
17  *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
18  *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19  *  See the License for the specific language governing permissions and
20  *  limitations under the License.
21  *
22  */
23 #include <stdio.h>
24 #include <string.h>
25 #include <malloc.h>
26 #include <limits.h>
27 #include <assert.h>
28 #include <stdlib.h>
29 #include <sys/param.h>
30 #include "soc/hwcrypto_periph.h"
31 #include "esp_system.h"
32 #include "esp_log.h"
33 #include "esp_attr.h"
34 #include "bignum_impl.h"
35 #include "soc/soc_caps.h"
36 
37 #include <mbedtls/bignum.h>
38 
39 
40 /* Some implementation notes:
41  *
42  * - Naming convention x_words, y_words, z_words for number of words (limbs) used in a particular
43  *   bignum. This number may be less than the size of the bignum
44  *
45  * - Naming convention hw_words for the hardware length of the operation. This number maybe be rounded up
46  *   for targets that requres this (e.g. ESP32), and may be larger than any of the numbers
47  *   involved in the calculation.
48  *
49  * - Timing behaviour of these functions will depend on the length of the inputs. This is fundamentally
50  *   the same constraint as the software mbedTLS implementations, and relies on the same
51  *   countermeasures (exponent blinding, etc) which are used in mbedTLS.
52  */
53 
54 static const __attribute__((unused)) char *TAG = "bignum";
55 
56 #define ciL    (sizeof(mbedtls_mpi_uint))         /* chars in limb  */
57 #define biL    (ciL << 3)                         /* bits  in limb  */
58 
59 
60 /* Convert bit count to word count
61  */
bits_to_words(size_t bits)62 static inline size_t bits_to_words(size_t bits)
63 {
64     return (bits + 31) / 32;
65 }
66 
67 /* Return the number of words actually used to represent an mpi
68    number.
69 */
70 int __wrap_mbedtls_mpi_exp_mod( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M, mbedtls_mpi *_Rinv );
71 extern int __real_mbedtls_mpi_exp_mod( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M, mbedtls_mpi *_Rinv );
72 
mpi_words(const mbedtls_mpi * mpi)73 static size_t mpi_words(const mbedtls_mpi *mpi)
74 {
75     for (size_t i = mpi->n; i > 0; i--) {
76         if (mpi->p[i - 1] != 0) {
77             return i;
78         }
79     }
80     return 0;
81 }
82 
83 
84 /**
85  *
86  * There is a need for the value of integer N' such that B^-1(B-1)-N^-1N'=1,
87  * where B^-1(B-1) mod N=1. Actually, only the least significant part of
88  * N' is needed, hence the definition N0'=N' mod b. We reproduce below the
89  * simple algorithm from an article by Dusse and Kaliski to efficiently
90  * find N0' from N0 and b
91  */
modular_inverse(const mbedtls_mpi * M)92 static mbedtls_mpi_uint modular_inverse(const mbedtls_mpi *M)
93 {
94     int i;
95     uint64_t t = 1;
96     uint64_t two_2_i_minus_1 = 2;   /* 2^(i-1) */
97     uint64_t two_2_i = 4;           /* 2^i */
98     uint64_t N = M->p[0];
99 
100     for (i = 2; i <= 32; i++) {
101         if ((mbedtls_mpi_uint) N * t % two_2_i >= two_2_i_minus_1) {
102             t += two_2_i_minus_1;
103         }
104 
105         two_2_i_minus_1 <<= 1;
106         two_2_i <<= 1;
107     }
108 
109     return (mbedtls_mpi_uint)(UINT32_MAX - t + 1);
110 }
111 
112 /* Calculate Rinv = RR^2 mod M, where:
113  *
114  *  R = b^n where b = 2^32, n=num_words,
115  *  R = 2^N (where N=num_bits)
116  *  RR = R^2 = 2^(2*N) (where N=num_bits=num_words*32)
117  *
118  * This calculation is computationally expensive (mbedtls_mpi_mod_mpi)
119  * so caller should cache the result where possible.
120  *
121  * DO NOT call this function while holding esp_mpi_enable_hardware_hw_op().
122  *
123  */
calculate_rinv(mbedtls_mpi * Rinv,const mbedtls_mpi * M,int num_words)124 static int calculate_rinv(mbedtls_mpi *Rinv, const mbedtls_mpi *M, int num_words)
125 {
126     int ret;
127     size_t num_bits = num_words * 32;
128     mbedtls_mpi RR;
129     mbedtls_mpi_init(&RR);
130     MBEDTLS_MPI_CHK(mbedtls_mpi_set_bit(&RR, num_bits * 2, 1));
131     MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(Rinv, &RR, M));
132 
133 cleanup:
134     mbedtls_mpi_free(&RR);
135 
136     return ret;
137 }
138 
139 
140 
141 
142 
143 
144 /* Z = (X * Y) mod M
145 
146    Not an mbedTLS function
147 */
esp_mpi_mul_mpi_mod(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,const mbedtls_mpi * M)148 int esp_mpi_mul_mpi_mod(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M)
149 {
150     int ret = 0;
151 
152     size_t x_bits = mbedtls_mpi_bitlen(X);
153     size_t y_bits = mbedtls_mpi_bitlen(Y);
154     size_t m_bits = mbedtls_mpi_bitlen(M);
155     size_t z_bits = MIN(m_bits, x_bits + y_bits);
156     size_t x_words = bits_to_words(x_bits);
157     size_t y_words = bits_to_words(y_bits);
158     size_t m_words = bits_to_words(m_bits);
159     size_t z_words = bits_to_words(z_bits);
160     size_t hw_words = esp_mpi_hardware_words(MAX(x_words, MAX(y_words, m_words))); /* longest operand */
161     mbedtls_mpi Rinv;
162     mbedtls_mpi_uint Mprime;
163 
164     /* Calculate and load the first stage montgomery multiplication */
165     mbedtls_mpi_init(&Rinv);
166     MBEDTLS_MPI_CHK(calculate_rinv(&Rinv, M, hw_words));
167     Mprime = modular_inverse(M);
168 
169     esp_mpi_enable_hardware_hw_op();
170     /* Load and start a (X * Y) mod M calculation */
171     esp_mpi_mul_mpi_mod_hw_op(X, Y, M, &Rinv, Mprime, hw_words);
172 
173     MBEDTLS_MPI_CHK(mbedtls_mpi_grow(Z, z_words));
174 
175     esp_mpi_read_result_hw_op(Z, z_words);
176     Z->s = X->s * Y->s;
177 
178 cleanup:
179     mbedtls_mpi_free(&Rinv);
180     esp_mpi_disable_hardware_hw_op();
181 
182     return ret;
183 }
184 
185 #ifdef ESP_MPI_USE_MONT_EXP
186 /*
187  * Return the most significant one-bit.
188  */
mbedtls_mpi_msb(const mbedtls_mpi * X)189 static size_t mbedtls_mpi_msb( const mbedtls_mpi *X )
190 {
191     int i, j;
192     if (X != NULL && X->n != 0) {
193         for (i = X->n - 1; i >= 0; i--) {
194             if (X->p[i] != 0) {
195                 for (j = biL - 1; j >= 0; j--) {
196                     if ((X->p[i] & (1 << j)) != 0) {
197                         return (i * biL) + j;
198                     }
199                 }
200             }
201         }
202     }
203     return 0;
204 }
205 
206 /*
207  * Montgomery exponentiation: Z = X ^ Y mod M  (HAC 14.94)
208  */
mpi_montgomery_exp_calc(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,const mbedtls_mpi * M,mbedtls_mpi * Rinv,size_t hw_words,mbedtls_mpi_uint Mprime)209 static int mpi_montgomery_exp_calc( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M,
210                                     mbedtls_mpi *Rinv,
211                                     size_t hw_words,
212                                     mbedtls_mpi_uint Mprime )
213 {
214     int ret = 0;
215     mbedtls_mpi X_, one;
216 
217     mbedtls_mpi_init(&X_);
218     mbedtls_mpi_init(&one);
219     if ( ( ( ret = mbedtls_mpi_grow(&one, hw_words) ) != 0 ) ||
220             ( ( ret = mbedtls_mpi_set_bit(&one, 0, 1) )  != 0 ) ) {
221         goto cleanup2;
222     }
223 
224     // Algorithm from HAC 14.94
225     {
226         // 0 determine t (highest bit set in y)
227         int t = mbedtls_mpi_msb(Y);
228 
229         esp_mpi_enable_hardware_hw_op();
230 
231         // 1.1 x_ = mont(x, R^2 mod m)
232         //        = mont(x, rb)
233         MBEDTLS_MPI_CHK( esp_mont_hw_op(&X_, X, Rinv, M, Mprime, hw_words, false) );
234 
235         // 1.2 z = R mod m
236         // now z = R mod m = Mont (R^2 mod m, 1) mod M (as Mont(x) = X&R^-1 mod M)
237         MBEDTLS_MPI_CHK( esp_mont_hw_op(Z, Rinv, &one, M, Mprime, hw_words, true) );
238 
239         // 2 for i from t down to 0
240         for (int i = t; i >= 0; i--) {
241             // 2.1 z = mont(z,z)
242             if (i != t) { // skip on the first iteration as is still unity
243                 MBEDTLS_MPI_CHK( esp_mont_hw_op(Z, Z, Z, M, Mprime, hw_words, true) );
244             }
245 
246             // 2.2 if y[i] = 1 then z = mont(A, x_)
247             if (mbedtls_mpi_get_bit(Y, i)) {
248                 MBEDTLS_MPI_CHK( esp_mont_hw_op(Z, Z, &X_, M, Mprime, hw_words, true) );
249             }
250         }
251 
252         // 3 z = Mont(z, 1)
253         MBEDTLS_MPI_CHK( esp_mont_hw_op(Z, Z, &one, M, Mprime, hw_words, true) );
254     }
255 
256 cleanup:
257     esp_mpi_disable_hardware_hw_op();
258 
259 cleanup2:
260     mbedtls_mpi_free(&X_);
261     mbedtls_mpi_free(&one);
262     return ret;
263 }
264 
265 #endif //USE_MONT_EXPONENATIATION
266 
267 /*
268  * Z = X ^ Y mod M
269  *
270  * _Rinv is optional pre-calculated version of Rinv (via calculate_rinv()).
271  *
272  * (See RSA Accelerator section in Technical Reference for more about Mprime, Rinv)
273  *
274  */
__wrap_mbedtls_mpi_exp_mod(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,const mbedtls_mpi * M,mbedtls_mpi * _Rinv)275 int __wrap_mbedtls_mpi_exp_mod( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M, mbedtls_mpi *_Rinv )
276 {
277     int ret = 0;
278     size_t x_words = mpi_words(X);
279     size_t y_words = mpi_words(Y);
280     size_t m_words = mpi_words(M);
281 
282 
283     /* "all numbers must be the same length", so choose longest number
284        as cardinal length of operation...
285     */
286     size_t num_words = esp_mpi_hardware_words(MAX(m_words, MAX(x_words, y_words)));
287 
288     mbedtls_mpi Rinv_new; /* used if _Rinv == NULL */
289     mbedtls_mpi *Rinv;    /* points to _Rinv (if not NULL) othwerwise &RR_new */
290     mbedtls_mpi_uint Mprime;
291 
292     if (mbedtls_mpi_cmp_int(M, 0) <= 0 || (M->p[0] & 1) == 0) {
293         return MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
294     }
295 
296     if (mbedtls_mpi_cmp_int(Y, 0) < 0) {
297         return MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
298     }
299 
300     if (mbedtls_mpi_cmp_int(Y, 0) == 0) {
301         return mbedtls_mpi_lset(Z, 1);
302     }
303 
304     if (num_words * 32 > SOC_RSA_MAX_BIT_LEN) {
305 #ifdef CONFIG_MBEDTLS_LARGE_KEY_SOFTWARE_MPI
306         return __real_mbedtls_mpi_exp_mod(Z, X, Y, M, _Rinv);
307 #else
308         return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
309 #endif
310     }
311 
312     /* Determine RR pointer, either _RR for cached value
313        or local RR_new */
314     if (_Rinv == NULL) {
315         mbedtls_mpi_init(&Rinv_new);
316         Rinv = &Rinv_new;
317     } else {
318         Rinv = _Rinv;
319     }
320     if (Rinv->p == NULL) {
321         MBEDTLS_MPI_CHK(calculate_rinv(Rinv, M, num_words));
322     }
323 
324     Mprime = modular_inverse(M);
325 
326     // Montgomery exponentiation: Z = X ^ Y mod M  (HAC 14.94)
327 #ifdef ESP_MPI_USE_MONT_EXP
328     ret = mpi_montgomery_exp_calc(Z, X, Y, M, Rinv, num_words, Mprime) ;
329     MBEDTLS_MPI_CHK(ret);
330 #else
331     esp_mpi_enable_hardware_hw_op();
332 
333     esp_mpi_exp_mpi_mod_hw_op(X, Y, M, Rinv, Mprime, num_words);
334     ret = mbedtls_mpi_grow(Z, m_words);
335     if (ret != 0) {
336         esp_mpi_disable_hardware_hw_op();
337         goto cleanup;
338     }
339     esp_mpi_read_result_hw_op(Z, m_words);
340     esp_mpi_disable_hardware_hw_op();
341 #endif
342 
343     // Compensate for negative X
344     if (X->s == -1 && (Y->p[0] & 1) != 0) {
345         Z->s = -1;
346         MBEDTLS_MPI_CHK(mbedtls_mpi_add_mpi(Z, M, Z));
347     } else {
348         Z->s = 1;
349     }
350 
351 cleanup:
352     if (_Rinv == NULL) {
353         mbedtls_mpi_free(&Rinv_new);
354     }
355     return ret;
356 }
357 
358 #if defined(MBEDTLS_MPI_MUL_MPI_ALT) /* MBEDTLS_MPI_MUL_MPI_ALT */
359 
360 static int mpi_mult_mpi_failover_mod_mult( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t z_words);
361 static int mpi_mult_mpi_overlong(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t y_words, size_t z_words);
362 
363 /* Z = X * Y */
mbedtls_mpi_mul_mpi(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y)364 int mbedtls_mpi_mul_mpi( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y )
365 {
366     int ret = 0;
367     size_t x_bits = mbedtls_mpi_bitlen(X);
368     size_t y_bits = mbedtls_mpi_bitlen(Y);
369     size_t x_words = bits_to_words(x_bits);
370     size_t y_words = bits_to_words(y_bits);
371     size_t z_words = bits_to_words(x_bits + y_bits);
372     size_t hw_words = esp_mpi_hardware_words(MAX(x_words, y_words)); // length of one operand in hardware
373 
374     /* Short-circuit eval if either argument is 0 or 1.
375 
376        This is needed as the mpi modular division
377        argument will sometimes call in here when one
378        argument is too large for the hardware unit, but the other
379        argument is zero or one.
380     */
381     if (x_bits == 0 || y_bits == 0) {
382         mbedtls_mpi_lset(Z, 0);
383         return 0;
384     }
385     if (x_bits == 1) {
386         ret = mbedtls_mpi_copy(Z, Y);
387         Z->s *= X->s;
388         return ret;
389     }
390     if (y_bits == 1) {
391         ret = mbedtls_mpi_copy(Z, X);
392         Z->s *= Y->s;
393         return ret;
394     }
395 
396     /* Grow Z to result size early, avoid interim allocations */
397     MBEDTLS_MPI_CHK( mbedtls_mpi_grow(Z, z_words) );
398 
399     /* If either factor is over 2048 bits, we can't use the standard hardware multiplier
400        (it assumes result is double longest factor, and result is max 4096 bits.)
401 
402        However, we can fail over to mod_mult for up to 4096 bits of result (modulo
403        multiplication doesn't have the same restriction, so result is simply the
404        number of bits in X plus number of bits in in Y.)
405     */
406     if (hw_words * 32 > SOC_RSA_MAX_BIT_LEN/2) {
407         if (z_words * 32 <= SOC_RSA_MAX_BIT_LEN) {
408             /* Note: it's possible to use mpi_mult_mpi_overlong
409                for this case as well, but it's very slightly
410                slower and requires a memory allocation.
411             */
412             return mpi_mult_mpi_failover_mod_mult(Z, X, Y, z_words);
413         } else {
414             /* Still too long for the hardware unit... */
415             if (y_words > x_words) {
416                 return mpi_mult_mpi_overlong(Z, X, Y, y_words, z_words);
417             } else {
418                 return mpi_mult_mpi_overlong(Z, Y, X, x_words, z_words);
419             }
420         }
421     }
422 
423     /* Otherwise, we can use the (faster) multiply hardware unit */
424     esp_mpi_enable_hardware_hw_op();
425 
426     esp_mpi_mul_mpi_hw_op(X, Y, hw_words);
427     esp_mpi_read_result_hw_op(Z, z_words);
428 
429     esp_mpi_disable_hardware_hw_op();
430 
431     Z->s = X->s * Y->s;
432 
433 cleanup:
434     return ret;
435 }
436 
437 
438 
439 /* Deal with the case when X & Y are too long for the hardware unit, by splitting one operand
440    into two halves.
441 
442    Y must be the longer operand
443 
444    Slice Y into Yp, Ypp such that:
445    Yp = lower 'b' bits of Y
446    Ypp = upper 'b' bits of Y (right shifted)
447 
448    Such that
449    Z = X * Y
450    Z = X * (Yp + Ypp<<b)
451    Z = (X * Yp) + (X * Ypp<<b)
452 
453    Note that this function may recurse multiple times, if both X & Y
454    are too long for the hardware multiplication unit.
455 */
mpi_mult_mpi_overlong(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,size_t y_words,size_t z_words)456 static int mpi_mult_mpi_overlong(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t y_words, size_t z_words)
457 {
458     int ret = 0;
459     mbedtls_mpi Ztemp;
460     /* Rather than slicing in two on bits we slice on limbs (32 bit words) */
461     const size_t words_slice = y_words / 2;
462     /* Yp holds lower bits of Y (declared to reuse Y's array contents to save on copying) */
463     const mbedtls_mpi Yp = {
464         .p = Y->p,
465         .n = words_slice,
466         .s = Y->s
467     };
468     /* Ypp holds upper bits of Y, right shifted (also reuses Y's array contents) */
469     const mbedtls_mpi Ypp = {
470         .p = Y->p + words_slice,
471         .n = y_words - words_slice,
472         .s = Y->s
473     };
474     mbedtls_mpi_init(&Ztemp);
475 
476     /* Get result Ztemp = Yp * X (need temporary variable Ztemp) */
477     MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi(&Ztemp, X, &Yp) );
478 
479     /* Z = Ypp * Y */
480     MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi(Z, X, &Ypp) );
481 
482     /* Z = Z << b */
483     MBEDTLS_MPI_CHK( mbedtls_mpi_shift_l(Z, words_slice * 32) );
484 
485     /* Z += Ztemp */
486     MBEDTLS_MPI_CHK( mbedtls_mpi_add_mpi(Z, Z, &Ztemp) );
487 
488 cleanup:
489     mbedtls_mpi_free(&Ztemp);
490 
491     return ret;
492 }
493 
494 /* Special-case of mbedtls_mpi_mult_mpi(), where we use hardware montgomery mod
495    multiplication to calculate an mbedtls_mpi_mult_mpi result where either
496    A or B are >2048 bits so can't use the standard multiplication method.
497 
498    Result (number of words, based on A bits + B bits) must still be less than 4096 bits.
499 
500    This case is simpler than the general case modulo multiply of
501    esp_mpi_mul_mpi_mod() because we can control the other arguments:
502 
503    * Modulus is chosen with M=(2^num_bits - 1) (ie M=R-1), so output
504    * Mprime and Rinv are therefore predictable as follows:
505    isn't actually modulo anything.
506    Mprime 1
507    Rinv 1
508 
509    (See RSA Accelerator section in Technical Reference for more about Mprime, Rinv)
510 */
511 
mpi_mult_mpi_failover_mod_mult(mbedtls_mpi * Z,const mbedtls_mpi * X,const mbedtls_mpi * Y,size_t z_words)512 static int mpi_mult_mpi_failover_mod_mult( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t z_words)
513 {
514     int ret;
515     size_t hw_words = esp_mpi_hardware_words(z_words);
516 
517     esp_mpi_enable_hardware_hw_op();
518 
519     esp_mpi_mult_mpi_failover_mod_mult_hw_op(X, Y, hw_words );
520     MBEDTLS_MPI_CHK( mbedtls_mpi_grow(Z, hw_words) );
521     esp_mpi_read_result_hw_op(Z, hw_words);
522 
523     Z->s = X->s * Y->s;
524 cleanup:
525     esp_mpi_disable_hardware_hw_op();
526     return ret;
527 }
528 
529 #endif /* MBEDTLS_MPI_MUL_MPI_ALT */
530