1 /*
2  *  Copyright (c) 2023, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "ble_secure.hpp"
30 
31 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
32 
33 #include <openthread/platform/ble.h>
34 #include "common/locator_getters.hpp"
35 #include "common/log.hpp"
36 #include "common/tlvs.hpp"
37 #include "instance/instance.hpp"
38 #include "meshcop/secure_transport.hpp"
39 
40 using namespace ot;
41 
42 /**
43  * @file
44  *   This file implements the secure Ble agent.
45  */
46 
47 namespace ot {
48 namespace Ble {
49 
50 RegisterLogModule("BleSecure");
51 
BleSecure(Instance & aInstance)52 BleSecure::BleSecure(Instance &aInstance)
53     : InstanceLocator(aInstance)
54     , mTls(aInstance, false, false)
55     , mTcatAgent(aInstance)
56     , mTlvMode(false)
57     , mReceivedMessage(nullptr)
58     , mSendMessage(nullptr)
59     , mTransmitTask(aInstance)
60     , mBleState(kStopped)
61     , mMtuSize(kInitialMtuSize)
62 {
63 }
64 
Start(ConnectCallback aConnectHandler,ReceiveCallback aReceiveHandler,bool aTlvMode,void * aContext)65 Error BleSecure::Start(ConnectCallback aConnectHandler, ReceiveCallback aReceiveHandler, bool aTlvMode, void *aContext)
66 {
67     Error error = kErrorNone;
68 
69     VerifyOrExit(mBleState == kStopped, error = kErrorAlready);
70 
71     mConnectCallback.Set(aConnectHandler, aContext);
72     mReceiveCallback.Set(aReceiveHandler, aContext);
73     mTlvMode = aTlvMode;
74     mMtuSize = kInitialMtuSize;
75 
76     SuccessOrExit(error = otPlatBleEnable(&GetInstance()));
77     SuccessOrExit(error = otPlatBleGapAdvStart(&GetInstance(), OT_BLE_ADV_INTERVAL_DEFAULT));
78     SuccessOrExit(error = mTls.Open(&BleSecure::HandleTlsReceive, &BleSecure::HandleTlsConnected, this));
79     SuccessOrExit(error = mTls.Bind(HandleTransport, this));
80 
81 exit:
82     if (error == kErrorNone)
83     {
84         mBleState = kAdvertising;
85     }
86     return error;
87 }
88 
TcatStart(const MeshCoP::TcatAgent::VendorInfo & aVendorInfo,MeshCoP::TcatAgent::JoinCallback aJoinHandler)89 Error BleSecure::TcatStart(const MeshCoP::TcatAgent::VendorInfo &aVendorInfo,
90                            MeshCoP::TcatAgent::JoinCallback      aJoinHandler)
91 {
92     return mTcatAgent.Start(aVendorInfo, mReceiveCallback.GetHandler(), aJoinHandler, mReceiveCallback.GetContext());
93 }
94 
Stop(void)95 void BleSecure::Stop(void)
96 {
97     VerifyOrExit(mBleState != kStopped);
98     SuccessOrExit(otPlatBleGapAdvStop(&GetInstance()));
99     SuccessOrExit(otPlatBleDisable(&GetInstance()));
100     mBleState = kStopped;
101     mMtuSize  = kInitialMtuSize;
102 
103     if (mTcatAgent.IsEnabled())
104     {
105         mTcatAgent.Stop();
106     }
107 
108     mTls.Close();
109 
110     mTransmitQueue.DequeueAndFreeAll();
111 
112     mConnectCallback.Clear();
113     mReceiveCallback.Clear();
114 
115     FreeMessage(mReceivedMessage);
116     mReceivedMessage = nullptr;
117     FreeMessage(mSendMessage);
118     mSendMessage = nullptr;
119 
120 exit:
121     return;
122 }
123 
Connect(void)124 Error BleSecure::Connect(void)
125 {
126     Ip6::SockAddr sockaddr;
127 
128     return mTls.Connect(sockaddr);
129 }
130 
Disconnect(void)131 void BleSecure::Disconnect(void)
132 {
133     if (mTls.IsConnected())
134     {
135         mTls.Disconnect();
136     }
137 
138     if (mBleState == kConnected)
139     {
140         IgnoreReturnValue(otPlatBleGapDisconnect(&GetInstance()));
141     }
142 }
143 
SetPsk(const MeshCoP::JoinerPskd & aPskd)144 void BleSecure::SetPsk(const MeshCoP::JoinerPskd &aPskd)
145 {
146     static_assert(static_cast<uint16_t>(MeshCoP::JoinerPskd::kMaxLength) <=
147                       static_cast<uint16_t>(MeshCoP::SecureTransport::kPskMaxLength),
148                   "The maximum length of TLS PSK is smaller than joiner PSKd");
149 
150     SuccessOrAssert(mTls.SetPsk(reinterpret_cast<const uint8_t *>(aPskd.GetAsCString()), aPskd.GetLength()));
151 }
152 
153 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength)154 Error BleSecure::GetPeerCertificateBase64(unsigned char *aPeerCert, size_t *aCertLength)
155 {
156     Error error;
157 
158     VerifyOrExit(aCertLength != nullptr, error = kErrorInvalidArgs);
159 
160     error = mTls.GetPeerCertificateBase64(aPeerCert, aCertLength, *aCertLength);
161 
162 exit:
163     return error;
164 }
165 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
166 
SendMessage(ot::Message & aMessage)167 Error BleSecure::SendMessage(ot::Message &aMessage)
168 {
169     Error error = kErrorNone;
170 
171     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
172     if (mSendMessage == nullptr)
173     {
174         mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
175         VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
176     }
177     SuccessOrExit(error = mSendMessage->AppendBytesFromMessage(aMessage, 0, aMessage.GetLength()));
178     SuccessOrExit(error = Flush());
179 
180 exit:
181     aMessage.Free();
182     return error;
183 }
184 
Send(uint8_t * aBuf,uint16_t aLength)185 Error BleSecure::Send(uint8_t *aBuf, uint16_t aLength)
186 {
187     Error error = kErrorNone;
188 
189     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
190     if (mSendMessage == nullptr)
191     {
192         mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
193         VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
194     }
195     SuccessOrExit(error = mSendMessage->AppendBytes(aBuf, aLength));
196 
197 exit:
198     return error;
199 }
200 
SendApplicationTlv(uint8_t * aBuf,uint16_t aLength)201 Error BleSecure::SendApplicationTlv(uint8_t *aBuf, uint16_t aLength)
202 {
203     Error error = kErrorNone;
204     if (aLength > Tlv::kBaseTlvMaxLength)
205     {
206         ot::ExtendedTlv tlv;
207 
208         tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
209         tlv.SetLength(aLength);
210         SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
211     }
212     else
213     {
214         ot::Tlv tlv;
215 
216         tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
217         tlv.SetLength((uint8_t)aLength);
218         SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
219     }
220 
221     error = Send(aBuf, aLength);
222 exit:
223     return error;
224 }
225 
Flush(void)226 Error BleSecure::Flush(void)
227 {
228     Error error = kErrorNone;
229 
230     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
231     VerifyOrExit(mSendMessage->GetLength() != 0, error = kErrorNone);
232 
233     mTransmitQueue.Enqueue(*mSendMessage);
234     mTransmitTask.Post();
235 
236     mSendMessage = nullptr;
237 
238 exit:
239     return error;
240 }
241 
HandleBleReceive(uint8_t * aBuf,uint16_t aLength)242 Error BleSecure::HandleBleReceive(uint8_t *aBuf, uint16_t aLength)
243 {
244     ot::Message     *message = nullptr;
245     Ip6::MessageInfo messageInfo;
246     Error            error = kErrorNone;
247 
248     if ((message = Get<MessagePool>().Allocate(Message::kTypeBle, 0)) == nullptr)
249     {
250         error = kErrorNoBufs;
251         ExitNow();
252     }
253     SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
254 
255     // Cannot call Receive(..) directly because Setup(..) and mState are private
256     mTls.HandleReceive(*message, messageInfo);
257 
258 exit:
259     FreeMessage(message);
260     return error;
261 }
262 
HandleBleConnected(uint16_t aConnectionId)263 void BleSecure::HandleBleConnected(uint16_t aConnectionId)
264 {
265     OT_UNUSED_VARIABLE(aConnectionId);
266 
267     mBleState = kConnected;
268 
269     IgnoreReturnValue(otPlatBleGattMtuGet(&GetInstance(), &mMtuSize));
270 
271     mConnectCallback.InvokeIfSet(&GetInstance(), IsConnected(), true);
272 }
273 
HandleBleDisconnected(uint16_t aConnectionId)274 void BleSecure::HandleBleDisconnected(uint16_t aConnectionId)
275 {
276     OT_UNUSED_VARIABLE(aConnectionId);
277 
278     mBleState = kAdvertising;
279     mMtuSize  = kInitialMtuSize;
280 
281     if (IsConnected())
282     {
283         Disconnect(); // Stop TLS connection
284     }
285 
286     mConnectCallback.InvokeIfSet(&GetInstance(), false, false);
287 }
288 
HandleBleMtuUpdate(uint16_t aMtu)289 Error BleSecure::HandleBleMtuUpdate(uint16_t aMtu)
290 {
291     Error error = kErrorNone;
292 
293     if (aMtu <= OT_BLE_ATT_MTU_MAX)
294     {
295         mMtuSize = aMtu;
296     }
297     else
298     {
299         mMtuSize = OT_BLE_ATT_MTU_MAX;
300         error    = kErrorInvalidArgs;
301     }
302 
303     return error;
304 }
305 
HandleTlsConnected(void * aContext,bool aConnected)306 void BleSecure::HandleTlsConnected(void *aContext, bool aConnected)
307 {
308     return static_cast<BleSecure *>(aContext)->HandleTlsConnected(aConnected);
309 }
310 
HandleTlsConnected(bool aConnected)311 void BleSecure::HandleTlsConnected(bool aConnected)
312 {
313     if (aConnected)
314     {
315         if (mReceivedMessage == nullptr)
316         {
317             mReceivedMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
318         }
319 
320         if (mTcatAgent.IsEnabled())
321         {
322             IgnoreReturnValue(mTcatAgent.Connected(mTls));
323         }
324     }
325     else
326     {
327         FreeMessage(mReceivedMessage);
328         mReceivedMessage = nullptr;
329 
330         if (mTcatAgent.IsEnabled())
331         {
332             mTcatAgent.Disconnected();
333         }
334     }
335 
336     mConnectCallback.InvokeIfSet(&GetInstance(), aConnected, true);
337 }
338 
HandleTlsReceive(void * aContext,uint8_t * aBuf,uint16_t aLength)339 void BleSecure::HandleTlsReceive(void *aContext, uint8_t *aBuf, uint16_t aLength)
340 {
341     return static_cast<BleSecure *>(aContext)->HandleTlsReceive(aBuf, aLength);
342 }
343 
HandleTlsReceive(uint8_t * aBuf,uint16_t aLength)344 void BleSecure::HandleTlsReceive(uint8_t *aBuf, uint16_t aLength)
345 {
346     VerifyOrExit(mReceivedMessage != nullptr);
347 
348     if (!mTlvMode)
349     {
350         SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
351         mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, 0, OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
352         IgnoreReturnValue(mReceivedMessage->SetLength(0));
353     }
354     else
355     {
356         ot::Tlv  tlv;
357         uint32_t requiredBytes = sizeof(Tlv);
358         uint32_t offset;
359 
360         while (aLength > 0)
361         {
362             if (mReceivedMessage->GetLength() < requiredBytes)
363             {
364                 uint32_t missingBytes = requiredBytes - mReceivedMessage->GetLength();
365 
366                 if (missingBytes > aLength)
367                 {
368                     SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
369                     break;
370                 }
371                 else
372                 {
373                     SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, (uint16_t)missingBytes));
374                     aLength -= missingBytes;
375                     aBuf += missingBytes;
376                 }
377             }
378 
379             IgnoreReturnValue(mReceivedMessage->Read(0, tlv));
380 
381             if (tlv.IsExtended())
382             {
383                 ot::ExtendedTlv extTlv;
384                 requiredBytes = sizeof(extTlv);
385 
386                 if (mReceivedMessage->GetLength() < requiredBytes)
387                 {
388                     continue;
389                 }
390 
391                 IgnoreReturnValue(mReceivedMessage->Read(0, extTlv));
392                 requiredBytes = extTlv.GetSize();
393                 offset        = sizeof(extTlv);
394             }
395             else
396             {
397                 requiredBytes = tlv.GetSize();
398                 offset        = sizeof(tlv);
399             }
400 
401             if (mReceivedMessage->GetLength() < requiredBytes)
402             {
403                 continue;
404             }
405 
406             // TLV fully loaded
407 
408             if (mTcatAgent.IsEnabled())
409             {
410                 ot::Message *message;
411                 Error        error = kErrorNone;
412 
413                 message = Get<MessagePool>().Allocate(Message::kTypeBle);
414                 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
415 
416                 error = mTcatAgent.HandleSingleTlv(*mReceivedMessage, *message);
417                 if (message->GetLength() != 0)
418                 {
419                     IgnoreReturnValue(SendMessage(*message));
420                 }
421 
422                 if (error == kErrorAbort)
423                 {
424                     Disconnect();
425                     Stop();
426                     ExitNow();
427                 }
428             }
429             else
430             {
431                 mReceivedMessage->SetOffset((uint16_t)offset);
432                 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, (int32_t)offset,
433                                              OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
434             }
435 
436             SuccessOrExit(mReceivedMessage->SetLength(0)); // also sets the offset to 0
437             requiredBytes = sizeof(Tlv);
438         }
439     }
440 
441 exit:
442     return;
443 }
444 
HandleTransmit(void)445 void BleSecure::HandleTransmit(void)
446 {
447     Error        error   = kErrorNone;
448     ot::Message *message = mTransmitQueue.GetHead();
449 
450     VerifyOrExit(message != nullptr);
451     mTransmitQueue.Dequeue(*message);
452 
453     if (mTransmitQueue.GetHead() != nullptr)
454     {
455         mTransmitTask.Post();
456     }
457 
458     SuccessOrExit(error = mTls.Send(*message, message->GetLength()));
459 
460 exit:
461     if (error != kErrorNone)
462     {
463         LogNote("Transmit: %s", ErrorToString(error));
464         message->Free();
465     }
466     else
467     {
468         LogDebg("Transmit: %s", ErrorToString(error));
469     }
470 }
471 
HandleTransport(void * aContext,ot::Message & aMessage,const Ip6::MessageInfo & aMessageInfo)472 Error BleSecure::HandleTransport(void *aContext, ot::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
473 {
474     OT_UNUSED_VARIABLE(aMessageInfo);
475     return static_cast<BleSecure *>(aContext)->HandleTransport(aMessage);
476 }
477 
HandleTransport(ot::Message & aMessage)478 Error BleSecure::HandleTransport(ot::Message &aMessage)
479 {
480     otBleRadioPacket packet;
481     uint16_t         len    = aMessage.GetLength();
482     uint16_t         offset = 0;
483     Error            error  = kErrorNone;
484 
485     while (len > 0)
486     {
487         if (len <= mMtuSize - kGattOverhead)
488         {
489             packet.mLength = len;
490         }
491         else
492         {
493             packet.mLength = mMtuSize - kGattOverhead;
494         }
495 
496         if (packet.mLength > kPacketBufferSize)
497         {
498             packet.mLength = kPacketBufferSize;
499         }
500 
501         IgnoreReturnValue(aMessage.Read(offset, mPacketBuffer, packet.mLength));
502         packet.mValue = mPacketBuffer;
503         packet.mPower = OT_BLE_DEFAULT_POWER;
504 
505         SuccessOrExit(error = otPlatBleGattServerIndicate(&GetInstance(), kTxBleHandle, &packet));
506 
507         len -= packet.mLength;
508         offset += packet.mLength;
509     }
510 
511     aMessage.Free();
512 exit:
513     return error;
514 }
515 
516 } // namespace Ble
517 } // namespace ot
518 
otPlatBleGattServerOnWriteRequest(otInstance * aInstance,uint16_t aHandle,const otBleRadioPacket * aPacket)519 void otPlatBleGattServerOnWriteRequest(otInstance *aInstance, uint16_t aHandle, const otBleRadioPacket *aPacket)
520 {
521     OT_UNUSED_VARIABLE(aHandle); // Only a single handle is expected for RX
522 
523     VerifyOrExit(aPacket != nullptr);
524     IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleReceive(aPacket->mValue, aPacket->mLength));
525 exit:
526     return;
527 }
528 
otPlatBleGapOnConnected(otInstance * aInstance,uint16_t aConnectionId)529 void otPlatBleGapOnConnected(otInstance *aInstance, uint16_t aConnectionId)
530 {
531     AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleConnected(aConnectionId);
532 }
533 
otPlatBleGapOnDisconnected(otInstance * aInstance,uint16_t aConnectionId)534 void otPlatBleGapOnDisconnected(otInstance *aInstance, uint16_t aConnectionId)
535 {
536     AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleDisconnected(aConnectionId);
537 }
538 
otPlatBleGattOnMtuUpdate(otInstance * aInstance,uint16_t aMtu)539 void otPlatBleGattOnMtuUpdate(otInstance *aInstance, uint16_t aMtu)
540 {
541     IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleMtuUpdate(aMtu));
542 }
543 
544 #endif // OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
545