1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements common methods for manipulating MeshCoP Datasets.
32  *
33  */
34 
35 #include "dataset.hpp"
36 
37 #include <stdio.h>
38 
39 #include "common/code_utils.hpp"
40 #include "common/encoding.hpp"
41 #include "common/instance.hpp"
42 #include "common/locator_getters.hpp"
43 #include "common/logging.hpp"
44 #include "mac/mac_types.hpp"
45 #include "meshcop/meshcop_tlvs.hpp"
46 #include "thread/mle_tlvs.hpp"
47 
48 namespace ot {
49 namespace MeshCoP {
50 
GenerateRandom(Instance & aInstance)51 Error Dataset::Info::GenerateRandom(Instance &aInstance)
52 {
53     Error            error;
54     Mac::ChannelMask supportedChannels = aInstance.Get<Mac::Mac>().GetSupportedChannelMask();
55     Mac::ChannelMask preferredChannels(aInstance.Get<Radio>().GetPreferredChannelMask());
56 
57     // If the preferred channel mask is not empty, select a random
58     // channel from it, otherwise choose one from the supported
59     // channel mask.
60 
61     preferredChannels.Intersect(supportedChannels);
62 
63     if (preferredChannels.IsEmpty())
64     {
65         preferredChannels = supportedChannels;
66     }
67 
68     Clear();
69 
70     mActiveTimestamp = 1;
71     mChannel         = preferredChannels.ChooseRandomChannel();
72     mChannelMask     = supportedChannels.GetMask();
73     mPanId           = Mac::GenerateRandomPanId();
74     static_cast<SecurityPolicy &>(mSecurityPolicy).SetToDefault();
75 
76     SuccessOrExit(error = static_cast<NetworkKey &>(mNetworkKey).GenerateRandom());
77     SuccessOrExit(error = static_cast<Pskc &>(mPskc).GenerateRandom());
78     SuccessOrExit(error = Random::Crypto::FillBuffer(mExtendedPanId.m8, sizeof(mExtendedPanId.m8)));
79     SuccessOrExit(error = static_cast<Ip6::NetworkPrefix &>(mMeshLocalPrefix).GenerateRandomUla());
80 
81     snprintf(mNetworkName.m8, sizeof(mNetworkName), "OpenThread-%04x", mPanId);
82 
83     mComponents.mIsActiveTimestampPresent = true;
84     mComponents.mIsNetworkKeyPresent      = true;
85     mComponents.mIsNetworkNamePresent     = true;
86     mComponents.mIsExtendedPanIdPresent   = true;
87     mComponents.mIsMeshLocalPrefixPresent = true;
88     mComponents.mIsPanIdPresent           = true;
89     mComponents.mIsChannelPresent         = true;
90     mComponents.mIsPskcPresent            = true;
91     mComponents.mIsSecurityPolicyPresent  = true;
92     mComponents.mIsChannelMaskPresent     = true;
93 
94 exit:
95     return error;
96 }
97 
IsSubsetOf(const Info & aOther) const98 bool Dataset::Info::IsSubsetOf(const Info &aOther) const
99 {
100     bool isSubset = false;
101 
102     if (IsNetworkKeyPresent())
103     {
104         VerifyOrExit(aOther.IsNetworkKeyPresent() && GetNetworkKey() == aOther.GetNetworkKey());
105     }
106 
107     if (IsNetworkNamePresent())
108     {
109         VerifyOrExit(aOther.IsNetworkNamePresent() && GetNetworkName() == aOther.GetNetworkName());
110     }
111 
112     if (IsExtendedPanIdPresent())
113     {
114         VerifyOrExit(aOther.IsExtendedPanIdPresent() && GetExtendedPanId() == aOther.GetExtendedPanId());
115     }
116 
117     if (IsMeshLocalPrefixPresent())
118     {
119         VerifyOrExit(aOther.IsMeshLocalPrefixPresent() && GetMeshLocalPrefix() == aOther.GetMeshLocalPrefix());
120     }
121 
122     if (IsPanIdPresent())
123     {
124         VerifyOrExit(aOther.IsPanIdPresent() && GetPanId() == aOther.GetPanId());
125     }
126 
127     if (IsChannelPresent())
128     {
129         VerifyOrExit(aOther.IsChannelPresent() && GetChannel() == aOther.GetChannel());
130     }
131 
132     if (IsPskcPresent())
133     {
134         VerifyOrExit(aOther.IsPskcPresent() && GetPskc() == aOther.GetPskc());
135     }
136 
137     if (IsSecurityPolicyPresent())
138     {
139         VerifyOrExit(aOther.IsSecurityPolicyPresent() && GetSecurityPolicy() == aOther.GetSecurityPolicy());
140     }
141 
142     if (IsChannelMaskPresent())
143     {
144         VerifyOrExit(aOther.IsChannelMaskPresent() && GetChannelMask() == aOther.GetChannelMask());
145     }
146 
147     isSubset = true;
148 
149 exit:
150     return isSubset;
151 }
152 
Dataset(void)153 Dataset::Dataset(void)
154     : mUpdateTime(0)
155     , mLength(0)
156 {
157     memset(mTlvs, 0, sizeof(mTlvs));
158 }
159 
Clear(void)160 void Dataset::Clear(void)
161 {
162     mLength = 0;
163 }
164 
IsValid(void) const165 bool Dataset::IsValid(void) const
166 {
167     bool       rval = true;
168     const Tlv *end  = GetTlvsEnd();
169 
170     for (const Tlv *cur = GetTlvsStart(); cur < end; cur = cur->GetNext())
171     {
172         VerifyOrExit(!cur->IsExtended() && (cur + 1) <= end && cur->GetNext() <= end && Tlv::IsValid(*cur),
173                      rval = false);
174     }
175 
176 exit:
177     return rval;
178 }
179 
GetTlv(Tlv::Type aType) const180 const Tlv *Dataset::GetTlv(Tlv::Type aType) const
181 {
182     return Tlv::FindTlv(mTlvs, mLength, aType);
183 }
184 
ConvertTo(Info & aDatasetInfo) const185 void Dataset::ConvertTo(Info &aDatasetInfo) const
186 {
187     aDatasetInfo.Clear();
188 
189     for (const Tlv *cur = GetTlvsStart(); cur < GetTlvsEnd(); cur = cur->GetNext())
190     {
191         switch (cur->GetType())
192         {
193         case Tlv::kActiveTimestamp:
194             aDatasetInfo.SetActiveTimestamp(static_cast<const ActiveTimestampTlv *>(cur)->GetSeconds());
195             break;
196 
197         case Tlv::kChannel:
198             aDatasetInfo.SetChannel(static_cast<const ChannelTlv *>(cur)->GetChannel());
199             break;
200 
201         case Tlv::kChannelMask:
202         {
203             uint32_t mask = static_cast<const ChannelMaskTlv *>(cur)->GetChannelMask();
204 
205             if (mask != 0)
206             {
207                 aDatasetInfo.SetChannelMask(mask);
208             }
209 
210             break;
211         }
212 
213         case Tlv::kDelayTimer:
214             aDatasetInfo.SetDelay(static_cast<const DelayTimerTlv *>(cur)->GetDelayTimer());
215             break;
216 
217         case Tlv::kExtendedPanId:
218             aDatasetInfo.SetExtendedPanId(static_cast<const ExtendedPanIdTlv *>(cur)->GetExtendedPanId());
219             break;
220 
221         case Tlv::kMeshLocalPrefix:
222             aDatasetInfo.SetMeshLocalPrefix(static_cast<const MeshLocalPrefixTlv *>(cur)->GetMeshLocalPrefix());
223             break;
224 
225         case Tlv::kNetworkKey:
226             aDatasetInfo.SetNetworkKey(static_cast<const NetworkKeyTlv *>(cur)->GetNetworkKey());
227             break;
228 
229         case Tlv::kNetworkName:
230             aDatasetInfo.SetNetworkName(static_cast<const NetworkNameTlv *>(cur)->GetNetworkName());
231             break;
232 
233         case Tlv::kPanId:
234             aDatasetInfo.SetPanId(static_cast<const PanIdTlv *>(cur)->GetPanId());
235             break;
236 
237         case Tlv::kPendingTimestamp:
238             aDatasetInfo.SetPendingTimestamp(static_cast<const PendingTimestampTlv *>(cur)->GetSeconds());
239             break;
240 
241         case Tlv::kPskc:
242             aDatasetInfo.SetPskc(static_cast<const PskcTlv *>(cur)->GetPskc());
243             break;
244 
245         case Tlv::kSecurityPolicy:
246         {
247             const SecurityPolicyTlv *tlv = static_cast<const SecurityPolicyTlv *>(cur);
248 
249             aDatasetInfo.SetSecurityPolicy(tlv->GetSecurityPolicy());
250             break;
251         }
252 
253         default:
254             break;
255         }
256     }
257 }
258 
ConvertTo(otOperationalDatasetTlvs & aDataset) const259 void Dataset::ConvertTo(otOperationalDatasetTlvs &aDataset) const
260 {
261     memcpy(aDataset.mTlvs, mTlvs, mLength);
262     aDataset.mLength = static_cast<uint8_t>(mLength);
263 }
264 
Set(Type aType,const Dataset & aDataset)265 void Dataset::Set(Type aType, const Dataset &aDataset)
266 {
267     memcpy(mTlvs, aDataset.mTlvs, aDataset.mLength);
268     mLength = aDataset.mLength;
269 
270     if (aType == kActive)
271     {
272         RemoveTlv(Tlv::kPendingTimestamp);
273         RemoveTlv(Tlv::kDelayTimer);
274     }
275 
276     mUpdateTime = aDataset.GetUpdateTime();
277 }
278 
SetFrom(const otOperationalDatasetTlvs & aDataset)279 void Dataset::SetFrom(const otOperationalDatasetTlvs &aDataset)
280 {
281     mLength = aDataset.mLength;
282     memcpy(mTlvs, aDataset.mTlvs, mLength);
283 }
284 
SetFrom(const Info & aDatasetInfo)285 Error Dataset::SetFrom(const Info &aDatasetInfo)
286 {
287     Error error = kErrorNone;
288 
289     if (aDatasetInfo.IsActiveTimestampPresent())
290     {
291         ActiveTimestampTlv tlv;
292         tlv.Init();
293         tlv.SetSeconds(aDatasetInfo.GetActiveTimestamp());
294         tlv.SetTicks(0);
295         IgnoreError(SetTlv(tlv));
296     }
297 
298     if (aDatasetInfo.IsPendingTimestampPresent())
299     {
300         PendingTimestampTlv tlv;
301         tlv.Init();
302         tlv.SetSeconds(aDatasetInfo.GetPendingTimestamp());
303         tlv.SetTicks(0);
304         IgnoreError(SetTlv(tlv));
305     }
306 
307     if (aDatasetInfo.IsDelayPresent())
308     {
309         IgnoreError(SetTlv(Tlv::kDelayTimer, aDatasetInfo.GetDelay()));
310     }
311 
312     if (aDatasetInfo.IsChannelPresent())
313     {
314         ChannelTlv tlv;
315         tlv.Init();
316         tlv.SetChannel(aDatasetInfo.GetChannel());
317         IgnoreError(SetTlv(tlv));
318     }
319 
320     if (aDatasetInfo.IsChannelMaskPresent())
321     {
322         ChannelMaskTlv tlv;
323         tlv.Init();
324         tlv.SetChannelMask(aDatasetInfo.GetChannelMask());
325         IgnoreError(SetTlv(tlv));
326     }
327 
328     if (aDatasetInfo.IsExtendedPanIdPresent())
329     {
330         IgnoreError(SetTlv(Tlv::kExtendedPanId, aDatasetInfo.GetExtendedPanId()));
331     }
332 
333     if (aDatasetInfo.IsMeshLocalPrefixPresent())
334     {
335         IgnoreError(SetTlv(Tlv::kMeshLocalPrefix, aDatasetInfo.GetMeshLocalPrefix()));
336     }
337 
338     if (aDatasetInfo.IsNetworkKeyPresent())
339     {
340         IgnoreError(SetTlv(Tlv::kNetworkKey, aDatasetInfo.GetNetworkKey()));
341     }
342 
343     if (aDatasetInfo.IsNetworkNamePresent())
344     {
345         Mac::NameData nameData = aDatasetInfo.GetNetworkName().GetAsData();
346 
347         IgnoreError(SetTlv(Tlv::kNetworkName, nameData.GetBuffer(), nameData.GetLength()));
348     }
349 
350     if (aDatasetInfo.IsPanIdPresent())
351     {
352         IgnoreError(SetTlv(Tlv::kPanId, aDatasetInfo.GetPanId()));
353     }
354 
355     if (aDatasetInfo.IsPskcPresent())
356     {
357         IgnoreError(SetTlv(Tlv::kPskc, aDatasetInfo.GetPskc()));
358     }
359 
360     if (aDatasetInfo.IsSecurityPolicyPresent())
361     {
362         SecurityPolicyTlv tlv;
363 
364         tlv.Init();
365         tlv.SetSecurityPolicy(aDatasetInfo.GetSecurityPolicy());
366         IgnoreError(SetTlv(tlv));
367     }
368 
369     mUpdateTime = TimerMilli::GetNow();
370 
371     return error;
372 }
373 
GetTimestamp(Type aType) const374 const Timestamp *Dataset::GetTimestamp(Type aType) const
375 {
376     const Timestamp *timestamp = nullptr;
377 
378     if (aType == kActive)
379     {
380         const ActiveTimestampTlv *tlv = GetTlv<ActiveTimestampTlv>();
381         VerifyOrExit(tlv != nullptr);
382         timestamp = static_cast<const Timestamp *>(tlv);
383     }
384     else
385     {
386         const PendingTimestampTlv *tlv = GetTlv<PendingTimestampTlv>();
387         VerifyOrExit(tlv != nullptr);
388         timestamp = static_cast<const Timestamp *>(tlv);
389     }
390 
391 exit:
392     return timestamp;
393 }
394 
SetTimestamp(Type aType,const Timestamp & aTimestamp)395 void Dataset::SetTimestamp(Type aType, const Timestamp &aTimestamp)
396 {
397     IgnoreError(SetTlv((aType == kActive) ? Tlv::kActiveTimestamp : Tlv::kPendingTimestamp, aTimestamp));
398 }
399 
SetTlv(Tlv::Type aType,const void * aValue,uint8_t aLength)400 Error Dataset::SetTlv(Tlv::Type aType, const void *aValue, uint8_t aLength)
401 {
402     Error    error          = kErrorNone;
403     uint16_t bytesAvailable = sizeof(mTlvs) - mLength;
404     Tlv *    old            = GetTlv(aType);
405     Tlv      tlv;
406 
407     if (old != nullptr)
408     {
409         bytesAvailable += sizeof(Tlv) + old->GetLength();
410     }
411 
412     VerifyOrExit(sizeof(Tlv) + aLength <= bytesAvailable, error = kErrorNoBufs);
413 
414     if (old != nullptr)
415     {
416         RemoveTlv(old);
417     }
418 
419     tlv.SetType(aType);
420     tlv.SetLength(aLength);
421     memcpy(mTlvs + mLength, &tlv, sizeof(Tlv));
422     mLength += sizeof(Tlv);
423 
424     memcpy(mTlvs + mLength, aValue, aLength);
425     mLength += aLength;
426 
427     mUpdateTime = TimerMilli::GetNow();
428 
429 exit:
430     return error;
431 }
432 
SetTlv(const Tlv & aTlv)433 Error Dataset::SetTlv(const Tlv &aTlv)
434 {
435     return SetTlv(aTlv.GetType(), aTlv.GetValue(), aTlv.GetLength());
436 }
437 
Set(const Message & aMessage,uint16_t aOffset,uint8_t aLength)438 Error Dataset::Set(const Message &aMessage, uint16_t aOffset, uint8_t aLength)
439 {
440     Error error = kErrorInvalidArgs;
441 
442     SuccessOrExit(aMessage.Read(aOffset, mTlvs, aLength));
443     mLength = aLength;
444 
445     mUpdateTime = TimerMilli::GetNow();
446     error       = kErrorNone;
447 
448 exit:
449     return error;
450 }
451 
RemoveTlv(Tlv::Type aType)452 void Dataset::RemoveTlv(Tlv::Type aType)
453 {
454     Tlv *tlv;
455 
456     VerifyOrExit((tlv = GetTlv(aType)) != nullptr);
457     RemoveTlv(tlv);
458 
459 exit:
460     return;
461 }
462 
AppendMleDatasetTlv(Type aType,Message & aMessage) const463 Error Dataset::AppendMleDatasetTlv(Type aType, Message &aMessage) const
464 {
465     Error          error = kErrorNone;
466     Mle::Tlv       tlv;
467     Mle::Tlv::Type type;
468 
469     VerifyOrExit(mLength > 0);
470 
471     type = (aType == kActive ? Mle::Tlv::kActiveDataset : Mle::Tlv::kPendingDataset);
472 
473     tlv.SetType(type);
474     tlv.SetLength(static_cast<uint8_t>(mLength) - sizeof(Tlv) - sizeof(Timestamp));
475     SuccessOrExit(error = aMessage.Append(tlv));
476 
477     for (const Tlv *cur = GetTlvsStart(); cur < GetTlvsEnd(); cur = cur->GetNext())
478     {
479         if (((aType == kActive) && (cur->GetType() == Tlv::kActiveTimestamp)) ||
480             ((aType == kPending) && (cur->GetType() == Tlv::kPendingTimestamp)))
481         {
482             ; // skip Active or Pending Timestamp TLV
483         }
484         else if (cur->GetType() == Tlv::kDelayTimer)
485         {
486             uint32_t      elapsed = TimerMilli::GetNow() - mUpdateTime;
487             DelayTimerTlv delayTimer(static_cast<const DelayTimerTlv &>(*cur));
488 
489             if (delayTimer.GetDelayTimer() > elapsed)
490             {
491                 delayTimer.SetDelayTimer(delayTimer.GetDelayTimer() - elapsed);
492             }
493             else
494             {
495                 delayTimer.SetDelayTimer(0);
496             }
497 
498             SuccessOrExit(error = delayTimer.AppendTo(aMessage));
499         }
500         else
501         {
502             SuccessOrExit(error = cur->AppendTo(aMessage));
503         }
504     }
505 
506 exit:
507     return error;
508 }
509 
RemoveTlv(Tlv * aTlv)510 void Dataset::RemoveTlv(Tlv *aTlv)
511 {
512     uint8_t *start  = reinterpret_cast<uint8_t *>(aTlv);
513     uint16_t length = sizeof(Tlv) + aTlv->GetLength();
514 
515     memmove(start, start + length, mLength - (static_cast<uint8_t>(start - mTlvs) + length));
516     mLength -= length;
517 }
518 
ApplyConfiguration(Instance & aInstance,bool * aIsNetworkKeyUpdated) const519 Error Dataset::ApplyConfiguration(Instance &aInstance, bool *aIsNetworkKeyUpdated) const
520 {
521     Mac::Mac &  mac        = aInstance.Get<Mac::Mac>();
522     KeyManager &keyManager = aInstance.Get<KeyManager>();
523     Error       error      = kErrorNone;
524 
525     VerifyOrExit(IsValid(), error = kErrorParse);
526 
527     if (aIsNetworkKeyUpdated)
528     {
529         *aIsNetworkKeyUpdated = false;
530     }
531 
532     for (const Tlv *cur = GetTlvsStart(); cur < GetTlvsEnd(); cur = cur->GetNext())
533     {
534         switch (cur->GetType())
535         {
536         case Tlv::kChannel:
537         {
538             uint8_t channel = static_cast<uint8_t>(static_cast<const ChannelTlv *>(cur)->GetChannel());
539 
540             error = mac.SetPanChannel(channel);
541 
542             if (error != kErrorNone)
543             {
544                 otLogWarnMeshCoP("DatasetManager::ApplyConfiguration() Failed to set channel to %d (%s)", channel,
545                                  ErrorToString(error));
546                 ExitNow();
547             }
548 
549             break;
550         }
551 
552         case Tlv::kPanId:
553             mac.SetPanId(static_cast<const PanIdTlv *>(cur)->GetPanId());
554             break;
555 
556         case Tlv::kExtendedPanId:
557             mac.SetExtendedPanId(static_cast<const ExtendedPanIdTlv *>(cur)->GetExtendedPanId());
558             break;
559 
560         case Tlv::kNetworkName:
561             IgnoreError(mac.SetNetworkName(static_cast<const NetworkNameTlv *>(cur)->GetNetworkName()));
562             break;
563 
564         case Tlv::kNetworkKey:
565         {
566             const NetworkKeyTlv *key = static_cast<const NetworkKeyTlv *>(cur);
567 
568             if (aIsNetworkKeyUpdated && (key->GetNetworkKey() != keyManager.GetNetworkKey()))
569             {
570                 *aIsNetworkKeyUpdated = true;
571             }
572 
573             IgnoreError(keyManager.SetNetworkKey(key->GetNetworkKey()));
574             break;
575         }
576 
577 #if OPENTHREAD_FTD
578 
579         case Tlv::kPskc:
580             keyManager.SetPskc(static_cast<const PskcTlv *>(cur)->GetPskc());
581             break;
582 
583 #endif
584 
585         case Tlv::kMeshLocalPrefix:
586             aInstance.Get<Mle::MleRouter>().SetMeshLocalPrefix(
587                 static_cast<const MeshLocalPrefixTlv *>(cur)->GetMeshLocalPrefix());
588             break;
589 
590         case Tlv::kSecurityPolicy:
591         {
592             const SecurityPolicyTlv *securityPolicy = static_cast<const SecurityPolicyTlv *>(cur);
593             keyManager.SetSecurityPolicy(securityPolicy->GetSecurityPolicy());
594             break;
595         }
596 
597         default:
598             break;
599         }
600     }
601 
602 exit:
603     return error;
604 }
605 
ConvertToActive(void)606 void Dataset::ConvertToActive(void)
607 {
608     RemoveTlv(Tlv::kPendingTimestamp);
609     RemoveTlv(Tlv::kDelayTimer);
610 }
611 
TypeToString(Type aType)612 const char *Dataset::TypeToString(Type aType)
613 {
614     return (aType == kActive) ? "Active" : "Pending";
615 }
616 
617 } // namespace MeshCoP
618 } // namespace ot
619