1 /*
2  *  SSL session cache implementation
3  *
4  *  Copyright The Mbed TLS Contributors
5  *  SPDX-License-Identifier: Apache-2.0
6  *
7  *  Licensed under the Apache License, Version 2.0 (the "License"); you may
8  *  not use this file except in compliance with the License.
9  *  You may obtain a copy of the License at
10  *
11  *  http://www.apache.org/licenses/LICENSE-2.0
12  *
13  *  Unless required by applicable law or agreed to in writing, software
14  *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
15  *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  *  See the License for the specific language governing permissions and
17  *  limitations under the License.
18  */
19 /*
20  * These session callbacks use a simple chained list
21  * to store and retrieve the session information.
22  */
23 
24 #include "common.h"
25 
26 #if defined(MBEDTLS_SSL_CACHE_C)
27 
28 #include "mbedtls/platform.h"
29 
30 #include "mbedtls/ssl_cache.h"
31 #include "ssl_misc.h"
32 
33 #include <string.h>
34 
mbedtls_ssl_cache_init(mbedtls_ssl_cache_context * cache)35 void mbedtls_ssl_cache_init(mbedtls_ssl_cache_context *cache)
36 {
37     memset(cache, 0, sizeof(mbedtls_ssl_cache_context));
38 
39     cache->timeout = MBEDTLS_SSL_CACHE_DEFAULT_TIMEOUT;
40     cache->max_entries = MBEDTLS_SSL_CACHE_DEFAULT_MAX_ENTRIES;
41 
42 #if defined(MBEDTLS_THREADING_C)
43     mbedtls_mutex_init(&cache->mutex);
44 #endif
45 }
46 
47 MBEDTLS_CHECK_RETURN_CRITICAL
ssl_cache_find_entry(mbedtls_ssl_cache_context * cache,unsigned char const * session_id,size_t session_id_len,mbedtls_ssl_cache_entry ** dst)48 static int ssl_cache_find_entry(mbedtls_ssl_cache_context *cache,
49                                 unsigned char const *session_id,
50                                 size_t session_id_len,
51                                 mbedtls_ssl_cache_entry **dst)
52 {
53     int ret = 1;
54 #if defined(MBEDTLS_HAVE_TIME)
55     mbedtls_time_t t = mbedtls_time(NULL);
56 #endif
57     mbedtls_ssl_cache_entry *cur;
58 
59     for (cur = cache->chain; cur != NULL; cur = cur->next) {
60 #if defined(MBEDTLS_HAVE_TIME)
61         if (cache->timeout != 0 &&
62             (int) (t - cur->timestamp) > cache->timeout) {
63             continue;
64         }
65 #endif
66 
67         if (session_id_len != cur->session_id_len ||
68             memcmp(session_id, cur->session_id,
69                    cur->session_id_len) != 0) {
70             continue;
71         }
72 
73         break;
74     }
75 
76     if (cur != NULL) {
77         *dst = cur;
78         ret = 0;
79     }
80 
81     return ret;
82 }
83 
84 
mbedtls_ssl_cache_get(void * data,unsigned char const * session_id,size_t session_id_len,mbedtls_ssl_session * session)85 int mbedtls_ssl_cache_get(void *data,
86                           unsigned char const *session_id,
87                           size_t session_id_len,
88                           mbedtls_ssl_session *session)
89 {
90     int ret = 1;
91     mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
92     mbedtls_ssl_cache_entry *entry;
93 
94 #if defined(MBEDTLS_THREADING_C)
95     if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
96         return ret;
97     }
98 #endif
99 
100     ret = ssl_cache_find_entry(cache, session_id, session_id_len, &entry);
101     if (ret != 0) {
102         goto exit;
103     }
104 
105     ret = mbedtls_ssl_session_load(session,
106                                    entry->session,
107                                    entry->session_len);
108     if (ret != 0) {
109         goto exit;
110     }
111 
112     ret = 0;
113 
114 exit:
115 #if defined(MBEDTLS_THREADING_C)
116     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
117         ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
118     }
119 #endif
120 
121     return ret;
122 }
123 
124 /* zeroize a cache entry */
ssl_cache_entry_zeroize(mbedtls_ssl_cache_entry * entry)125 static void ssl_cache_entry_zeroize(mbedtls_ssl_cache_entry *entry)
126 {
127     if (entry == NULL) {
128         return;
129     }
130 
131     /* zeroize and free session structure */
132     if (entry->session != NULL) {
133         mbedtls_platform_zeroize(entry->session, entry->session_len);
134         mbedtls_free(entry->session);
135     }
136 
137     /* zeroize the whole entry structure */
138     mbedtls_platform_zeroize(entry, sizeof(mbedtls_ssl_cache_entry));
139 }
140 
141 MBEDTLS_CHECK_RETURN_CRITICAL
ssl_cache_pick_writing_slot(mbedtls_ssl_cache_context * cache,unsigned char const * session_id,size_t session_id_len,mbedtls_ssl_cache_entry ** dst)142 static int ssl_cache_pick_writing_slot(mbedtls_ssl_cache_context *cache,
143                                        unsigned char const *session_id,
144                                        size_t session_id_len,
145                                        mbedtls_ssl_cache_entry **dst)
146 {
147 #if defined(MBEDTLS_HAVE_TIME)
148     mbedtls_time_t t = mbedtls_time(NULL), oldest = 0;
149 #endif /* MBEDTLS_HAVE_TIME */
150 
151     mbedtls_ssl_cache_entry *old = NULL;
152     int count = 0;
153     mbedtls_ssl_cache_entry *cur, *last;
154 
155     /* Check 1: Is there already an entry with the given session ID?
156      *
157      * If yes, overwrite it.
158      *
159      * If not, `count` will hold the size of the session cache
160      * at the end of this loop, and `last` will point to the last
161      * entry, both of which will be used later. */
162 
163     last = NULL;
164     for (cur = cache->chain; cur != NULL; cur = cur->next) {
165         count++;
166         if (session_id_len == cur->session_id_len &&
167             memcmp(session_id, cur->session_id, cur->session_id_len) == 0) {
168             goto found;
169         }
170         last = cur;
171     }
172 
173     /* Check 2: Is there an outdated entry in the cache?
174      *
175      * If so, overwrite it.
176      *
177      * If not, remember the oldest entry in `old` for later.
178      */
179 
180 #if defined(MBEDTLS_HAVE_TIME)
181     for (cur = cache->chain; cur != NULL; cur = cur->next) {
182         if (cache->timeout != 0 &&
183             (int) (t - cur->timestamp) > cache->timeout) {
184             goto found;
185         }
186 
187         if (oldest == 0 || cur->timestamp < oldest) {
188             oldest = cur->timestamp;
189             old = cur;
190         }
191     }
192 #endif /* MBEDTLS_HAVE_TIME */
193 
194     /* Check 3: Is there free space in the cache? */
195 
196     if (count < cache->max_entries) {
197         /* Create new entry */
198         cur = mbedtls_calloc(1, sizeof(mbedtls_ssl_cache_entry));
199         if (cur == NULL) {
200             return 1;
201         }
202 
203         /* Append to the end of the linked list. */
204         if (last == NULL) {
205             cache->chain = cur;
206         } else {
207             last->next = cur;
208         }
209 
210         goto found;
211     }
212 
213     /* Last resort: The cache is full and doesn't contain any outdated
214      * elements. In this case, we evict the oldest one, judged by timestamp
215      * (if present) or cache-order. */
216 
217 #if defined(MBEDTLS_HAVE_TIME)
218     if (old == NULL) {
219         /* This should only happen on an ill-configured cache
220          * with max_entries == 0. */
221         return 1;
222     }
223 #else /* MBEDTLS_HAVE_TIME */
224     /* Reuse first entry in chain, but move to last place. */
225     if (cache->chain == NULL) {
226         return 1;
227     }
228 
229     old = cache->chain;
230     cache->chain = old->next;
231     old->next = NULL;
232     last->next = old;
233 #endif /* MBEDTLS_HAVE_TIME */
234 
235     /* Now `old` points to the oldest entry to be overwritten. */
236     cur = old;
237 
238 found:
239 
240     /* If we're reusing an entry, free it first. */
241     if (cur->session != NULL) {
242         /* `ssl_cache_entry_zeroize` would break the chain,
243          * so we reuse `old` to record `next` temporarily. */
244         old = cur->next;
245         ssl_cache_entry_zeroize(cur);
246         cur->next = old;
247     }
248 
249 #if defined(MBEDTLS_HAVE_TIME)
250     cur->timestamp = t;
251 #endif
252 
253     *dst = cur;
254     return 0;
255 }
256 
mbedtls_ssl_cache_set(void * data,unsigned char const * session_id,size_t session_id_len,const mbedtls_ssl_session * session)257 int mbedtls_ssl_cache_set(void *data,
258                           unsigned char const *session_id,
259                           size_t session_id_len,
260                           const mbedtls_ssl_session *session)
261 {
262     int ret = 1;
263     mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
264     mbedtls_ssl_cache_entry *cur;
265 
266     size_t session_serialized_len;
267     unsigned char *session_serialized = NULL;
268 
269 #if defined(MBEDTLS_THREADING_C)
270     if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
271         return ret;
272     }
273 #endif
274 
275     ret = ssl_cache_pick_writing_slot(cache,
276                                       session_id, session_id_len,
277                                       &cur);
278     if (ret != 0) {
279         goto exit;
280     }
281 
282     /* Check how much space we need to serialize the session
283      * and allocate a sufficiently large buffer. */
284     ret = mbedtls_ssl_session_save(session, NULL, 0, &session_serialized_len);
285     if (ret != MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL) {
286         ret = 1;
287         goto exit;
288     }
289 
290     session_serialized = mbedtls_calloc(1, session_serialized_len);
291     if (session_serialized == NULL) {
292         ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
293         goto exit;
294     }
295 
296     /* Now serialize the session into the allocated buffer. */
297     ret = mbedtls_ssl_session_save(session,
298                                    session_serialized,
299                                    session_serialized_len,
300                                    &session_serialized_len);
301     if (ret != 0) {
302         goto exit;
303     }
304 
305     if (session_id_len > sizeof(cur->session_id)) {
306         ret = 1;
307         goto exit;
308     }
309     cur->session_id_len = session_id_len;
310     memcpy(cur->session_id, session_id, session_id_len);
311 
312     cur->session = session_serialized;
313     cur->session_len = session_serialized_len;
314     session_serialized = NULL;
315 
316     ret = 0;
317 
318 exit:
319 #if defined(MBEDTLS_THREADING_C)
320     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
321         ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
322     }
323 #endif
324 
325     if (session_serialized != NULL) {
326         mbedtls_platform_zeroize(session_serialized, session_serialized_len);
327         mbedtls_free(session_serialized);
328         session_serialized = NULL;
329     }
330 
331     return ret;
332 }
333 
mbedtls_ssl_cache_remove(void * data,unsigned char const * session_id,size_t session_id_len)334 int mbedtls_ssl_cache_remove(void *data,
335                              unsigned char const *session_id,
336                              size_t session_id_len)
337 {
338     int ret = 1;
339     mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
340     mbedtls_ssl_cache_entry *entry;
341     mbedtls_ssl_cache_entry *prev;
342 
343 #if defined(MBEDTLS_THREADING_C)
344     if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
345         return ret;
346     }
347 #endif
348 
349     ret = ssl_cache_find_entry(cache, session_id, session_id_len, &entry);
350     /* No valid entry found, exit with success */
351     if (ret != 0) {
352         ret = 0;
353         goto exit;
354     }
355 
356     /* Now we remove the entry from the chain */
357     if (entry == cache->chain) {
358         cache->chain = entry->next;
359         goto free;
360     }
361     for (prev = cache->chain; prev->next != NULL; prev = prev->next) {
362         if (prev->next == entry) {
363             prev->next = entry->next;
364             break;
365         }
366     }
367 
368 free:
369     ssl_cache_entry_zeroize(entry);
370     mbedtls_free(entry);
371     ret = 0;
372 
373 exit:
374 #if defined(MBEDTLS_THREADING_C)
375     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
376         ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
377     }
378 #endif
379 
380     return ret;
381 }
382 
383 #if defined(MBEDTLS_HAVE_TIME)
mbedtls_ssl_cache_set_timeout(mbedtls_ssl_cache_context * cache,int timeout)384 void mbedtls_ssl_cache_set_timeout(mbedtls_ssl_cache_context *cache, int timeout)
385 {
386     if (timeout < 0) {
387         timeout = 0;
388     }
389 
390     cache->timeout = timeout;
391 }
392 #endif /* MBEDTLS_HAVE_TIME */
393 
mbedtls_ssl_cache_set_max_entries(mbedtls_ssl_cache_context * cache,int max)394 void mbedtls_ssl_cache_set_max_entries(mbedtls_ssl_cache_context *cache, int max)
395 {
396     if (max < 0) {
397         max = 0;
398     }
399 
400     cache->max_entries = max;
401 }
402 
mbedtls_ssl_cache_free(mbedtls_ssl_cache_context * cache)403 void mbedtls_ssl_cache_free(mbedtls_ssl_cache_context *cache)
404 {
405     mbedtls_ssl_cache_entry *cur, *prv;
406 
407     cur = cache->chain;
408 
409     while (cur != NULL) {
410         prv = cur;
411         cur = cur->next;
412 
413         ssl_cache_entry_zeroize(prv);
414         mbedtls_free(prv);
415     }
416 
417 #if defined(MBEDTLS_THREADING_C)
418     mbedtls_mutex_free(&cache->mutex);
419 #endif
420     cache->chain = NULL;
421 }
422 
423 #endif /* MBEDTLS_SSL_CACHE_C */
424