1 /*
2  * Minimal code for RSA support from LibTomMath 0.41
3  * http://libtom.org/
4  * http://libtom.org/files/ltm-0.41.tar.bz2
5  * This library was released in public domain by Tom St Denis.
6  *
7  * The combination in this file may not use all of the optimized algorithms
8  * from LibTomMath and may be considerable slower than the LibTomMath with its
9  * default settings. The main purpose of having this version here is to make it
10  * easier to build bignum.c wrapper without having to install and build an
11  * external library.
12  *
13  * If CONFIG_INTERNAL_LIBTOMMATH is defined, bignum.c includes this
14  * libtommath.c file instead of using the external LibTomMath library.
15  */
16 #include "os.h"
17 #include "stdarg.h"
18 
19 #ifndef CHAR_BIT
20 #define CHAR_BIT 8
21 #endif
22 
23 #define BN_MP_INVMOD_C
24 #define BN_S_MP_EXPTMOD_C /* Note: #undef in tommath_superclass.h; this would
25 			   * require BN_MP_EXPTMOD_FAST_C instead */
26 #define BN_S_MP_MUL_DIGS_C
27 #define BN_MP_INVMOD_SLOW_C
28 #define BN_S_MP_SQR_C
29 #define BN_S_MP_MUL_HIGH_DIGS_C /* Note: #undef in tommath_superclass.h; this
30 				 * would require other than mp_reduce */
31 
32 #ifdef LTM_FAST
33 
34 /* Use faster div at the cost of about 1 kB */
35 #define BN_MP_MUL_D_C
36 
37 /* Include faster exptmod (Montgomery) at the cost of about 2.5 kB in code */
38 #define BN_MP_EXPTMOD_FAST_C
39 #define BN_MP_MONTGOMERY_SETUP_C
40 #define BN_FAST_MP_MONTGOMERY_REDUCE_C
41 #define BN_MP_MONTGOMERY_CALC_NORMALIZATION_C
42 #define BN_MP_MUL_2_C
43 
44 /* Include faster sqr at the cost of about 0.5 kB in code */
45 #define BN_FAST_S_MP_SQR_C
46 
47 #else /* LTM_FAST */
48 
49 #define BN_MP_DIV_SMALL
50 #define BN_MP_INIT_MULTI_C
51 #define BN_MP_CLEAR_MULTI_C
52 #define BN_MP_ABS_C
53 #endif /* LTM_FAST */
54 
55 /* Current uses do not require support for negative exponent in exptmod, so we
56  * can save about 1.5 kB in leaving out invmod. */
57 #define LTM_NO_NEG_EXP
58 
59 /* from tommath.h */
60 
61 #ifndef MIN
62    #define MIN(x,y) ((x)<(y)?(x):(y))
63 #endif
64 
65 #ifndef MAX
66    #define MAX(x,y) ((x)>(y)?(x):(y))
67 #endif
68 
69 #define  OPT_CAST(x) (x *)
70 
71 typedef unsigned long mp_digit;
72 typedef u64 mp_word;
73 
74 #define DIGIT_BIT          28
75 #define MP_28BIT
76 
77 
78 #define XMALLOC  os_malloc
79 #define XFREE    os_free
80 #define XREALLOC os_realloc
81 
82 
83 #define MP_MASK          ((((mp_digit)1)<<((mp_digit)DIGIT_BIT))-((mp_digit)1))
84 
85 #define MP_LT        -1   /* less than */
86 #define MP_EQ         0   /* equal to */
87 #define MP_GT         1   /* greater than */
88 
89 #define MP_ZPOS       0   /* positive integer */
90 #define MP_NEG        1   /* negative */
91 
92 #define MP_OKAY       0   /* ok result */
93 #define MP_MEM        -2  /* out of mem */
94 #define MP_VAL        -3  /* invalid input */
95 
96 #define MP_YES        1   /* yes response */
97 #define MP_NO         0   /* no response */
98 
99 typedef int           mp_err;
100 
101 /* define this to use lower memory usage routines (exptmods mostly) */
102 #define MP_LOW_MEM
103 
104 /* default precision */
105 #ifndef MP_PREC
106    #ifndef MP_LOW_MEM
107       #define MP_PREC                 32     /* default digits of precision */
108    #else
109       #define MP_PREC                 8      /* default digits of precision */
110    #endif
111 #endif
112 
113 /* size of comba arrays, should be at least 2 * 2**(BITS_PER_WORD - BITS_PER_DIGIT*2) */
114 #define MP_WARRAY               (1 << (sizeof(mp_word) * CHAR_BIT - 2 * DIGIT_BIT + 1))
115 
116 /* the infamous mp_int structure */
117 typedef struct  {
118     int used, alloc, sign;
119     mp_digit *dp;
120 } mp_int;
121 
122 
123 /* ---> Basic Manipulations <--- */
124 #define mp_iszero(a) (((a)->used == 0) ? MP_YES : MP_NO)
125 #define mp_iseven(a) (((a)->used > 0 && (((a)->dp[0] & 1) == 0)) ? MP_YES : MP_NO)
126 #define mp_isodd(a)  (((a)->used > 0 && (((a)->dp[0] & 1) == 1)) ? MP_YES : MP_NO)
127 
128 
129 /* prototypes for copied functions */
130 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
131 static int s_mp_exptmod(mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode);
132 static int s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs);
133 static int s_mp_sqr(mp_int * a, mp_int * b);
134 static int s_mp_mul_high_digs(mp_int * a, mp_int * b, mp_int * c, int digs);
135 
136 static int fast_s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs);
137 
138 #ifdef BN_MP_INIT_MULTI_C
139 static int mp_init_multi(mp_int *mp, ...);
140 #endif
141 #ifdef BN_MP_CLEAR_MULTI_C
142 static void mp_clear_multi(mp_int *mp, ...);
143 #endif
144 static int mp_lshd(mp_int * a, int b);
145 static void mp_set(mp_int * a, mp_digit b);
146 static void mp_clamp(mp_int * a);
147 static void mp_exch(mp_int * a, mp_int * b);
148 static void mp_rshd(mp_int * a, int b);
149 static void mp_zero(mp_int * a);
150 static int mp_mod_2d(mp_int * a, int b, mp_int * c);
151 static int mp_div_2d(mp_int * a, int b, mp_int * c, mp_int * d);
152 static int mp_init_copy(mp_int * a, mp_int * b);
153 static int mp_mul_2d(mp_int * a, int b, mp_int * c);
154 #ifndef LTM_NO_NEG_EXP
155 static int mp_div_2(mp_int * a, mp_int * b);
156 static int mp_invmod(mp_int * a, mp_int * b, mp_int * c);
157 static int mp_invmod_slow(mp_int * a, mp_int * b, mp_int * c);
158 #endif /* LTM_NO_NEG_EXP */
159 static int mp_copy(mp_int * a, mp_int * b);
160 static int mp_count_bits(mp_int * a);
161 static int mp_div(mp_int * a, mp_int * b, mp_int * c, mp_int * d);
162 static int mp_mod(mp_int * a, mp_int * b, mp_int * c);
163 static int mp_grow(mp_int * a, int size);
164 static int mp_cmp_mag(mp_int * a, mp_int * b);
165 #ifdef BN_MP_ABS_C
166 static int mp_abs(mp_int * a, mp_int * b);
167 #endif
168 static int mp_sqr(mp_int * a, mp_int * b);
169 static int mp_reduce_2k_l(mp_int *a, mp_int *n, mp_int *d);
170 static int mp_reduce_2k_setup_l(mp_int *a, mp_int *d);
171 static int mp_2expt(mp_int * a, int b);
172 static int mp_reduce_setup(mp_int * a, mp_int * b);
173 static int mp_reduce(mp_int * x, mp_int * m, mp_int * mu);
174 static int mp_init_size(mp_int * a, int size);
175 #ifdef BN_MP_EXPTMOD_FAST_C
176 static int mp_exptmod_fast (mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode);
177 #endif /* BN_MP_EXPTMOD_FAST_C */
178 #ifdef BN_FAST_S_MP_SQR_C
179 static int fast_s_mp_sqr (mp_int * a, mp_int * b);
180 #endif /* BN_FAST_S_MP_SQR_C */
181 #ifdef BN_MP_MUL_D_C
182 static int mp_mul_d (mp_int * a, mp_digit b, mp_int * c);
183 #endif /* BN_MP_MUL_D_C */
184 
185 
186 
187 /* functions from bn_<func name>.c */
188 
189 
190 /* reverse an array, used for radix code */
191 static void
bn_reverse(unsigned char * s,int len)192 bn_reverse (unsigned char *s, int len)
193 {
194   int     ix, iy;
195   unsigned char t;
196 
197   ix = 0;
198   iy = len - 1;
199   while (ix < iy) {
200     t     = s[ix];
201     s[ix] = s[iy];
202     s[iy] = t;
203     ++ix;
204     --iy;
205   }
206 }
207 
208 
209 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
210 static int
s_mp_add(mp_int * a,mp_int * b,mp_int * c)211 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
212 {
213   mp_int *x;
214   int     olduse, res, min, max;
215 
216   /* find sizes, we let |a| <= |b| which means we have to sort
217    * them.  "x" will point to the input with the most digits
218    */
219   if (a->used > b->used) {
220     min = b->used;
221     max = a->used;
222     x = a;
223   } else {
224     min = a->used;
225     max = b->used;
226     x = b;
227   }
228 
229   /* init result */
230   if (c->alloc < max + 1) {
231     if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
232       return res;
233     }
234   }
235 
236   /* get old used digit count and set new one */
237   olduse = c->used;
238   c->used = max + 1;
239 
240   {
241     register mp_digit u, *tmpa, *tmpb, *tmpc;
242     register int i;
243 
244     /* alias for digit pointers */
245 
246     /* first input */
247     tmpa = a->dp;
248 
249     /* second input */
250     tmpb = b->dp;
251 
252     /* destination */
253     tmpc = c->dp;
254 
255     /* zero the carry */
256     u = 0;
257     for (i = 0; i < min; i++) {
258       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
259       *tmpc = *tmpa++ + *tmpb++ + u;
260 
261       /* U = carry bit of T[i] */
262       u = *tmpc >> ((mp_digit)DIGIT_BIT);
263 
264       /* take away carry bit from T[i] */
265       *tmpc++ &= MP_MASK;
266     }
267 
268     /* now copy higher words if any, that is in A+B
269      * if A or B has more digits add those in
270      */
271     if (min != max) {
272       for (; i < max; i++) {
273         /* T[i] = X[i] + U */
274         *tmpc = x->dp[i] + u;
275 
276         /* U = carry bit of T[i] */
277         u = *tmpc >> ((mp_digit)DIGIT_BIT);
278 
279         /* take away carry bit from T[i] */
280         *tmpc++ &= MP_MASK;
281       }
282     }
283 
284     /* add carry */
285     *tmpc++ = u;
286 
287     /* clear digits above oldused */
288     for (i = c->used; i < olduse; i++) {
289       *tmpc++ = 0;
290     }
291   }
292 
293   mp_clamp (c);
294   return MP_OKAY;
295 }
296 
297 
298 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
299 static int
s_mp_sub(mp_int * a,mp_int * b,mp_int * c)300 s_mp_sub (mp_int * a, mp_int * b, mp_int * c)
301 {
302   int     olduse, res, min, max;
303 
304   /* find sizes */
305   min = b->used;
306   max = a->used;
307 
308   /* init result */
309   if (c->alloc < max) {
310     if ((res = mp_grow (c, max)) != MP_OKAY) {
311       return res;
312     }
313   }
314   olduse = c->used;
315   c->used = max;
316 
317   {
318     register mp_digit u, *tmpa, *tmpb, *tmpc;
319     register int i;
320 
321     /* alias for digit pointers */
322     tmpa = a->dp;
323     tmpb = b->dp;
324     tmpc = c->dp;
325 
326     /* set carry to zero */
327     u = 0;
328     for (i = 0; i < min; i++) {
329       /* T[i] = A[i] - B[i] - U */
330       *tmpc = *tmpa++ - *tmpb++ - u;
331 
332       /* U = carry bit of T[i]
333        * Note this saves performing an AND operation since
334        * if a carry does occur it will propagate all the way to the
335        * MSB.  As a result a single shift is enough to get the carry
336        */
337       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
338 
339       /* Clear carry from T[i] */
340       *tmpc++ &= MP_MASK;
341     }
342 
343     /* now copy higher words if any, e.g. if A has more digits than B  */
344     for (; i < max; i++) {
345       /* T[i] = A[i] - U */
346       *tmpc = *tmpa++ - u;
347 
348       /* U = carry bit of T[i] */
349       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
350 
351       /* Clear carry from T[i] */
352       *tmpc++ &= MP_MASK;
353     }
354 
355     /* clear digits above used (since we may not have grown result above) */
356     for (i = c->used; i < olduse; i++) {
357       *tmpc++ = 0;
358     }
359   }
360 
361   mp_clamp (c);
362   return MP_OKAY;
363 }
364 
365 
366 /* init a new mp_int */
367 static int
mp_init(mp_int * a)368 mp_init (mp_int * a)
369 {
370   int i;
371 
372   /* allocate memory required and clear it */
373   a->dp = OPT_CAST(mp_digit) XMALLOC (sizeof (mp_digit) * MP_PREC);
374   if (a->dp == NULL) {
375     return MP_MEM;
376   }
377 
378   /* set the digits to zero */
379   for (i = 0; i < MP_PREC; i++) {
380       a->dp[i] = 0;
381   }
382 
383   /* set the used to zero, allocated digits to the default precision
384    * and sign to positive */
385   a->used  = 0;
386   a->alloc = MP_PREC;
387   a->sign  = MP_ZPOS;
388 
389   return MP_OKAY;
390 }
391 
392 
393 /* clear one (frees)  */
394 static void
mp_clear(mp_int * a)395 mp_clear (mp_int * a)
396 {
397   int i;
398 
399   /* only do anything if a hasn't been freed previously */
400   if (a->dp != NULL) {
401     /* first zero the digits */
402     for (i = 0; i < a->used; i++) {
403         a->dp[i] = 0;
404     }
405 
406     /* free ram */
407     XFREE(a->dp);
408 
409     /* reset members to make debugging easier */
410     a->dp    = NULL;
411     a->alloc = a->used = 0;
412     a->sign  = MP_ZPOS;
413   }
414 }
415 
416 
417 /* high level addition (handles signs) */
418 static int
mp_add(mp_int * a,mp_int * b,mp_int * c)419 mp_add (mp_int * a, mp_int * b, mp_int * c)
420 {
421   int     sa, sb, res;
422 
423   /* get sign of both inputs */
424   sa = a->sign;
425   sb = b->sign;
426 
427   /* handle two cases, not four */
428   if (sa == sb) {
429     /* both positive or both negative */
430     /* add their magnitudes, copy the sign */
431     c->sign = sa;
432     res = s_mp_add (a, b, c);
433   } else {
434     /* one positive, the other negative */
435     /* subtract the one with the greater magnitude from */
436     /* the one of the lesser magnitude.  The result gets */
437     /* the sign of the one with the greater magnitude. */
438     if (mp_cmp_mag (a, b) == MP_LT) {
439       c->sign = sb;
440       res = s_mp_sub (b, a, c);
441     } else {
442       c->sign = sa;
443       res = s_mp_sub (a, b, c);
444     }
445   }
446   return res;
447 }
448 
449 
450 /* high level subtraction (handles signs) */
451 static int
mp_sub(mp_int * a,mp_int * b,mp_int * c)452 mp_sub (mp_int * a, mp_int * b, mp_int * c)
453 {
454   int     sa, sb, res;
455 
456   sa = a->sign;
457   sb = b->sign;
458 
459   if (sa != sb) {
460     /* subtract a negative from a positive, OR */
461     /* subtract a positive from a negative. */
462     /* In either case, ADD their magnitudes, */
463     /* and use the sign of the first number. */
464     c->sign = sa;
465     res = s_mp_add (a, b, c);
466   } else {
467     /* subtract a positive from a positive, OR */
468     /* subtract a negative from a negative. */
469     /* First, take the difference between their */
470     /* magnitudes, then... */
471     if (mp_cmp_mag (a, b) != MP_LT) {
472       /* Copy the sign from the first */
473       c->sign = sa;
474       /* The first has a larger or equal magnitude */
475       res = s_mp_sub (a, b, c);
476     } else {
477       /* The result has the *opposite* sign from */
478       /* the first number. */
479       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
480       /* The second has a larger magnitude */
481       res = s_mp_sub (b, a, c);
482     }
483   }
484   return res;
485 }
486 
487 
488 /* high level multiplication (handles sign) */
489 static int
mp_mul(mp_int * a,mp_int * b,mp_int * c)490 mp_mul (mp_int * a, mp_int * b, mp_int * c)
491 {
492   int     res, neg;
493   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
494 
495   /* use Toom-Cook? */
496 #ifdef BN_MP_TOOM_MUL_C
497   if (MIN (a->used, b->used) >= TOOM_MUL_CUTOFF) {
498     res = mp_toom_mul(a, b, c);
499   } else
500 #endif
501 #ifdef BN_MP_KARATSUBA_MUL_C
502   /* use Karatsuba? */
503   if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
504     res = mp_karatsuba_mul (a, b, c);
505   } else
506 #endif
507   {
508     /* can we use the fast multiplier?
509      *
510      * The fast multiplier can be used if the output will
511      * have less than MP_WARRAY digits and the number of
512      * digits won't affect carry propagation
513      */
514 #ifdef BN_FAST_S_MP_MUL_DIGS_C
515     int     digs = a->used + b->used + 1;
516 
517     if ((digs < MP_WARRAY) &&
518         MIN(a->used, b->used) <=
519         (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
520       res = fast_s_mp_mul_digs (a, b, c, digs);
521     } else
522 #endif
523 #ifdef BN_S_MP_MUL_DIGS_C
524       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
525 #else
526 #error mp_mul could fail
527       res = MP_VAL;
528 #endif
529 
530   }
531   c->sign = (c->used > 0) ? neg : MP_ZPOS;
532   return res;
533 }
534 
535 
536 /* d = a * b (mod c) */
537 static int
mp_mulmod(mp_int * a,mp_int * b,mp_int * c,mp_int * d)538 mp_mulmod (mp_int * a, mp_int * b, mp_int * c, mp_int * d)
539 {
540   int     res;
541   mp_int  t;
542 
543   if ((res = mp_init (&t)) != MP_OKAY) {
544     return res;
545   }
546 
547   if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
548     mp_clear (&t);
549     return res;
550   }
551   res = mp_mod (&t, c, d);
552   mp_clear (&t);
553   return res;
554 }
555 
556 
557 /* c = a mod b, 0 <= c < b */
558 static int
mp_mod(mp_int * a,mp_int * b,mp_int * c)559 mp_mod (mp_int * a, mp_int * b, mp_int * c)
560 {
561   mp_int  t;
562   int     res;
563 
564   if ((res = mp_init (&t)) != MP_OKAY) {
565     return res;
566   }
567 
568   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
569     mp_clear (&t);
570     return res;
571   }
572 
573   if (t.sign != b->sign) {
574     res = mp_add (b, &t, c);
575   } else {
576     res = MP_OKAY;
577     mp_exch (&t, c);
578   }
579 
580   mp_clear (&t);
581   return res;
582 }
583 
584 
585 /* this is a shell function that calls either the normal or Montgomery
586  * exptmod functions.  Originally the call to the montgomery code was
587  * embedded in the normal function but that wasted a lot of stack space
588  * for nothing (since 99% of the time the Montgomery code would be called)
589  */
590 static int
mp_exptmod(mp_int * G,mp_int * X,mp_int * P,mp_int * Y)591 mp_exptmod (mp_int * G, mp_int * X, mp_int * P, mp_int * Y)
592 {
593 #if defined(BN_MP_DR_IS_MODULUS_C)||defined(BN_MP_REDUCE_IS_2K_C)||defined(BN_MP_EXPTMOD_FAST_C)
594   int dr = 0;
595 #endif
596 
597   /* modulus P must be positive */
598   if (P->sign == MP_NEG) {
599      return MP_VAL;
600   }
601 
602   /* if exponent X is negative we have to recurse */
603   if (X->sign == MP_NEG) {
604 #ifdef LTM_NO_NEG_EXP
605         return MP_VAL;
606 #else /* LTM_NO_NEG_EXP */
607 #ifdef BN_MP_INVMOD_C
608      mp_int tmpG, tmpX;
609      int err;
610 
611      /* first compute 1/G mod P */
612      if ((err = mp_init(&tmpG)) != MP_OKAY) {
613         return err;
614      }
615      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
616         mp_clear(&tmpG);
617         return err;
618      }
619 
620      /* now get |X| */
621      if ((err = mp_init(&tmpX)) != MP_OKAY) {
622         mp_clear(&tmpG);
623         return err;
624      }
625      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
626         mp_clear_multi(&tmpG, &tmpX, NULL);
627         return err;
628      }
629 
630      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
631      err = mp_exptmod(&tmpG, &tmpX, P, Y);
632      mp_clear_multi(&tmpG, &tmpX, NULL);
633      return err;
634 #else
635 #error mp_exptmod would always fail
636      /* no invmod */
637      return MP_VAL;
638 #endif
639 #endif /* LTM_NO_NEG_EXP */
640   }
641 
642 /* modified diminished radix reduction */
643 #if defined(BN_MP_REDUCE_IS_2K_L_C) && defined(BN_MP_REDUCE_2K_L_C) && defined(BN_S_MP_EXPTMOD_C)
644   if (mp_reduce_is_2k_l(P) == MP_YES) {
645      return s_mp_exptmod(G, X, P, Y, 1);
646   }
647 #endif
648 
649 #ifdef BN_MP_DR_IS_MODULUS_C
650   /* is it a DR modulus? */
651   dr = mp_dr_is_modulus(P);
652 #endif
653 
654 #ifdef BN_MP_REDUCE_IS_2K_C
655   /* if not, is it a unrestricted DR modulus? */
656   if (dr == 0) {
657      dr = mp_reduce_is_2k(P) << 1;
658   }
659 #endif
660 
661   /* if the modulus is odd or dr != 0 use the montgomery method */
662 #ifdef BN_MP_EXPTMOD_FAST_C
663   if (mp_isodd (P) == 1 || dr !=  0) {
664     return mp_exptmod_fast (G, X, P, Y, dr);
665   } else {
666 #endif
667 #ifdef BN_S_MP_EXPTMOD_C
668     /* otherwise use the generic Barrett reduction technique */
669     return s_mp_exptmod (G, X, P, Y, 0);
670 #else
671 #error mp_exptmod could fail
672     /* no exptmod for evens */
673     return MP_VAL;
674 #endif
675 #ifdef BN_MP_EXPTMOD_FAST_C
676   }
677 #endif
678 }
679 
680 
681 /* compare two ints (signed)*/
682 static int
mp_cmp(mp_int * a,mp_int * b)683 mp_cmp (mp_int * a, mp_int * b)
684 {
685   /* compare based on sign */
686   if (a->sign != b->sign) {
687      if (a->sign == MP_NEG) {
688         return MP_LT;
689      } else {
690         return MP_GT;
691      }
692   }
693 
694   /* compare digits */
695   if (a->sign == MP_NEG) {
696      /* if negative compare opposite direction */
697      return mp_cmp_mag(b, a);
698   } else {
699      return mp_cmp_mag(a, b);
700   }
701 }
702 
703 
704 /* compare a digit */
705 static int
mp_cmp_d(mp_int * a,mp_digit b)706 mp_cmp_d(mp_int * a, mp_digit b)
707 {
708   /* compare based on sign */
709   if (a->sign == MP_NEG) {
710     return MP_LT;
711   }
712 
713   /* compare based on magnitude */
714   if (a->used > 1) {
715     return MP_GT;
716   }
717 
718   /* compare the only digit of a to b */
719   if (a->dp[0] > b) {
720     return MP_GT;
721   } else if (a->dp[0] < b) {
722     return MP_LT;
723   } else {
724     return MP_EQ;
725   }
726 }
727 
728 
729 #ifndef LTM_NO_NEG_EXP
730 /* hac 14.61, pp608 */
731 static int
mp_invmod(mp_int * a,mp_int * b,mp_int * c)732 mp_invmod (mp_int * a, mp_int * b, mp_int * c)
733 {
734   /* b cannot be negative */
735   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
736     return MP_VAL;
737   }
738 
739 #ifdef BN_FAST_MP_INVMOD_C
740   /* if the modulus is odd we can use a faster routine instead */
741   if (mp_isodd (b) == 1) {
742     return fast_mp_invmod (a, b, c);
743   }
744 #endif
745 
746 #ifdef BN_MP_INVMOD_SLOW_C
747   return mp_invmod_slow(a, b, c);
748 #endif
749 
750 #ifndef BN_FAST_MP_INVMOD_C
751 #ifndef BN_MP_INVMOD_SLOW_C
752 #error mp_invmod would always fail
753 #endif
754 #endif
755   return MP_VAL;
756 }
757 #endif /* LTM_NO_NEG_EXP */
758 
759 
760 /* get the size for an unsigned equivalent */
761 static int
mp_unsigned_bin_size(mp_int * a)762 mp_unsigned_bin_size (mp_int * a)
763 {
764   int     size = mp_count_bits (a);
765   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
766 }
767 
768 
769 #ifndef LTM_NO_NEG_EXP
770 /* hac 14.61, pp608 */
771 static int
mp_invmod_slow(mp_int * a,mp_int * b,mp_int * c)772 mp_invmod_slow (mp_int * a, mp_int * b, mp_int * c)
773 {
774   mp_int  x, y, u, v, A, B, C, D;
775   int     res;
776 
777   /* b cannot be negative */
778   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
779     return MP_VAL;
780   }
781 
782   /* init temps */
783   if ((res = mp_init_multi(&x, &y, &u, &v,
784                            &A, &B, &C, &D, NULL)) != MP_OKAY) {
785      return res;
786   }
787 
788   /* x = a, y = b */
789   if ((res = mp_mod(a, b, &x)) != MP_OKAY) {
790       goto LBL_ERR;
791   }
792   if ((res = mp_copy (b, &y)) != MP_OKAY) {
793     goto LBL_ERR;
794   }
795 
796   /* 2. [modified] if x,y are both even then return an error! */
797   if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
798     res = MP_VAL;
799     goto LBL_ERR;
800   }
801 
802   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
803   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
804     goto LBL_ERR;
805   }
806   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
807     goto LBL_ERR;
808   }
809   mp_set (&A, 1);
810   mp_set (&D, 1);
811 
812 top:
813   /* 4.  while u is even do */
814   while (mp_iseven (&u) == 1) {
815     /* 4.1 u = u/2 */
816     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
817       goto LBL_ERR;
818     }
819     /* 4.2 if A or B is odd then */
820     if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
821       /* A = (A+y)/2, B = (B-x)/2 */
822       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
823          goto LBL_ERR;
824       }
825       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
826          goto LBL_ERR;
827       }
828     }
829     /* A = A/2, B = B/2 */
830     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
831       goto LBL_ERR;
832     }
833     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
834       goto LBL_ERR;
835     }
836   }
837 
838   /* 5.  while v is even do */
839   while (mp_iseven (&v) == 1) {
840     /* 5.1 v = v/2 */
841     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
842       goto LBL_ERR;
843     }
844     /* 5.2 if C or D is odd then */
845     if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
846       /* C = (C+y)/2, D = (D-x)/2 */
847       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
848          goto LBL_ERR;
849       }
850       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
851          goto LBL_ERR;
852       }
853     }
854     /* C = C/2, D = D/2 */
855     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
856       goto LBL_ERR;
857     }
858     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
859       goto LBL_ERR;
860     }
861   }
862 
863   /* 6.  if u >= v then */
864   if (mp_cmp (&u, &v) != MP_LT) {
865     /* u = u - v, A = A - C, B = B - D */
866     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
867       goto LBL_ERR;
868     }
869 
870     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
871       goto LBL_ERR;
872     }
873 
874     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
875       goto LBL_ERR;
876     }
877   } else {
878     /* v - v - u, C = C - A, D = D - B */
879     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
880       goto LBL_ERR;
881     }
882 
883     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
884       goto LBL_ERR;
885     }
886 
887     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
888       goto LBL_ERR;
889     }
890   }
891 
892   /* if not zero goto step 4 */
893   if (mp_iszero (&u) == 0)
894     goto top;
895 
896   /* now a = C, b = D, gcd == g*v */
897 
898   /* if v != 1 then there is no inverse */
899   if (mp_cmp_d (&v, 1) != MP_EQ) {
900     res = MP_VAL;
901     goto LBL_ERR;
902   }
903 
904   /* if its too low */
905   while (mp_cmp_d(&C, 0) == MP_LT) {
906       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
907          goto LBL_ERR;
908       }
909   }
910 
911   /* too big */
912   while (mp_cmp_mag(&C, b) != MP_LT) {
913       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
914          goto LBL_ERR;
915       }
916   }
917 
918   /* C is now the inverse */
919   mp_exch (&C, c);
920   res = MP_OKAY;
921 LBL_ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
922   return res;
923 }
924 #endif /* LTM_NO_NEG_EXP */
925 
926 
927 /* compare maginitude of two ints (unsigned) */
928 static int
mp_cmp_mag(mp_int * a,mp_int * b)929 mp_cmp_mag (mp_int * a, mp_int * b)
930 {
931   int     n;
932   mp_digit *tmpa, *tmpb;
933 
934   /* compare based on # of non-zero digits */
935   if (a->used > b->used) {
936     return MP_GT;
937   }
938 
939   if (a->used < b->used) {
940     return MP_LT;
941   }
942 
943   /* alias for a */
944   tmpa = a->dp + (a->used - 1);
945 
946   /* alias for b */
947   tmpb = b->dp + (a->used - 1);
948 
949   /* compare based on digits  */
950   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
951     if (*tmpa > *tmpb) {
952       return MP_GT;
953     }
954 
955     if (*tmpa < *tmpb) {
956       return MP_LT;
957     }
958   }
959   return MP_EQ;
960 }
961 
962 
963 /* reads a unsigned char array, assumes the msb is stored first [big endian] */
964 static int
mp_read_unsigned_bin(mp_int * a,const unsigned char * b,int c)965 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
966 {
967   int     res;
968 
969   /* make sure there are at least two digits */
970   if (a->alloc < 2) {
971      if ((res = mp_grow(a, 2)) != MP_OKAY) {
972         return res;
973      }
974   }
975 
976   /* zero the int */
977   mp_zero (a);
978 
979   /* read the bytes in */
980   while (c-- > 0) {
981     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
982       return res;
983     }
984 
985 #ifndef MP_8BIT
986       a->dp[0] |= *b++;
987       a->used += 1;
988 #else
989       a->dp[0] = (*b & MP_MASK);
990       a->dp[1] |= ((*b++ >> 7U) & 1);
991       a->used += 2;
992 #endif
993   }
994   mp_clamp (a);
995   return MP_OKAY;
996 }
997 
998 
999 /* store in unsigned [big endian] format */
1000 static int
mp_to_unsigned_bin(mp_int * a,unsigned char * b)1001 mp_to_unsigned_bin (mp_int * a, unsigned char *b)
1002 {
1003   int     x, res;
1004   mp_int  t;
1005 
1006   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
1007     return res;
1008   }
1009 
1010   x = 0;
1011   while (mp_iszero (&t) == 0) {
1012 #ifndef MP_8BIT
1013       b[x++] = (unsigned char) (t.dp[0] & 255);
1014 #else
1015       b[x++] = (unsigned char) (t.dp[0] | ((t.dp[1] & 0x01) << 7));
1016 #endif
1017     if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
1018       mp_clear (&t);
1019       return res;
1020     }
1021   }
1022   bn_reverse (b, x);
1023   mp_clear (&t);
1024   return MP_OKAY;
1025 }
1026 
1027 
1028 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1029 static int
mp_div_2d(mp_int * a,int b,mp_int * c,mp_int * d)1030 mp_div_2d (mp_int * a, int b, mp_int * c, mp_int * d)
1031 {
1032   mp_digit D, r, rr;
1033   int     x, res;
1034   mp_int  t;
1035 
1036 
1037   /* if the shift count is <= 0 then we do no work */
1038   if (b <= 0) {
1039     res = mp_copy (a, c);
1040     if (d != NULL) {
1041       mp_zero (d);
1042     }
1043     return res;
1044   }
1045 
1046   if ((res = mp_init (&t)) != MP_OKAY) {
1047     return res;
1048   }
1049 
1050   /* get the remainder */
1051   if (d != NULL) {
1052     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1053       mp_clear (&t);
1054       return res;
1055     }
1056   }
1057 
1058   /* copy */
1059   if ((res = mp_copy (a, c)) != MP_OKAY) {
1060     mp_clear (&t);
1061     return res;
1062   }
1063 
1064   /* shift by as many digits in the bit count */
1065   if (b >= (int)DIGIT_BIT) {
1066     mp_rshd (c, b / DIGIT_BIT);
1067   }
1068 
1069   /* shift any bit count < DIGIT_BIT */
1070   D = (mp_digit) (b % DIGIT_BIT);
1071   if (D != 0) {
1072     register mp_digit *tmpc, mask, shift;
1073 
1074     /* mask */
1075     mask = (((mp_digit)1) << D) - 1;
1076 
1077     /* shift for lsb */
1078     shift = DIGIT_BIT - D;
1079 
1080     /* alias */
1081     tmpc = c->dp + (c->used - 1);
1082 
1083     /* carry */
1084     r = 0;
1085     for (x = c->used - 1; x >= 0; x--) {
1086       /* get the lower  bits of this word in a temp */
1087       rr = *tmpc & mask;
1088 
1089       /* shift the current word and mix in the carry bits from the previous word */
1090       *tmpc = (*tmpc >> D) | (r << shift);
1091       --tmpc;
1092 
1093       /* set the carry to the carry bits of the current word found above */
1094       r = rr;
1095     }
1096   }
1097   mp_clamp (c);
1098   if (d != NULL) {
1099     mp_exch (&t, d);
1100   }
1101   mp_clear (&t);
1102   return MP_OKAY;
1103 }
1104 
1105 
1106 static int
mp_init_copy(mp_int * a,mp_int * b)1107 mp_init_copy (mp_int * a, mp_int * b)
1108 {
1109   int     res;
1110 
1111   if ((res = mp_init (a)) != MP_OKAY) {
1112     return res;
1113   }
1114   return mp_copy (b, a);
1115 }
1116 
1117 
1118 /* set to zero */
1119 static void
mp_zero(mp_int * a)1120 mp_zero (mp_int * a)
1121 {
1122   int       n;
1123   mp_digit *tmp;
1124 
1125   a->sign = MP_ZPOS;
1126   a->used = 0;
1127 
1128   tmp = a->dp;
1129   for (n = 0; n < a->alloc; n++) {
1130      *tmp++ = 0;
1131   }
1132 }
1133 
1134 
1135 /* copy, b = a */
1136 static int
mp_copy(mp_int * a,mp_int * b)1137 mp_copy (mp_int * a, mp_int * b)
1138 {
1139   int     res, n;
1140 
1141   /* if dst == src do nothing */
1142   if (a == b) {
1143     return MP_OKAY;
1144   }
1145 
1146   /* grow dest */
1147   if (b->alloc < a->used) {
1148      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1149         return res;
1150      }
1151   }
1152 
1153   /* zero b and copy the parameters over */
1154   {
1155     register mp_digit *tmpa, *tmpb;
1156 
1157     /* pointer aliases */
1158 
1159     /* source */
1160     tmpa = a->dp;
1161 
1162     /* destination */
1163     tmpb = b->dp;
1164 
1165     /* copy all the digits */
1166     for (n = 0; n < a->used; n++) {
1167       *tmpb++ = *tmpa++;
1168     }
1169 
1170     /* clear high digits */
1171     for (; n < b->used; n++) {
1172       *tmpb++ = 0;
1173     }
1174   }
1175 
1176   /* copy used count and sign */
1177   b->used = a->used;
1178   b->sign = a->sign;
1179   return MP_OKAY;
1180 }
1181 
1182 
1183 /* shift right a certain amount of digits */
1184 static void
mp_rshd(mp_int * a,int b)1185 mp_rshd (mp_int * a, int b)
1186 {
1187   int     x;
1188 
1189   /* if b <= 0 then ignore it */
1190   if (b <= 0) {
1191     return;
1192   }
1193 
1194   /* if b > used then simply zero it and return */
1195   if (a->used <= b) {
1196     mp_zero (a);
1197     return;
1198   }
1199 
1200   {
1201     register mp_digit *bottom, *top;
1202 
1203     /* shift the digits down */
1204 
1205     /* bottom */
1206     bottom = a->dp;
1207 
1208     /* top [offset into digits] */
1209     top = a->dp + b;
1210 
1211     /* this is implemented as a sliding window where
1212      * the window is b-digits long and digits from
1213      * the top of the window are copied to the bottom
1214      *
1215      * e.g.
1216 
1217      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
1218                  /\                   |      ---->
1219                   \-------------------/      ---->
1220      */
1221     for (x = 0; x < (a->used - b); x++) {
1222       *bottom++ = *top++;
1223     }
1224 
1225     /* zero the top digits */
1226     for (; x < a->used; x++) {
1227       *bottom++ = 0;
1228     }
1229   }
1230 
1231   /* remove excess digits */
1232   a->used -= b;
1233 }
1234 
1235 
1236 /* swap the elements of two integers, for cases where you can't simply swap the
1237  * mp_int pointers around
1238  */
1239 static void
mp_exch(mp_int * a,mp_int * b)1240 mp_exch (mp_int * a, mp_int * b)
1241 {
1242   mp_int  t;
1243 
1244   t  = *a;
1245   *a = *b;
1246   *b = t;
1247 }
1248 
1249 
1250 /* trim unused digits
1251  *
1252  * This is used to ensure that leading zero digits are
1253  * trimed and the leading "used" digit will be non-zero
1254  * Typically very fast.  Also fixes the sign if there
1255  * are no more leading digits
1256  */
1257 static void
mp_clamp(mp_int * a)1258 mp_clamp (mp_int * a)
1259 {
1260   /* decrease used while the most significant digit is
1261    * zero.
1262    */
1263   while (a->used > 0 && a->dp[a->used - 1] == 0) {
1264     --(a->used);
1265   }
1266 
1267   /* reset the sign flag if used == 0 */
1268   if (a->used == 0) {
1269     a->sign = MP_ZPOS;
1270   }
1271 }
1272 
1273 
1274 /* grow as required */
1275 static int
mp_grow(mp_int * a,int size)1276 mp_grow (mp_int * a, int size)
1277 {
1278   int     i;
1279   mp_digit *tmp;
1280 
1281   /* if the alloc size is smaller alloc more ram */
1282   if (a->alloc < size) {
1283     /* ensure there are always at least MP_PREC digits extra on top */
1284     size += (MP_PREC * 2) - (size % MP_PREC);
1285 
1286     /* reallocate the array a->dp
1287      *
1288      * We store the return in a temporary variable
1289      * in case the operation failed we don't want
1290      * to overwrite the dp member of a.
1291      */
1292     tmp = OPT_CAST(mp_digit) XREALLOC (a->dp, sizeof (mp_digit) * size);
1293     if (tmp == NULL) {
1294       /* reallocation failed but "a" is still valid [can be freed] */
1295       return MP_MEM;
1296     }
1297 
1298     /* reallocation succeeded so set a->dp */
1299     a->dp = tmp;
1300 
1301     /* zero excess digits */
1302     i        = a->alloc;
1303     a->alloc = size;
1304     for (; i < a->alloc; i++) {
1305       a->dp[i] = 0;
1306     }
1307   }
1308   return MP_OKAY;
1309 }
1310 
1311 
1312 #ifdef BN_MP_ABS_C
1313 /* b = |a|
1314  *
1315  * Simple function copies the input and fixes the sign to positive
1316  */
1317 static int
mp_abs(mp_int * a,mp_int * b)1318 mp_abs (mp_int * a, mp_int * b)
1319 {
1320   int     res;
1321 
1322   /* copy a to b */
1323   if (a != b) {
1324      if ((res = mp_copy (a, b)) != MP_OKAY) {
1325        return res;
1326      }
1327   }
1328 
1329   /* force the sign of b to positive */
1330   b->sign = MP_ZPOS;
1331 
1332   return MP_OKAY;
1333 }
1334 #endif
1335 
1336 
1337 /* set to a digit */
1338 static void
mp_set(mp_int * a,mp_digit b)1339 mp_set (mp_int * a, mp_digit b)
1340 {
1341   mp_zero (a);
1342   a->dp[0] = b & MP_MASK;
1343   a->used  = (a->dp[0] != 0) ? 1 : 0;
1344 }
1345 
1346 
1347 #ifndef LTM_NO_NEG_EXP
1348 /* b = a/2 */
1349 static int
mp_div_2(mp_int * a,mp_int * b)1350 mp_div_2(mp_int * a, mp_int * b)
1351 {
1352   int     x, res, oldused;
1353 
1354   /* copy */
1355   if (b->alloc < a->used) {
1356     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1357       return res;
1358     }
1359   }
1360 
1361   oldused = b->used;
1362   b->used = a->used;
1363   {
1364     register mp_digit r, rr, *tmpa, *tmpb;
1365 
1366     /* source alias */
1367     tmpa = a->dp + b->used - 1;
1368 
1369     /* dest alias */
1370     tmpb = b->dp + b->used - 1;
1371 
1372     /* carry */
1373     r = 0;
1374     for (x = b->used - 1; x >= 0; x--) {
1375       /* get the carry for the next iteration */
1376       rr = *tmpa & 1;
1377 
1378       /* shift the current digit, add in carry and store */
1379       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
1380 
1381       /* forward carry to next iteration */
1382       r = rr;
1383     }
1384 
1385     /* zero excess digits */
1386     tmpb = b->dp + b->used;
1387     for (x = b->used; x < oldused; x++) {
1388       *tmpb++ = 0;
1389     }
1390   }
1391   b->sign = a->sign;
1392   mp_clamp (b);
1393   return MP_OKAY;
1394 }
1395 #endif /* LTM_NO_NEG_EXP */
1396 
1397 
1398 /* shift left by a certain bit count */
1399 static int
mp_mul_2d(mp_int * a,int b,mp_int * c)1400 mp_mul_2d (mp_int * a, int b, mp_int * c)
1401 {
1402   mp_digit d;
1403   int      res;
1404 
1405   /* copy */
1406   if (a != c) {
1407      if ((res = mp_copy (a, c)) != MP_OKAY) {
1408        return res;
1409      }
1410   }
1411 
1412   if (c->alloc < (int)(c->used + b/DIGIT_BIT + 1)) {
1413      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
1414        return res;
1415      }
1416   }
1417 
1418   /* shift by as many digits in the bit count */
1419   if (b >= (int)DIGIT_BIT) {
1420     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
1421       return res;
1422     }
1423   }
1424 
1425   /* shift any bit count < DIGIT_BIT */
1426   d = (mp_digit) (b % DIGIT_BIT);
1427   if (d != 0) {
1428     register mp_digit *tmpc, shift, mask, r, rr;
1429     register int x;
1430 
1431     /* bitmask for carries */
1432     mask = (((mp_digit)1) << d) - 1;
1433 
1434     /* shift for msbs */
1435     shift = DIGIT_BIT - d;
1436 
1437     /* alias */
1438     tmpc = c->dp;
1439 
1440     /* carry */
1441     r    = 0;
1442     for (x = 0; x < c->used; x++) {
1443       /* get the higher bits of the current word */
1444       rr = (*tmpc >> shift) & mask;
1445 
1446       /* shift the current word and OR in the carry */
1447       *tmpc = ((*tmpc << d) | r) & MP_MASK;
1448       ++tmpc;
1449 
1450       /* set the carry to the carry bits of the current word */
1451       r = rr;
1452     }
1453 
1454     /* set final carry */
1455     if (r != 0) {
1456        c->dp[(c->used)++] = r;
1457     }
1458   }
1459   mp_clamp (c);
1460   return MP_OKAY;
1461 }
1462 
1463 
1464 #ifdef BN_MP_INIT_MULTI_C
1465 static int
mp_init_multi(mp_int * mp,...)1466 mp_init_multi(mp_int *mp, ...)
1467 {
1468     mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
1469     int n = 0;                 /* Number of ok inits */
1470     mp_int* cur_arg = mp;
1471     va_list args;
1472 
1473     va_start(args, mp);        /* init args to next argument from caller */
1474     while (cur_arg != NULL) {
1475         if (mp_init(cur_arg) != MP_OKAY) {
1476             /* Oops - error! Back-track and mp_clear what we already
1477                succeeded in init-ing, then return error.
1478             */
1479             va_list clean_args;
1480 
1481             /* end the current list */
1482             va_end(args);
1483 
1484             /* now start cleaning up */
1485             cur_arg = mp;
1486             va_start(clean_args, mp);
1487             while (n--) {
1488                 mp_clear(cur_arg);
1489                 cur_arg = va_arg(clean_args, mp_int*);
1490             }
1491             va_end(clean_args);
1492             res = MP_MEM;
1493             break;
1494         }
1495         n++;
1496         cur_arg = va_arg(args, mp_int*);
1497     }
1498     va_end(args);
1499     return res;                /* Assumed ok, if error flagged above. */
1500 }
1501 #endif
1502 
1503 
1504 #ifdef BN_MP_CLEAR_MULTI_C
1505 static void
mp_clear_multi(mp_int * mp,...)1506 mp_clear_multi(mp_int *mp, ...)
1507 {
1508     mp_int* next_mp = mp;
1509     va_list args;
1510     va_start(args, mp);
1511     while (next_mp != NULL) {
1512         mp_clear(next_mp);
1513         next_mp = va_arg(args, mp_int*);
1514     }
1515     va_end(args);
1516 }
1517 #endif
1518 
1519 
1520 /* shift left a certain amount of digits */
1521 static int
mp_lshd(mp_int * a,int b)1522 mp_lshd (mp_int * a, int b)
1523 {
1524   int     x, res;
1525 
1526   /* if its less than zero return */
1527   if (b <= 0) {
1528     return MP_OKAY;
1529   }
1530 
1531   /* grow to fit the new digits */
1532   if (a->alloc < a->used + b) {
1533      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
1534        return res;
1535      }
1536   }
1537 
1538   {
1539     register mp_digit *top, *bottom;
1540 
1541     /* increment the used by the shift amount then copy upwards */
1542     a->used += b;
1543 
1544     /* top */
1545     top = a->dp + a->used - 1;
1546 
1547     /* base */
1548     bottom = a->dp + a->used - 1 - b;
1549 
1550     /* much like mp_rshd this is implemented using a sliding window
1551      * except the window goes the otherway around.  Copying from
1552      * the bottom to the top.  see bn_mp_rshd.c for more info.
1553      */
1554     for (x = a->used - 1; x >= b; x--) {
1555       *top-- = *bottom--;
1556     }
1557 
1558     /* zero the lower digits */
1559     top = a->dp;
1560     for (x = 0; x < b; x++) {
1561       *top++ = 0;
1562     }
1563   }
1564   return MP_OKAY;
1565 }
1566 
1567 
1568 /* returns the number of bits in an int */
1569 static int
mp_count_bits(mp_int * a)1570 mp_count_bits (mp_int * a)
1571 {
1572   int     r;
1573   mp_digit q;
1574 
1575   /* shortcut */
1576   if (a->used == 0) {
1577     return 0;
1578   }
1579 
1580   /* get number of digits and add that */
1581   r = (a->used - 1) * DIGIT_BIT;
1582 
1583   /* take the last digit and count the bits in it */
1584   q = a->dp[a->used - 1];
1585   while (q > ((mp_digit) 0)) {
1586     ++r;
1587     q >>= ((mp_digit) 1);
1588   }
1589   return r;
1590 }
1591 
1592 
1593 /* calc a value mod 2**b */
1594 static int
mp_mod_2d(mp_int * a,int b,mp_int * c)1595 mp_mod_2d (mp_int * a, int b, mp_int * c)
1596 {
1597   int     x, res;
1598 
1599   /* if b is <= 0 then zero the int */
1600   if (b <= 0) {
1601     mp_zero (c);
1602     return MP_OKAY;
1603   }
1604 
1605   /* if the modulus is larger than the value than return */
1606   if (b >= (int) (a->used * DIGIT_BIT)) {
1607     res = mp_copy (a, c);
1608     return res;
1609   }
1610 
1611   /* copy */
1612   if ((res = mp_copy (a, c)) != MP_OKAY) {
1613     return res;
1614   }
1615 
1616   /* zero digits above the last digit of the modulus */
1617   for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
1618     c->dp[x] = 0;
1619   }
1620   /* clear the digit that is not completely outside/inside the modulus */
1621   c->dp[b / DIGIT_BIT] &=
1622     (mp_digit) ((((mp_digit) 1) << (((mp_digit) b) % DIGIT_BIT)) - ((mp_digit) 1));
1623   mp_clamp (c);
1624   return MP_OKAY;
1625 }
1626 
1627 
1628 #ifdef BN_MP_DIV_SMALL
1629 
1630 /* slower bit-bang division... also smaller */
1631 static int
mp_div(mp_int * a,mp_int * b,mp_int * c,mp_int * d)1632 mp_div(mp_int * a, mp_int * b, mp_int * c, mp_int * d)
1633 {
1634    mp_int ta, tb, tq, q;
1635    int    res, n, n2;
1636 
1637   /* is divisor zero ? */
1638   if (mp_iszero (b) == 1) {
1639     return MP_VAL;
1640   }
1641 
1642   /* if a < b then q=0, r = a */
1643   if (mp_cmp_mag (a, b) == MP_LT) {
1644     if (d != NULL) {
1645       res = mp_copy (a, d);
1646     } else {
1647       res = MP_OKAY;
1648     }
1649     if (c != NULL) {
1650       mp_zero (c);
1651     }
1652     return res;
1653   }
1654 
1655   /* init our temps */
1656   if ((res = mp_init_multi(&ta, &tb, &tq, &q, NULL) != MP_OKAY)) {
1657      return res;
1658   }
1659 
1660 
1661   mp_set(&tq, 1);
1662   n = mp_count_bits(a) - mp_count_bits(b);
1663   if (((res = mp_abs(a, &ta)) != MP_OKAY) ||
1664       ((res = mp_abs(b, &tb)) != MP_OKAY) ||
1665       ((res = mp_mul_2d(&tb, n, &tb)) != MP_OKAY) ||
1666       ((res = mp_mul_2d(&tq, n, &tq)) != MP_OKAY)) {
1667       goto LBL_ERR;
1668   }
1669 
1670   while (n-- >= 0) {
1671      if (mp_cmp(&tb, &ta) != MP_GT) {
1672         if (((res = mp_sub(&ta, &tb, &ta)) != MP_OKAY) ||
1673             ((res = mp_add(&q, &tq, &q)) != MP_OKAY)) {
1674            goto LBL_ERR;
1675         }
1676      }
1677      if (((res = mp_div_2d(&tb, 1, &tb, NULL)) != MP_OKAY) ||
1678          ((res = mp_div_2d(&tq, 1, &tq, NULL)) != MP_OKAY)) {
1679            goto LBL_ERR;
1680      }
1681   }
1682 
1683   /* now q == quotient and ta == remainder */
1684   n  = a->sign;
1685   n2 = (a->sign == b->sign ? MP_ZPOS : MP_NEG);
1686   if (c != NULL) {
1687      mp_exch(c, &q);
1688      c->sign  = (mp_iszero(c) == MP_YES) ? MP_ZPOS : n2;
1689   }
1690   if (d != NULL) {
1691      mp_exch(d, &ta);
1692      d->sign = (mp_iszero(d) == MP_YES) ? MP_ZPOS : n;
1693   }
1694 LBL_ERR:
1695    mp_clear_multi(&ta, &tb, &tq, &q, NULL);
1696    return res;
1697 }
1698 
1699 #else
1700 
1701 /* integer signed division.
1702  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1703  * HAC pp.598 Algorithm 14.20
1704  *
1705  * Note that the description in HAC is horribly
1706  * incomplete.  For example, it doesn't consider
1707  * the case where digits are removed from 'x' in
1708  * the inner loop.  It also doesn't consider the
1709  * case that y has fewer than three digits, etc..
1710  *
1711  * The overall algorithm is as described as
1712  * 14.20 from HAC but fixed to treat these cases.
1713 */
1714 static int
mp_div(mp_int * a,mp_int * b,mp_int * c,mp_int * d)1715 mp_div (mp_int * a, mp_int * b, mp_int * c, mp_int * d)
1716 {
1717   mp_int  q, x, y, t1, t2;
1718   int     res, n, t, i, norm, neg;
1719 
1720   /* is divisor zero ? */
1721   if (mp_iszero (b) == 1) {
1722     return MP_VAL;
1723   }
1724 
1725   /* if a < b then q=0, r = a */
1726   if (mp_cmp_mag (a, b) == MP_LT) {
1727     if (d != NULL) {
1728       res = mp_copy (a, d);
1729     } else {
1730       res = MP_OKAY;
1731     }
1732     if (c != NULL) {
1733       mp_zero (c);
1734     }
1735     return res;
1736   }
1737 
1738   if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1739     return res;
1740   }
1741   q.used = a->used + 2;
1742 
1743   if ((res = mp_init (&t1)) != MP_OKAY) {
1744     goto LBL_Q;
1745   }
1746 
1747   if ((res = mp_init (&t2)) != MP_OKAY) {
1748     goto LBL_T1;
1749   }
1750 
1751   if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1752     goto LBL_T2;
1753   }
1754 
1755   if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1756     goto LBL_X;
1757   }
1758 
1759   /* fix the sign */
1760   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1761   x.sign = y.sign = MP_ZPOS;
1762 
1763   /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1764   norm = mp_count_bits(&y) % DIGIT_BIT;
1765   if (norm < (int)(DIGIT_BIT-1)) {
1766      norm = (DIGIT_BIT-1) - norm;
1767      if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1768        goto LBL_Y;
1769      }
1770      if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1771        goto LBL_Y;
1772      }
1773   } else {
1774      norm = 0;
1775   }
1776 
1777   /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1778   n = x.used - 1;
1779   t = y.used - 1;
1780 
1781   /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1782   if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1783     goto LBL_Y;
1784   }
1785 
1786   while (mp_cmp (&x, &y) != MP_LT) {
1787     ++(q.dp[n - t]);
1788     if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1789       goto LBL_Y;
1790     }
1791   }
1792 
1793   /* reset y by shifting it back down */
1794   mp_rshd (&y, n - t);
1795 
1796   /* step 3. for i from n down to (t + 1) */
1797   for (i = n; i >= (t + 1); i--) {
1798     if (i > x.used) {
1799       continue;
1800     }
1801 
1802     /* step 3.1 if xi == yt then set q{i-t-1} to b-1,
1803      * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1804     if (x.dp[i] == y.dp[t]) {
1805       q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1806     } else {
1807       mp_word tmp;
1808       tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1809       tmp |= ((mp_word) x.dp[i - 1]);
1810       tmp /= ((mp_word) y.dp[t]);
1811       if (tmp > (mp_word) MP_MASK)
1812         tmp = MP_MASK;
1813       q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1814     }
1815 
1816     /* while (q{i-t-1} * (yt * b + y{t-1})) >
1817              xi * b**2 + xi-1 * b + xi-2
1818 
1819        do q{i-t-1} -= 1;
1820     */
1821     q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1822     do {
1823       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1824 
1825       /* find left hand */
1826       mp_zero (&t1);
1827       t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1828       t1.dp[1] = y.dp[t];
1829       t1.used = 2;
1830       if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1831         goto LBL_Y;
1832       }
1833 
1834       /* find right hand */
1835       t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1836       t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1837       t2.dp[2] = x.dp[i];
1838       t2.used = 3;
1839     } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1840 
1841     /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1842     if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1843       goto LBL_Y;
1844     }
1845 
1846     if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1847       goto LBL_Y;
1848     }
1849 
1850     if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1851       goto LBL_Y;
1852     }
1853 
1854     /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1855     if (x.sign == MP_NEG) {
1856       if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1857         goto LBL_Y;
1858       }
1859       if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1860         goto LBL_Y;
1861       }
1862       if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1863         goto LBL_Y;
1864       }
1865 
1866       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1867     }
1868   }
1869 
1870   /* now q is the quotient and x is the remainder
1871    * [which we have to normalize]
1872    */
1873 
1874   /* get sign before writing to c */
1875   x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1876 
1877   if (c != NULL) {
1878     mp_clamp (&q);
1879     mp_exch (&q, c);
1880     c->sign = neg;
1881   }
1882 
1883   if (d != NULL) {
1884     mp_div_2d (&x, norm, &x, NULL);
1885     mp_exch (&x, d);
1886   }
1887 
1888   res = MP_OKAY;
1889 
1890 LBL_Y:mp_clear (&y);
1891 LBL_X:mp_clear (&x);
1892 LBL_T2:mp_clear (&t2);
1893 LBL_T1:mp_clear (&t1);
1894 LBL_Q:mp_clear (&q);
1895   return res;
1896 }
1897 
1898 #endif
1899 
1900 
1901 #ifdef MP_LOW_MEM
1902    #define TAB_SIZE 32
1903 #else
1904    #define TAB_SIZE 256
1905 #endif
1906 
1907 static int
s_mp_exptmod(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)1908 s_mp_exptmod (mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode)
1909 {
1910   mp_int  M[TAB_SIZE], res, mu;
1911   mp_digit buf;
1912   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1913   int (*redux)(mp_int*,mp_int*,mp_int*);
1914 
1915   /* find window size */
1916   x = mp_count_bits (X);
1917   if (x <= 7) {
1918     winsize = 2;
1919   } else if (x <= 36) {
1920     winsize = 3;
1921   } else if (x <= 140) {
1922     winsize = 4;
1923   } else if (x <= 450) {
1924     winsize = 5;
1925   } else if (x <= 1303) {
1926     winsize = 6;
1927   } else if (x <= 3529) {
1928     winsize = 7;
1929   } else {
1930     winsize = 8;
1931   }
1932 
1933 #ifdef MP_LOW_MEM
1934     if (winsize > 5) {
1935        winsize = 5;
1936     }
1937 #endif
1938 
1939   /* init M array */
1940   /* init first cell */
1941   if ((err = mp_init(&M[1])) != MP_OKAY) {
1942      return err;
1943   }
1944 
1945   /* now init the second half of the array */
1946   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1947     if ((err = mp_init(&M[x])) != MP_OKAY) {
1948       for (y = 1<<(winsize-1); y < x; y++) {
1949         mp_clear (&M[y]);
1950       }
1951       mp_clear(&M[1]);
1952       return err;
1953     }
1954   }
1955 
1956   /* create mu, used for Barrett reduction */
1957   if ((err = mp_init (&mu)) != MP_OKAY) {
1958     goto LBL_M;
1959   }
1960 
1961   if (redmode == 0) {
1962      if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
1963         goto LBL_MU;
1964      }
1965      redux = mp_reduce;
1966   } else {
1967      if ((err = mp_reduce_2k_setup_l (P, &mu)) != MP_OKAY) {
1968         goto LBL_MU;
1969      }
1970      redux = mp_reduce_2k_l;
1971   }
1972 
1973   /* create M table
1974    *
1975    * The M table contains powers of the base,
1976    * e.g. M[x] = G**x mod P
1977    *
1978    * The first half of the table is not
1979    * computed though accept for M[0] and M[1]
1980    */
1981   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
1982     goto LBL_MU;
1983   }
1984 
1985   /* compute the value at M[1<<(winsize-1)] by squaring
1986    * M[1] (winsize-1) times
1987    */
1988   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
1989     goto LBL_MU;
1990   }
1991 
1992   for (x = 0; x < (winsize - 1); x++) {
1993     /* square it */
1994     if ((err = mp_sqr (&M[1 << (winsize - 1)],
1995                        &M[1 << (winsize - 1)])) != MP_OKAY) {
1996       goto LBL_MU;
1997     }
1998 
1999     /* reduce modulo P */
2000     if ((err = redux (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
2001       goto LBL_MU;
2002     }
2003   }
2004 
2005   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
2006    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
2007    */
2008   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
2009     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
2010       goto LBL_MU;
2011     }
2012     if ((err = redux (&M[x], P, &mu)) != MP_OKAY) {
2013       goto LBL_MU;
2014     }
2015   }
2016 
2017   /* setup result */
2018   if ((err = mp_init (&res)) != MP_OKAY) {
2019     goto LBL_MU;
2020   }
2021   mp_set (&res, 1);
2022 
2023   /* set initial mode and bit cnt */
2024   mode   = 0;
2025   bitcnt = 1;
2026   buf    = 0;
2027   digidx = X->used - 1;
2028   bitcpy = 0;
2029   bitbuf = 0;
2030 
2031   for (;;) {
2032     /* grab next digit as required */
2033     if (--bitcnt == 0) {
2034       /* if digidx == -1 we are out of digits */
2035       if (digidx == -1) {
2036         break;
2037       }
2038       /* read next digit and reset the bitcnt */
2039       buf    = X->dp[digidx--];
2040       bitcnt = (int) DIGIT_BIT;
2041     }
2042 
2043     /* grab the next msb from the exponent */
2044     y     = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
2045     buf <<= (mp_digit)1;
2046 
2047     /* if the bit is zero and mode == 0 then we ignore it
2048      * These represent the leading zero bits before the first 1 bit
2049      * in the exponent.  Technically this opt is not required but it
2050      * does lower the # of trivial squaring/reductions used
2051      */
2052     if (mode == 0 && y == 0) {
2053       continue;
2054     }
2055 
2056     /* if the bit is zero and mode == 1 then we square */
2057     if (mode == 1 && y == 0) {
2058       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2059         goto LBL_RES;
2060       }
2061       if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2062         goto LBL_RES;
2063       }
2064       continue;
2065     }
2066 
2067     /* else we add it to the window */
2068     bitbuf |= (y << (winsize - ++bitcpy));
2069     mode    = 2;
2070 
2071     if (bitcpy == winsize) {
2072       /* ok window is filled so square as required and multiply  */
2073       /* square first */
2074       for (x = 0; x < winsize; x++) {
2075         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2076           goto LBL_RES;
2077         }
2078         if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2079           goto LBL_RES;
2080         }
2081       }
2082 
2083       /* then multiply */
2084       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2085         goto LBL_RES;
2086       }
2087       if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2088         goto LBL_RES;
2089       }
2090 
2091       /* empty window and reset */
2092       bitcpy = 0;
2093       bitbuf = 0;
2094       mode   = 1;
2095     }
2096   }
2097 
2098   /* if bits remain then square/multiply */
2099   if (mode == 2 && bitcpy > 0) {
2100     /* square then multiply if the bit is set */
2101     for (x = 0; x < bitcpy; x++) {
2102       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2103         goto LBL_RES;
2104       }
2105       if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2106         goto LBL_RES;
2107       }
2108 
2109       bitbuf <<= 1;
2110       if ((bitbuf & (1 << winsize)) != 0) {
2111         /* then multiply */
2112         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2113           goto LBL_RES;
2114         }
2115         if ((err = redux (&res, P, &mu)) != MP_OKAY) {
2116           goto LBL_RES;
2117         }
2118       }
2119     }
2120   }
2121 
2122   mp_exch (&res, Y);
2123   err = MP_OKAY;
2124 LBL_RES:mp_clear (&res);
2125 LBL_MU:mp_clear (&mu);
2126 LBL_M:
2127   mp_clear(&M[1]);
2128   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2129     mp_clear (&M[x]);
2130   }
2131   return err;
2132 }
2133 
2134 
2135 /* computes b = a*a */
2136 static int
mp_sqr(mp_int * a,mp_int * b)2137 mp_sqr (mp_int * a, mp_int * b)
2138 {
2139   int     res;
2140 
2141 #ifdef BN_MP_TOOM_SQR_C
2142   /* use Toom-Cook? */
2143   if (a->used >= TOOM_SQR_CUTOFF) {
2144     res = mp_toom_sqr(a, b);
2145   /* Karatsuba? */
2146   } else
2147 #endif
2148 #ifdef BN_MP_KARATSUBA_SQR_C
2149 if (a->used >= KARATSUBA_SQR_CUTOFF) {
2150     res = mp_karatsuba_sqr (a, b);
2151   } else
2152 #endif
2153   {
2154 #ifdef BN_FAST_S_MP_SQR_C
2155     /* can we use the fast comba multiplier? */
2156     if ((a->used * 2 + 1) < MP_WARRAY &&
2157          a->used <
2158          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
2159       res = fast_s_mp_sqr (a, b);
2160     } else
2161 #endif
2162 #ifdef BN_S_MP_SQR_C
2163       res = s_mp_sqr (a, b);
2164 #else
2165 #error mp_sqr could fail
2166       res = MP_VAL;
2167 #endif
2168   }
2169   b->sign = MP_ZPOS;
2170   return res;
2171 }
2172 
2173 
2174 /* reduces a modulo n where n is of the form 2**p - d
2175    This differs from reduce_2k since "d" can be larger
2176    than a single digit.
2177 */
2178 static int
mp_reduce_2k_l(mp_int * a,mp_int * n,mp_int * d)2179 mp_reduce_2k_l(mp_int *a, mp_int *n, mp_int *d)
2180 {
2181    mp_int q;
2182    int    p, res;
2183 
2184    if ((res = mp_init(&q)) != MP_OKAY) {
2185       return res;
2186    }
2187 
2188    p = mp_count_bits(n);
2189 top:
2190    /* q = a/2**p, a = a mod 2**p */
2191    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
2192       goto ERR;
2193    }
2194 
2195    /* q = q * d */
2196    if ((res = mp_mul(&q, d, &q)) != MP_OKAY) {
2197       goto ERR;
2198    }
2199 
2200    /* a = a + q */
2201    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
2202       goto ERR;
2203    }
2204 
2205    if (mp_cmp_mag(a, n) != MP_LT) {
2206       s_mp_sub(a, n, a);
2207       goto top;
2208    }
2209 
2210 ERR:
2211    mp_clear(&q);
2212    return res;
2213 }
2214 
2215 
2216 /* determines the setup value */
2217 static int
mp_reduce_2k_setup_l(mp_int * a,mp_int * d)2218 mp_reduce_2k_setup_l(mp_int *a, mp_int *d)
2219 {
2220    int    res;
2221    mp_int tmp;
2222 
2223    if ((res = mp_init(&tmp)) != MP_OKAY) {
2224       return res;
2225    }
2226 
2227    if ((res = mp_2expt(&tmp, mp_count_bits(a))) != MP_OKAY) {
2228       goto ERR;
2229    }
2230 
2231    if ((res = s_mp_sub(&tmp, a, d)) != MP_OKAY) {
2232       goto ERR;
2233    }
2234 
2235 ERR:
2236    mp_clear(&tmp);
2237    return res;
2238 }
2239 
2240 
2241 /* computes a = 2**b
2242  *
2243  * Simple algorithm which zeroes the int, grows it then just sets one bit
2244  * as required.
2245  */
2246 static int
mp_2expt(mp_int * a,int b)2247 mp_2expt (mp_int * a, int b)
2248 {
2249   int     res;
2250 
2251   /* zero a as per default */
2252   mp_zero (a);
2253 
2254   /* grow a to accommodate the single bit */
2255   if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
2256     return res;
2257   }
2258 
2259   /* set the used count of where the bit will go */
2260   a->used = b / DIGIT_BIT + 1;
2261 
2262   /* put the single bit in its place */
2263   a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
2264 
2265   return MP_OKAY;
2266 }
2267 
2268 
2269 /* pre-calculate the value required for Barrett reduction
2270  * For a given modulus "b" it calulates the value required in "a"
2271  */
2272 static int
mp_reduce_setup(mp_int * a,mp_int * b)2273 mp_reduce_setup (mp_int * a, mp_int * b)
2274 {
2275   int     res;
2276 
2277   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
2278     return res;
2279   }
2280   return mp_div (a, b, a, NULL);
2281 }
2282 
2283 
2284 /* reduces x mod m, assumes 0 < x < m**2, mu is
2285  * precomputed via mp_reduce_setup.
2286  * From HAC pp.604 Algorithm 14.42
2287  */
2288 static int
mp_reduce(mp_int * x,mp_int * m,mp_int * mu)2289 mp_reduce (mp_int * x, mp_int * m, mp_int * mu)
2290 {
2291   mp_int  q;
2292   int     res, um = m->used;
2293 
2294   /* q = x */
2295   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
2296     return res;
2297   }
2298 
2299   /* q1 = x / b**(k-1)  */
2300   mp_rshd (&q, um - 1);
2301 
2302   /* according to HAC this optimization is ok */
2303   if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
2304     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
2305       goto CLEANUP;
2306     }
2307   } else {
2308 #ifdef BN_S_MP_MUL_HIGH_DIGS_C
2309     if ((res = s_mp_mul_high_digs (&q, mu, &q, um)) != MP_OKAY) {
2310       goto CLEANUP;
2311     }
2312 #elif defined(BN_FAST_S_MP_MUL_HIGH_DIGS_C)
2313     if ((res = fast_s_mp_mul_high_digs (&q, mu, &q, um)) != MP_OKAY) {
2314       goto CLEANUP;
2315     }
2316 #else
2317     {
2318 #error mp_reduce would always fail
2319       res = MP_VAL;
2320       goto CLEANUP;
2321     }
2322 #endif
2323   }
2324 
2325   /* q3 = q2 / b**(k+1) */
2326   mp_rshd (&q, um + 1);
2327 
2328   /* x = x mod b**(k+1), quick (no division) */
2329   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
2330     goto CLEANUP;
2331   }
2332 
2333   /* q = q * m mod b**(k+1), quick (no division) */
2334   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
2335     goto CLEANUP;
2336   }
2337 
2338   /* x = x - q */
2339   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
2340     goto CLEANUP;
2341   }
2342 
2343   /* If x < 0, add b**(k+1) to it */
2344   if (mp_cmp_d (x, 0) == MP_LT) {
2345     mp_set (&q, 1);
2346     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY) {
2347       goto CLEANUP;
2348     }
2349     if ((res = mp_add (x, &q, x)) != MP_OKAY) {
2350       goto CLEANUP;
2351     }
2352   }
2353 
2354   /* Back off if it's too big */
2355   while (mp_cmp (x, m) != MP_LT) {
2356     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
2357       goto CLEANUP;
2358     }
2359   }
2360 
2361 CLEANUP:
2362   mp_clear (&q);
2363 
2364   return res;
2365 }
2366 
2367 
2368 /* multiplies |a| * |b| and only computes up to digs digits of result
2369  * HAC pp. 595, Algorithm 14.12  Modified so you can control how
2370  * many digits of output are created.
2371  */
2372 static int
s_mp_mul_digs(mp_int * a,mp_int * b,mp_int * c,int digs)2373 s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
2374 {
2375   mp_int  t;
2376   int     res, pa, pb, ix, iy;
2377   mp_digit u;
2378   mp_word r;
2379   mp_digit tmpx, *tmpt, *tmpy;
2380 
2381   /* can we use the fast multiplier? */
2382   if (((digs) < MP_WARRAY) &&
2383       MIN (a->used, b->used) <
2384           (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2385     return fast_s_mp_mul_digs (a, b, c, digs);
2386   }
2387 
2388   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
2389     return res;
2390   }
2391   t.used = digs;
2392 
2393   /* compute the digits of the product directly */
2394   pa = a->used;
2395   for (ix = 0; ix < pa; ix++) {
2396     /* set the carry to zero */
2397     u = 0;
2398 
2399     /* limit ourselves to making digs digits of output */
2400     pb = MIN (b->used, digs - ix);
2401 
2402     /* setup some aliases */
2403     /* copy of the digit from a used within the nested loop */
2404     tmpx = a->dp[ix];
2405 
2406     /* an alias for the destination shifted ix places */
2407     tmpt = t.dp + ix;
2408 
2409     /* an alias for the digits of b */
2410     tmpy = b->dp;
2411 
2412     /* compute the columns of the output and propagate the carry */
2413     for (iy = 0; iy < pb; iy++) {
2414       /* compute the column as a mp_word */
2415       r       = ((mp_word)*tmpt) +
2416                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
2417                 ((mp_word) u);
2418 
2419       /* the new column is the lower part of the result */
2420       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2421 
2422       /* get the carry word from the result */
2423       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
2424     }
2425     /* set carry if it is placed below digs */
2426     if (ix + iy < digs) {
2427       *tmpt = u;
2428     }
2429   }
2430 
2431   mp_clamp (&t);
2432   mp_exch (&t, c);
2433 
2434   mp_clear (&t);
2435   return MP_OKAY;
2436 }
2437 
2438 
2439 /* Fast (comba) multiplier
2440  *
2441  * This is the fast column-array [comba] multiplier.  It is
2442  * designed to compute the columns of the product first
2443  * then handle the carries afterwards.  This has the effect
2444  * of making the nested loops that compute the columns very
2445  * simple and schedulable on super-scalar processors.
2446  *
2447  * This has been modified to produce a variable number of
2448  * digits of output so if say only a half-product is required
2449  * you don't have to compute the upper half (a feature
2450  * required for fast Barrett reduction).
2451  *
2452  * Based on Algorithm 14.12 on pp.595 of HAC.
2453  *
2454  */
2455 static int
fast_s_mp_mul_digs(mp_int * a,mp_int * b,mp_int * c,int digs)2456 fast_s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
2457 {
2458   int     olduse, res, pa, ix, iz;
2459   mp_digit W[MP_WARRAY];
2460   register mp_word  _W;
2461 
2462   /* grow the destination as required */
2463   if (c->alloc < digs) {
2464     if ((res = mp_grow (c, digs)) != MP_OKAY) {
2465       return res;
2466     }
2467   }
2468 
2469   /* number of output digits to produce */
2470   pa = MIN(digs, a->used + b->used);
2471 
2472   /* clear the carry */
2473   _W = 0;
2474   for (ix = 0; ix < pa; ix++) {
2475       int      tx, ty;
2476       int      iy;
2477       mp_digit *tmpx, *tmpy;
2478 
2479       /* get offsets into the two bignums */
2480       ty = MIN(b->used-1, ix);
2481       tx = ix - ty;
2482 
2483       /* setup temp aliases */
2484       tmpx = a->dp + tx;
2485       tmpy = b->dp + ty;
2486 
2487       /* this is the number of times the loop will iterrate, essentially
2488          while (tx++ < a->used && ty-- >= 0) { ... }
2489        */
2490       iy = MIN(a->used-tx, ty+1);
2491 
2492       /* execute loop */
2493       for (iz = 0; iz < iy; ++iz) {
2494          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
2495 
2496       }
2497 
2498       /* store term */
2499       W[ix] = ((mp_digit)_W) & MP_MASK;
2500 
2501       /* make next carry */
2502       _W = _W >> ((mp_word)DIGIT_BIT);
2503  }
2504 
2505   /* setup dest */
2506   olduse  = c->used;
2507   c->used = pa;
2508 
2509   {
2510     register mp_digit *tmpc;
2511     tmpc = c->dp;
2512     for (ix = 0; ix < pa+1; ix++) {
2513       /* now extract the previous digit [below the carry] */
2514       *tmpc++ = W[ix];
2515     }
2516 
2517     /* clear unused digits [that existed in the old copy of c] */
2518     for (; ix < olduse; ix++) {
2519       *tmpc++ = 0;
2520     }
2521   }
2522   mp_clamp (c);
2523   return MP_OKAY;
2524 }
2525 
2526 
2527 /* init an mp_init for a given size */
2528 static int
mp_init_size(mp_int * a,int size)2529 mp_init_size (mp_int * a, int size)
2530 {
2531   int x;
2532 
2533   /* pad size so there are always extra digits */
2534   size += (MP_PREC * 2) - (size % MP_PREC);
2535 
2536   /* alloc mem */
2537   a->dp = OPT_CAST(mp_digit) XMALLOC (sizeof (mp_digit) * size);
2538   if (a->dp == NULL) {
2539     return MP_MEM;
2540   }
2541 
2542   /* set the members */
2543   a->used  = 0;
2544   a->alloc = size;
2545   a->sign  = MP_ZPOS;
2546 
2547   /* zero the digits */
2548   for (x = 0; x < size; x++) {
2549       a->dp[x] = 0;
2550   }
2551 
2552   return MP_OKAY;
2553 }
2554 
2555 
2556 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
2557 static int
s_mp_sqr(mp_int * a,mp_int * b)2558 s_mp_sqr (mp_int * a, mp_int * b)
2559 {
2560   mp_int  t;
2561   int     res, ix, iy, pa;
2562   mp_word r;
2563   mp_digit u, tmpx, *tmpt;
2564 
2565   pa = a->used;
2566   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
2567     return res;
2568   }
2569 
2570   /* default used is maximum possible size */
2571   t.used = 2*pa + 1;
2572 
2573   for (ix = 0; ix < pa; ix++) {
2574     /* first calculate the digit at 2*ix */
2575     /* calculate double precision result */
2576     r = ((mp_word) t.dp[2*ix]) +
2577         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
2578 
2579     /* store lower part in result */
2580     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
2581 
2582     /* get the carry */
2583     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2584 
2585     /* left hand side of A[ix] * A[iy] */
2586     tmpx        = a->dp[ix];
2587 
2588     /* alias for where to store the results */
2589     tmpt        = t.dp + (2*ix + 1);
2590 
2591     for (iy = ix + 1; iy < pa; iy++) {
2592       /* first calculate the product */
2593       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
2594 
2595       /* now calculate the double precision result, note we use
2596        * addition instead of *2 since it's easier to optimize
2597        */
2598       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
2599 
2600       /* store lower part */
2601       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2602 
2603       /* get carry */
2604       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2605     }
2606     /* propagate upwards */
2607     while (u != ((mp_digit) 0)) {
2608       r       = ((mp_word) *tmpt) + ((mp_word) u);
2609       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2610       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2611     }
2612   }
2613 
2614   mp_clamp (&t);
2615   mp_exch (&t, b);
2616   mp_clear (&t);
2617   return MP_OKAY;
2618 }
2619 
2620 
2621 /* multiplies |a| * |b| and does not compute the lower digs digits
2622  * [meant to get the higher part of the product]
2623  */
2624 static int
s_mp_mul_high_digs(mp_int * a,mp_int * b,mp_int * c,int digs)2625 s_mp_mul_high_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
2626 {
2627   mp_int  t;
2628   int     res, pa, pb, ix, iy;
2629   mp_digit u;
2630   mp_word r;
2631   mp_digit tmpx, *tmpt, *tmpy;
2632 
2633   /* can we use the fast multiplier? */
2634 #ifdef BN_FAST_S_MP_MUL_HIGH_DIGS_C
2635   if (((a->used + b->used + 1) < MP_WARRAY)
2636       && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2637     return fast_s_mp_mul_high_digs (a, b, c, digs);
2638   }
2639 #endif
2640 
2641   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
2642     return res;
2643   }
2644   t.used = a->used + b->used + 1;
2645 
2646   pa = a->used;
2647   pb = b->used;
2648   for (ix = 0; ix < pa; ix++) {
2649     /* clear the carry */
2650     u = 0;
2651 
2652     /* left hand side of A[ix] * B[iy] */
2653     tmpx = a->dp[ix];
2654 
2655     /* alias to the address of where the digits will be stored */
2656     tmpt = &(t.dp[digs]);
2657 
2658     /* alias for where to read the right hand side from */
2659     tmpy = b->dp + (digs - ix);
2660 
2661     for (iy = digs - ix; iy < pb; iy++) {
2662       /* calculate the double precision result */
2663       r       = ((mp_word)*tmpt) +
2664                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
2665                 ((mp_word) u);
2666 
2667       /* get the lower part */
2668       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
2669 
2670       /* carry the carry */
2671       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
2672     }
2673     *tmpt = u;
2674   }
2675   mp_clamp (&t);
2676   mp_exch (&t, c);
2677   mp_clear (&t);
2678   return MP_OKAY;
2679 }
2680 
2681 
2682 #ifdef BN_MP_MONTGOMERY_SETUP_C
2683 /* setups the montgomery reduction stuff */
2684 static int
mp_montgomery_setup(mp_int * n,mp_digit * rho)2685 mp_montgomery_setup (mp_int * n, mp_digit * rho)
2686 {
2687   mp_digit x, b;
2688 
2689 /* fast inversion mod 2**k
2690  *
2691  * Based on the fact that
2692  *
2693  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
2694  *                    =>  2*X*A - X*X*A*A = 1
2695  *                    =>  2*(1) - (1)     = 1
2696  */
2697   b = n->dp[0];
2698 
2699   if ((b & 1) == 0) {
2700     return MP_VAL;
2701   }
2702 
2703   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
2704   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
2705 #if !defined(MP_8BIT)
2706   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
2707 #endif
2708 #if defined(MP_64BIT) || !(defined(MP_8BIT) || defined(MP_16BIT))
2709   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
2710 #endif
2711 #ifdef MP_64BIT
2712   x *= 2 - b * x;               /* here x*a==1 mod 2**64 */
2713 #endif
2714 
2715   /* rho = -1/m mod b */
2716   *rho = (unsigned long)(((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
2717 
2718   return MP_OKAY;
2719 }
2720 #endif
2721 
2722 
2723 #ifdef BN_FAST_MP_MONTGOMERY_REDUCE_C
2724 /* computes xR**-1 == x (mod N) via Montgomery Reduction
2725  *
2726  * This is an optimized implementation of montgomery_reduce
2727  * which uses the comba method to quickly calculate the columns of the
2728  * reduction.
2729  *
2730  * Based on Algorithm 14.32 on pp.601 of HAC.
2731 */
2732 int
fast_mp_montgomery_reduce(mp_int * x,mp_int * n,mp_digit rho)2733 fast_mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho)
2734 {
2735   int     ix, res, olduse;
2736   mp_word W[MP_WARRAY];
2737 
2738   /* get old used count */
2739   olduse = x->used;
2740 
2741   /* grow a as required */
2742   if (x->alloc < n->used + 1) {
2743     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
2744       return res;
2745     }
2746   }
2747 
2748   /* first we have to get the digits of the input into
2749    * an array of double precision words W[...]
2750    */
2751   {
2752     register mp_word *_W;
2753     register mp_digit *tmpx;
2754 
2755     /* alias for the W[] array */
2756     _W   = W;
2757 
2758     /* alias for the digits of  x*/
2759     tmpx = x->dp;
2760 
2761     /* copy the digits of a into W[0..a->used-1] */
2762     for (ix = 0; ix < x->used; ix++) {
2763       *_W++ = *tmpx++;
2764     }
2765 
2766     /* zero the high words of W[a->used..m->used*2] */
2767     for (; ix < n->used * 2 + 1; ix++) {
2768       *_W++ = 0;
2769     }
2770   }
2771 
2772   /* now we proceed to zero successive digits
2773    * from the least significant upwards
2774    */
2775   for (ix = 0; ix < n->used; ix++) {
2776     /* mu = ai * m' mod b
2777      *
2778      * We avoid a double precision multiplication (which isn't required)
2779      * by casting the value down to a mp_digit.  Note this requires
2780      * that W[ix-1] have  the carry cleared (see after the inner loop)
2781      */
2782     register mp_digit mu;
2783     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
2784 
2785     /* a = a + mu * m * b**i
2786      *
2787      * This is computed in place and on the fly.  The multiplication
2788      * by b**i is handled by offseting which columns the results
2789      * are added to.
2790      *
2791      * Note the comba method normally doesn't handle carries in the
2792      * inner loop In this case we fix the carry from the previous
2793      * column since the Montgomery reduction requires digits of the
2794      * result (so far) [see above] to work.  This is
2795      * handled by fixing up one carry after the inner loop.  The
2796      * carry fixups are done in order so after these loops the
2797      * first m->used words of W[] have the carries fixed
2798      */
2799     {
2800       register int iy;
2801       register mp_digit *tmpn;
2802       register mp_word *_W;
2803 
2804       /* alias for the digits of the modulus */
2805       tmpn = n->dp;
2806 
2807       /* Alias for the columns set by an offset of ix */
2808       _W = W + ix;
2809 
2810       /* inner loop */
2811       for (iy = 0; iy < n->used; iy++) {
2812           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
2813       }
2814     }
2815 
2816     /* now fix carry for next digit, W[ix+1] */
2817     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
2818   }
2819 
2820   /* now we have to propagate the carries and
2821    * shift the words downward [all those least
2822    * significant digits we zeroed].
2823    */
2824   {
2825     register mp_digit *tmpx;
2826     register mp_word *_W, *_W1;
2827 
2828     /* nox fix rest of carries */
2829 
2830     /* alias for current word */
2831     _W1 = W + ix;
2832 
2833     /* alias for next word, where the carry goes */
2834     _W = W + ++ix;
2835 
2836     for (; ix <= n->used * 2 + 1; ix++) {
2837       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
2838     }
2839 
2840     /* copy out, A = A/b**n
2841      *
2842      * The result is A/b**n but instead of converting from an
2843      * array of mp_word to mp_digit than calling mp_rshd
2844      * we just copy them in the right order
2845      */
2846 
2847     /* alias for destination word */
2848     tmpx = x->dp;
2849 
2850     /* alias for shifted double precision result */
2851     _W = W + n->used;
2852 
2853     for (ix = 0; ix < n->used + 1; ix++) {
2854       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
2855     }
2856 
2857     /* zero oldused digits, if the input a was larger than
2858      * m->used+1 we'll have to clear the digits
2859      */
2860     for (; ix < olduse; ix++) {
2861       *tmpx++ = 0;
2862     }
2863   }
2864 
2865   /* set the max used and clamp */
2866   x->used = n->used + 1;
2867   mp_clamp (x);
2868 
2869   /* if A >= m then A = A - m */
2870   if (mp_cmp_mag (x, n) != MP_LT) {
2871     return s_mp_sub (x, n, x);
2872   }
2873   return MP_OKAY;
2874 }
2875 #endif
2876 
2877 
2878 #ifdef BN_MP_MUL_2_C
2879 /* b = a*2 */
2880 static int
mp_mul_2(mp_int * a,mp_int * b)2881 mp_mul_2(mp_int * a, mp_int * b)
2882 {
2883   int     x, res, oldused;
2884 
2885   /* grow to accommodate result */
2886   if (b->alloc < a->used + 1) {
2887     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2888       return res;
2889     }
2890   }
2891 
2892   oldused = b->used;
2893   b->used = a->used;
2894 
2895   {
2896     register mp_digit r, rr, *tmpa, *tmpb;
2897 
2898     /* alias for source */
2899     tmpa = a->dp;
2900 
2901     /* alias for dest */
2902     tmpb = b->dp;
2903 
2904     /* carry */
2905     r = 0;
2906     for (x = 0; x < a->used; x++) {
2907 
2908       /* get what will be the *next* carry bit from the
2909        * MSB of the current digit
2910        */
2911       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2912 
2913       /* now shift up this digit, add in the carry [from the previous] */
2914       *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2915 
2916       /* copy the carry that would be from the source
2917        * digit into the next iteration
2918        */
2919       r = rr;
2920     }
2921 
2922     /* new leading digit? */
2923     if (r != 0) {
2924       /* add a MSB which is always 1 at this point */
2925       *tmpb = 1;
2926       ++(b->used);
2927     }
2928 
2929     /* now zero any excess digits on the destination
2930      * that we didn't write to
2931      */
2932     tmpb = b->dp + b->used;
2933     for (x = b->used; x < oldused; x++) {
2934       *tmpb++ = 0;
2935     }
2936   }
2937   b->sign = a->sign;
2938   return MP_OKAY;
2939 }
2940 #endif
2941 
2942 
2943 #ifdef BN_MP_MONTGOMERY_CALC_NORMALIZATION_C
2944 /*
2945  * shifts with subtractions when the result is greater than b.
2946  *
2947  * The method is slightly modified to shift B unconditionally up to just under
2948  * the leading bit of b.  This saves a lot of multiple precision shifting.
2949  */
2950 static int
mp_montgomery_calc_normalization(mp_int * a,mp_int * b)2951 mp_montgomery_calc_normalization (mp_int * a, mp_int * b)
2952 {
2953   int     x, bits, res;
2954 
2955   /* how many bits of last digit does b use */
2956   bits = mp_count_bits (b) % DIGIT_BIT;
2957 
2958   if (b->used > 1) {
2959      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2960         return res;
2961      }
2962   } else {
2963      mp_set(a, 1);
2964      bits = 1;
2965   }
2966 
2967 
2968   /* now compute C = A * B mod b */
2969   for (x = bits - 1; x < (int)DIGIT_BIT; x++) {
2970     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2971       return res;
2972     }
2973     if (mp_cmp_mag (a, b) != MP_LT) {
2974       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2975         return res;
2976       }
2977     }
2978   }
2979 
2980   return MP_OKAY;
2981 }
2982 #endif
2983 
2984 
2985 #ifdef BN_MP_EXPTMOD_FAST_C
2986 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
2987  *
2988  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
2989  * The value of k changes based on the size of the exponent.
2990  *
2991  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
2992  */
2993 
2994 static int
mp_exptmod_fast(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)2995 mp_exptmod_fast (mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode)
2996 {
2997   mp_int  M[TAB_SIZE], res;
2998   mp_digit buf, mp;
2999   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3000 
3001   /* use a pointer to the reduction algorithm.  This allows us to use
3002    * one of many reduction algorithms without modding the guts of
3003    * the code with if statements everywhere.
3004    */
3005   int     (*redux)(mp_int*,mp_int*,mp_digit);
3006 
3007   /* find window size */
3008   x = mp_count_bits (X);
3009   if (x <= 7) {
3010     winsize = 2;
3011   } else if (x <= 36) {
3012     winsize = 3;
3013   } else if (x <= 140) {
3014     winsize = 4;
3015   } else if (x <= 450) {
3016     winsize = 5;
3017   } else if (x <= 1303) {
3018     winsize = 6;
3019   } else if (x <= 3529) {
3020     winsize = 7;
3021   } else {
3022     winsize = 8;
3023   }
3024 
3025 #ifdef MP_LOW_MEM
3026   if (winsize > 5) {
3027      winsize = 5;
3028   }
3029 #endif
3030 
3031   /* init M array */
3032   /* init first cell */
3033   if ((err = mp_init(&M[1])) != MP_OKAY) {
3034      return err;
3035   }
3036 
3037   /* now init the second half of the array */
3038   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3039     if ((err = mp_init(&M[x])) != MP_OKAY) {
3040       for (y = 1<<(winsize-1); y < x; y++) {
3041         mp_clear (&M[y]);
3042       }
3043       mp_clear(&M[1]);
3044       return err;
3045     }
3046   }
3047 
3048   /* determine and setup reduction code */
3049   if (redmode == 0) {
3050 #ifdef BN_MP_MONTGOMERY_SETUP_C
3051      /* now setup montgomery  */
3052      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
3053         goto LBL_M;
3054      }
3055 #else
3056      err = MP_VAL;
3057      goto LBL_M;
3058 #endif
3059 
3060      /* automatically pick the comba one if available (saves quite a few calls/ifs) */
3061 #ifdef BN_FAST_MP_MONTGOMERY_REDUCE_C
3062      if (((P->used * 2 + 1) < MP_WARRAY) &&
3063           P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3064         redux = fast_mp_montgomery_reduce;
3065      } else
3066 #endif
3067      {
3068 #ifdef BN_MP_MONTGOMERY_REDUCE_C
3069         /* use slower baseline Montgomery method */
3070         redux = mp_montgomery_reduce;
3071 #else
3072         err = MP_VAL;
3073         goto LBL_M;
3074 #endif
3075      }
3076   } else if (redmode == 1) {
3077 #if defined(BN_MP_DR_SETUP_C) && defined(BN_MP_DR_REDUCE_C)
3078      /* setup DR reduction for moduli of the form B**k - b */
3079      mp_dr_setup(P, &mp);
3080      redux = mp_dr_reduce;
3081 #else
3082      err = MP_VAL;
3083      goto LBL_M;
3084 #endif
3085   } else {
3086 #if defined(BN_MP_REDUCE_2K_SETUP_C) && defined(BN_MP_REDUCE_2K_C)
3087      /* setup DR reduction for moduli of the form 2**k - b */
3088      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
3089         goto LBL_M;
3090      }
3091      redux = mp_reduce_2k;
3092 #else
3093      err = MP_VAL;
3094      goto LBL_M;
3095 #endif
3096   }
3097 
3098   /* setup result */
3099   if ((err = mp_init (&res)) != MP_OKAY) {
3100     goto LBL_M;
3101   }
3102 
3103   /* create M table
3104    *
3105 
3106    *
3107    * The first half of the table is not computed though accept for M[0] and M[1]
3108    */
3109 
3110   if (redmode == 0) {
3111 #ifdef BN_MP_MONTGOMERY_CALC_NORMALIZATION_C
3112      /* now we need R mod m */
3113      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
3114        goto LBL_RES;
3115      }
3116 #else
3117      err = MP_VAL;
3118      goto LBL_RES;
3119 #endif
3120 
3121      /* now set M[1] to G * R mod m */
3122      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
3123        goto LBL_RES;
3124      }
3125   } else {
3126      mp_set(&res, 1);
3127      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
3128         goto LBL_RES;
3129      }
3130   }
3131 
3132   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
3133   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
3134     goto LBL_RES;
3135   }
3136 
3137   for (x = 0; x < (winsize - 1); x++) {
3138     if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
3139       goto LBL_RES;
3140     }
3141     if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
3142       goto LBL_RES;
3143     }
3144   }
3145 
3146   /* create upper table */
3147   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
3148     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
3149       goto LBL_RES;
3150     }
3151     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
3152       goto LBL_RES;
3153     }
3154   }
3155 
3156   /* set initial mode and bit cnt */
3157   mode   = 0;
3158   bitcnt = 1;
3159   buf    = 0;
3160   digidx = X->used - 1;
3161   bitcpy = 0;
3162   bitbuf = 0;
3163 
3164   for (;;) {
3165     /* grab next digit as required */
3166     if (--bitcnt == 0) {
3167       /* if digidx == -1 we are out of digits so break */
3168       if (digidx == -1) {
3169         break;
3170       }
3171       /* read next digit and reset bitcnt */
3172       buf    = X->dp[digidx--];
3173       bitcnt = (int)DIGIT_BIT;
3174     }
3175 
3176     /* grab the next msb from the exponent */
3177     y     = (mp_digit)(buf >> (DIGIT_BIT - 1)) & 1;
3178     buf <<= (mp_digit)1;
3179 
3180     /* if the bit is zero and mode == 0 then we ignore it
3181      * These represent the leading zero bits before the first 1 bit
3182      * in the exponent.  Technically this opt is not required but it
3183      * does lower the # of trivial squaring/reductions used
3184      */
3185     if (mode == 0 && y == 0) {
3186       continue;
3187     }
3188 
3189     /* if the bit is zero and mode == 1 then we square */
3190     if (mode == 1 && y == 0) {
3191       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
3192         goto LBL_RES;
3193       }
3194       if ((err = redux (&res, P, mp)) != MP_OKAY) {
3195         goto LBL_RES;
3196       }
3197       continue;
3198     }
3199 
3200     /* else we add it to the window */
3201     bitbuf |= (y << (winsize - ++bitcpy));
3202     mode    = 2;
3203 
3204     if (bitcpy == winsize) {
3205       /* ok window is filled so square as required and multiply  */
3206       /* square first */
3207       for (x = 0; x < winsize; x++) {
3208         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
3209           goto LBL_RES;
3210         }
3211         if ((err = redux (&res, P, mp)) != MP_OKAY) {
3212           goto LBL_RES;
3213         }
3214       }
3215 
3216       /* then multiply */
3217       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
3218         goto LBL_RES;
3219       }
3220       if ((err = redux (&res, P, mp)) != MP_OKAY) {
3221         goto LBL_RES;
3222       }
3223 
3224       /* empty window and reset */
3225       bitcpy = 0;
3226       bitbuf = 0;
3227       mode   = 1;
3228     }
3229   }
3230 
3231   /* if bits remain then square/multiply */
3232   if (mode == 2 && bitcpy > 0) {
3233     /* square then multiply if the bit is set */
3234     for (x = 0; x < bitcpy; x++) {
3235       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
3236         goto LBL_RES;
3237       }
3238       if ((err = redux (&res, P, mp)) != MP_OKAY) {
3239         goto LBL_RES;
3240       }
3241 
3242       /* get next bit of the window */
3243       bitbuf <<= 1;
3244       if ((bitbuf & (1 << winsize)) != 0) {
3245         /* then multiply */
3246         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
3247           goto LBL_RES;
3248         }
3249         if ((err = redux (&res, P, mp)) != MP_OKAY) {
3250           goto LBL_RES;
3251         }
3252       }
3253     }
3254   }
3255 
3256   if (redmode == 0) {
3257      /* fixup result if Montgomery reduction is used
3258       * recall that any value in a Montgomery system is
3259       * actually multiplied by R mod n.  So we have
3260       * to reduce one more time to cancel out the factor
3261       * of R.
3262       */
3263      if ((err = redux(&res, P, mp)) != MP_OKAY) {
3264        goto LBL_RES;
3265      }
3266   }
3267 
3268   /* swap res with Y */
3269   mp_exch (&res, Y);
3270   err = MP_OKAY;
3271 LBL_RES:mp_clear (&res);
3272 LBL_M:
3273   mp_clear(&M[1]);
3274   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3275     mp_clear (&M[x]);
3276   }
3277   return err;
3278 }
3279 #endif
3280 
3281 
3282 #ifdef BN_FAST_S_MP_SQR_C
3283 /* the jist of squaring...
3284  * you do like mult except the offset of the tmpx [one that
3285  * starts closer to zero] can't equal the offset of tmpy.
3286  * So basically you set up iy like before then you min it with
3287  * (ty-tx) so that it never happens.  You double all those
3288  * you add in the inner loop
3289 
3290 After that loop you do the squares and add them in.
3291 */
3292 
3293 static int
fast_s_mp_sqr(mp_int * a,mp_int * b)3294 fast_s_mp_sqr (mp_int * a, mp_int * b)
3295 {
3296   int       olduse, res, pa, ix, iz;
3297   mp_digit   W[MP_WARRAY], *tmpx;
3298   mp_word   W1;
3299 
3300   /* grow the destination as required */
3301   pa = a->used + a->used;
3302   if (b->alloc < pa) {
3303     if ((res = mp_grow (b, pa)) != MP_OKAY) {
3304       return res;
3305     }
3306   }
3307 
3308   /* number of output digits to produce */
3309   W1 = 0;
3310   for (ix = 0; ix < pa; ix++) {
3311       int      tx, ty, iy;
3312       mp_word  _W;
3313       mp_digit *tmpy;
3314 
3315       /* clear counter */
3316       _W = 0;
3317 
3318       /* get offsets into the two bignums */
3319       ty = MIN(a->used-1, ix);
3320       tx = ix - ty;
3321 
3322       /* setup temp aliases */
3323       tmpx = a->dp + tx;
3324       tmpy = a->dp + ty;
3325 
3326       /* this is the number of times the loop will iterrate, essentially
3327          while (tx++ < a->used && ty-- >= 0) { ... }
3328        */
3329       iy = MIN(a->used-tx, ty+1);
3330 
3331       /* now for squaring tx can never equal ty
3332        * we halve the distance since they approach at a rate of 2x
3333        * and we have to round because odd cases need to be executed
3334        */
3335       iy = MIN(iy, (ty-tx+1)>>1);
3336 
3337       /* execute loop */
3338       for (iz = 0; iz < iy; iz++) {
3339          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
3340       }
3341 
3342       /* double the inner product and add carry */
3343       _W = _W + _W + W1;
3344 
3345       /* even columns have the square term in them */
3346       if ((ix&1) == 0) {
3347          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
3348       }
3349 
3350       /* store it */
3351       W[ix] = (mp_digit)(_W & MP_MASK);
3352 
3353       /* make next carry */
3354       W1 = _W >> ((mp_word)DIGIT_BIT);
3355   }
3356 
3357   /* setup dest */
3358   olduse  = b->used;
3359   b->used = a->used+a->used;
3360 
3361   {
3362     mp_digit *tmpb;
3363     tmpb = b->dp;
3364     for (ix = 0; ix < pa; ix++) {
3365       *tmpb++ = W[ix] & MP_MASK;
3366     }
3367 
3368     /* clear unused digits [that existed in the old copy of c] */
3369     for (; ix < olduse; ix++) {
3370       *tmpb++ = 0;
3371     }
3372   }
3373   mp_clamp (b);
3374   return MP_OKAY;
3375 }
3376 #endif
3377 
3378 
3379 #ifdef BN_MP_MUL_D_C
3380 /* multiply by a digit */
3381 static int
mp_mul_d(mp_int * a,mp_digit b,mp_int * c)3382 mp_mul_d (mp_int * a, mp_digit b, mp_int * c)
3383 {
3384   mp_digit u, *tmpa, *tmpc;
3385   mp_word  r;
3386   int      ix, res, olduse;
3387 
3388   /* make sure c is big enough to hold a*b */
3389   if (c->alloc < a->used + 1) {
3390     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
3391       return res;
3392     }
3393   }
3394 
3395   /* get the original destinations used count */
3396   olduse = c->used;
3397 
3398   /* set the sign */
3399   c->sign = a->sign;
3400 
3401   /* alias for a->dp [source] */
3402   tmpa = a->dp;
3403 
3404   /* alias for c->dp [dest] */
3405   tmpc = c->dp;
3406 
3407   /* zero carry */
3408   u = 0;
3409 
3410   /* compute columns */
3411   for (ix = 0; ix < a->used; ix++) {
3412     /* compute product and carry sum for this term */
3413     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
3414 
3415     /* mask off higher bits to get a single digit */
3416     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
3417 
3418     /* send carry into next iteration */
3419     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
3420   }
3421 
3422   /* store final carry [if any] and increment ix offset  */
3423   *tmpc++ = u;
3424   ++ix;
3425 
3426   /* now zero digits above the top */
3427   while (ix++ < olduse) {
3428      *tmpc++ = 0;
3429   }
3430 
3431   /* set used count */
3432   c->used = a->used + 1;
3433   mp_clamp(c);
3434 
3435   return MP_OKAY;
3436 }
3437 #endif
3438