1 /*
2 * Copyright (c) 2023, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "ble_secure.hpp"
30
31 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
32
33 #include <openthread/platform/ble.h>
34 #include "common/locator_getters.hpp"
35 #include "common/log.hpp"
36 #include "common/tlvs.hpp"
37 #include "instance/instance.hpp"
38 #include "meshcop/secure_transport.hpp"
39
40 using namespace ot;
41
42 /**
43 * @file
44 * This file implements the secure Ble agent.
45 */
46
47 namespace ot {
48 namespace Ble {
49
50 RegisterLogModule("BleSecure");
51
BleSecure(Instance & aInstance)52 BleSecure::BleSecure(Instance &aInstance)
53 : InstanceLocator(aInstance)
54 , mTls(aInstance, false, false)
55 , mTcatAgent(aInstance)
56 , mTlvMode(false)
57 , mReceivedMessage(nullptr)
58 , mSendMessage(nullptr)
59 , mTransmitTask(aInstance)
60 , mBleState(kStopped)
61 , mMtuSize(kInitialMtuSize)
62 {
63 }
64
Start(ConnectCallback aConnectHandler,ReceiveCallback aReceiveHandler,bool aTlvMode,void * aContext)65 Error BleSecure::Start(ConnectCallback aConnectHandler, ReceiveCallback aReceiveHandler, bool aTlvMode, void *aContext)
66 {
67 Error error = kErrorNone;
68
69 VerifyOrExit(mBleState == kStopped, error = kErrorAlready);
70
71 mConnectCallback.Set(aConnectHandler, aContext);
72 mReceiveCallback.Set(aReceiveHandler, aContext);
73 mTlvMode = aTlvMode;
74 mMtuSize = kInitialMtuSize;
75
76 SuccessOrExit(error = otPlatBleEnable(&GetInstance()));
77 SuccessOrExit(error = otPlatBleGapAdvStart(&GetInstance(), OT_BLE_ADV_INTERVAL_DEFAULT));
78 SuccessOrExit(error = mTls.Open(&BleSecure::HandleTlsReceive, &BleSecure::HandleTlsConnected, this));
79 SuccessOrExit(error = mTls.Bind(HandleTransport, this));
80
81 exit:
82 if (error == kErrorNone)
83 {
84 mBleState = kAdvertising;
85 }
86 return error;
87 }
88
TcatStart(const MeshCoP::TcatAgent::VendorInfo & aVendorInfo,MeshCoP::TcatAgent::JoinCallback aJoinHandler)89 Error BleSecure::TcatStart(const MeshCoP::TcatAgent::VendorInfo &aVendorInfo,
90 MeshCoP::TcatAgent::JoinCallback aJoinHandler)
91 {
92 return mTcatAgent.Start(aVendorInfo, mReceiveCallback.GetHandler(), aJoinHandler, mReceiveCallback.GetContext());
93 }
94
Stop(void)95 void BleSecure::Stop(void)
96 {
97 VerifyOrExit(mBleState != kStopped);
98 SuccessOrExit(otPlatBleGapAdvStop(&GetInstance()));
99 SuccessOrExit(otPlatBleDisable(&GetInstance()));
100 mBleState = kStopped;
101 mMtuSize = kInitialMtuSize;
102
103 if (mTcatAgent.IsEnabled())
104 {
105 mTcatAgent.Stop();
106 }
107
108 mTls.Close();
109
110 mTransmitQueue.DequeueAndFreeAll();
111
112 mConnectCallback.Clear();
113 mReceiveCallback.Clear();
114
115 FreeMessage(mReceivedMessage);
116 mReceivedMessage = nullptr;
117 FreeMessage(mSendMessage);
118 mSendMessage = nullptr;
119
120 exit:
121 return;
122 }
123
Connect(void)124 Error BleSecure::Connect(void)
125 {
126 Ip6::SockAddr sockaddr;
127
128 return mTls.Connect(sockaddr);
129 }
130
Disconnect(void)131 void BleSecure::Disconnect(void)
132 {
133 if (mTls.IsConnected())
134 {
135 mTls.Disconnect();
136 }
137
138 if (mBleState == kConnected)
139 {
140 IgnoreReturnValue(otPlatBleGapDisconnect(&GetInstance()));
141 }
142 }
143
SetPsk(const MeshCoP::JoinerPskd & aPskd)144 void BleSecure::SetPsk(const MeshCoP::JoinerPskd &aPskd)
145 {
146 static_assert(static_cast<uint16_t>(MeshCoP::JoinerPskd::kMaxLength) <=
147 static_cast<uint16_t>(MeshCoP::SecureTransport::kPskMaxLength),
148 "The maximum length of TLS PSK is smaller than joiner PSKd");
149
150 SuccessOrAssert(mTls.SetPsk(reinterpret_cast<const uint8_t *>(aPskd.GetAsCString()), aPskd.GetLength()));
151 }
152
153 #if defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
GetPeerCertificateBase64(unsigned char * aPeerCert,size_t * aCertLength)154 Error BleSecure::GetPeerCertificateBase64(unsigned char *aPeerCert, size_t *aCertLength)
155 {
156 Error error;
157
158 VerifyOrExit(aCertLength != nullptr, error = kErrorInvalidArgs);
159
160 error = mTls.GetPeerCertificateBase64(aPeerCert, aCertLength, *aCertLength);
161
162 exit:
163 return error;
164 }
165 #endif // defined(MBEDTLS_BASE64_C) && defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
166
SendMessage(ot::Message & aMessage)167 Error BleSecure::SendMessage(ot::Message &aMessage)
168 {
169 Error error = kErrorNone;
170
171 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
172 if (mSendMessage == nullptr)
173 {
174 mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
175 VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
176 }
177 SuccessOrExit(error = mSendMessage->AppendBytesFromMessage(aMessage, 0, aMessage.GetLength()));
178 SuccessOrExit(error = Flush());
179
180 exit:
181 aMessage.Free();
182 return error;
183 }
184
Send(uint8_t * aBuf,uint16_t aLength)185 Error BleSecure::Send(uint8_t *aBuf, uint16_t aLength)
186 {
187 Error error = kErrorNone;
188
189 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
190 if (mSendMessage == nullptr)
191 {
192 mSendMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
193 VerifyOrExit(mSendMessage != nullptr, error = kErrorNoBufs);
194 }
195 SuccessOrExit(error = mSendMessage->AppendBytes(aBuf, aLength));
196
197 exit:
198 return error;
199 }
200
SendApplicationTlv(uint8_t * aBuf,uint16_t aLength)201 Error BleSecure::SendApplicationTlv(uint8_t *aBuf, uint16_t aLength)
202 {
203 Error error = kErrorNone;
204 if (aLength > Tlv::kBaseTlvMaxLength)
205 {
206 ot::ExtendedTlv tlv;
207
208 tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
209 tlv.SetLength(aLength);
210 SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
211 }
212 else
213 {
214 ot::Tlv tlv;
215
216 tlv.SetType(ot::MeshCoP::TcatAgent::kTlvSendApplicationData);
217 tlv.SetLength((uint8_t)aLength);
218 SuccessOrExit(error = Send(reinterpret_cast<uint8_t *>(&tlv), sizeof(tlv)));
219 }
220
221 error = Send(aBuf, aLength);
222 exit:
223 return error;
224 }
225
Flush(void)226 Error BleSecure::Flush(void)
227 {
228 Error error = kErrorNone;
229
230 VerifyOrExit(IsConnected(), error = kErrorInvalidState);
231 VerifyOrExit(mSendMessage->GetLength() != 0, error = kErrorNone);
232
233 mTransmitQueue.Enqueue(*mSendMessage);
234 mTransmitTask.Post();
235
236 mSendMessage = nullptr;
237
238 exit:
239 return error;
240 }
241
HandleBleReceive(uint8_t * aBuf,uint16_t aLength)242 Error BleSecure::HandleBleReceive(uint8_t *aBuf, uint16_t aLength)
243 {
244 ot::Message *message = nullptr;
245 Ip6::MessageInfo messageInfo;
246 Error error = kErrorNone;
247
248 if ((message = Get<MessagePool>().Allocate(Message::kTypeBle, 0)) == nullptr)
249 {
250 error = kErrorNoBufs;
251 ExitNow();
252 }
253 SuccessOrExit(error = message->AppendBytes(aBuf, aLength));
254
255 // Cannot call Receive(..) directly because Setup(..) and mState are private
256 mTls.HandleReceive(*message, messageInfo);
257
258 exit:
259 FreeMessage(message);
260 return error;
261 }
262
HandleBleConnected(uint16_t aConnectionId)263 void BleSecure::HandleBleConnected(uint16_t aConnectionId)
264 {
265 OT_UNUSED_VARIABLE(aConnectionId);
266
267 mBleState = kConnected;
268
269 IgnoreReturnValue(otPlatBleGattMtuGet(&GetInstance(), &mMtuSize));
270
271 mConnectCallback.InvokeIfSet(&GetInstance(), IsConnected(), true);
272 }
273
HandleBleDisconnected(uint16_t aConnectionId)274 void BleSecure::HandleBleDisconnected(uint16_t aConnectionId)
275 {
276 OT_UNUSED_VARIABLE(aConnectionId);
277
278 mBleState = kAdvertising;
279 mMtuSize = kInitialMtuSize;
280
281 if (IsConnected())
282 {
283 Disconnect(); // Stop TLS connection
284 }
285
286 mConnectCallback.InvokeIfSet(&GetInstance(), false, false);
287 }
288
HandleBleMtuUpdate(uint16_t aMtu)289 Error BleSecure::HandleBleMtuUpdate(uint16_t aMtu)
290 {
291 Error error = kErrorNone;
292
293 if (aMtu <= OT_BLE_ATT_MTU_MAX)
294 {
295 mMtuSize = aMtu;
296 }
297 else
298 {
299 mMtuSize = OT_BLE_ATT_MTU_MAX;
300 error = kErrorInvalidArgs;
301 }
302
303 return error;
304 }
305
HandleTlsConnected(void * aContext,bool aConnected)306 void BleSecure::HandleTlsConnected(void *aContext, bool aConnected)
307 {
308 return static_cast<BleSecure *>(aContext)->HandleTlsConnected(aConnected);
309 }
310
HandleTlsConnected(bool aConnected)311 void BleSecure::HandleTlsConnected(bool aConnected)
312 {
313 if (aConnected)
314 {
315 if (mReceivedMessage == nullptr)
316 {
317 mReceivedMessage = Get<MessagePool>().Allocate(Message::kTypeBle);
318 }
319
320 if (mTcatAgent.IsEnabled())
321 {
322 IgnoreReturnValue(mTcatAgent.Connected(mTls));
323 }
324 }
325 else
326 {
327 FreeMessage(mReceivedMessage);
328 mReceivedMessage = nullptr;
329
330 if (mTcatAgent.IsEnabled())
331 {
332 mTcatAgent.Disconnected();
333 }
334 }
335
336 mConnectCallback.InvokeIfSet(&GetInstance(), aConnected, true);
337 }
338
HandleTlsReceive(void * aContext,uint8_t * aBuf,uint16_t aLength)339 void BleSecure::HandleTlsReceive(void *aContext, uint8_t *aBuf, uint16_t aLength)
340 {
341 return static_cast<BleSecure *>(aContext)->HandleTlsReceive(aBuf, aLength);
342 }
343
HandleTlsReceive(uint8_t * aBuf,uint16_t aLength)344 void BleSecure::HandleTlsReceive(uint8_t *aBuf, uint16_t aLength)
345 {
346 VerifyOrExit(mReceivedMessage != nullptr);
347
348 if (!mTlvMode)
349 {
350 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
351 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, 0, OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
352 IgnoreReturnValue(mReceivedMessage->SetLength(0));
353 }
354 else
355 {
356 ot::Tlv tlv;
357 uint32_t requiredBytes = sizeof(Tlv);
358 uint32_t offset;
359
360 while (aLength > 0)
361 {
362 if (mReceivedMessage->GetLength() < requiredBytes)
363 {
364 uint32_t missingBytes = requiredBytes - mReceivedMessage->GetLength();
365
366 if (missingBytes > aLength)
367 {
368 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, aLength));
369 break;
370 }
371 else
372 {
373 SuccessOrExit(mReceivedMessage->AppendBytes(aBuf, (uint16_t)missingBytes));
374 aLength -= missingBytes;
375 aBuf += missingBytes;
376 }
377 }
378
379 IgnoreReturnValue(mReceivedMessage->Read(0, tlv));
380
381 if (tlv.IsExtended())
382 {
383 ot::ExtendedTlv extTlv;
384 requiredBytes = sizeof(extTlv);
385
386 if (mReceivedMessage->GetLength() < requiredBytes)
387 {
388 continue;
389 }
390
391 IgnoreReturnValue(mReceivedMessage->Read(0, extTlv));
392 requiredBytes = extTlv.GetSize();
393 offset = sizeof(extTlv);
394 }
395 else
396 {
397 requiredBytes = tlv.GetSize();
398 offset = sizeof(tlv);
399 }
400
401 if (mReceivedMessage->GetLength() < requiredBytes)
402 {
403 continue;
404 }
405
406 // TLV fully loaded
407
408 if (mTcatAgent.IsEnabled())
409 {
410 ot::Message *message;
411 Error error = kErrorNone;
412
413 message = Get<MessagePool>().Allocate(Message::kTypeBle);
414 VerifyOrExit(message != nullptr, error = kErrorNoBufs);
415
416 error = mTcatAgent.HandleSingleTlv(*mReceivedMessage, *message);
417 if (message->GetLength() != 0)
418 {
419 IgnoreReturnValue(SendMessage(*message));
420 }
421
422 if (error == kErrorAbort)
423 {
424 Disconnect();
425 Stop();
426 ExitNow();
427 }
428 }
429 else
430 {
431 mReceivedMessage->SetOffset((uint16_t)offset);
432 mReceiveCallback.InvokeIfSet(&GetInstance(), mReceivedMessage, (int32_t)offset,
433 OT_TCAT_APPLICATION_PROTOCOL_NONE, "");
434 }
435
436 SuccessOrExit(mReceivedMessage->SetLength(0)); // also sets the offset to 0
437 requiredBytes = sizeof(Tlv);
438 }
439 }
440
441 exit:
442 return;
443 }
444
HandleTransmit(void)445 void BleSecure::HandleTransmit(void)
446 {
447 Error error = kErrorNone;
448 ot::Message *message = mTransmitQueue.GetHead();
449
450 VerifyOrExit(message != nullptr);
451 mTransmitQueue.Dequeue(*message);
452
453 if (mTransmitQueue.GetHead() != nullptr)
454 {
455 mTransmitTask.Post();
456 }
457
458 SuccessOrExit(error = mTls.Send(*message, message->GetLength()));
459
460 exit:
461 if (error != kErrorNone)
462 {
463 LogNote("Transmit: %s", ErrorToString(error));
464 message->Free();
465 }
466 else
467 {
468 LogDebg("Transmit: %s", ErrorToString(error));
469 }
470 }
471
HandleTransport(void * aContext,ot::Message & aMessage,const Ip6::MessageInfo & aMessageInfo)472 Error BleSecure::HandleTransport(void *aContext, ot::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
473 {
474 OT_UNUSED_VARIABLE(aMessageInfo);
475 return static_cast<BleSecure *>(aContext)->HandleTransport(aMessage);
476 }
477
HandleTransport(ot::Message & aMessage)478 Error BleSecure::HandleTransport(ot::Message &aMessage)
479 {
480 otBleRadioPacket packet;
481 uint16_t len = aMessage.GetLength();
482 uint16_t offset = 0;
483 Error error = kErrorNone;
484
485 while (len > 0)
486 {
487 if (len <= mMtuSize - kGattOverhead)
488 {
489 packet.mLength = len;
490 }
491 else
492 {
493 packet.mLength = mMtuSize - kGattOverhead;
494 }
495
496 if (packet.mLength > kPacketBufferSize)
497 {
498 packet.mLength = kPacketBufferSize;
499 }
500
501 IgnoreReturnValue(aMessage.Read(offset, mPacketBuffer, packet.mLength));
502 packet.mValue = mPacketBuffer;
503 packet.mPower = OT_BLE_DEFAULT_POWER;
504
505 SuccessOrExit(error = otPlatBleGattServerIndicate(&GetInstance(), kTxBleHandle, &packet));
506
507 len -= packet.mLength;
508 offset += packet.mLength;
509 }
510
511 aMessage.Free();
512 exit:
513 return error;
514 }
515
516 } // namespace Ble
517 } // namespace ot
518
otPlatBleGattServerOnWriteRequest(otInstance * aInstance,uint16_t aHandle,const otBleRadioPacket * aPacket)519 void otPlatBleGattServerOnWriteRequest(otInstance *aInstance, uint16_t aHandle, const otBleRadioPacket *aPacket)
520 {
521 OT_UNUSED_VARIABLE(aHandle); // Only a single handle is expected for RX
522
523 VerifyOrExit(aPacket != nullptr);
524 IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleReceive(aPacket->mValue, aPacket->mLength));
525 exit:
526 return;
527 }
528
otPlatBleGapOnConnected(otInstance * aInstance,uint16_t aConnectionId)529 void otPlatBleGapOnConnected(otInstance *aInstance, uint16_t aConnectionId)
530 {
531 AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleConnected(aConnectionId);
532 }
533
otPlatBleGapOnDisconnected(otInstance * aInstance,uint16_t aConnectionId)534 void otPlatBleGapOnDisconnected(otInstance *aInstance, uint16_t aConnectionId)
535 {
536 AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleDisconnected(aConnectionId);
537 }
538
otPlatBleGattOnMtuUpdate(otInstance * aInstance,uint16_t aMtu)539 void otPlatBleGattOnMtuUpdate(otInstance *aInstance, uint16_t aMtu)
540 {
541 IgnoreReturnValue(AsCoreType(aInstance).Get<Ble::BleSecure>().HandleBleMtuUpdate(aMtu));
542 }
543
544 #endif // OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
545