1 /*
2  * SPDX-FileCopyrightText: 2015 Joseph Birr-Pixton <jpixton@gmail.com>
3  *
4  * SPDX-License-Identifier: CC0-1.0
5  */
6 
7 /*
8  * fast-pbkdf2 - Optimal PBKDF2-HMAC calculation
9  * Written in 2015 by Joseph Birr-Pixton <jpixton@gmail.com>
10  *
11  * To the extent possible under law, the author(s) have dedicated all
12  * copyright and related and neighboring rights to this software to the
13  * public domain worldwide. This software is distributed without any
14  * warranty.
15  *
16  * You should have received a copy of the CC0 Public Domain Dedication
17  * along with this software. If not, see
18  * <http://creativecommons.org/publicdomain/zero/1.0/>.
19  */
20 #include "utils/common.h"
21 #include "fastpbkdf2.h"
22 
23 #include <assert.h>
24 #include <string.h>
25 #if defined(__GNUC__)
26 #include <endian.h>
27 #endif
28 
29 #include <mbedtls/sha1.h>
30 #include "mbedtls/esp_config.h"
31 #include "utils/wpa_debug.h"
32 
33 /* --- MSVC doesn't support C99 --- */
34 #ifdef _MSC_VER
35 #define restrict
36 #define _Pragma __pragma
37 #endif
38 
39 /* --- Common useful things --- */
40 #ifndef MIN
41 #define MIN(a, b) ((a) > (b)) ? (b) : (a)
42 #endif
43 
write32_be(uint32_t n,uint8_t out[4])44 static inline void write32_be(uint32_t n, uint8_t out[4])
45 {
46 #if defined(__GNUC__) && __GNUC__ >= 4 && __BYTE_ORDER == __LITTLE_ENDIAN
47   *(uint32_t *)(out) = __builtin_bswap32(n);
48 #else
49   out[0] = (n >> 24) & 0xff;
50   out[1] = (n >> 16) & 0xff;
51   out[2] = (n >> 8) & 0xff;
52   out[3] = n & 0xff;
53 #endif
54 }
55 
56 /* Prepare block (of blocksz bytes) to contain md padding denoting a msg-size
57  * message (in bytes).  block has a prefix of used bytes.
58  *
59  * Message length is expressed in 32 bits (so suitable for sha1, sha256, sha512). */
md_pad(uint8_t * block,size_t blocksz,size_t used,size_t msg)60 static inline void md_pad(uint8_t *block, size_t blocksz, size_t used, size_t msg)
61 {
62   memset(block + used, 0, blocksz - used - 4);
63   block[used] = 0x80;
64   block += blocksz - 4;
65   write32_be((uint32_t) (msg * 8), block);
66 }
67 
68 /* Internal function/type names for hash-specific things. */
69 #define HMAC_CTX(_name) HMAC_ ## _name ## _ctx
70 #define HMAC_INIT(_name) HMAC_ ## _name ## _init
71 #define HMAC_UPDATE(_name) HMAC_ ## _name ## _update
72 #define HMAC_FINAL(_name) HMAC_ ## _name ## _final
73 
74 #define PBKDF2_F(_name) pbkdf2_f_ ## _name
75 #define PBKDF2(_name) pbkdf2_ ## _name
76 
77 /* This macro expands to decls for the whole implementation for a given
78  * hash function.  Arguments are:
79  *
80  * _name like 'sha1', added to symbol names
81  * _blocksz block size, in bytes
82  * _hashsz digest output, in bytes
83  * _ctx hash context type
84  * _init hash context initialisation function
85  *    args: (_ctx *c)
86  * _update hash context update function
87  *    args: (_ctx *c, const void *data, size_t ndata)
88  * _final hash context finish function
89  *    args: (_ctx *c, void *out)
90  * _xform hash context raw block update function
91  *    args: (_ctx *c, const void *data)
92  * _xcpy hash context raw copy function (only need copy hash state)
93  *    args: (_ctx * restrict out, const _ctx *restrict in)
94  * _xtract hash context state extraction
95  *    args: args (_ctx *restrict c, uint8_t *restrict out)
96  * _xxor hash context xor function (only need xor hash state)
97  *    args: (_ctx *restrict out, const _ctx *restrict in)
98  *
99  * The resulting function is named PBKDF2(_name).
100  */
101 #define DECL_PBKDF2(_name, _blocksz, _hashsz, _ctx,                           \
102                     _init, _update, _xform, _final, _xcpy, _xtract, _xxor)    \
103   typedef struct {                                                            \
104     _ctx inner;                                                               \
105     _ctx outer;                                                               \
106   } HMAC_CTX(_name);                                                          \
107                                                                               \
108   static inline void HMAC_INIT(_name)(HMAC_CTX(_name) *ctx,                   \
109                                       const uint8_t *key, size_t nkey)        \
110   {                                                                           \
111     /* Prepare key: */                                                        \
112     uint8_t k[_blocksz];                                                      \
113                                                                               \
114     /* Shorten long keys. */                                                  \
115     if (nkey > _blocksz)                                                      \
116     {                                                                         \
117       _init(&ctx->inner);                                                     \
118       _update(&ctx->inner, key, nkey);                                        \
119       _final(&ctx->inner, k);                                                 \
120                                                                               \
121       key = k;                                                                \
122       nkey = _hashsz;                                                         \
123     }                                                                         \
124                                                                               \
125     /* Standard doesn't cover case where blocksz < hashsz. */                 \
126     assert(nkey <= _blocksz);                                                 \
127                                                                               \
128     /* Right zero-pad short keys. */                                          \
129     if (k != key)                                                             \
130       memcpy(k, key, nkey);                                                   \
131     if (_blocksz > nkey)                                                      \
132       memset(k + nkey, 0, _blocksz - nkey);                                   \
133                                                                               \
134     /* Start inner hash computation */                                        \
135     uint8_t blk_inner[_blocksz];                                              \
136     uint8_t blk_outer[_blocksz];                                              \
137                                                                               \
138     for (size_t i = 0; i < _blocksz; i++)                                     \
139     {                                                                         \
140       blk_inner[i] = 0x36 ^ k[i];                                             \
141       blk_outer[i] = 0x5c ^ k[i];                                             \
142     }                                                                         \
143                                                                               \
144     _init(&ctx->inner);                                                       \
145     _update(&ctx->inner, blk_inner, sizeof blk_inner);                        \
146                                                                               \
147     /* And outer. */                                                          \
148     _init(&ctx->outer);                                                       \
149     _update(&ctx->outer, blk_outer, sizeof blk_outer);                        \
150   }                                                                           \
151                                                                               \
152   static inline void HMAC_UPDATE(_name)(HMAC_CTX(_name) *ctx,                 \
153                                         const void *data, size_t ndata)       \
154   {                                                                           \
155     _update(&ctx->inner, data, ndata);                                        \
156   }                                                                           \
157                                                                               \
158   static inline void HMAC_FINAL(_name)(HMAC_CTX(_name) *ctx,                  \
159                                        uint8_t out[_hashsz])                  \
160   {                                                                           \
161     _final(&ctx->inner, out);                                                 \
162     _update(&ctx->outer, out, _hashsz);                                       \
163     _final(&ctx->outer, out);                                                 \
164   }                                                                           \
165                                                                               \
166                                                                               \
167   /* --- PBKDF2 --- */                                                        \
168   static inline void PBKDF2_F(_name)(const HMAC_CTX(_name) *startctx,         \
169                                      uint32_t counter,                        \
170                                      const uint8_t *salt, size_t nsalt,       \
171                                      uint32_t iterations,                     \
172                                      uint8_t *out)                            \
173   {                                                                           \
174     uint8_t countbuf[4];                                                      \
175     write32_be(counter, countbuf);                                            \
176                                                                               \
177     /* Prepare loop-invariant padding block. */                               \
178     uint8_t Ublock[_blocksz];                                                 \
179     md_pad(Ublock, _blocksz, _hashsz, _blocksz + _hashsz);                    \
180                                                                               \
181     /* First iteration:                                                       \
182      *   U_1 = PRF(P, S || INT_32_BE(i))                                      \
183      */                                                                       \
184     HMAC_CTX(_name) ctx = *startctx;                                          \
185     HMAC_UPDATE(_name)(&ctx, salt, nsalt);                                    \
186     HMAC_UPDATE(_name)(&ctx, countbuf, sizeof countbuf);                      \
187     HMAC_FINAL(_name)(&ctx, Ublock);                                          \
188     _ctx result = ctx.outer;                                                  \
189                                                                               \
190     /* Subsequent iterations:                                                 \
191      *   U_c = PRF(P, U_{c-1})                                                \
192      */                                                                       \
193     for (uint32_t i = 1; i < iterations; i++)                                 \
194     {                                                                         \
195       /* Complete inner hash with previous U */                               \
196       _xcpy(&ctx.inner, &startctx->inner);                                    \
197       _xform(&ctx.inner, Ublock);                                             \
198       _xtract(&ctx.inner, Ublock);                                            \
199       /* Complete outer hash with inner output */                             \
200       _xcpy(&ctx.outer, &startctx->outer);                                    \
201       _xform(&ctx.outer, Ublock);                                             \
202       _xtract(&ctx.outer, Ublock);                                            \
203       _xxor(&result, &ctx.outer);                                             \
204     }                                                                         \
205                                                                               \
206     /* Reform result into output buffer. */                                   \
207     _xtract(&result, out);                                                    \
208   }                                                                           \
209                                                                               \
210   static inline void PBKDF2(_name)(const uint8_t *pw, size_t npw,             \
211                      const uint8_t *salt, size_t nsalt,                       \
212                      uint32_t iterations,                                     \
213                      uint8_t *out, size_t nout)                               \
214   {                                                                           \
215     assert(iterations);                                                       \
216     assert(out && nout);                                                      \
217                                                                               \
218     /* Starting point for inner loop. */                                      \
219     HMAC_CTX(_name) ctx;                                                      \
220     HMAC_INIT(_name)(&ctx, pw, npw);                                          \
221                                                                               \
222     /* How many blocks do we need? */                                         \
223     uint32_t blocks_needed = (uint32_t)(nout + _hashsz - 1) / _hashsz;        \
224                                                                               \
225     for (uint32_t counter = 1; counter <= blocks_needed; counter++)           \
226     {                                                                         \
227       uint8_t block[_hashsz];                                                 \
228       PBKDF2_F(_name)(&ctx, counter, salt, nsalt, iterations, block);         \
229                                                                               \
230       size_t offset = (counter - 1) * _hashsz;                                \
231       size_t taken = MIN(nout - offset, _hashsz);                             \
232       memcpy(out + offset, block, taken);                                     \
233     }                                                                         \
234   }
235 
sha1_extract(mbedtls_sha1_context * restrict ctx,uint8_t * restrict out)236 static inline void sha1_extract(mbedtls_sha1_context *restrict ctx, uint8_t *restrict out)
237 {
238 #if defined(MBEDTLS_SHA1_ALT)
239 #if CONFIG_IDF_TARGET_ESP32
240   /* ESP32 stores internal SHA state in BE format similar to software */
241   write32_be(ctx->state[0], out);
242   write32_be(ctx->state[1], out + 4);
243   write32_be(ctx->state[2], out + 8);
244   write32_be(ctx->state[3], out + 12);
245   write32_be(ctx->state[4], out + 16);
246 #else
247   *(uint32_t *)(out) = ctx->state[0];
248   *(uint32_t *)(out + 4) = ctx->state[1];
249   *(uint32_t *)(out + 8) = ctx->state[2];
250   *(uint32_t *)(out + 12) = ctx->state[3];
251   *(uint32_t *)(out + 16) = ctx->state[4];
252 #endif
253 #else
254   write32_be(ctx->MBEDTLS_PRIVATE(state)[0], out);
255   write32_be(ctx->MBEDTLS_PRIVATE(state)[1], out + 4);
256   write32_be(ctx->MBEDTLS_PRIVATE(state)[2], out + 8);
257   write32_be(ctx->MBEDTLS_PRIVATE(state)[3], out + 12);
258   write32_be(ctx->MBEDTLS_PRIVATE(state)[4], out + 16);
259 #endif
260 }
261 
sha1_cpy(mbedtls_sha1_context * restrict out,const mbedtls_sha1_context * restrict in)262 static inline void sha1_cpy(mbedtls_sha1_context *restrict out, const mbedtls_sha1_context *restrict in)
263 {
264 #if defined(MBEDTLS_SHA1_ALT)
265   out->state[0] = in->state[0];
266   out->state[1] = in->state[1];
267   out->state[2] = in->state[2];
268   out->state[3] = in->state[3];
269   out->state[4] = in->state[4];
270 #else
271   out->MBEDTLS_PRIVATE(state)[0] = in->MBEDTLS_PRIVATE(state)[0];
272   out->MBEDTLS_PRIVATE(state)[1] = in->MBEDTLS_PRIVATE(state)[1];
273   out->MBEDTLS_PRIVATE(state)[2] = in->MBEDTLS_PRIVATE(state)[2];
274   out->MBEDTLS_PRIVATE(state)[3] = in->MBEDTLS_PRIVATE(state)[3];
275   out->MBEDTLS_PRIVATE(state)[4] = in->MBEDTLS_PRIVATE(state)[4];
276 #endif
277 }
278 
sha1_xor(mbedtls_sha1_context * restrict out,const mbedtls_sha1_context * restrict in)279 static inline void sha1_xor(mbedtls_sha1_context *restrict out, const mbedtls_sha1_context *restrict in)
280 {
281 #if defined(MBEDTLS_SHA1_ALT)
282   out->state[0] ^= in->state[0];
283   out->state[1] ^= in->state[1];
284   out->state[2] ^= in->state[2];
285   out->state[3] ^= in->state[3];
286   out->state[4] ^= in->state[4];
287 #else
288   out->MBEDTLS_PRIVATE(state)[0] ^= in->MBEDTLS_PRIVATE(state)[0];
289   out->MBEDTLS_PRIVATE(state)[1] ^= in->MBEDTLS_PRIVATE(state)[1];
290   out->MBEDTLS_PRIVATE(state)[2] ^= in->MBEDTLS_PRIVATE(state)[2];
291   out->MBEDTLS_PRIVATE(state)[3] ^= in->MBEDTLS_PRIVATE(state)[3];
292   out->MBEDTLS_PRIVATE(state)[4] ^= in->MBEDTLS_PRIVATE(state)[4];
293 #endif
294 }
295 
mbedtls_sha1_init_start(mbedtls_sha1_context * ctx)296 static int mbedtls_sha1_init_start(mbedtls_sha1_context *ctx)
297 {
298   mbedtls_sha1_init(ctx);
299   mbedtls_sha1_starts(ctx);
300 #if defined(CONFIG_IDF_TARGET_ESP32) && defined(MBEDTLS_SHA1_ALT)
301   /* Use software mode for esp32 since hardware can't give output more than 20 */
302   esp_mbedtls_set_sha1_mode(ctx, ESP_MBEDTLS_SHA1_SOFTWARE);
303 #endif
304   return 0;
305 }
306 
307 #ifndef MBEDTLS_SHA1_ALT
sha1_finish(mbedtls_sha1_context * ctx,unsigned char output[20])308 static int sha1_finish(mbedtls_sha1_context *ctx,
309                         unsigned char output[20])
310 {
311     int ret = -1;
312     uint32_t used;
313     uint32_t high, low;
314 
315     /*
316      * Add padding: 0x80 then 0x00 until 8 bytes remain for the length
317      */
318     used = ctx->MBEDTLS_PRIVATE(total)[0] & 0x3F;
319 
320     ctx->MBEDTLS_PRIVATE(buffer)[used++] = 0x80;
321 
322     if (used <= 56) {
323         /* Enough room for padding + length in current block */
324         memset(ctx->MBEDTLS_PRIVATE(buffer) + used, 0, 56 - used);
325     } else {
326         /* We'll need an extra block */
327         memset(ctx->MBEDTLS_PRIVATE(buffer) + used, 0, 64 - used);
328 
329         if ((ret = mbedtls_internal_sha1_process(ctx, ctx->MBEDTLS_PRIVATE(buffer))) != 0) {
330             goto exit;
331         }
332 
333         memset(ctx->MBEDTLS_PRIVATE(buffer), 0, 56);
334     }
335 
336     /*
337      * Add message length
338      */
339     high = (ctx->MBEDTLS_PRIVATE(total)[0] >> 29)
340            | (ctx->MBEDTLS_PRIVATE(total)[1] <<  3);
341     low  = (ctx->MBEDTLS_PRIVATE(total)[0] <<  3);
342 
343     write32_be(high, ctx->MBEDTLS_PRIVATE(buffer) + 56);
344     write32_be(low, ctx->MBEDTLS_PRIVATE(buffer) + 60);
345 
346     if ((ret = mbedtls_internal_sha1_process(ctx, ctx->MBEDTLS_PRIVATE(buffer))) != 0) {
347         goto exit;
348     }
349 
350     /*
351      * Output final state
352      */
353     write32_be(ctx->MBEDTLS_PRIVATE(state)[0], output);
354     write32_be(ctx->MBEDTLS_PRIVATE(state)[1], output + 4);
355     write32_be(ctx->MBEDTLS_PRIVATE(state)[2], output + 8);
356     write32_be(ctx->MBEDTLS_PRIVATE(state)[3], output + 12);
357     write32_be(ctx->MBEDTLS_PRIVATE(state)[4], output + 16);
358 
359     ret = 0;
360 
361 exit:
362     return ret;
363 }
364 #endif
365 
366 DECL_PBKDF2(sha1,                           // _name
367             64,                             // _blocksz
368             20,                             // _hashsz
369             mbedtls_sha1_context,           // _ctx
370             mbedtls_sha1_init_start,        // _init
371             mbedtls_sha1_update,            // _update
372             mbedtls_internal_sha1_process,  // _xform
373 #if defined(MBEDTLS_SHA1_ALT)
374             mbedtls_sha1_finish,            // _final
375 #else
376             sha1_finish,                   // _final
377 #endif
378             sha1_cpy,                       // _xcpy
379             sha1_extract,                   // _xtract
380             sha1_xor)                       // _xxor
381 
fastpbkdf2_hmac_sha1(const uint8_t * pw,size_t npw,const uint8_t * salt,size_t nsalt,uint32_t iterations,uint8_t * out,size_t nout)382 void fastpbkdf2_hmac_sha1(const uint8_t *pw, size_t npw,
383                           const uint8_t *salt, size_t nsalt,
384                           uint32_t iterations,
385                           uint8_t *out, size_t nout)
386 {
387   PBKDF2(sha1)(pw, npw, salt, nsalt, iterations, out, nout);
388 }
389