1 /*
2 * Copyright (c) 2023, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "ble_secure.hpp"
30
31 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
32
33 #include <openthread/platform/ble.h>
34 #include "common/locator_getters.hpp"
35 #include "common/log.hpp"
36 #include "common/tlvs.hpp"
37 #include "instance/instance.hpp"
38 #include "meshcop/secure_transport.hpp"
39
40 using namespace ot;
41
42 /**
43 * @file
44 * This file implements the secure Ble agent.
45 */
46
47 namespace ot {
48 namespace Ble {
49
50 RegisterLogModule("BleSecure");
51
BleSecure(Instance & aInstance)52 BleSecure::BleSecure(Instance &aInstance)
53 : InstanceLocator(aInstance)
54 , mTls(aInstance, false, false)
55 , mTcatAgent(aInstance)
56 , mTlvMode(false)
57 , mReceivedMessage(nullptr)
58 , mSendMessage(nullptr)
59 , mTransmitTask(aInstance)
60 , mBleState(kStopped)
61 , mMtuSize(kInitialMtuSize)
62 {
63 }
64
Start(ConnectCallback aConnectHandler,ReceiveCallback aReceiveHandler,bool aTlvMode,void * aContext)65 Error BleSecure::Start(ConnectCallback aConnectHandler, ReceiveCallback aReceiveHandler, bool aTlvMode, void *aContext)
66 {
67 Error error = kErrorNone;
68 uint16_t advertisementLen = 0;
69 uint8_t *advertisementData = nullptr;
70
71 VerifyOrExit(mBleState == kStopped, error = kErrorAlready);
72
73 mConnectCallback.Set(aConnectHandler, aContext);
74 mReceiveCallback.Set(aReceiveHandler, aContext);
75 mTlvMode = aTlvMode;
76 mMtuSize = kInitialMtuSize;
77
78 SuccessOrExit(error = otPlatBleEnable(&GetInstance()));
79
80 SuccessOrExit(error = otPlatBleGetAdvertisementBuffer(&GetInstance(), &advertisementData));
81 SuccessOrExit(error = mTcatAgent.GetAdvertisementData(advertisementLen, advertisementData));
82 VerifyOrExit(advertisementData != nullptr, error = kErrorFailed);
83 SuccessOrExit(error = otPlatBleGapAdvSetData(&GetInstance(), advertisementData, advertisementLen));
84 SuccessOrExit(error = otPlatBleGapAdvStart(&GetInstance(), OT_BLE_ADV_INTERVAL_DEFAULT));
85
86 SuccessOrExit(error = mTls.Open(&BleSecure::HandleTlsReceive, &BleSecure::HandleTlsConnected, this));
87 SuccessOrExit(error = mTls.Bind(HandleTransport, this));
88
89 exit:
90 if (error == kErrorNone)
91 {
92 mBleState = kAdvertising;
93 }
94 return error;
95 }
96
TcatStart(MeshCoP::TcatAgent::JoinCallback aJoinHandler)97 Error BleSecure::TcatStart(MeshCoP::TcatAgent::JoinCallback aJoinHandler)
98 {
99 Error error;
100
101 VerifyOrExit(mBleState != kStopped, error = kErrorInvalidState);
102
103 error = mTcatAgent.Start(mReceiveCallback.GetHandler(), aJoinHandler, mReceiveCallback.GetContext());
104
105 exit:
106 return error;
107 }
108
Stop(void)109 void BleSecure::Stop(void)
110 {
111 VerifyOrExit(mBleState != kStopped);
112 SuccessOrExit(otPlatBleGapAdvStop(&GetInstance()));
113 SuccessOrExit(otPlatBleDisable(&GetInstance()));
114 mBleState = kStopped;
115 mMtuSize = kInitialMtuSize;
116
117 if (mTcatAgent.IsEnabled())
118 {
119 mTcatAgent.Stop();
120 }
121
122 mTls.Close();
123
124 mTransmitQueue.DequeueAndFreeAll();
125
126 mConnectCallback.Clear();
127 mReceiveCallback.Clear();
128
129 FreeMessage(mReceivedMessage);
130 mReceivedMessage = nullptr;
131 FreeMessage(mSendMessage);
132 mSendMessage = nullptr;
133
134 exit:
135 return;
136 }
137
Connect(void)138 Error BleSecure::Connect(void)
139 {
140 Ip6::SockAddr sockaddr;
141 Error error;
142
143 VerifyOrExit(mBleState == kConnected, error = kErrorInvalidState);
144
145 error = mTls.Connect(sockaddr);
146
147 exit:
148 return error;
149 }
150
Disconnect(void)151 void BleSecure::Disconnect(void)
152 {
153 if (mTls.IsConnected())
154 {
155 mTls.Disconnect();
156 }
157
158 if (mBleState == kConnected)
159 {
160 mBleState = kAdvertising;
161 IgnoreReturnValue(otPlatBleGapDisconnect(&GetInstance()));
162 }
163
164 mConnectCallback.InvokeIfSet(&GetInstance(), false, false);
165 }
166
SetPsk(const MeshCoP::JoinerPskd & aPskd)167 void BleSecure::SetPsk(const MeshCoP::JoinerPskd &aPskd)
168 {
169 static_assert(static_cast<uint16_t>(MeshCoP::JoinerPskd::kMaxLength) <=
170 static_cast<uint16_t>(MeshCoP::SecureTransport::kPskMaxLength),
171 "The maximum length of TLS PSK is smaller than joiner PSKd");
172
173 SuccessOrAssert(mTls.SetPsk(reinterpret_cast<const uint8_t *>(aPskd.GetAsCString()), aPskd.GetLength()));
174 }
175
176 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength)177 Error BleSecure::GetPeerCertificateBase64(unsigned char *aPeerCert, size_t *aCertLength)
178 {
179 Error error;
180
181 VerifyOrExit(aCertLength != nullptr, error = kErrorInvalidArgs);
182
183 error = mTls.GetPeerCertificateBase64(aPeerCert, aCertLength, *aCertLength);
184
185 exit:
186 return error;
187 }
188 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
189
SendMessage(ot::Message & aMessage)190 Error BleSecure::SendMessage(ot::Message &aMessage)
191 {
192 Error error = kErrorNone;
193
194 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
195 if (mSendMessage == nullptr)
196 {
197 mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
198 VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
199 }
200 SuccessOrExit(error = mSendMessage->AppendBytesFromMessage(aMessage, 0, aMessage.GetLength()));
201 SuccessOrExit(error = Flush());
202
203 exit:
204 aMessage.Free();
205 return error;
206 }
207
Send(uint8_t * aBuf,uint16_t aLength)208 Error BleSecure::Send(uint8_t *aBuf, uint16_t aLength)
209 {
210 Error error = kErrorNone;
211
212 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
213 if (mSendMessage == nullptr)
214 {
215 mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
216 VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
217 }
218 SuccessOrExit(error = mSendMessage->AppendBytes(aBuf, aLength));
219
220 exit:
221 return error;
222 }
223
SendApplicationTlv(uint8_t * aBuf,uint16_t aLength)224 Error BleSecure::SendApplicationTlv(uint8_t *aBuf, uint16_t aLength)
225 {
226 Error error = kErrorNone;
227 if (aLength > Tlv::kBaseTlvMaxLength)
228 {
229 ot::ExtendedTlv tlv;
230
231 tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
232 tlv.SetLength(aLength);
233 SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
234 }
235 else
236 {
237 ot::Tlv tlv;
238
239 tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
240 tlv.SetLength((uint8_t)aLength);
241 SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
242 }
243
244 error = Send(aBuf, aLength);
245 exit:
246 return error;
247 }
248
Flush(void)249 Error BleSecure::Flush(void)
250 {
251 Error error = kErrorNone;
252
253 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
254 VerifyOrExit(mSendMessage->GetLength() != 0, error = kErrorNone);
255
256 mTransmitQueue.Enqueue(*mSendMessage);
257 mTransmitTask.Post();
258
259 mSendMessage = nullptr;
260
261 exit:
262 return error;
263 }
264
HandleBleReceive(uint8_t * aBuf,uint16_t aLength)265 Error BleSecure::HandleBleReceive(uint8_t *aBuf, uint16_t aLength)
266 {
267 ot::Message *message = nullptr;
268 Ip6::MessageInfo messageInfo;
269 Error error = kErrorNone;
270
271 if ((message = Get<MessagePool>().Allocate(Message::kTypeBle, 0)) == nullptr)
272 {
273 error = kErrorNoBufs;
274 ExitNow();
275 }
276 SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
277
278 // Cannot call Receive(..) directly because Setup(..) and mState are private
279 mTls.HandleReceive(*message, messageInfo);
280
281 exit:
282 FreeMessage(message);
283 return error;
284 }
285
HandleBleConnected(uint16_t aConnectionId)286 void BleSecure::HandleBleConnected(uint16_t aConnectionId)
287 {
288 OT_UNUSED_VARIABLE(aConnectionId);
289
290 mBleState = kConnected;
291
292 IgnoreReturnValue(otPlatBleGattMtuGet(&GetInstance(), &mMtuSize));
293
294 mConnectCallback.InvokeIfSet(&GetInstance(), IsConnected(), true);
295 }
296
HandleBleDisconnected(uint16_t aConnectionId)297 void BleSecure::HandleBleDisconnected(uint16_t aConnectionId)
298 {
299 OT_UNUSED_VARIABLE(aConnectionId);
300
301 mBleState = kAdvertising;
302 mMtuSize = kInitialMtuSize;
303
304 Disconnect(); // Stop TLS connection
305 }
306
HandleBleMtuUpdate(uint16_t aMtu)307 Error BleSecure::HandleBleMtuUpdate(uint16_t aMtu)
308 {
309 Error error = kErrorNone;
310
311 if (aMtu <= OT_BLE_ATT_MTU_MAX)
312 {
313 mMtuSize = aMtu;
314 }
315 else
316 {
317 mMtuSize = OT_BLE_ATT_MTU_MAX;
318 error = kErrorInvalidArgs;
319 }
320
321 return error;
322 }
323
HandleTlsConnected(void * aContext,bool aConnected)324 void BleSecure::HandleTlsConnected(void *aContext, bool aConnected)
325 {
326 return static_cast<BleSecure *>(aContext)->HandleTlsConnected(aConnected);
327 }
328
HandleTlsConnected(bool aConnected)329 void BleSecure::HandleTlsConnected(bool aConnected)
330 {
331 if (aConnected)
332 {
333 if (mReceivedMessage == nullptr)
334 {
335 mReceivedMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
336 }
337
338 if (mTcatAgent.IsEnabled())
339 {
340 Error err = mTcatAgent.Connected(mTls);
341
342 if (err != kErrorNone)
343 {
344 mTls.Close();
345 LogWarn("Rejected TCAT Commissioner, error: %s", ErrorToString(err));
346 ExitNow();
347 }
348 }
349 }
350 else
351 {
352 FreeMessage(mReceivedMessage);
353 mReceivedMessage = nullptr;
354
355 if (mTcatAgent.IsEnabled())
356 {
357 mTcatAgent.Disconnected();
358 }
359 }
360
361 mConnectCallback.InvokeIfSet(&GetInstance(), aConnected, true);
362
363 exit:
364 return;
365 }
366
HandleTlsReceive(void * aContext,uint8_t * aBuf,uint16_t aLength)367 void BleSecure::HandleTlsReceive(void *aContext, uint8_t *aBuf, uint16_t aLength)
368 {
369 return static_cast<BleSecure *>(aContext)->HandleTlsReceive(aBuf, aLength);
370 }
371
HandleTlsReceive(uint8_t * aBuf,uint16_t aLength)372 void BleSecure::HandleTlsReceive(uint8_t *aBuf, uint16_t aLength)
373 {
374 VerifyOrExit(mReceivedMessage != nullptr);
375
376 if (!mTlvMode)
377 {
378 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
379 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, 0, OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
380 IgnoreReturnValue(mReceivedMessage->SetLength(0));
381 }
382 else
383 {
384 ot::Tlv tlv;
385 uint32_t requiredBytes = sizeof(Tlv);
386 uint32_t offset;
387
388 while (aLength > 0)
389 {
390 if (mReceivedMessage->GetLength() < requiredBytes)
391 {
392 uint32_t missingBytes = requiredBytes - mReceivedMessage->GetLength();
393
394 if (missingBytes > aLength)
395 {
396 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
397 break;
398 }
399 else
400 {
401 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, (uint16_t)missingBytes));
402 aLength -= missingBytes;
403 aBuf += missingBytes;
404 }
405 }
406
407 IgnoreReturnValue(mReceivedMessage->Read(0, tlv));
408
409 if (tlv.IsExtended())
410 {
411 ot::ExtendedTlv extTlv;
412 requiredBytes = sizeof(extTlv);
413
414 if (mReceivedMessage->GetLength() < requiredBytes)
415 {
416 continue;
417 }
418
419 IgnoreReturnValue(mReceivedMessage->Read(0, extTlv));
420 requiredBytes = extTlv.GetSize();
421 offset = sizeof(extTlv);
422 }
423 else
424 {
425 requiredBytes = tlv.GetSize();
426 offset = sizeof(tlv);
427 }
428
429 if (mReceivedMessage->GetLength() < requiredBytes)
430 {
431 continue;
432 }
433
434 // TLV fully loaded
435
436 if (mTcatAgent.IsEnabled())
437 {
438 ot::Message *message;
439 Error error = kErrorNone;
440
441 message = Get<MessagePool>().Allocate(Message::kTypeBle);
442 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
443
444 error = mTcatAgent.HandleSingleTlv(*mReceivedMessage, *message);
445 if (message->GetLength() != 0)
446 {
447 IgnoreReturnValue(SendMessage(*message));
448 }
449
450 if (error == kErrorAbort)
451 {
452 Disconnect();
453 Stop();
454 ExitNow();
455 }
456 }
457 else
458 {
459 mReceivedMessage->SetOffset((uint16_t)offset);
460 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, (int32_t)offset,
461 OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
462 }
463
464 SuccessOrExit(mReceivedMessage->SetLength(0)); // also sets the offset to 0
465 requiredBytes = sizeof(Tlv);
466 }
467 }
468
469 exit:
470 return;
471 }
472
HandleTransmit(void)473 void BleSecure::HandleTransmit(void)
474 {
475 Error error = kErrorNone;
476 ot::Message *message = mTransmitQueue.GetHead();
477
478 VerifyOrExit(message != nullptr);
479 mTransmitQueue.Dequeue(*message);
480
481 if (mTransmitQueue.GetHead() != nullptr)
482 {
483 mTransmitTask.Post();
484 }
485
486 SuccessOrExit(error = mTls.Send(*message, message->GetLength()));
487
488 exit:
489 if (error != kErrorNone)
490 {
491 LogNote("Transmit: %s", ErrorToString(error));
492 message->Free();
493 }
494 else
495 {
496 LogDebg("Transmit: %s", ErrorToString(error));
497 }
498 }
499
HandleTransport(void * aContext,ot::Message & aMessage,const Ip6::MessageInfo & aMessageInfo)500 Error BleSecure::HandleTransport(void *aContext, ot::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
501 {
502 OT_UNUSED_VARIABLE(aMessageInfo);
503 return static_cast<BleSecure *>(aContext)->HandleTransport(aMessage);
504 }
505
HandleTransport(ot::Message & aMessage)506 Error BleSecure::HandleTransport(ot::Message &aMessage)
507 {
508 otBleRadioPacket packet;
509 uint16_t len = aMessage.GetLength();
510 uint16_t offset = 0;
511 Error error = kErrorNone;
512
513 while (len > 0)
514 {
515 if (len <= mMtuSize - kGattOverhead)
516 {
517 packet.mLength = len;
518 }
519 else
520 {
521 packet.mLength = mMtuSize - kGattOverhead;
522 }
523
524 if (packet.mLength > kPacketBufferSize)
525 {
526 packet.mLength = kPacketBufferSize;
527 }
528
529 IgnoreReturnValue(aMessage.Read(offset, mPacketBuffer, packet.mLength));
530 packet.mValue = mPacketBuffer;
531 packet.mPower = OT_BLE_DEFAULT_POWER;
532
533 SuccessOrExit(error = otPlatBleGattServerIndicate(&GetInstance(), kTxBleHandle, &packet));
534
535 len -= packet.mLength;
536 offset += packet.mLength;
537 }
538
539 aMessage.Free();
540 exit:
541 return error;
542 }
543
544 } // namespace Ble
545 } // namespace ot
546
otPlatBleGattServerOnWriteRequest(otInstance * aInstance,uint16_t aHandle,const otBleRadioPacket * aPacket)547 void otPlatBleGattServerOnWriteRequest(otInstance *aInstance, uint16_t aHandle, const otBleRadioPacket *aPacket)
548 {
549 OT_UNUSED_VARIABLE(aHandle); // Only a single handle is expected for RX
550
551 VerifyOrExit(aPacket != nullptr);
552 IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleReceive(aPacket->mValue, aPacket->mLength));
553 exit:
554 return;
555 }
556
otPlatBleGapOnConnected(otInstance * aInstance,uint16_t aConnectionId)557 void otPlatBleGapOnConnected(otInstance *aInstance, uint16_t aConnectionId)
558 {
559 AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleConnected(aConnectionId);
560 }
561
otPlatBleGapOnDisconnected(otInstance * aInstance,uint16_t aConnectionId)562 void otPlatBleGapOnDisconnected(otInstance *aInstance, uint16_t aConnectionId)
563 {
564 AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleDisconnected(aConnectionId);
565 }
566
otPlatBleGattOnMtuUpdate(otInstance * aInstance,uint16_t aMtu)567 void otPlatBleGattOnMtuUpdate(otInstance *aInstance, uint16_t aMtu)
568 {
569 IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleMtuUpdate(aMtu));
570 }
571
572 #endif // OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
573