1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements common methods for manipulating MeshCoP Datasets.
32  *
33  */
34 
35 #include "dataset.hpp"
36 
37 #include <stdio.h>
38 
39 #include "common/code_utils.hpp"
40 #include "common/encoding.hpp"
41 #include "common/locator_getters.hpp"
42 #include "common/log.hpp"
43 #include "instance/instance.hpp"
44 #include "mac/mac_types.hpp"
45 #include "meshcop/meshcop_tlvs.hpp"
46 #include "meshcop/timestamp.hpp"
47 #include "thread/mle_tlvs.hpp"
48 
49 namespace ot {
50 namespace MeshCoP {
51 
52 RegisterLogModule("Dataset");
53 
GenerateRandom(Instance & aInstance)54 Error Dataset::Info::GenerateRandom(Instance &aInstance)
55 {
56     Error            error;
57     Mac::ChannelMask supportedChannels = aInstance.Get<Mac::Mac>().GetSupportedChannelMask();
58     Mac::ChannelMask preferredChannels(aInstance.Get<Radio>().GetPreferredChannelMask());
59     StringWriter     nameWriter(mNetworkName.m8, sizeof(mNetworkName));
60 
61     // If the preferred channel mask is not empty, select a random
62     // channel from it, otherwise choose one from the supported
63     // channel mask.
64 
65     preferredChannels.Intersect(supportedChannels);
66 
67     if (preferredChannels.IsEmpty())
68     {
69         preferredChannels = supportedChannels;
70     }
71 
72     Clear();
73 
74     mActiveTimestamp.mSeconds       = 1;
75     mActiveTimestamp.mTicks         = 0;
76     mActiveTimestamp.mAuthoritative = false;
77     mChannel                        = preferredChannels.ChooseRandomChannel();
78     mChannelMask                    = supportedChannels.GetMask();
79     mPanId                          = Mac::GenerateRandomPanId();
80     AsCoreType(&mSecurityPolicy).SetToDefault();
81 
82     SuccessOrExit(error = AsCoreType(&mNetworkKey).GenerateRandom());
83     SuccessOrExit(error = AsCoreType(&mPskc).GenerateRandom());
84     SuccessOrExit(error = Random::Crypto::Fill(mExtendedPanId));
85     SuccessOrExit(error = AsCoreType(&mMeshLocalPrefix).GenerateRandomUla());
86 
87     nameWriter.Append("%s-%04x", NetworkName::kNetworkNameInit, mPanId);
88 
89     mComponents.mIsActiveTimestampPresent = true;
90     mComponents.mIsNetworkKeyPresent      = true;
91     mComponents.mIsNetworkNamePresent     = true;
92     mComponents.mIsExtendedPanIdPresent   = true;
93     mComponents.mIsMeshLocalPrefixPresent = true;
94     mComponents.mIsPanIdPresent           = true;
95     mComponents.mIsChannelPresent         = true;
96     mComponents.mIsPskcPresent            = true;
97     mComponents.mIsSecurityPolicyPresent  = true;
98     mComponents.mIsChannelMaskPresent     = true;
99 
100 exit:
101     return error;
102 }
103 
Dataset(void)104 Dataset::Dataset(void)
105     : mLength(0)
106     , mUpdateTime(0)
107 {
108     ClearAllBytes(mTlvs);
109 }
110 
ValidateTlvs(void) const111 Error Dataset::ValidateTlvs(void) const
112 {
113     Error      error = kErrorParse;
114     const Tlv *end   = GetTlvsEnd();
115     uint16_t   validatedLength;
116 
117     VerifyOrExit(mLength <= kMaxLength);
118 
119     for (const Tlv *tlv = GetTlvsStart(); tlv < end; tlv = tlv->GetNext())
120     {
121         VerifyOrExit(!tlv->IsExtended() && ((tlv + 1) <= end) && (tlv->GetNext() <= end));
122         VerifyOrExit(IsTlvValid(*tlv));
123 
124         // Ensure there are no duplicate TLVs.
125         validatedLength = static_cast<uint16_t>(reinterpret_cast<const uint8_t *>(tlv) - mTlvs);
126         VerifyOrExit(Tlv::FindTlv(mTlvs, validatedLength, tlv->GetType()) == nullptr);
127     }
128 
129     error = kErrorNone;
130 
131 exit:
132     return error;
133 }
134 
IsTlvValid(const Tlv & aTlv)135 bool Dataset::IsTlvValid(const Tlv &aTlv)
136 {
137     bool    isValid   = true;
138     uint8_t minLength = 0;
139 
140     switch (aTlv.GetType())
141     {
142     case Tlv::kPanId:
143         minLength = sizeof(PanIdTlv::UintValueType);
144         break;
145     case Tlv::kExtendedPanId:
146         minLength = sizeof(ExtendedPanIdTlv::ValueType);
147         break;
148     case Tlv::kPskc:
149         minLength = sizeof(PskcTlv::ValueType);
150         break;
151     case Tlv::kNetworkKey:
152         minLength = sizeof(NetworkKeyTlv::ValueType);
153         break;
154     case Tlv::kMeshLocalPrefix:
155         minLength = sizeof(MeshLocalPrefixTlv::ValueType);
156         break;
157     case Tlv::kChannel:
158         VerifyOrExit(aTlv.GetLength() >= sizeof(ChannelTlvValue), isValid = false);
159         isValid = aTlv.ReadValueAs<ChannelTlv>().IsValid();
160         break;
161     case Tlv::kNetworkName:
162         isValid = As<NetworkNameTlv>(aTlv).IsValid();
163         break;
164 
165     case Tlv::kSecurityPolicy:
166         isValid = As<SecurityPolicyTlv>(aTlv).IsValid();
167         break;
168 
169     case Tlv::kChannelMask:
170         isValid = As<ChannelMaskTlv>(aTlv).IsValid();
171         break;
172 
173     default:
174         break;
175     }
176 
177     if (minLength > 0)
178     {
179         isValid = (aTlv.GetLength() >= minLength);
180     }
181 
182 exit:
183     return isValid;
184 }
185 
ContainsAllTlvs(const Tlv::Type aTlvTypes[],uint8_t aLength) const186 bool Dataset::ContainsAllTlvs(const Tlv::Type aTlvTypes[], uint8_t aLength) const
187 {
188     bool containsAll = true;
189 
190     for (uint8_t index = 0; index < aLength; index++)
191     {
192         if (!ContainsTlv(aTlvTypes[index]))
193         {
194             containsAll = false;
195             break;
196         }
197     }
198 
199     return containsAll;
200 }
201 
ContainsAllRequiredTlvsFor(Type aType) const202 bool Dataset::ContainsAllRequiredTlvsFor(Type aType) const
203 {
204     static const Tlv::Type kDatasetTlvs[] = {
205         Tlv::kActiveTimestamp,
206         Tlv::kChannel,
207         Tlv::kChannelMask,
208         Tlv::kExtendedPanId,
209         Tlv::kMeshLocalPrefix,
210         Tlv::kNetworkKey,
211         Tlv::kNetworkName,
212         Tlv::kPanId,
213         Tlv::kPskc,
214         Tlv::kSecurityPolicy,
215         // The last two TLVs are for Pending Dataset
216         Tlv::kPendingTimestamp,
217         Tlv::kDelayTimer,
218     };
219 
220     uint8_t length = sizeof(kDatasetTlvs);
221 
222     if (aType == kActive)
223     {
224         length -= 2;
225     }
226 
227     return ContainsAllTlvs(kDatasetTlvs, length);
228 }
229 
FindTlv(Tlv::Type aType) const230 const Tlv *Dataset::FindTlv(Tlv::Type aType) const { return As<Tlv>(Tlv::FindTlv(mTlvs, mLength, aType)); }
231 
ConvertTo(Info & aDatasetInfo) const232 void Dataset::ConvertTo(Info &aDatasetInfo) const
233 {
234     aDatasetInfo.Clear();
235 
236     for (const Tlv *cur = GetTlvsStart(); cur < GetTlvsEnd(); cur = cur->GetNext())
237     {
238         switch (cur->GetType())
239         {
240         case Tlv::kActiveTimestamp:
241             aDatasetInfo.Set<kActiveTimestamp>(cur->ReadValueAs<ActiveTimestampTlv>());
242             break;
243 
244         case Tlv::kChannel:
245             aDatasetInfo.Set<kChannel>(cur->ReadValueAs<ChannelTlv>().GetChannel());
246             break;
247 
248         case Tlv::kChannelMask:
249         {
250             uint32_t mask;
251 
252             if (As<ChannelMaskTlv>(cur)->ReadChannelMask(mask) == kErrorNone)
253             {
254                 aDatasetInfo.Set<kChannelMask>(mask);
255             }
256 
257             break;
258         }
259 
260         case Tlv::kDelayTimer:
261             aDatasetInfo.Set<kDelay>(cur->ReadValueAs<DelayTimerTlv>());
262             break;
263 
264         case Tlv::kExtendedPanId:
265             aDatasetInfo.Set<kExtendedPanId>(cur->ReadValueAs<ExtendedPanIdTlv>());
266             break;
267 
268         case Tlv::kMeshLocalPrefix:
269             aDatasetInfo.Set<kMeshLocalPrefix>(cur->ReadValueAs<MeshLocalPrefixTlv>());
270             break;
271 
272         case Tlv::kNetworkKey:
273             aDatasetInfo.Set<kNetworkKey>(cur->ReadValueAs<NetworkKeyTlv>());
274             break;
275 
276         case Tlv::kNetworkName:
277             IgnoreError(aDatasetInfo.Update<kNetworkName>().Set(As<NetworkNameTlv>(cur)->GetNetworkName()));
278             break;
279 
280         case Tlv::kPanId:
281             aDatasetInfo.Set<kPanId>(cur->ReadValueAs<PanIdTlv>());
282             break;
283 
284         case Tlv::kPendingTimestamp:
285             aDatasetInfo.Set<kPendingTimestamp>(cur->ReadValueAs<PendingTimestampTlv>());
286             break;
287 
288         case Tlv::kPskc:
289             aDatasetInfo.Set<kPskc>(cur->ReadValueAs<PskcTlv>());
290             break;
291 
292         case Tlv::kSecurityPolicy:
293             aDatasetInfo.Set<kSecurityPolicy>(As<SecurityPolicyTlv>(cur)->GetSecurityPolicy());
294             break;
295 
296         default:
297             break;
298         }
299     }
300 }
301 
ConvertTo(Tlvs & aTlvs) const302 void Dataset::ConvertTo(Tlvs &aTlvs) const
303 {
304     memcpy(aTlvs.mTlvs, mTlvs, mLength);
305     aTlvs.mLength = static_cast<uint8_t>(mLength);
306 }
307 
SetFrom(const Dataset & aDataset)308 void Dataset::SetFrom(const Dataset &aDataset)
309 {
310     memcpy(mTlvs, aDataset.mTlvs, aDataset.mLength);
311     mLength     = aDataset.mLength;
312     mUpdateTime = aDataset.GetUpdateTime();
313 }
314 
SetFrom(const Tlvs & aTlvs)315 Error Dataset::SetFrom(const Tlvs &aTlvs) { return SetFrom(aTlvs.mTlvs, aTlvs.mLength); }
316 
SetFrom(const uint8_t * aTlvs,uint8_t aLength)317 Error Dataset::SetFrom(const uint8_t *aTlvs, uint8_t aLength)
318 {
319     Error error = kErrorNone;
320 
321     VerifyOrExit(aLength <= kMaxLength, error = kErrorInvalidArgs);
322 
323     mLength = aLength;
324     memcpy(mTlvs, aTlvs, mLength);
325 
326     mUpdateTime = TimerMilli::GetNow();
327 
328 exit:
329     return error;
330 }
331 
SetFrom(const Info & aDatasetInfo)332 void Dataset::SetFrom(const Info &aDatasetInfo)
333 {
334     Clear();
335     IgnoreError(WriteTlvsFrom(aDatasetInfo));
336 
337     // `mUpdateTime` is already set by `WriteTlvsFrom()`.
338 }
339 
SetFrom(const Message & aMessage,const OffsetRange & aOffsetRange)340 Error Dataset::SetFrom(const Message &aMessage, const OffsetRange &aOffsetRange)
341 {
342     Error error = kErrorNone;
343 
344     VerifyOrExit(aOffsetRange.GetLength() <= kMaxLength, error = kErrorInvalidArgs);
345 
346     SuccessOrExit(error = aMessage.Read(aOffsetRange, mTlvs, aOffsetRange.GetLength()));
347     mLength = static_cast<uint8_t>(aOffsetRange.GetLength());
348 
349     mUpdateTime = TimerMilli::GetNow();
350 
351 exit:
352     return error;
353 }
354 
WriteTlv(Tlv::Type aType,const void * aValue,uint8_t aLength)355 Error Dataset::WriteTlv(Tlv::Type aType, const void *aValue, uint8_t aLength)
356 {
357     Error    error          = kErrorNone;
358     uint16_t bytesAvailable = sizeof(mTlvs) - mLength;
359     Tlv     *oldTlv         = FindTlv(aType);
360     Tlv     *newTlv;
361 
362     if (oldTlv != nullptr)
363     {
364         bytesAvailable += sizeof(Tlv) + oldTlv->GetLength();
365     }
366 
367     VerifyOrExit(sizeof(Tlv) + aLength <= bytesAvailable, error = kErrorNoBufs);
368 
369     RemoveTlv(oldTlv);
370 
371     newTlv = GetTlvsEnd();
372     mLength += sizeof(Tlv) + aLength;
373 
374     newTlv->SetType(aType);
375     newTlv->SetLength(aLength);
376     memcpy(newTlv->GetValue(), aValue, aLength);
377 
378     mUpdateTime = TimerMilli::GetNow();
379 
380 exit:
381     return error;
382 }
383 
WriteTlv(const Tlv & aTlv)384 Error Dataset::WriteTlv(const Tlv &aTlv) { return WriteTlv(aTlv.GetType(), aTlv.GetValue(), aTlv.GetLength()); }
385 
WriteTlvsFrom(const Dataset & aDataset)386 Error Dataset::WriteTlvsFrom(const Dataset &aDataset)
387 {
388     Error error;
389 
390     SuccessOrExit(error = aDataset.ValidateTlvs());
391 
392     for (const Tlv *tlv = aDataset.GetTlvsStart(); tlv < aDataset.GetTlvsEnd(); tlv = tlv->GetNext())
393     {
394         SuccessOrExit(error = WriteTlv(*tlv));
395     }
396 
397 exit:
398     return error;
399 }
400 
WriteTlvsFrom(const uint8_t * aTlvs,uint8_t aLength)401 Error Dataset::WriteTlvsFrom(const uint8_t *aTlvs, uint8_t aLength)
402 {
403     Error   error;
404     Dataset dataset;
405 
406     SuccessOrExit(error = dataset.SetFrom(aTlvs, aLength));
407     error = WriteTlvsFrom(dataset);
408 
409 exit:
410     return error;
411 }
412 
WriteTlvsFrom(const Dataset::Info & aDatasetInfo)413 Error Dataset::WriteTlvsFrom(const Dataset::Info &aDatasetInfo)
414 {
415     Error error = kErrorNone;
416 
417     if (aDatasetInfo.IsPresent<kActiveTimestamp>())
418     {
419         Timestamp activeTimestamp;
420 
421         aDatasetInfo.Get<kActiveTimestamp>(activeTimestamp);
422         SuccessOrExit(error = Write<ActiveTimestampTlv>(activeTimestamp));
423     }
424 
425     if (aDatasetInfo.IsPresent<kPendingTimestamp>())
426     {
427         Timestamp pendingTimestamp;
428 
429         aDatasetInfo.Get<kPendingTimestamp>(pendingTimestamp);
430         SuccessOrExit(error = Write<PendingTimestampTlv>(pendingTimestamp));
431     }
432 
433     if (aDatasetInfo.IsPresent<kDelay>())
434     {
435         SuccessOrExit(error = Write<DelayTimerTlv>(aDatasetInfo.Get<kDelay>()));
436     }
437 
438     if (aDatasetInfo.IsPresent<kChannel>())
439     {
440         ChannelTlvValue channelValue;
441 
442         channelValue.SetChannelAndPage(aDatasetInfo.Get<kChannel>());
443         SuccessOrExit(error = Write<ChannelTlv>(channelValue));
444     }
445 
446     if (aDatasetInfo.IsPresent<kChannelMask>())
447     {
448         ChannelMaskTlv::Value value;
449 
450         ChannelMaskTlv::PrepareValue(value, aDatasetInfo.Get<kChannelMask>());
451         SuccessOrExit(error = WriteTlv(Tlv::kChannelMask, value.mData, value.mLength));
452     }
453 
454     if (aDatasetInfo.IsPresent<kExtendedPanId>())
455     {
456         SuccessOrExit(error = Write<ExtendedPanIdTlv>(aDatasetInfo.Get<kExtendedPanId>()));
457     }
458 
459     if (aDatasetInfo.IsPresent<kMeshLocalPrefix>())
460     {
461         SuccessOrExit(error = Write<MeshLocalPrefixTlv>(aDatasetInfo.Get<kMeshLocalPrefix>()));
462     }
463 
464     if (aDatasetInfo.IsPresent<kNetworkKey>())
465     {
466         SuccessOrExit(error = Write<NetworkKeyTlv>(aDatasetInfo.Get<kNetworkKey>()));
467     }
468 
469     if (aDatasetInfo.IsPresent<kNetworkName>())
470     {
471         NameData nameData = aDatasetInfo.Get<kNetworkName>().GetAsData();
472 
473         SuccessOrExit(error = WriteTlv(Tlv::kNetworkName, nameData.GetBuffer(), nameData.GetLength()));
474     }
475 
476     if (aDatasetInfo.IsPresent<kPanId>())
477     {
478         SuccessOrExit(error = Write<PanIdTlv>(aDatasetInfo.Get<kPanId>()));
479     }
480 
481     if (aDatasetInfo.IsPresent<kPskc>())
482     {
483         SuccessOrExit(error = Write<PskcTlv>(aDatasetInfo.Get<kPskc>()));
484     }
485 
486     if (aDatasetInfo.IsPresent<kSecurityPolicy>())
487     {
488         SecurityPolicyTlv tlv;
489 
490         tlv.Init();
491         tlv.SetSecurityPolicy(aDatasetInfo.Get<kSecurityPolicy>());
492         SuccessOrExit(error = WriteTlv(tlv));
493     }
494 
495 exit:
496     return error;
497 }
498 
AppendTlvsFrom(const uint8_t * aTlvs,uint8_t aLength)499 Error Dataset::AppendTlvsFrom(const uint8_t *aTlvs, uint8_t aLength)
500 {
501     Error    error     = kErrorNone;
502     uint16_t newLength = mLength;
503 
504     newLength += aLength;
505     VerifyOrExit(newLength <= kMaxLength, error = kErrorNoBufs);
506 
507     memcpy(mTlvs + mLength, aTlvs, aLength);
508     mLength += aLength;
509 
510 exit:
511     return error;
512 }
513 
RemoveTlv(Tlv::Type aType)514 void Dataset::RemoveTlv(Tlv::Type aType) { RemoveTlv(FindTlv(aType)); }
515 
RemoveTlv(Tlv * aTlv)516 void Dataset::RemoveTlv(Tlv *aTlv)
517 {
518     if (aTlv != nullptr)
519     {
520         uint8_t *start  = reinterpret_cast<uint8_t *>(aTlv);
521         uint16_t length = sizeof(Tlv) + aTlv->GetLength();
522 
523         memmove(start, start + length, mLength - (static_cast<uint8_t>(start - mTlvs) + length));
524         mLength -= length;
525     }
526 }
527 
ReadTimestamp(Type aType,Timestamp & aTimestamp) const528 Error Dataset::ReadTimestamp(Type aType, Timestamp &aTimestamp) const
529 {
530     Error      error = kErrorNone;
531     const Tlv *tlv   = FindTlv(TimestampTlvFor(aType));
532 
533     VerifyOrExit(tlv != nullptr, error = kErrorNotFound);
534 
535     // Since both `ActiveTimestampTlv` and `PendingTimestampTlv` use
536     // `Timestamp` as their TLV value format, we can safely use
537     // `ReadValueAs<ActiveTimestampTlv>()` for both.
538 
539     aTimestamp = tlv->ReadValueAs<ActiveTimestampTlv>();
540 
541 exit:
542     return error;
543 }
544 
WriteTimestamp(Type aType,const Timestamp & aTimestamp)545 Error Dataset::WriteTimestamp(Type aType, const Timestamp &aTimestamp)
546 {
547     return WriteTlv(TimestampTlvFor(aType), &aTimestamp, sizeof(Timestamp));
548 }
549 
RemoveTimestamp(Type aType)550 void Dataset::RemoveTimestamp(Type aType) { RemoveTlv(TimestampTlvFor(aType)); }
551 
IsSubsetOf(const Dataset & aOther) const552 bool Dataset::IsSubsetOf(const Dataset &aOther) const
553 {
554     bool isSubset = false;
555 
556     for (const Tlv *tlv = GetTlvsStart(); tlv < GetTlvsEnd(); tlv = tlv->GetNext())
557     {
558         const Tlv *otherTlv;
559 
560         if ((tlv->GetType() == Tlv::kActiveTimestamp) || (tlv->GetType() == Tlv::kPendingTimestamp) ||
561             (tlv->GetType() == Tlv::kDelayTimer))
562         {
563             continue;
564         }
565 
566         otherTlv = aOther.FindTlv(tlv->GetType());
567         VerifyOrExit(otherTlv != nullptr);
568         VerifyOrExit(memcmp(tlv, otherTlv, tlv->GetSize()) == 0);
569     }
570 
571     isSubset = true;
572 
573 exit:
574     return isSubset;
575 }
576 
TypeToString(Type aType)577 const char *Dataset::TypeToString(Type aType) { return (aType == kActive) ? "Active" : "Pending"; }
578 
579 } // namespace MeshCoP
580 } // namespace ot
581