1 /*
2  *  Copyright (c) 2023, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the TCAT Agent service.
32  */
33 
34 #include "tcat_agent.hpp"
35 
36 #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
37 
38 #include <stdio.h>
39 
40 #include "common/array.hpp"
41 #include "common/code_utils.hpp"
42 #include "common/debug.hpp"
43 #include "common/encoding.hpp"
44 #include "common/locator_getters.hpp"
45 #include "common/string.hpp"
46 #include "instance/instance.hpp"
47 #include "radio/radio.hpp"
48 #include "thread/thread_netif.hpp"
49 #include "thread/uri_paths.hpp"
50 #include "utils/otns.hpp"
51 
52 namespace ot {
53 namespace MeshCoP {
54 
55 RegisterLogModule("TcatAgent");
56 
IsValid(void) const57 bool TcatAgent::VendorInfo::IsValid(void) const
58 {
59     return mProvisioningUrl == nullptr || IsValidUtf8String(mProvisioningUrl) || mPskdString != nullptr;
60 }
61 
TcatAgent(Instance & aInstance)62 TcatAgent::TcatAgent(Instance &aInstance)
63     : InstanceLocator(aInstance)
64     , mVendorInfo(nullptr)
65     , mCurrentApplicationProtocol(kApplicationProtocolNone)
66     , mState(kStateDisabled)
67     , mAlreadyCommissioned(false)
68     , mCommissionerHasNetworkName(false)
69     , mCommissionerHasDomainName(false)
70     , mCommissionerHasExtendedPanId(false)
71 {
72     mJoinerPskd.Clear();
73     mCurrentServiceName[0] = 0;
74 }
75 
Start(const TcatAgent::VendorInfo & aVendorInfo,AppDataReceiveCallback aAppDataReceiveCallback,JoinCallback aHandler,void * aContext)76 Error TcatAgent::Start(const TcatAgent::VendorInfo &aVendorInfo,
77                        AppDataReceiveCallback       aAppDataReceiveCallback,
78                        JoinCallback                 aHandler,
79                        void                        *aContext)
80 {
81     Error error = kErrorNone;
82 
83     LogInfo("Starting");
84 
85     VerifyOrExit(aVendorInfo.IsValid(), error = kErrorInvalidArgs);
86     SuccessOrExit(error = mJoinerPskd.SetFrom(aVendorInfo.mPskdString));
87 
88     mAppDataReceiveCallback.Set(aAppDataReceiveCallback, aContext);
89     mJoinCallback.Set(aHandler, aContext);
90 
91     mVendorInfo                 = &aVendorInfo;
92     mCurrentApplicationProtocol = kApplicationProtocolNone;
93     mState                      = kStateEnabled;
94     mAlreadyCommissioned        = false;
95 
96 exit:
97     LogError("start TCAT agent", error);
98     return error;
99 }
100 
Stop(void)101 void TcatAgent::Stop(void)
102 {
103     mCurrentApplicationProtocol = kApplicationProtocolNone;
104     mState                      = kStateDisabled;
105     mAlreadyCommissioned        = false;
106     mAppDataReceiveCallback.Clear();
107     mJoinCallback.Clear();
108     LogInfo("TCAT agent stopped");
109 }
110 
Connected(MeshCoP::SecureTransport & aTlsContext)111 Error TcatAgent::Connected(MeshCoP::SecureTransport &aTlsContext)
112 {
113     size_t len;
114     Error  error;
115 
116     VerifyOrExit(IsEnabled(), error = kErrorInvalidState);
117     len = sizeof(mCommissionerAuthorizationField);
118     SuccessOrExit(
119         error = aTlsContext.GetThreadAttributeFromPeerCertificate(
120             kCertificateAuthorizationField, reinterpret_cast<uint8_t *>(&mCommissionerAuthorizationField), &len));
121     VerifyOrExit(len == sizeof(mCommissionerAuthorizationField), error = kErrorParse);
122     VerifyOrExit((mCommissionerAuthorizationField.mHeader & kCommissionerFlag) == 1, error = kErrorParse);
123 
124     len = sizeof(mDeviceAuthorizationField);
125     SuccessOrExit(error = aTlsContext.GetThreadAttributeFromOwnCertificate(
126                       kCertificateAuthorizationField, reinterpret_cast<uint8_t *>(&mDeviceAuthorizationField), &len));
127     VerifyOrExit(len == sizeof(mDeviceAuthorizationField), error = kErrorParse);
128     VerifyOrExit((mDeviceAuthorizationField.mHeader & kCommissionerFlag) == 0, error = kErrorParse);
129 
130     mCommissionerHasDomainName    = false;
131     mCommissionerHasNetworkName   = false;
132     mCommissionerHasExtendedPanId = false;
133 
134     len = sizeof(mCommissionerDomainName) - 1;
135     if (aTlsContext.GetThreadAttributeFromPeerCertificate(
136             kCertificateDomainName, reinterpret_cast<uint8_t *>(&mCommissionerDomainName), &len) == kErrorNone)
137     {
138         mCommissionerDomainName.m8[len] = '\0';
139         mCommissionerHasDomainName      = true;
140     }
141 
142     len = sizeof(mCommissionerNetworkName) - 1;
143     if (aTlsContext.GetThreadAttributeFromPeerCertificate(
144             kCertificateNetworkName, reinterpret_cast<uint8_t *>(&mCommissionerNetworkName), &len) == kErrorNone)
145     {
146         mCommissionerNetworkName.m8[len] = '\0';
147         mCommissionerHasNetworkName      = true;
148     }
149 
150     len = sizeof(mCommissionerExtendedPanId);
151     if (aTlsContext.GetThreadAttributeFromPeerCertificate(
152             kCertificateExtendedPanId, reinterpret_cast<uint8_t *>(&mCommissionerExtendedPanId), &len) == kErrorNone)
153     {
154         if (len == sizeof(mCommissionerExtendedPanId))
155         {
156             mCommissionerHasExtendedPanId = true;
157         }
158     }
159 
160     mCurrentApplicationProtocol = kApplicationProtocolNone;
161     mCurrentServiceName[0]      = 0;
162     mState                      = kStateConnected;
163     mAlreadyCommissioned        = Get<ActiveDatasetManager>().IsCommissioned();
164     LogInfo("TCAT agent connected");
165 
166 exit:
167     return error;
168 }
169 
Disconnected(void)170 void TcatAgent::Disconnected(void)
171 {
172     mCurrentApplicationProtocol = kApplicationProtocolNone;
173     mAlreadyCommissioned        = false;
174 
175     if (mState != kStateDisabled)
176     {
177         mState = kStateEnabled;
178     }
179 
180     LogInfo("TCAT agent disconnected");
181 }
182 
CheckCommandClassAuthorizationFlags(CommandClassFlags aCommissionerCommandClassFlags,CommandClassFlags aDeviceCommandClassFlags,Dataset * aDataset) const183 bool TcatAgent::CheckCommandClassAuthorizationFlags(CommandClassFlags aCommissionerCommandClassFlags,
184                                                     CommandClassFlags aDeviceCommandClassFlags,
185                                                     Dataset          *aDataset) const
186 {
187     bool authorized                     = false;
188     bool additionalDeviceRequirementMet = false;
189     bool domainNamesMatch               = false;
190     bool networkNamesMatch              = false;
191     bool extendedPanIdsMatch            = false;
192 
193     VerifyOrExit(IsConnected());
194     VerifyOrExit(aCommissionerCommandClassFlags & kAccessFlag);
195 
196     if (aDeviceCommandClassFlags & kAccessFlag)
197     {
198         additionalDeviceRequirementMet = true;
199     }
200 
201     if (aDeviceCommandClassFlags & kPskdFlag)
202     {
203         additionalDeviceRequirementMet = true;
204     }
205 
206     if (aDeviceCommandClassFlags & kPskcFlag)
207     {
208         additionalDeviceRequirementMet = true;
209     }
210 
211     if (mCommissionerHasNetworkName || mCommissionerHasExtendedPanId)
212     {
213         Dataset::Info datasetInfo;
214         Error         datasetError = kErrorNone;
215 
216         if (aDataset == nullptr)
217         {
218             datasetError = Get<ActiveDatasetManager>().Read(datasetInfo);
219         }
220         else
221         {
222             aDataset->ConvertTo(datasetInfo);
223         }
224 
225         if (datasetError == kErrorNone)
226         {
227             if (datasetInfo.IsNetworkNamePresent() && mCommissionerHasNetworkName &&
228                 (datasetInfo.GetNetworkName() == mCommissionerNetworkName))
229             {
230                 networkNamesMatch = true;
231             }
232 
233             if (datasetInfo.IsExtendedPanIdPresent() && mCommissionerHasExtendedPanId &&
234                 (datasetInfo.GetExtendedPanId() == mCommissionerExtendedPanId))
235             {
236                 extendedPanIdsMatch = true;
237             }
238         }
239     }
240 
241     if (!networkNamesMatch)
242     {
243         VerifyOrExit((aCommissionerCommandClassFlags & kNetworkNameFlag) == 0);
244     }
245     else if (aDeviceCommandClassFlags & kNetworkNameFlag)
246     {
247         additionalDeviceRequirementMet = true;
248     }
249 
250     if (!extendedPanIdsMatch)
251     {
252         VerifyOrExit((aCommissionerCommandClassFlags & kExtendedPanIdFlag) == 0);
253     }
254     else if (aDeviceCommandClassFlags & kExtendedPanIdFlag)
255     {
256         additionalDeviceRequirementMet = true;
257     }
258 
259 #if (OPENTHREAD_CONFIG_THREAD_VERSION >= OT_THREAD_VERSION_1_2)
260     VerifyOrExit((aCommissionerCommandClassFlags & kThreadDomainFlag) == 0);
261 #endif
262 
263     if (!domainNamesMatch)
264     {
265         VerifyOrExit((aCommissionerCommandClassFlags & kThreadDomainFlag) == 0);
266     }
267     else if (aDeviceCommandClassFlags & kThreadDomainFlag)
268     {
269         additionalDeviceRequirementMet = true;
270     }
271 
272     if (additionalDeviceRequirementMet)
273     {
274         authorized = true;
275     }
276 
277 exit:
278     return authorized;
279 }
280 
IsCommandClassAuthorized(CommandClass aCommandClass) const281 bool TcatAgent::IsCommandClassAuthorized(CommandClass aCommandClass) const
282 {
283     bool authorized = false;
284 
285     switch (aCommandClass)
286     {
287     case kGeneral:
288         authorized = true;
289         break;
290 
291     case kCommissioning:
292         authorized = CheckCommandClassAuthorizationFlags(mCommissionerAuthorizationField.mCommissioningFlags,
293                                                          mDeviceAuthorizationField.mCommissioningFlags, nullptr);
294         break;
295 
296     case kExtraction:
297         authorized = CheckCommandClassAuthorizationFlags(mCommissionerAuthorizationField.mExtractionFlags,
298                                                          mDeviceAuthorizationField.mExtractionFlags, nullptr);
299         break;
300 
301     case kTlvDecommissioning:
302         authorized = CheckCommandClassAuthorizationFlags(mCommissionerAuthorizationField.mDecommissioningFlags,
303                                                          mDeviceAuthorizationField.mDecommissioningFlags, nullptr);
304         break;
305 
306     case kApplication:
307         authorized = CheckCommandClassAuthorizationFlags(mCommissionerAuthorizationField.mApplicationFlags,
308                                                          mDeviceAuthorizationField.mApplicationFlags, nullptr);
309         break;
310 
311     case kInvalid:
312         authorized = false;
313         break;
314     }
315 
316     return authorized;
317 }
318 
GetCommandClass(uint8_t aTlvType) const319 TcatAgent::CommandClass TcatAgent::GetCommandClass(uint8_t aTlvType) const
320 {
321     static constexpr int kGeneralTlvs            = 0x1F;
322     static constexpr int kCommissioningTlvs      = 0x3F;
323     static constexpr int kExtractionTlvs         = 0x5F;
324     static constexpr int kTlvDecommissioningTlvs = 0x7F;
325     static constexpr int kApplicationTlvs        = 0x9F;
326 
327     if (aTlvType <= kGeneralTlvs)
328     {
329         return kGeneral;
330     }
331     else if (aTlvType <= kCommissioningTlvs)
332     {
333         return kCommissioning;
334     }
335     else if (aTlvType <= kExtractionTlvs)
336     {
337         return kExtraction;
338     }
339     else if (aTlvType <= kTlvDecommissioningTlvs)
340     {
341         return kTlvDecommissioning;
342     }
343     else if (aTlvType <= kApplicationTlvs)
344     {
345         return kApplication;
346     }
347     else
348     {
349         return kInvalid;
350     }
351 }
352 
CanProcessTlv(uint8_t aTlvType) const353 bool TcatAgent::CanProcessTlv(uint8_t aTlvType) const
354 {
355     CommandClass tlvCommandClass = GetCommandClass(aTlvType);
356     return IsCommandClassAuthorized(tlvCommandClass);
357 }
358 
HandleSingleTlv(const Message & aIncommingMessage,Message & aOutgoingMessage)359 Error TcatAgent::HandleSingleTlv(const Message &aIncommingMessage, Message &aOutgoingMessage)
360 {
361     Error    error = kErrorParse;
362     ot::Tlv  tlv;
363     uint16_t offset = aIncommingMessage.GetOffset();
364     uint16_t length;
365     bool     response = false;
366 
367     VerifyOrExit(IsConnected(), error = kErrorInvalidState);
368     SuccessOrExit(error = aIncommingMessage.Read(offset, tlv));
369 
370     if (tlv.IsExtended())
371     {
372         ot::ExtendedTlv extTlv;
373         SuccessOrExit(error = aIncommingMessage.Read(offset, extTlv));
374         length = extTlv.GetLength();
375         offset += sizeof(ot::ExtendedTlv);
376     }
377     else
378     {
379         length = tlv.GetLength();
380         offset += sizeof(ot::Tlv);
381     }
382 
383     if (!CanProcessTlv(tlv.GetType()))
384     {
385         error = kErrorRejected;
386     }
387     else
388     {
389         switch (tlv.GetType())
390         {
391         case kTlvDisconnect:
392             error = kErrorAbort;
393             break;
394 
395         case kTlvSetActiveOperationalDataset:
396             error = HandleSetActiveOperationalDataset(aIncommingMessage, offset, length);
397             break;
398 
399         case kTlvStartThreadInterface:
400             error = HandleStartThreadInterface();
401             break;
402 
403         case kTlvStopThreadInterface:
404             error = otThreadSetEnabled(&GetInstance(), false);
405             break;
406 
407         case kTlvSendApplicationData:
408             LogInfo("Application data len:%d, offset:%d", length, offset);
409             mAppDataReceiveCallback.InvokeIfSet(&GetInstance(), &aIncommingMessage, offset,
410                                                 MapEnum(mCurrentApplicationProtocol), mCurrentServiceName);
411             response = true;
412             error    = kErrorNone;
413             break;
414 
415         default:
416             error = kErrorInvalidCommand;
417         }
418     }
419 
420     if (!response)
421     {
422         StatusCode statusCode;
423 
424         switch (error)
425         {
426         case kErrorNone:
427             statusCode = kStatusSuccess;
428             break;
429 
430         case kErrorInvalidState:
431             statusCode = kStatusUndefined;
432             break;
433 
434         case kErrorParse:
435             statusCode = kStatusParseError;
436             break;
437 
438         case kErrorInvalidCommand:
439             statusCode = kStatusUnsupported;
440             break;
441 
442         case kErrorRejected:
443             statusCode = kStatusUnauthorized;
444             break;
445 
446         case kErrorNotImplemented:
447             statusCode = kStatusUnsupported;
448             break;
449 
450         default:
451             statusCode = kStatusGeneralError;
452             break;
453         }
454 
455         SuccessOrExit(error = ot::Tlv::Append<ResponseWithStatusTlv>(aOutgoingMessage, statusCode));
456     }
457 
458 exit:
459     return error;
460 }
461 
HandleSetActiveOperationalDataset(const Message & aIncommingMessage,uint16_t aOffset,uint16_t aLength)462 Error TcatAgent::HandleSetActiveOperationalDataset(const Message &aIncommingMessage, uint16_t aOffset, uint16_t aLength)
463 {
464     Dataset                  dataset;
465     otOperationalDatasetTlvs datasetTlvs;
466     Error                    error;
467 
468     SuccessOrExit(error = dataset.ReadFromMessage(aIncommingMessage, aOffset, aLength));
469 
470     if (!CheckCommandClassAuthorizationFlags(mCommissionerAuthorizationField.mApplicationFlags,
471                                              mDeviceAuthorizationField.mApplicationFlags, &dataset))
472     {
473         error = kErrorRejected;
474         ExitNow();
475     }
476 
477     dataset.ConvertTo(datasetTlvs);
478     error = Get<ActiveDatasetManager>().Save(datasetTlvs);
479 
480 exit:
481     return error;
482 }
483 
HandleStartThreadInterface(void)484 Error TcatAgent::HandleStartThreadInterface(void)
485 {
486     Error         error;
487     Dataset::Info datasetInfo;
488 
489     VerifyOrExit(Get<ActiveDatasetManager>().Read(datasetInfo) == kErrorNone, error = kErrorInvalidState);
490     VerifyOrExit(datasetInfo.IsNetworkKeyPresent(), error = kErrorInvalidState);
491 
492 #if OPENTHREAD_CONFIG_LINK_RAW_ENABLE
493     VerifyOrExit(!Get<Mac::LinkRaw>().IsEnabled(), error = kErrorInvalidState);
494 #endif
495 
496     Get<ThreadNetif>().Up();
497     error = Get<Mle::MleRouter>().Start();
498 
499 exit:
500     return error;
501 }
502 
503 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_WARN)
LogError(const char * aActionText,Error aError)504 void TcatAgent::LogError(const char *aActionText, Error aError)
505 {
506     if (aError != kErrorNone)
507     {
508         LogWarn("Failed to %s: %s", aActionText, ErrorToString(aError));
509     }
510 }
511 #endif
512 
513 } // namespace MeshCoP
514 } // namespace ot
515 
516 #endif // OPENTHREAD_CONFIG_BLE_TCAT_ENABLE
517