1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the necessary hooks for mbedTLS.
32  */
33 
34 #include "secure_transport.hpp"
35 
36 #include <mbedtls/debug.h>
37 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
38 #include <mbedtls/pem.h>
39 #endif
40 
41 #include <openthread/platform/radio.h>
42 
43 #include "common/as_core_type.hpp"
44 #include "common/clearable.hpp"
45 #include "common/code_utils.hpp"
46 #include "common/debug.hpp"
47 #include "common/encoding.hpp"
48 #include "common/locator_getters.hpp"
49 #include "common/log.hpp"
50 #include "common/timer.hpp"
51 #include "crypto/mbedtls.hpp"
52 #include "crypto/sha256.hpp"
53 #include "instance/instance.hpp"
54 #include "thread/thread_netif.hpp"
55 
56 #if OPENTHREAD_CONFIG_SECURE_TRANSPORT_ENABLE
57 
58 namespace ot {
59 namespace MeshCoP {
60 
61 RegisterLogModule("SecTransport");
62 
63 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
64 const uint16_t SecureTransport::sGroups[] = {MBEDTLS_SSL_IANA_TLS_GROUP_SECP256R1, MBEDTLS_SSL_IANA_TLS_GROUP_NONE};
65 #else
66 const mbedtls_ecp_group_id SecureTransport::sCurves[] = {MBEDTLS_ECP_DP_SECP256R1, MBEDTLS_ECP_DP_NONE};
67 #endif
68 
69 #if defined(MBEDTLS_KEY_EXCHANGE__WITH_CERT__ENABLED) || defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
70 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
71 const uint16_t SecureTransport::sSignatures[] = {MBEDTLS_TLS1_3_SIG_ECDSA_SECP256R1_SHA256, MBEDTLS_TLS1_3_SIG_NONE};
72 #else
73 const int SecureTransport::sHashes[] = {MBEDTLS_MD_SHA256, MBEDTLS_MD_NONE};
74 #endif
75 #endif
76 
SecureTransport(Instance & aInstance,bool aLayerTwoSecurity,bool aDatagramTransport)77 SecureTransport::SecureTransport(Instance &aInstance, bool aLayerTwoSecurity, bool aDatagramTransport)
78     : InstanceLocator(aInstance)
79     , mState(kStateClosed)
80     , mPskLength(0)
81     , mVerifyPeerCertificate(true)
82     , mTimer(aInstance, SecureTransport::HandleTimer, this)
83     , mTimerIntermediate(0)
84     , mTimerSet(false)
85     , mLayerTwoSecurity(aLayerTwoSecurity)
86     , mDatagramTransport(aDatagramTransport)
87     , mMaxConnectionAttempts(0)
88     , mRemainingConnectionAttempts(0)
89     , mReceiveMessage(nullptr)
90     , mSocket(aInstance, *this)
91     , mMessageSubType(Message::kSubTypeNone)
92     , mMessageDefaultSubType(Message::kSubTypeNone)
93 {
94 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
95 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
96     mPreSharedKey         = nullptr;
97     mPreSharedKeyIdentity = nullptr;
98     mPreSharedKeyIdLength = 0;
99     mPreSharedKeyLength   = 0;
100 #endif
101 
102 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
103     mCaChainSrc       = nullptr;
104     mCaChainLength    = 0;
105     mOwnCertSrc       = nullptr;
106     mOwnCertLength    = 0;
107     mPrivateKeySrc    = nullptr;
108     mPrivateKeyLength = 0;
109     ClearAllBytes(mCaChain);
110     ClearAllBytes(mOwnCert);
111     ClearAllBytes(mPrivateKey);
112 #endif
113 #endif
114 
115     ClearAllBytes(mCipherSuites);
116     ClearAllBytes(mPsk);
117     ClearAllBytes(mSsl);
118     ClearAllBytes(mConf);
119 
120 #ifdef MBEDTLS_SSL_COOKIE_C
121     ClearAllBytes(mCookieCtx);
122 #endif
123 }
124 
FreeMbedtls(void)125 void SecureTransport::FreeMbedtls(void)
126 {
127 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
128     if (mDatagramTransport)
129     {
130         mbedtls_ssl_cookie_free(&mCookieCtx);
131     }
132 #endif
133 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
134 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
135     mbedtls_x509_crt_free(&mCaChain);
136     mbedtls_x509_crt_free(&mOwnCert);
137     mbedtls_pk_free(&mPrivateKey);
138 #endif
139 #endif
140     mbedtls_ssl_config_free(&mConf);
141     mbedtls_ssl_free(&mSsl);
142 }
143 
SetState(State aState)144 void SecureTransport::SetState(State aState)
145 {
146     VerifyOrExit(mState != aState);
147 
148     LogInfo("State: %s -> %s", StateToString(mState), StateToString(aState));
149     mState = aState;
150 
151 exit:
152     return;
153 }
154 
Open(ReceiveHandler aReceiveHandler,ConnectedHandler aConnectedHandler,void * aContext)155 Error SecureTransport::Open(ReceiveHandler aReceiveHandler, ConnectedHandler aConnectedHandler, void *aContext)
156 {
157     Error error;
158 
159     VerifyOrExit(IsStateClosed(), error = kErrorAlready);
160 
161     SuccessOrExit(error = mSocket.Open());
162 
163     mConnectedCallback.Set(aConnectedHandler, aContext);
164     mReceiveCallback.Set(aReceiveHandler, aContext);
165 
166     mRemainingConnectionAttempts = mMaxConnectionAttempts;
167 
168     SetState(kStateOpen);
169 
170 exit:
171     return error;
172 }
173 
SetMaxConnectionAttempts(uint16_t aMaxAttempts,AutoCloseCallback aCallback,void * aContext)174 Error SecureTransport::SetMaxConnectionAttempts(uint16_t aMaxAttempts, AutoCloseCallback aCallback, void *aContext)
175 {
176     Error error = kErrorNone;
177 
178     VerifyOrExit(IsStateClosed(), error = kErrorInvalidState);
179 
180     mMaxConnectionAttempts = aMaxAttempts;
181     mAutoCloseCallback.Set(aCallback, aContext);
182 
183 exit:
184     return error;
185 }
186 
Connect(const Ip6::SockAddr & aSockAddr)187 Error SecureTransport::Connect(const Ip6::SockAddr &aSockAddr)
188 {
189     Error error;
190 
191     VerifyOrExit(IsStateOpen(), error = kErrorInvalidState);
192 
193     if (mRemainingConnectionAttempts > 0)
194     {
195         mRemainingConnectionAttempts--;
196     }
197 
198     mMessageInfo.SetPeerAddr(aSockAddr.GetAddress());
199     mMessageInfo.SetPeerPort(aSockAddr.mPort);
200 
201     error = Setup(true);
202 
203 exit:
204     return error;
205 }
206 
HandleReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)207 void SecureTransport::HandleReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
208 {
209     VerifyOrExit(!IsStateClosed());
210 
211     if (IsStateOpen())
212     {
213         if (mRemainingConnectionAttempts > 0)
214         {
215             mRemainingConnectionAttempts--;
216         }
217 
218         mMessageInfo.SetPeerAddr(aMessageInfo.GetPeerAddr());
219         mMessageInfo.SetPeerPort(aMessageInfo.GetPeerPort());
220         mMessageInfo.SetIsHostInterface(aMessageInfo.IsHostInterface());
221 
222         mMessageInfo.SetSockAddr(aMessageInfo.GetSockAddr());
223         mMessageInfo.SetSockPort(aMessageInfo.GetSockPort());
224 
225         SuccessOrExit(Setup(false));
226     }
227     else
228     {
229         // Once DTLS session is started, communicate only with a single peer.
230         VerifyOrExit((mMessageInfo.GetPeerAddr() == aMessageInfo.GetPeerAddr()) &&
231                      (mMessageInfo.GetPeerPort() == aMessageInfo.GetPeerPort()));
232     }
233 
234 #ifdef MBEDTLS_SSL_SRV_C
235     if (IsStateConnecting())
236     {
237         IgnoreError(SetClientId(mMessageInfo.GetPeerAddr().mFields.m8, sizeof(mMessageInfo.GetPeerAddr().mFields)));
238     }
239 #endif
240 
241     Receive(aMessage);
242 
243 exit:
244     return;
245 }
246 
GetUdpPort(void) const247 uint16_t SecureTransport::GetUdpPort(void) const { return mSocket.GetSockName().GetPort(); }
248 
Bind(uint16_t aPort)249 Error SecureTransport::Bind(uint16_t aPort)
250 {
251     Error error;
252 
253     VerifyOrExit(IsStateOpen(), error = kErrorInvalidState);
254     VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready);
255 
256     SuccessOrExit(error = mSocket.Bind(aPort, Ip6::kNetifUnspecified));
257 
258 exit:
259     return error;
260 }
261 
Bind(TransportCallback aCallback,void * aContext)262 Error SecureTransport::Bind(TransportCallback aCallback, void *aContext)
263 {
264     Error error = kErrorNone;
265 
266     VerifyOrExit(IsStateOpen(), error = kErrorInvalidState);
267     VerifyOrExit(!mSocket.IsBound(), error = kErrorAlready);
268     VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready);
269 
270     mTransportCallback.Set(aCallback, aContext);
271 
272 exit:
273     return error;
274 }
275 
Setup(bool aClient)276 Error SecureTransport::Setup(bool aClient)
277 {
278     int rval;
279 
280     // do not handle new connection before guard time expired
281     VerifyOrExit(IsStateOpen(), rval = MBEDTLS_ERR_SSL_TIMEOUT);
282 
283     SetState(kStateInitializing);
284 
285     mbedtls_ssl_init(&mSsl);
286     mbedtls_ssl_config_init(&mConf);
287 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
288 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
289     mbedtls_x509_crt_init(&mCaChain);
290     mbedtls_x509_crt_init(&mOwnCert);
291     mbedtls_pk_init(&mPrivateKey);
292 #endif
293 #endif
294 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
295     if (mDatagramTransport)
296     {
297         mbedtls_ssl_cookie_init(&mCookieCtx);
298     }
299 #endif
300 
301     rval = mbedtls_ssl_config_defaults(
302         &mConf, aClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
303         mDatagramTransport ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
304     VerifyOrExit(rval == 0);
305 
306 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
307     if (mVerifyPeerCertificate && (mCipherSuites[0] == MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ||
308                                    mCipherSuites[0] == MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256))
309     {
310         mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED);
311     }
312     else
313     {
314         mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_NONE);
315     }
316 #else
317     OT_UNUSED_VARIABLE(mVerifyPeerCertificate);
318 #endif
319 
320     mbedtls_ssl_conf_rng(&mConf, Crypto::MbedTls::CryptoSecurePrng, nullptr);
321 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
322     mbedtls_ssl_conf_min_tls_version(&mConf, MBEDTLS_SSL_VERSION_TLS1_2);
323     mbedtls_ssl_conf_max_tls_version(&mConf, MBEDTLS_SSL_VERSION_TLS1_2);
324 #else
325     mbedtls_ssl_conf_min_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
326     mbedtls_ssl_conf_max_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
327 #endif
328 
329     OT_ASSERT(mCipherSuites[1] == 0);
330     mbedtls_ssl_conf_ciphersuites(&mConf, mCipherSuites);
331     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
332     {
333 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
334         mbedtls_ssl_conf_groups(&mConf, sGroups);
335 #else
336         mbedtls_ssl_conf_curves(&mConf, sCurves);
337 #endif
338 #if defined(MBEDTLS_KEY_EXCHANGE__WITH_CERT__ENABLED) || defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
339 #if (MBEDTLS_VERSION_NUMBER >= 0x03020000)
340         mbedtls_ssl_conf_sig_algs(&mConf, sSignatures);
341 #else
342         mbedtls_ssl_conf_sig_hashes(&mConf, sHashes);
343 #endif
344 #endif
345     }
346 
347 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
348     mbedtls_ssl_set_export_keys_cb(&mSsl, HandleMbedtlsExportKeys, this);
349 #else
350     mbedtls_ssl_conf_export_keys_cb(&mConf, HandleMbedtlsExportKeys, this);
351 #endif
352 
353     mbedtls_ssl_conf_handshake_timeout(&mConf, 8000, 60000);
354     mbedtls_ssl_conf_dbg(&mConf, HandleMbedtlsDebug, this);
355 
356 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C)
357     if (!aClient && mDatagramTransport)
358     {
359         rval = mbedtls_ssl_cookie_setup(&mCookieCtx, Crypto::MbedTls::CryptoSecurePrng, nullptr);
360         VerifyOrExit(rval == 0);
361 
362         mbedtls_ssl_conf_dtls_cookies(&mConf, mbedtls_ssl_cookie_write, mbedtls_ssl_cookie_check, &mCookieCtx);
363     }
364 #endif
365 
366     rval = mbedtls_ssl_setup(&mSsl, &mConf);
367     VerifyOrExit(rval == 0);
368 
369     mbedtls_ssl_set_bio(&mSsl, this, &SecureTransport::HandleMbedtlsTransmit, HandleMbedtlsReceive, nullptr);
370 
371     if (mDatagramTransport)
372     {
373         mbedtls_ssl_set_timer_cb(&mSsl, this, &SecureTransport::HandleMbedtlsSetTimer, HandleMbedtlsGetTimer);
374     }
375 
376     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
377     {
378         rval = mbedtls_ssl_set_hs_ecjpake_password(&mSsl, mPsk, mPskLength);
379     }
380 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
381     else
382     {
383         rval = SetApplicationSecureKeys();
384     }
385 #endif
386     VerifyOrExit(rval == 0);
387 
388     mReceiveMessage = nullptr;
389     mMessageSubType = Message::kSubTypeNone;
390 
391     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
392     {
393         LogInfo("DTLS started");
394     }
395 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
396     else
397     {
398         LogInfo("Application Secure (D)TLS started");
399     }
400 #endif
401 
402     SetState(kStateConnecting);
403 
404     Process();
405 
406 exit:
407     if (IsStateInitializing() && (rval != 0))
408     {
409         if ((mMaxConnectionAttempts > 0) && (mRemainingConnectionAttempts == 0))
410         {
411             Close();
412             mAutoCloseCallback.InvokeIfSet();
413         }
414         else
415         {
416             SetState(kStateOpen);
417             FreeMbedtls();
418         }
419     }
420 
421     return Crypto::MbedTls::MapError(rval);
422 }
423 
424 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
SetApplicationSecureKeys(void)425 int SecureTransport::SetApplicationSecureKeys(void)
426 {
427     int rval = 0;
428 
429     switch (mCipherSuites[0])
430     {
431     case MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8:
432     case MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
433 
434 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
435         if (mCaChainSrc != nullptr)
436         {
437             rval = mbedtls_x509_crt_parse(&mCaChain, static_cast<const unsigned char *>(mCaChainSrc),
438                                           static_cast<size_t>(mCaChainLength));
439             VerifyOrExit(rval == 0);
440             mbedtls_ssl_conf_ca_chain(&mConf, &mCaChain, nullptr);
441         }
442 
443         if (mOwnCertSrc != nullptr && mPrivateKeySrc != nullptr)
444         {
445             rval = mbedtls_x509_crt_parse(&mOwnCert, static_cast<const unsigned char *>(mOwnCertSrc),
446                                           static_cast<size_t>(mOwnCertLength));
447             VerifyOrExit(rval == 0);
448 
449 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
450             rval = mbedtls_pk_parse_key(&mPrivateKey, static_cast<const unsigned char *>(mPrivateKeySrc),
451                                         static_cast<size_t>(mPrivateKeyLength), nullptr, 0,
452                                         Crypto::MbedTls::CryptoSecurePrng, nullptr);
453 #else
454             rval = mbedtls_pk_parse_key(&mPrivateKey, static_cast<const unsigned char *>(mPrivateKeySrc),
455                                         static_cast<size_t>(mPrivateKeyLength), nullptr, 0);
456 #endif
457             VerifyOrExit(rval == 0);
458             rval = mbedtls_ssl_conf_own_cert(&mConf, &mOwnCert, &mPrivateKey);
459             VerifyOrExit(rval == 0);
460         }
461 #endif
462         break;
463 
464     case MBEDTLS_TLS_PSK_WITH_AES_128_CCM_8:
465 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
466         rval = mbedtls_ssl_conf_psk(&mConf, static_cast<const unsigned char *>(mPreSharedKey), mPreSharedKeyLength,
467                                     static_cast<const unsigned char *>(mPreSharedKeyIdentity), mPreSharedKeyIdLength);
468         VerifyOrExit(rval == 0);
469 #endif
470         break;
471 
472     default:
473         LogCrit("Application Coap Secure: Not supported cipher.");
474         rval = MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
475         ExitNow();
476         break;
477     }
478 
479 exit:
480     return rval;
481 }
482 
483 #endif // OPENTHREAD_CONFIG_TLS_API_ENABLE
484 
Close(void)485 void SecureTransport::Close(void)
486 {
487     Disconnect(kDisconnectedLocalClosed);
488 
489     SetState(kStateClosed);
490     mTimerSet = false;
491     mTransportCallback.Clear();
492 
493     IgnoreError(mSocket.Close());
494     mTimer.Stop();
495 }
496 
Disconnect(void)497 void SecureTransport::Disconnect(void) { Disconnect(kDisconnectedLocalClosed); }
498 
Disconnect(ConnectEvent aEvent)499 void SecureTransport::Disconnect(ConnectEvent aEvent)
500 {
501     VerifyOrExit(IsStateConnectingOrConnected());
502 
503     mbedtls_ssl_close_notify(&mSsl);
504     SetState(kStateCloseNotify);
505     mConnectEvent = aEvent;
506     mTimer.Start(kGuardTimeNewConnectionMilli);
507 
508     mMessageInfo.Clear();
509 
510     FreeMbedtls();
511 
512 exit:
513     return;
514 }
515 
SetPsk(const uint8_t * aPsk,uint8_t aPskLength)516 Error SecureTransport::SetPsk(const uint8_t *aPsk, uint8_t aPskLength)
517 {
518     Error error = kErrorNone;
519 
520     VerifyOrExit(aPskLength <= sizeof(mPsk), error = kErrorInvalidArgs);
521 
522     memcpy(mPsk, aPsk, aPskLength);
523     mPskLength       = aPskLength;
524     mCipherSuites[0] = MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8;
525     mCipherSuites[1] = 0;
526 
527 exit:
528     return error;
529 }
530 
531 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
532 #ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
533 
SetCertificate(const uint8_t * aX509Certificate,uint32_t aX509CertLength,const uint8_t * aPrivateKey,uint32_t aPrivateKeyLength)534 void SecureTransport::SetCertificate(const uint8_t *aX509Certificate,
535                                      uint32_t       aX509CertLength,
536                                      const uint8_t *aPrivateKey,
537                                      uint32_t       aPrivateKeyLength)
538 {
539     OT_ASSERT(aX509CertLength > 0);
540     OT_ASSERT(aX509Certificate != nullptr);
541 
542     OT_ASSERT(aPrivateKeyLength > 0);
543     OT_ASSERT(aPrivateKey != nullptr);
544 
545     mOwnCertSrc       = aX509Certificate;
546     mOwnCertLength    = aX509CertLength;
547     mPrivateKeySrc    = aPrivateKey;
548     mPrivateKeyLength = aPrivateKeyLength;
549 
550     if (mDatagramTransport)
551     {
552         mCipherSuites[0] = MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8;
553     }
554     else
555     {
556         mCipherSuites[0] = MBEDTLS_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256;
557     }
558 
559     mCipherSuites[1] = 0;
560 }
561 
SetCaCertificateChain(const uint8_t * aX509CaCertificateChain,uint32_t aX509CaCertChainLength)562 void SecureTransport::SetCaCertificateChain(const uint8_t *aX509CaCertificateChain, uint32_t aX509CaCertChainLength)
563 {
564     OT_ASSERT(aX509CaCertChainLength > 0);
565     OT_ASSERT(aX509CaCertificateChain != nullptr);
566 
567     mCaChainSrc    = aX509CaCertificateChain;
568     mCaChainLength = aX509CaCertChainLength;
569 }
570 
571 #endif // MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
572 
573 #ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
SetPreSharedKey(const uint8_t * aPsk,uint16_t aPskLength,const uint8_t * aPskIdentity,uint16_t aPskIdLength)574 void SecureTransport::SetPreSharedKey(const uint8_t *aPsk,
575                                       uint16_t       aPskLength,
576                                       const uint8_t *aPskIdentity,
577                                       uint16_t       aPskIdLength)
578 {
579     OT_ASSERT(aPsk != nullptr);
580     OT_ASSERT(aPskIdentity != nullptr);
581     OT_ASSERT(aPskLength > 0);
582     OT_ASSERT(aPskIdLength > 0);
583 
584     mPreSharedKey         = aPsk;
585     mPreSharedKeyLength   = aPskLength;
586     mPreSharedKeyIdentity = aPskIdentity;
587     mPreSharedKeyIdLength = aPskIdLength;
588 
589     mCipherSuites[0] = MBEDTLS_TLS_PSK_WITH_AES_128_CCM_8;
590     mCipherSuites[1] = 0;
591 }
592 #endif
593 
594 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength,size_t aCertBufferSize)595 Error SecureTransport::GetPeerCertificateBase64(unsigned char *aPeerCert, size_t *aCertLength, size_t aCertBufferSize)
596 {
597     Error error = kErrorNone;
598 
599     VerifyOrExit(IsStateConnected(), error = kErrorInvalidState);
600 
601 #if (MBEDTLS_VERSION_NUMBER >= 0x03010000)
602     VerifyOrExit(mbedtls_base64_encode(aPeerCert, aCertBufferSize, aCertLength,
603                                        mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.p,
604                                        mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->raw.len) == 0,
605                  error = kErrorNoBufs);
606 #else
607     VerifyOrExit(
608         mbedtls_base64_encode(
609             aPeerCert, aCertBufferSize, aCertLength,
610             mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(p),
611             mSsl.MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(peer_cert)->MBEDTLS_PRIVATE(raw).MBEDTLS_PRIVATE(len)) == 0,
612         error = kErrorNoBufs);
613 #endif
614 
615 exit:
616     return error;
617 }
618 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
619 
620 #if defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerSubjectAttributeByOid(const char * aOid,size_t aOidLength,uint8_t * aAttributeBuffer,size_t * aAttributeLength,int * aAsn1Type)621 Error SecureTransport::GetPeerSubjectAttributeByOid(const char *aOid,
622                                                     size_t      aOidLength,
623                                                     uint8_t    *aAttributeBuffer,
624                                                     size_t     *aAttributeLength,
625                                                     int        *aAsn1Type)
626 {
627     Error                          error = kErrorNone;
628     const mbedtls_asn1_named_data *data;
629     size_t                         length;
630     size_t                         attributeBufferSize;
631     mbedtls_x509_crt              *peerCert = const_cast<mbedtls_x509_crt *>(mbedtls_ssl_get_peer_cert(&mSsl));
632 
633     VerifyOrExit(aAttributeLength != nullptr, error = kErrorInvalidArgs);
634     attributeBufferSize = *aAttributeLength;
635     *aAttributeLength   = 0;
636 
637     VerifyOrExit(aAttributeBuffer != nullptr, error = kErrorNoBufs);
638     VerifyOrExit(peerCert != nullptr, error = kErrorInvalidState);
639 
640     data = mbedtls_asn1_find_named_data(&peerCert->subject, aOid, aOidLength);
641     VerifyOrExit(data != nullptr, error = kErrorNotFound);
642 
643     length = data->val.len;
644     VerifyOrExit(length <= attributeBufferSize, error = kErrorNoBufs);
645     *aAttributeLength = length;
646 
647     if (aAsn1Type != nullptr)
648     {
649         *aAsn1Type = data->val.tag;
650     }
651 
652     memcpy(aAttributeBuffer, data->val.p, length);
653 
654 exit:
655     return error;
656 }
657 
GetThreadAttributeFromPeerCertificate(int aThreadOidDescriptor,uint8_t * aAttributeBuffer,size_t * aAttributeLength)658 Error SecureTransport::GetThreadAttributeFromPeerCertificate(int      aThreadOidDescriptor,
659                                                              uint8_t *aAttributeBuffer,
660                                                              size_t  *aAttributeLength)
661 {
662     const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&mSsl);
663 
664     return GetThreadAttributeFromCertificate(cert, aThreadOidDescriptor, aAttributeBuffer, aAttributeLength);
665 }
666 
667 #endif // defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
668 
GetThreadAttributeFromOwnCertificate(int aThreadOidDescriptor,uint8_t * aAttributeBuffer,size_t * aAttributeLength)669 Error SecureTransport::GetThreadAttributeFromOwnCertificate(int      aThreadOidDescriptor,
670                                                             uint8_t *aAttributeBuffer,
671                                                             size_t  *aAttributeLength)
672 {
673     const mbedtls_x509_crt *cert = &mOwnCert;
674 
675     return GetThreadAttributeFromCertificate(cert, aThreadOidDescriptor, aAttributeBuffer, aAttributeLength);
676 }
677 
GetThreadAttributeFromCertificate(const mbedtls_x509_crt * aCert,int aThreadOidDescriptor,uint8_t * aAttributeBuffer,size_t * aAttributeLength)678 Error SecureTransport::GetThreadAttributeFromCertificate(const mbedtls_x509_crt *aCert,
679                                                          int                     aThreadOidDescriptor,
680                                                          uint8_t                *aAttributeBuffer,
681                                                          size_t                 *aAttributeLength)
682 {
683     Error            error  = kErrorNotFound;
684     char             oid[9] = {0x2B, 0x06, 0x01, 0x04, 0x01, static_cast<char>(0x82), static_cast<char>(0xDF),
685                                0x2A, 0x00}; // 1.3.6.1.4.1.44970.0
686     mbedtls_x509_buf v3_ext;
687     unsigned char   *p, *end, *endExtData;
688     size_t           len;
689     size_t           attributeBufferSize;
690     mbedtls_x509_buf extnOid;
691     int              ret, isCritical;
692 
693     VerifyOrExit(aAttributeLength != nullptr, error = kErrorInvalidArgs);
694     attributeBufferSize = *aAttributeLength;
695     *aAttributeLength   = 0;
696 
697     VerifyOrExit(aCert != nullptr, error = kErrorInvalidState);
698     v3_ext = aCert->v3_ext;
699     p      = v3_ext.p;
700     VerifyOrExit(p != nullptr, error = kErrorInvalidState);
701     end = p + v3_ext.len;
702     VerifyOrExit(mbedtls_asn1_get_tag(&p, end, &len, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE) == 0,
703                  error = kErrorParse);
704     VerifyOrExit(end == p + len, error = kErrorParse);
705 
706     VerifyOrExit(aThreadOidDescriptor < 128, error = kErrorNotImplemented);
707     oid[sizeof(oid) - 1] = static_cast<char>(aThreadOidDescriptor);
708 
709     while (p < end)
710     {
711         isCritical = 0;
712         VerifyOrExit(mbedtls_asn1_get_tag(&p, end, &len, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE) == 0,
713                      error = kErrorParse);
714         endExtData = p + len;
715 
716         // Get extension ID
717         VerifyOrExit(mbedtls_asn1_get_tag(&p, endExtData, &extnOid.len, MBEDTLS_ASN1_OID) == 0, error = kErrorParse);
718         extnOid.tag = MBEDTLS_ASN1_OID;
719         extnOid.p   = p;
720         p += extnOid.len;
721 
722         // Get optional critical
723         ret = mbedtls_asn1_get_bool(&p, endExtData, &isCritical);
724         VerifyOrExit(ret == 0 || ret == MBEDTLS_ERR_ASN1_UNEXPECTED_TAG, error = kErrorParse);
725 
726         // Data must be octet string type, see https://datatracker.ietf.org/doc/html/rfc5280#section-4.1
727         VerifyOrExit(mbedtls_asn1_get_tag(&p, endExtData, &len, MBEDTLS_ASN1_OCTET_STRING) == 0, error = kErrorParse);
728         VerifyOrExit(endExtData == p + len, error = kErrorParse);
729 
730         // TODO: extensions with isCritical == 1 that are unknown should lead to rejection of the entire cert.
731         if (extnOid.len == sizeof(oid) && memcmp(extnOid.p, oid, sizeof(oid)) == 0)
732         {
733             // per RFC 5280, octet string must contain ASN.1 Type Length Value octets
734             VerifyOrExit(len >= 2, error = kErrorParse);
735             VerifyOrExit(*(p + 1) == len - 2, error = kErrorParse); // check TLV Length, not Type.
736             *aAttributeLength = len - 2; // strip the ASN.1 Type Length bytes from embedded TLV
737 
738             if (aAttributeBuffer != nullptr)
739             {
740                 VerifyOrExit(*aAttributeLength <= attributeBufferSize, error = kErrorNoBufs);
741                 memcpy(aAttributeBuffer, p + 2, *aAttributeLength);
742             }
743 
744             error = kErrorNone;
745             break;
746         }
747         p += len;
748     }
749 
750 exit:
751     return error;
752 }
753 
754 #endif // OPENTHREAD_CONFIG_TLS_API_ENABLE
755 
756 #ifdef MBEDTLS_SSL_SRV_C
SetClientId(const uint8_t * aClientId,uint8_t aLength)757 Error SecureTransport::SetClientId(const uint8_t *aClientId, uint8_t aLength)
758 {
759     int rval = mbedtls_ssl_set_client_transport_id(&mSsl, aClientId, aLength);
760     return Crypto::MbedTls::MapError(rval);
761 }
762 #endif
763 
Send(Message & aMessage,uint16_t aLength)764 Error SecureTransport::Send(Message &aMessage, uint16_t aLength)
765 {
766     Error   error = kErrorNone;
767     uint8_t buffer[kApplicationDataMaxLength];
768 
769     VerifyOrExit(aLength <= kApplicationDataMaxLength, error = kErrorNoBufs);
770 
771     // Store message specific sub type.
772     if (aMessage.GetSubType() != Message::kSubTypeNone)
773     {
774         mMessageSubType = aMessage.GetSubType();
775     }
776 
777     aMessage.ReadBytes(0, buffer, aLength);
778 
779     SuccessOrExit(error = Crypto::MbedTls::MapError(mbedtls_ssl_write(&mSsl, buffer, aLength)));
780 
781     aMessage.Free();
782 
783 exit:
784     return error;
785 }
786 
Receive(Message & aMessage)787 void SecureTransport::Receive(Message &aMessage)
788 {
789     mReceiveMessage = &aMessage;
790 
791     Process();
792 
793     mReceiveMessage = nullptr;
794 }
795 
HandleMbedtlsTransmit(void * aContext,const unsigned char * aBuf,size_t aLength)796 int SecureTransport::HandleMbedtlsTransmit(void *aContext, const unsigned char *aBuf, size_t aLength)
797 {
798     return static_cast<SecureTransport *>(aContext)->HandleMbedtlsTransmit(aBuf, aLength);
799 }
800 
HandleMbedtlsTransmit(const unsigned char * aBuf,size_t aLength)801 int SecureTransport::HandleMbedtlsTransmit(const unsigned char *aBuf, size_t aLength)
802 {
803     Error error;
804     int   rval = 0;
805 
806     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
807     {
808         LogDebg("HandleMbedtlsTransmit DTLS");
809     }
810 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
811     else
812     {
813         LogDebg("HandleMbedtlsTransmit TLS");
814     }
815 #endif
816 
817     error = HandleSecureTransportSend(aBuf, static_cast<uint16_t>(aLength), mMessageSubType);
818 
819     // Restore default sub type.
820     mMessageSubType = mMessageDefaultSubType;
821 
822     switch (error)
823     {
824     case kErrorNone:
825         rval = static_cast<int>(aLength);
826         break;
827 
828     case kErrorNoBufs:
829         rval = MBEDTLS_ERR_SSL_WANT_WRITE;
830         break;
831 
832     default:
833         LogWarn("HandleMbedtlsTransmit: %s error", ErrorToString(error));
834         rval = MBEDTLS_ERR_NET_SEND_FAILED;
835         break;
836     }
837 
838     return rval;
839 }
840 
HandleMbedtlsReceive(void * aContext,unsigned char * aBuf,size_t aLength)841 int SecureTransport::HandleMbedtlsReceive(void *aContext, unsigned char *aBuf, size_t aLength)
842 {
843     return static_cast<SecureTransport *>(aContext)->HandleMbedtlsReceive(aBuf, aLength);
844 }
845 
HandleMbedtlsReceive(unsigned char * aBuf,size_t aLength)846 int SecureTransport::HandleMbedtlsReceive(unsigned char *aBuf, size_t aLength)
847 {
848     int rval;
849 
850     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
851     {
852         LogDebg("HandleMbedtlsReceive DTLS");
853     }
854 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
855     else
856     {
857         LogDebg("HandleMbedtlsReceive TLS");
858     }
859 #endif
860 
861     VerifyOrExit(mReceiveMessage != nullptr && (rval = mReceiveMessage->GetLength() - mReceiveMessage->GetOffset()) > 0,
862                  rval = MBEDTLS_ERR_SSL_WANT_READ);
863 
864     if (aLength > static_cast<size_t>(rval))
865     {
866         aLength = static_cast<size_t>(rval);
867     }
868 
869     rval = mReceiveMessage->ReadBytes(mReceiveMessage->GetOffset(), aBuf, static_cast<uint16_t>(aLength));
870     mReceiveMessage->MoveOffset(rval);
871 
872 exit:
873     return rval;
874 }
875 
HandleMbedtlsGetTimer(void * aContext)876 int SecureTransport::HandleMbedtlsGetTimer(void *aContext)
877 {
878     return static_cast<SecureTransport *>(aContext)->HandleMbedtlsGetTimer();
879 }
880 
HandleMbedtlsGetTimer(void)881 int SecureTransport::HandleMbedtlsGetTimer(void)
882 {
883     int rval;
884 
885     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
886     {
887         LogDebg("HandleMbedtlsGetTimer");
888     }
889 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
890     else
891     {
892         LogDebg("HandleMbedtlsGetTimer");
893     }
894 #endif
895 
896     if (!mTimerSet)
897     {
898         rval = -1;
899     }
900     else if (!mTimer.IsRunning())
901     {
902         rval = 2;
903     }
904     else if (mTimerIntermediate <= TimerMilli::GetNow())
905     {
906         rval = 1;
907     }
908     else
909     {
910         rval = 0;
911     }
912 
913     return rval;
914 }
915 
HandleMbedtlsSetTimer(void * aContext,uint32_t aIntermediate,uint32_t aFinish)916 void SecureTransport::HandleMbedtlsSetTimer(void *aContext, uint32_t aIntermediate, uint32_t aFinish)
917 {
918     static_cast<SecureTransport *>(aContext)->HandleMbedtlsSetTimer(aIntermediate, aFinish);
919 }
920 
HandleMbedtlsSetTimer(uint32_t aIntermediate,uint32_t aFinish)921 void SecureTransport::HandleMbedtlsSetTimer(uint32_t aIntermediate, uint32_t aFinish)
922 {
923     if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
924     {
925         LogDebg("SetTimer DTLS");
926     }
927 #if OPENTHREAD_CONFIG_TLS_API_ENABLE
928     else
929     {
930         LogDebg("SetTimer TLS");
931     }
932 #endif
933 
934     if (aFinish == 0)
935     {
936         mTimerSet = false;
937         mTimer.Stop();
938     }
939     else
940     {
941         mTimerSet = true;
942         mTimer.Start(aFinish);
943         mTimerIntermediate = TimerMilli::GetNow() + aIntermediate;
944     }
945 }
946 
947 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
948 
HandleMbedtlsExportKeys(void * aContext,mbedtls_ssl_key_export_type aType,const unsigned char * aMasterSecret,size_t aMasterSecretLen,const unsigned char aClientRandom[32],const unsigned char aServerRandom[32],mbedtls_tls_prf_types aTlsPrfType)949 void SecureTransport::HandleMbedtlsExportKeys(void                       *aContext,
950                                               mbedtls_ssl_key_export_type aType,
951                                               const unsigned char        *aMasterSecret,
952                                               size_t                      aMasterSecretLen,
953                                               const unsigned char         aClientRandom[32],
954                                               const unsigned char         aServerRandom[32],
955                                               mbedtls_tls_prf_types       aTlsPrfType)
956 {
957     static_cast<SecureTransport *>(aContext)->HandleMbedtlsExportKeys(aType, aMasterSecret, aMasterSecretLen,
958                                                                       aClientRandom, aServerRandom, aTlsPrfType);
959 }
960 
HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,const unsigned char * aMasterSecret,size_t aMasterSecretLen,const unsigned char aClientRandom[32],const unsigned char aServerRandom[32],mbedtls_tls_prf_types aTlsPrfType)961 void SecureTransport::HandleMbedtlsExportKeys(mbedtls_ssl_key_export_type aType,
962                                               const unsigned char        *aMasterSecret,
963                                               size_t                      aMasterSecretLen,
964                                               const unsigned char         aClientRandom[32],
965                                               const unsigned char         aServerRandom[32],
966                                               mbedtls_tls_prf_types       aTlsPrfType)
967 {
968     Crypto::Sha256::Hash kek;
969     Crypto::Sha256       sha256;
970     unsigned char        keyBlock[kSecureTransportKeyBlockSize];
971     unsigned char        randBytes[2 * kSecureTransportRandomBufferSize];
972 
973     VerifyOrExit(mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8);
974     VerifyOrExit(aType == MBEDTLS_SSL_KEY_EXPORT_TLS12_MASTER_SECRET);
975 
976     memcpy(randBytes, aServerRandom, kSecureTransportRandomBufferSize);
977     memcpy(randBytes + kSecureTransportRandomBufferSize, aClientRandom, kSecureTransportRandomBufferSize);
978 
979     // Retrieve the Key block from Master secret
980     mbedtls_ssl_tls_prf(aTlsPrfType, aMasterSecret, aMasterSecretLen, "key expansion", randBytes, sizeof(randBytes),
981                         keyBlock, sizeof(keyBlock));
982 
983     sha256.Start();
984     sha256.Update(keyBlock, kSecureTransportKeyBlockSize);
985     sha256.Finish(kek);
986 
987     LogDebg("Generated KEK");
988     Get<KeyManager>().SetKek(kek.GetBytes());
989 
990 exit:
991     return;
992 }
993 
994 #else
995 
HandleMbedtlsExportKeys(void * aContext,const unsigned char * aMasterSecret,const unsigned char * aKeyBlock,size_t aMacLength,size_t aKeyLength,size_t aIvLength)996 int SecureTransport::HandleMbedtlsExportKeys(void                *aContext,
997                                              const unsigned char *aMasterSecret,
998                                              const unsigned char *aKeyBlock,
999                                              size_t               aMacLength,
1000                                              size_t               aKeyLength,
1001                                              size_t               aIvLength)
1002 {
1003     return static_cast<SecureTransport *>(aContext)->HandleMbedtlsExportKeys(aMasterSecret, aKeyBlock, aMacLength,
1004                                                                              aKeyLength, aIvLength);
1005 }
1006 
HandleMbedtlsExportKeys(const unsigned char * aMasterSecret,const unsigned char * aKeyBlock,size_t aMacLength,size_t aKeyLength,size_t aIvLength)1007 int SecureTransport::HandleMbedtlsExportKeys(const unsigned char *aMasterSecret,
1008                                              const unsigned char *aKeyBlock,
1009                                              size_t               aMacLength,
1010                                              size_t               aKeyLength,
1011                                              size_t               aIvLength)
1012 {
1013     OT_UNUSED_VARIABLE(aMasterSecret);
1014 
1015     Crypto::Sha256::Hash kek;
1016     Crypto::Sha256       sha256;
1017 
1018     VerifyOrExit(mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8);
1019 
1020     sha256.Start();
1021     sha256.Update(aKeyBlock, 2 * static_cast<uint16_t>(aMacLength + aKeyLength + aIvLength));
1022     sha256.Finish(kek);
1023 
1024     LogDebg("Generated KEK");
1025     Get<KeyManager>().SetKek(kek.GetBytes());
1026 
1027 exit:
1028     return 0;
1029 }
1030 
1031 #endif // (MBEDTLS_VERSION_NUMBER >= 0x03000000)
1032 
HandleTimer(Timer & aTimer)1033 void SecureTransport::HandleTimer(Timer &aTimer)
1034 {
1035     static_cast<SecureTransport *>(static_cast<TimerMilliContext &>(aTimer).GetContext())->HandleTimer();
1036 }
1037 
HandleTimer(void)1038 void SecureTransport::HandleTimer(void)
1039 {
1040     if (IsStateConnectingOrConnected())
1041     {
1042         Process();
1043     }
1044     else if (IsStateCloseNotify())
1045     {
1046         if ((mMaxConnectionAttempts > 0) && (mRemainingConnectionAttempts == 0))
1047         {
1048             Close();
1049             mConnectEvent = kDisconnectedMaxAttempts;
1050             mAutoCloseCallback.InvokeIfSet();
1051         }
1052         else
1053         {
1054             SetState(kStateOpen);
1055             mTimer.Stop();
1056         }
1057         mConnectedCallback.InvokeIfSet(mConnectEvent);
1058     }
1059 }
1060 
Process(void)1061 void SecureTransport::Process(void)
1062 {
1063     uint8_t      buf[OPENTHREAD_CONFIG_DTLS_MAX_CONTENT_LEN];
1064     bool         shouldDisconnect = false;
1065     int          rval;
1066     ConnectEvent event;
1067 
1068     while (IsStateConnectingOrConnected())
1069     {
1070         if (IsStateConnecting())
1071         {
1072             rval = mbedtls_ssl_handshake(&mSsl);
1073 
1074             if (mSsl.MBEDTLS_PRIVATE(state) == MBEDTLS_SSL_HANDSHAKE_OVER)
1075             {
1076                 SetState(kStateConnected);
1077                 mConnectEvent = kConnected;
1078                 mConnectedCallback.InvokeIfSet(mConnectEvent);
1079             }
1080         }
1081         else
1082         {
1083             rval = mbedtls_ssl_read(&mSsl, buf, sizeof(buf));
1084         }
1085 
1086         if (rval > 0)
1087         {
1088             mReceiveCallback.InvokeIfSet(buf, static_cast<uint16_t>(rval));
1089         }
1090         else if (rval == 0 || rval == MBEDTLS_ERR_SSL_WANT_READ || rval == MBEDTLS_ERR_SSL_WANT_WRITE)
1091         {
1092             break;
1093         }
1094         else
1095         {
1096             switch (rval)
1097             {
1098             case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
1099                 mbedtls_ssl_close_notify(&mSsl);
1100                 event = kDisconnectedPeerClosed;
1101                 ExitNow(shouldDisconnect = true);
1102                 OT_UNREACHABLE_CODE(break);
1103 
1104             case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
1105                 break;
1106 
1107             case MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE:
1108                 mbedtls_ssl_close_notify(&mSsl);
1109                 event = kDisconnectedError;
1110                 ExitNow(shouldDisconnect = true);
1111                 OT_UNREACHABLE_CODE(break);
1112 
1113             case MBEDTLS_ERR_SSL_INVALID_MAC:
1114                 if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
1115                 {
1116                     mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
1117                                                    MBEDTLS_SSL_ALERT_MSG_BAD_RECORD_MAC);
1118                     event = kDisconnectedError;
1119                     ExitNow(shouldDisconnect = true);
1120                 }
1121 
1122                 break;
1123 
1124             default:
1125                 if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
1126                 {
1127                     mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
1128                                                    MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE);
1129                     event = kDisconnectedError;
1130                     ExitNow(shouldDisconnect = true);
1131                 }
1132 
1133                 break;
1134             }
1135 
1136             mbedtls_ssl_session_reset(&mSsl);
1137             if (mCipherSuites[0] == MBEDTLS_TLS_ECJPAKE_WITH_AES_128_CCM_8)
1138             {
1139                 mbedtls_ssl_set_hs_ecjpake_password(&mSsl, mPsk, mPskLength);
1140             }
1141             break;
1142         }
1143     }
1144 
1145 exit:
1146 
1147     if (shouldDisconnect)
1148     {
1149         Disconnect(event);
1150     }
1151 }
1152 
HandleMbedtlsDebug(void * aContext,int aLevel,const char * aFile,int aLine,const char * aStr)1153 void SecureTransport::HandleMbedtlsDebug(void *aContext, int aLevel, const char *aFile, int aLine, const char *aStr)
1154 {
1155     static_cast<SecureTransport *>(aContext)->HandleMbedtlsDebug(aLevel, aFile, aLine, aStr);
1156 }
1157 
HandleMbedtlsDebug(int aLevel,const char * aFile,int aLine,const char * aStr)1158 void SecureTransport::HandleMbedtlsDebug(int aLevel, const char *aFile, int aLine, const char *aStr)
1159 {
1160     OT_UNUSED_VARIABLE(aStr);
1161     OT_UNUSED_VARIABLE(aFile);
1162     OT_UNUSED_VARIABLE(aLine);
1163 
1164     switch (aLevel)
1165     {
1166     case 1:
1167         LogCrit("[%u] %s", mSocket.GetSockName().mPort, aStr);
1168         break;
1169 
1170     case 2:
1171         LogWarn("[%u] %s", mSocket.GetSockName().mPort, aStr);
1172         break;
1173 
1174     case 3:
1175         LogInfo("[%u] %s", mSocket.GetSockName().mPort, aStr);
1176         break;
1177 
1178     case 4:
1179     default:
1180         LogDebg("[%u] %s", mSocket.GetSockName().mPort, aStr);
1181         break;
1182     }
1183 }
1184 
HandleSecureTransportSend(const uint8_t * aBuf,uint16_t aLength,Message::SubType aMessageSubType)1185 Error SecureTransport::HandleSecureTransportSend(const uint8_t   *aBuf,
1186                                                  uint16_t         aLength,
1187                                                  Message::SubType aMessageSubType)
1188 {
1189     Error        error   = kErrorNone;
1190     ot::Message *message = nullptr;
1191 
1192     VerifyOrExit((message = mSocket.NewMessage()) != nullptr, error = kErrorNoBufs);
1193     message->SetSubType(aMessageSubType);
1194     message->SetLinkSecurityEnabled(mLayerTwoSecurity);
1195 
1196     SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
1197 
1198     // Set message sub type in case Joiner Finalize Response is appended to the message.
1199     if (aMessageSubType != Message::kSubTypeNone)
1200     {
1201         message->SetSubType(aMessageSubType);
1202     }
1203 
1204     if (mTransportCallback.IsSet())
1205     {
1206         SuccessOrExit(error = mTransportCallback.Invoke(*message, mMessageInfo));
1207     }
1208     else
1209     {
1210         SuccessOrExit(error = mSocket.SendTo(*message, mMessageInfo));
1211     }
1212 
1213 exit:
1214     FreeMessageOnError(message, error);
1215     return error;
1216 }
1217 
1218 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_INFO)
1219 
StateToString(State aState)1220 const char *SecureTransport::StateToString(State aState)
1221 {
1222     static const char *const kStateStrings[] = {
1223         "Closed",       // (0) kStateClosed
1224         "Open",         // (1) kStateOpen
1225         "Initializing", // (2) kStateInitializing
1226         "Connecting",   // (3) kStateConnecting
1227         "Connected",    // (4) kStateConnected
1228         "CloseNotify",  // (5) kStateCloseNotify
1229     };
1230 
1231     static_assert(0 == kStateClosed, "kStateClosed valid is incorrect");
1232     static_assert(1 == kStateOpen, "kStateOpen valid is incorrect");
1233     static_assert(2 == kStateInitializing, "kStateInitializing valid is incorrect");
1234     static_assert(3 == kStateConnecting, "kStateConnecting valid is incorrect");
1235     static_assert(4 == kStateConnected, "kStateConnected valid is incorrect");
1236     static_assert(5 == kStateCloseNotify, "kStateCloseNotify valid is incorrect");
1237 
1238     return kStateStrings[aState];
1239 }
1240 
1241 #endif
1242 
1243 } // namespace MeshCoP
1244 } // namespace ot
1245 
1246 #endif // OPENTHREAD_CONFIG_SECURE_TRANSPORT_ENABLE
1247