1 /*
2  *  SSL session cache implementation
3  *
4  *  Copyright The Mbed TLS Contributors
5  *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6  */
7 /*
8  * These session callbacks use a simple chained list
9  * to store and retrieve the session information.
10  */
11 
12 #include "common.h"
13 
14 #if defined(MBEDTLS_SSL_CACHE_C)
15 
16 #include "mbedtls/platform.h"
17 
18 #include "mbedtls/ssl_cache.h"
19 #include "ssl_misc.h"
20 #include "mbedtls/error.h"
21 
22 #include <string.h>
23 
mbedtls_ssl_cache_init(mbedtls_ssl_cache_context * cache)24 void mbedtls_ssl_cache_init(mbedtls_ssl_cache_context *cache)
25 {
26     memset(cache, 0, sizeof(mbedtls_ssl_cache_context));
27 
28     cache->timeout = MBEDTLS_SSL_CACHE_DEFAULT_TIMEOUT;
29     cache->max_entries = MBEDTLS_SSL_CACHE_DEFAULT_MAX_ENTRIES;
30 
31 #if defined(MBEDTLS_THREADING_C)
32     mbedtls_mutex_init(&cache->mutex);
33 #endif
34 }
35 
36 MBEDTLS_CHECK_RETURN_CRITICAL
ssl_cache_find_entry(mbedtls_ssl_cache_context * cache,unsigned char const * session_id,size_t session_id_len,mbedtls_ssl_cache_entry ** dst)37 static int ssl_cache_find_entry(mbedtls_ssl_cache_context *cache,
38                                 unsigned char const *session_id,
39                                 size_t session_id_len,
40                                 mbedtls_ssl_cache_entry **dst)
41 {
42     int ret = MBEDTLS_ERR_SSL_CACHE_ENTRY_NOT_FOUND;
43 #if defined(MBEDTLS_HAVE_TIME)
44     mbedtls_time_t t = mbedtls_time(NULL);
45 #endif
46     mbedtls_ssl_cache_entry *cur;
47 
48     for (cur = cache->chain; cur != NULL; cur = cur->next) {
49 #if defined(MBEDTLS_HAVE_TIME)
50         if (cache->timeout != 0 &&
51             (int) (t - cur->timestamp) > cache->timeout) {
52             continue;
53         }
54 #endif
55 
56         if (session_id_len != cur->session_id_len ||
57             memcmp(session_id, cur->session_id,
58                    cur->session_id_len) != 0) {
59             continue;
60         }
61 
62         break;
63     }
64 
65     if (cur != NULL) {
66         *dst = cur;
67         ret = 0;
68     }
69 
70     return ret;
71 }
72 
73 
mbedtls_ssl_cache_get(void * data,unsigned char const * session_id,size_t session_id_len,mbedtls_ssl_session * session)74 int mbedtls_ssl_cache_get(void *data,
75                           unsigned char const *session_id,
76                           size_t session_id_len,
77                           mbedtls_ssl_session *session)
78 {
79     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
80     mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
81     mbedtls_ssl_cache_entry *entry;
82 
83 #if defined(MBEDTLS_THREADING_C)
84     if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
85         return ret;
86     }
87 #endif
88 
89     ret = ssl_cache_find_entry(cache, session_id, session_id_len, &entry);
90     if (ret != 0) {
91         goto exit;
92     }
93 
94     ret = mbedtls_ssl_session_load(session,
95                                    entry->session,
96                                    entry->session_len);
97     if (ret != 0) {
98         goto exit;
99     }
100 
101     ret = 0;
102 
103 exit:
104 #if defined(MBEDTLS_THREADING_C)
105     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
106         ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
107     }
108 #endif
109 
110     return ret;
111 }
112 
113 /* zeroize a cache entry */
ssl_cache_entry_zeroize(mbedtls_ssl_cache_entry * entry)114 static void ssl_cache_entry_zeroize(mbedtls_ssl_cache_entry *entry)
115 {
116     if (entry == NULL) {
117         return;
118     }
119 
120     /* zeroize and free session structure */
121     if (entry->session != NULL) {
122         mbedtls_zeroize_and_free(entry->session, entry->session_len);
123     }
124 
125     /* zeroize the whole entry structure */
126     mbedtls_platform_zeroize(entry, sizeof(mbedtls_ssl_cache_entry));
127 }
128 
129 MBEDTLS_CHECK_RETURN_CRITICAL
ssl_cache_pick_writing_slot(mbedtls_ssl_cache_context * cache,unsigned char const * session_id,size_t session_id_len,mbedtls_ssl_cache_entry ** dst)130 static int ssl_cache_pick_writing_slot(mbedtls_ssl_cache_context *cache,
131                                        unsigned char const *session_id,
132                                        size_t session_id_len,
133                                        mbedtls_ssl_cache_entry **dst)
134 {
135 #if defined(MBEDTLS_HAVE_TIME)
136     mbedtls_time_t t = mbedtls_time(NULL), oldest = 0;
137 #endif /* MBEDTLS_HAVE_TIME */
138 
139     mbedtls_ssl_cache_entry *old = NULL;
140     int count = 0;
141     mbedtls_ssl_cache_entry *cur, *last;
142 
143     /* Check 1: Is there already an entry with the given session ID?
144      *
145      * If yes, overwrite it.
146      *
147      * If not, `count` will hold the size of the session cache
148      * at the end of this loop, and `last` will point to the last
149      * entry, both of which will be used later. */
150 
151     last = NULL;
152     for (cur = cache->chain; cur != NULL; cur = cur->next) {
153         count++;
154         if (session_id_len == cur->session_id_len &&
155             memcmp(session_id, cur->session_id, cur->session_id_len) == 0) {
156             goto found;
157         }
158         last = cur;
159     }
160 
161     /* Check 2: Is there an outdated entry in the cache?
162      *
163      * If so, overwrite it.
164      *
165      * If not, remember the oldest entry in `old` for later.
166      */
167 
168 #if defined(MBEDTLS_HAVE_TIME)
169     for (cur = cache->chain; cur != NULL; cur = cur->next) {
170         if (cache->timeout != 0 &&
171             (int) (t - cur->timestamp) > cache->timeout) {
172             goto found;
173         }
174 
175         if (oldest == 0 || cur->timestamp < oldest) {
176             oldest = cur->timestamp;
177             old = cur;
178         }
179     }
180 #endif /* MBEDTLS_HAVE_TIME */
181 
182     /* Check 3: Is there free space in the cache? */
183 
184     if (count < cache->max_entries) {
185         /* Create new entry */
186         cur = mbedtls_calloc(1, sizeof(mbedtls_ssl_cache_entry));
187         if (cur == NULL) {
188             return MBEDTLS_ERR_SSL_ALLOC_FAILED;
189         }
190 
191         /* Append to the end of the linked list. */
192         if (last == NULL) {
193             cache->chain = cur;
194         } else {
195             last->next = cur;
196         }
197 
198         goto found;
199     }
200 
201     /* Last resort: The cache is full and doesn't contain any outdated
202      * elements. In this case, we evict the oldest one, judged by timestamp
203      * (if present) or cache-order. */
204 
205 #if defined(MBEDTLS_HAVE_TIME)
206     if (old == NULL) {
207         /* This should only happen on an ill-configured cache
208          * with max_entries == 0. */
209         return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
210     }
211 #else /* MBEDTLS_HAVE_TIME */
212     /* Reuse first entry in chain, but move to last place. */
213     if (cache->chain == NULL) {
214         /* This should never happen */
215         return MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
216     }
217 
218     old = cache->chain;
219     cache->chain = old->next;
220     old->next = NULL;
221     last->next = old;
222 #endif /* MBEDTLS_HAVE_TIME */
223 
224     /* Now `old` points to the oldest entry to be overwritten. */
225     cur = old;
226 
227 found:
228 
229     /* If we're reusing an entry, free it first. */
230     if (cur->session != NULL) {
231         /* `ssl_cache_entry_zeroize` would break the chain,
232          * so we reuse `old` to record `next` temporarily. */
233         old = cur->next;
234         ssl_cache_entry_zeroize(cur);
235         cur->next = old;
236     }
237 
238 #if defined(MBEDTLS_HAVE_TIME)
239     cur->timestamp = t;
240 #endif
241 
242     *dst = cur;
243     return 0;
244 }
245 
mbedtls_ssl_cache_set(void * data,unsigned char const * session_id,size_t session_id_len,const mbedtls_ssl_session * session)246 int mbedtls_ssl_cache_set(void *data,
247                           unsigned char const *session_id,
248                           size_t session_id_len,
249                           const mbedtls_ssl_session *session)
250 {
251     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
252     mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
253     mbedtls_ssl_cache_entry *cur;
254 
255     size_t session_serialized_len = 0;
256     unsigned char *session_serialized = NULL;
257 
258 #if defined(MBEDTLS_THREADING_C)
259     if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
260         return ret;
261     }
262 #endif
263 
264     ret = ssl_cache_pick_writing_slot(cache,
265                                       session_id, session_id_len,
266                                       &cur);
267     if (ret != 0) {
268         goto exit;
269     }
270 
271     /* Check how much space we need to serialize the session
272      * and allocate a sufficiently large buffer. */
273     ret = mbedtls_ssl_session_save(session, NULL, 0, &session_serialized_len);
274     if (ret != MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL) {
275         goto exit;
276     }
277 
278     session_serialized = mbedtls_calloc(1, session_serialized_len);
279     if (session_serialized == NULL) {
280         ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
281         goto exit;
282     }
283 
284     /* Now serialize the session into the allocated buffer. */
285     ret = mbedtls_ssl_session_save(session,
286                                    session_serialized,
287                                    session_serialized_len,
288                                    &session_serialized_len);
289     if (ret != 0) {
290         goto exit;
291     }
292 
293     if (session_id_len > sizeof(cur->session_id)) {
294         ret = MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
295         goto exit;
296     }
297     cur->session_id_len = session_id_len;
298     memcpy(cur->session_id, session_id, session_id_len);
299 
300     cur->session = session_serialized;
301     cur->session_len = session_serialized_len;
302     session_serialized = NULL;
303 
304     ret = 0;
305 
306 exit:
307 #if defined(MBEDTLS_THREADING_C)
308     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
309         ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
310     }
311 #endif
312 
313     if (session_serialized != NULL) {
314         mbedtls_zeroize_and_free(session_serialized, session_serialized_len);
315         session_serialized = NULL;
316     }
317 
318     return ret;
319 }
320 
mbedtls_ssl_cache_remove(void * data,unsigned char const * session_id,size_t session_id_len)321 int mbedtls_ssl_cache_remove(void *data,
322                              unsigned char const *session_id,
323                              size_t session_id_len)
324 {
325     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
326     mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
327     mbedtls_ssl_cache_entry *entry;
328     mbedtls_ssl_cache_entry *prev;
329 
330 #if defined(MBEDTLS_THREADING_C)
331     if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
332         return ret;
333     }
334 #endif
335 
336     ret = ssl_cache_find_entry(cache, session_id, session_id_len, &entry);
337     /* No valid entry found, exit with success */
338     if (ret != 0) {
339         ret = 0;
340         goto exit;
341     }
342 
343     /* Now we remove the entry from the chain */
344     if (entry == cache->chain) {
345         cache->chain = entry->next;
346         goto free;
347     }
348     for (prev = cache->chain; prev->next != NULL; prev = prev->next) {
349         if (prev->next == entry) {
350             prev->next = entry->next;
351             break;
352         }
353     }
354 
355 free:
356     ssl_cache_entry_zeroize(entry);
357     mbedtls_free(entry);
358     ret = 0;
359 
360 exit:
361 #if defined(MBEDTLS_THREADING_C)
362     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
363         ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
364     }
365 #endif
366 
367     return ret;
368 }
369 
370 #if defined(MBEDTLS_HAVE_TIME)
mbedtls_ssl_cache_set_timeout(mbedtls_ssl_cache_context * cache,int timeout)371 void mbedtls_ssl_cache_set_timeout(mbedtls_ssl_cache_context *cache, int timeout)
372 {
373     if (timeout < 0) {
374         timeout = 0;
375     }
376 
377     cache->timeout = timeout;
378 }
379 #endif /* MBEDTLS_HAVE_TIME */
380 
mbedtls_ssl_cache_set_max_entries(mbedtls_ssl_cache_context * cache,int max)381 void mbedtls_ssl_cache_set_max_entries(mbedtls_ssl_cache_context *cache, int max)
382 {
383     if (max < 0) {
384         max = 0;
385     }
386 
387     cache->max_entries = max;
388 }
389 
mbedtls_ssl_cache_free(mbedtls_ssl_cache_context * cache)390 void mbedtls_ssl_cache_free(mbedtls_ssl_cache_context *cache)
391 {
392     mbedtls_ssl_cache_entry *cur, *prv;
393 
394     cur = cache->chain;
395 
396     while (cur != NULL) {
397         prv = cur;
398         cur = cur->next;
399 
400         ssl_cache_entry_zeroize(prv);
401         mbedtls_free(prv);
402     }
403 
404 #if defined(MBEDTLS_THREADING_C)
405     mbedtls_mutex_free(&cache->mutex);
406 #endif
407     cache->chain = NULL;
408 }
409 
410 #endif /* MBEDTLS_SSL_CACHE_C */
411