1 /*
2  *  Copyright (c) 2023, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "ble_secure.hpp"
30 
31 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
32 
33 #include <openthread/platform/ble.h>
34 #include "common/locator_getters.hpp"
35 #include "common/log.hpp"
36 #include "common/tlvs.hpp"
37 #include "instance/instance.hpp"
38 #include "meshcop/secure_transport.hpp"
39 
40 using namespace ot;
41 
42 /**
43  * @file
44  *   This file implements the secure Ble agent.
45  */
46 
47 namespace ot {
48 namespace Ble {
49 
50 RegisterLogModule("BleSecure");
51 
BleSecure(Instance & aInstance)52 BleSecure::BleSecure(Instance &aInstance)
53     : InstanceLocator(aInstance)
54     , mTls(aInstance, false, false)
55     , mTcatAgent(aInstance)
56     , mTlvMode(false)
57     , mReceivedMessage(nullptr)
58     , mSendMessage(nullptr)
59     , mTransmitTask(aInstance)
60     , mBleState(kStopped)
61     , mMtuSize(kInitialMtuSize)
62 {
63 }
64 
Start(ConnectCallback aConnectHandler,ReceiveCallback aReceiveHandler,bool aTlvMode,void * aContext)65 Error BleSecure::Start(ConnectCallback aConnectHandler, ReceiveCallback aReceiveHandler, bool aTlvMode, void *aContext)
66 {
67     Error    error             = kErrorNone;
68     uint16_t advertisementLen  = 0;
69     uint8_t *advertisementData = nullptr;
70 
71     VerifyOrExit(mBleState == kStopped, error = kErrorAlready);
72 
73     mConnectCallback.Set(aConnectHandler, aContext);
74     mReceiveCallback.Set(aReceiveHandler, aContext);
75     mTlvMode = aTlvMode;
76     mMtuSize = kInitialMtuSize;
77 
78     SuccessOrExit(error = otPlatBleEnable(&GetInstance()));
79 
80     SuccessOrExit(error = otPlatBleGetAdvertisementBuffer(&GetInstance(), &advertisementData));
81     SuccessOrExit(error = mTcatAgent.GetAdvertisementData(advertisementLen, advertisementData));
82     VerifyOrExit(advertisementData != nullptr, error = kErrorFailed);
83     SuccessOrExit(error = otPlatBleGapAdvSetData(&GetInstance(), advertisementData, advertisementLen));
84     SuccessOrExit(error = otPlatBleGapAdvStart(&GetInstance(), OT_BLE_ADV_INTERVAL_DEFAULT));
85 
86     SuccessOrExit(error = mTls.Open(&BleSecure::HandleTlsReceive, &BleSecure::HandleTlsConnected, this));
87     SuccessOrExit(error = mTls.Bind(HandleTransport, this));
88 
89 exit:
90     if (error == kErrorNone)
91     {
92         mBleState = kAdvertising;
93     }
94     return error;
95 }
96 
TcatStart(MeshCoP::TcatAgent::JoinCallback aJoinHandler)97 Error BleSecure::TcatStart(MeshCoP::TcatAgent::JoinCallback aJoinHandler)
98 {
99     Error error;
100 
101     VerifyOrExit(mBleState != kStopped, error = kErrorInvalidState);
102 
103     error = mTcatAgent.Start(mReceiveCallback.GetHandler(), aJoinHandler, mReceiveCallback.GetContext());
104 
105 exit:
106     return error;
107 }
108 
Stop(void)109 void BleSecure::Stop(void)
110 {
111     VerifyOrExit(mBleState != kStopped);
112     SuccessOrExit(otPlatBleGapAdvStop(&GetInstance()));
113     SuccessOrExit(otPlatBleDisable(&GetInstance()));
114     mBleState = kStopped;
115     mMtuSize  = kInitialMtuSize;
116 
117     if (mTcatAgent.IsEnabled())
118     {
119         mTcatAgent.Stop();
120     }
121 
122     mTls.Close();
123 
124     mTransmitQueue.DequeueAndFreeAll();
125 
126     mConnectCallback.Clear();
127     mReceiveCallback.Clear();
128 
129     FreeMessage(mReceivedMessage);
130     mReceivedMessage = nullptr;
131     FreeMessage(mSendMessage);
132     mSendMessage = nullptr;
133 
134 exit:
135     return;
136 }
137 
Connect(void)138 Error BleSecure::Connect(void)
139 {
140     Ip6::SockAddr sockaddr;
141     Error         error;
142 
143     VerifyOrExit(mBleState == kConnected, error = kErrorInvalidState);
144 
145     error = mTls.Connect(sockaddr);
146 
147 exit:
148     return error;
149 }
150 
Disconnect(void)151 void BleSecure::Disconnect(void)
152 {
153     if (mTls.IsConnected())
154     {
155         mTls.Disconnect();
156     }
157 
158     if (mBleState == kConnected)
159     {
160         mBleState = kAdvertising;
161         IgnoreReturnValue(otPlatBleGapDisconnect(&GetInstance()));
162     }
163 
164     mConnectCallback.InvokeIfSet(&GetInstance(), false, false);
165 }
166 
SetPsk(const MeshCoP::JoinerPskd & aPskd)167 void BleSecure::SetPsk(const MeshCoP::JoinerPskd &aPskd)
168 {
169     static_assert(static_cast<uint16_t>(MeshCoP::JoinerPskd::kMaxLength) <=
170                       static_cast<uint16_t>(MeshCoP::SecureTransport::kPskMaxLength),
171                   "The maximum length of TLS PSK is smaller than joiner PSKd");
172 
173     SuccessOrAssert(mTls.SetPsk(reinterpret_cast<const uint8_t *>(aPskd.GetAsCString()), aPskd.GetLength()));
174 }
175 
176 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength)177 Error BleSecure::GetPeerCertificateBase64(unsigned char *aPeerCert, size_t *aCertLength)
178 {
179     Error error;
180 
181     VerifyOrExit(aCertLength != nullptr, error = kErrorInvalidArgs);
182 
183     error = mTls.GetPeerCertificateBase64(aPeerCert, aCertLength, *aCertLength);
184 
185 exit:
186     return error;
187 }
188 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
189 
SendMessage(ot::Message & aMessage)190 Error BleSecure::SendMessage(ot::Message &aMessage)
191 {
192     Error error = kErrorNone;
193 
194     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
195     if (mSendMessage == nullptr)
196     {
197         mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
198         VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
199     }
200     SuccessOrExit(error = mSendMessage->AppendBytesFromMessage(aMessage, 0, aMessage.GetLength()));
201     SuccessOrExit(error = Flush());
202 
203 exit:
204     aMessage.Free();
205     return error;
206 }
207 
Send(uint8_t * aBuf,uint16_t aLength)208 Error BleSecure::Send(uint8_t *aBuf, uint16_t aLength)
209 {
210     Error error = kErrorNone;
211 
212     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
213     if (mSendMessage == nullptr)
214     {
215         mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
216         VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
217     }
218     SuccessOrExit(error = mSendMessage->AppendBytes(aBuf, aLength));
219 
220 exit:
221     return error;
222 }
223 
SendApplicationTlv(uint8_t * aBuf,uint16_t aLength)224 Error BleSecure::SendApplicationTlv(uint8_t *aBuf, uint16_t aLength)
225 {
226     Error error = kErrorNone;
227     if (aLength > Tlv::kBaseTlvMaxLength)
228     {
229         ot::ExtendedTlv tlv;
230 
231         tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
232         tlv.SetLength(aLength);
233         SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
234     }
235     else
236     {
237         ot::Tlv tlv;
238 
239         tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
240         tlv.SetLength((uint8_t)aLength);
241         SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
242     }
243 
244     error = Send(aBuf, aLength);
245 exit:
246     return error;
247 }
248 
Flush(void)249 Error BleSecure::Flush(void)
250 {
251     Error error = kErrorNone;
252 
253     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
254     VerifyOrExit(mSendMessage->GetLength() != 0, error = kErrorNone);
255 
256     mTransmitQueue.Enqueue(*mSendMessage);
257     mTransmitTask.Post();
258 
259     mSendMessage = nullptr;
260 
261 exit:
262     return error;
263 }
264 
HandleBleReceive(uint8_t * aBuf,uint16_t aLength)265 Error BleSecure::HandleBleReceive(uint8_t *aBuf, uint16_t aLength)
266 {
267     ot::Message     *message = nullptr;
268     Ip6::MessageInfo messageInfo;
269     Error            error = kErrorNone;
270 
271     if ((message = Get<MessagePool>().Allocate(Message::kTypeBle, 0)) == nullptr)
272     {
273         error = kErrorNoBufs;
274         ExitNow();
275     }
276     SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
277 
278     // Cannot call Receive(..) directly because Setup(..) and mState are private
279     mTls.HandleReceive(*message, messageInfo);
280 
281 exit:
282     FreeMessage(message);
283     return error;
284 }
285 
HandleBleConnected(uint16_t aConnectionId)286 void BleSecure::HandleBleConnected(uint16_t aConnectionId)
287 {
288     OT_UNUSED_VARIABLE(aConnectionId);
289 
290     mBleState = kConnected;
291 
292     IgnoreReturnValue(otPlatBleGattMtuGet(&GetInstance(), &mMtuSize));
293 
294     mConnectCallback.InvokeIfSet(&GetInstance(), IsConnected(), true);
295 }
296 
HandleBleDisconnected(uint16_t aConnectionId)297 void BleSecure::HandleBleDisconnected(uint16_t aConnectionId)
298 {
299     OT_UNUSED_VARIABLE(aConnectionId);
300 
301     mBleState = kAdvertising;
302     mMtuSize  = kInitialMtuSize;
303 
304     Disconnect(); // Stop TLS connection
305 }
306 
HandleBleMtuUpdate(uint16_t aMtu)307 Error BleSecure::HandleBleMtuUpdate(uint16_t aMtu)
308 {
309     Error error = kErrorNone;
310 
311     if (aMtu <= OT_BLE_ATT_MTU_MAX)
312     {
313         mMtuSize = aMtu;
314     }
315     else
316     {
317         mMtuSize = OT_BLE_ATT_MTU_MAX;
318         error    = kErrorInvalidArgs;
319     }
320 
321     return error;
322 }
323 
HandleTlsConnected(void * aContext,bool aConnected)324 void BleSecure::HandleTlsConnected(void *aContext, bool aConnected)
325 {
326     return static_cast<BleSecure *>(aContext)->HandleTlsConnected(aConnected);
327 }
328 
HandleTlsConnected(bool aConnected)329 void BleSecure::HandleTlsConnected(bool aConnected)
330 {
331     if (aConnected)
332     {
333         if (mReceivedMessage == nullptr)
334         {
335             mReceivedMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
336         }
337 
338         if (mTcatAgent.IsEnabled())
339         {
340             Error err = mTcatAgent.Connected(mTls);
341 
342             if (err != kErrorNone)
343             {
344                 mTls.Close();
345                 LogWarn("Rejected TCAT Commissioner, error: %s", ErrorToString(err));
346                 ExitNow();
347             }
348         }
349     }
350     else
351     {
352         FreeMessage(mReceivedMessage);
353         mReceivedMessage = nullptr;
354 
355         if (mTcatAgent.IsEnabled())
356         {
357             mTcatAgent.Disconnected();
358         }
359     }
360 
361     mConnectCallback.InvokeIfSet(&GetInstance(), aConnected, true);
362 
363 exit:
364     return;
365 }
366 
HandleTlsReceive(void * aContext,uint8_t * aBuf,uint16_t aLength)367 void BleSecure::HandleTlsReceive(void *aContext, uint8_t *aBuf, uint16_t aLength)
368 {
369     return static_cast<BleSecure *>(aContext)->HandleTlsReceive(aBuf, aLength);
370 }
371 
HandleTlsReceive(uint8_t * aBuf,uint16_t aLength)372 void BleSecure::HandleTlsReceive(uint8_t *aBuf, uint16_t aLength)
373 {
374     VerifyOrExit(mReceivedMessage != nullptr);
375 
376     if (!mTlvMode)
377     {
378         SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
379         mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, 0, OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
380         IgnoreReturnValue(mReceivedMessage->SetLength(0));
381     }
382     else
383     {
384         ot::Tlv  tlv;
385         uint32_t requiredBytes = sizeof(Tlv);
386         uint32_t offset;
387 
388         while (aLength > 0)
389         {
390             if (mReceivedMessage->GetLength() < requiredBytes)
391             {
392                 uint32_t missingBytes = requiredBytes - mReceivedMessage->GetLength();
393 
394                 if (missingBytes > aLength)
395                 {
396                     SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
397                     break;
398                 }
399                 else
400                 {
401                     SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, (uint16_t)missingBytes));
402                     aLength -= missingBytes;
403                     aBuf += missingBytes;
404                 }
405             }
406 
407             IgnoreReturnValue(mReceivedMessage->Read(0, tlv));
408 
409             if (tlv.IsExtended())
410             {
411                 ot::ExtendedTlv extTlv;
412                 requiredBytes = sizeof(extTlv);
413 
414                 if (mReceivedMessage->GetLength() < requiredBytes)
415                 {
416                     continue;
417                 }
418 
419                 IgnoreReturnValue(mReceivedMessage->Read(0, extTlv));
420                 requiredBytes = extTlv.GetSize();
421                 offset        = sizeof(extTlv);
422             }
423             else
424             {
425                 requiredBytes = tlv.GetSize();
426                 offset        = sizeof(tlv);
427             }
428 
429             if (mReceivedMessage->GetLength() < requiredBytes)
430             {
431                 continue;
432             }
433 
434             // TLV fully loaded
435 
436             if (mTcatAgent.IsEnabled())
437             {
438                 ot::Message *message;
439                 Error        error = kErrorNone;
440 
441                 message = Get<MessagePool>().Allocate(Message::kTypeBle);
442                 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
443 
444                 error = mTcatAgent.HandleSingleTlv(*mReceivedMessage, *message);
445                 if (message->GetLength() != 0)
446                 {
447                     IgnoreReturnValue(SendMessage(*message));
448                 }
449 
450                 if (error == kErrorAbort)
451                 {
452                     Disconnect();
453                     Stop();
454                     ExitNow();
455                 }
456             }
457             else
458             {
459                 mReceivedMessage->SetOffset((uint16_t)offset);
460                 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, (int32_t)offset,
461                                              OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
462             }
463 
464             SuccessOrExit(mReceivedMessage->SetLength(0)); // also sets the offset to 0
465             requiredBytes = sizeof(Tlv);
466         }
467     }
468 
469 exit:
470     return;
471 }
472 
HandleTransmit(void)473 void BleSecure::HandleTransmit(void)
474 {
475     Error        error   = kErrorNone;
476     ot::Message *message = mTransmitQueue.GetHead();
477 
478     VerifyOrExit(message != nullptr);
479     mTransmitQueue.Dequeue(*message);
480 
481     if (mTransmitQueue.GetHead() != nullptr)
482     {
483         mTransmitTask.Post();
484     }
485 
486     SuccessOrExit(error = mTls.Send(*message, message->GetLength()));
487 
488 exit:
489     if (error != kErrorNone)
490     {
491         LogNote("Transmit: %s", ErrorToString(error));
492         message->Free();
493     }
494     else
495     {
496         LogDebg("Transmit: %s", ErrorToString(error));
497     }
498 }
499 
HandleTransport(void * aContext,ot::Message & aMessage,const Ip6::MessageInfo & aMessageInfo)500 Error BleSecure::HandleTransport(void *aContext, ot::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
501 {
502     OT_UNUSED_VARIABLE(aMessageInfo);
503     return static_cast<BleSecure *>(aContext)->HandleTransport(aMessage);
504 }
505 
HandleTransport(ot::Message & aMessage)506 Error BleSecure::HandleTransport(ot::Message &aMessage)
507 {
508     otBleRadioPacket packet;
509     uint16_t         len    = aMessage.GetLength();
510     uint16_t         offset = 0;
511     Error            error  = kErrorNone;
512 
513     while (len > 0)
514     {
515         if (len <= mMtuSize - kGattOverhead)
516         {
517             packet.mLength = len;
518         }
519         else
520         {
521             packet.mLength = mMtuSize - kGattOverhead;
522         }
523 
524         if (packet.mLength > kPacketBufferSize)
525         {
526             packet.mLength = kPacketBufferSize;
527         }
528 
529         IgnoreReturnValue(aMessage.Read(offset, mPacketBuffer, packet.mLength));
530         packet.mValue = mPacketBuffer;
531         packet.mPower = OT_BLE_DEFAULT_POWER;
532 
533         SuccessOrExit(error = otPlatBleGattServerIndicate(&GetInstance(), kTxBleHandle, &packet));
534 
535         len -= packet.mLength;
536         offset += packet.mLength;
537     }
538 
539     aMessage.Free();
540 exit:
541     return error;
542 }
543 
544 } // namespace Ble
545 } // namespace ot
546 
otPlatBleGattServerOnWriteRequest(otInstance * aInstance,uint16_t aHandle,const otBleRadioPacket * aPacket)547 void otPlatBleGattServerOnWriteRequest(otInstance *aInstance, uint16_t aHandle, const otBleRadioPacket *aPacket)
548 {
549     OT_UNUSED_VARIABLE(aHandle); // Only a single handle is expected for RX
550 
551     VerifyOrExit(aPacket != nullptr);
552     IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleReceive(aPacket->mValue, aPacket->mLength));
553 exit:
554     return;
555 }
556 
otPlatBleGapOnConnected(otInstance * aInstance,uint16_t aConnectionId)557 void otPlatBleGapOnConnected(otInstance *aInstance, uint16_t aConnectionId)
558 {
559     AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleConnected(aConnectionId);
560 }
561 
otPlatBleGapOnDisconnected(otInstance * aInstance,uint16_t aConnectionId)562 void otPlatBleGapOnDisconnected(otInstance *aInstance, uint16_t aConnectionId)
563 {
564     AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleDisconnected(aConnectionId);
565 }
566 
otPlatBleGattOnMtuUpdate(otInstance * aInstance,uint16_t aMtu)567 void otPlatBleGattOnMtuUpdate(otInstance *aInstance, uint16_t aMtu)
568 {
569     IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleMtuUpdate(aMtu));
570 }
571 
572 #endif // OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
573