1 /*
2  *  Copyright (c) 2016-2017, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements MechCop TLV helper functions.
32  */
33 
34 #include "meshcop_tlvs.hpp"
35 
36 #include "common/const_cast.hpp"
37 #include "common/debug.hpp"
38 #include "common/num_utils.hpp"
39 #include "common/numeric_limits.hpp"
40 #include "common/string.hpp"
41 #include "meshcop/meshcop.hpp"
42 
43 namespace ot {
44 namespace MeshCoP {
45 
IsValid(const Tlv & aTlv)46 bool Tlv::IsValid(const Tlv &aTlv)
47 {
48     bool    isValid   = true;
49     uint8_t minLength = 0;
50 
51     switch (aTlv.GetType())
52     {
53     case Tlv::kPanId:
54         minLength = sizeof(PanIdTlv::UintValueType);
55         break;
56     case Tlv::kExtendedPanId:
57         minLength = sizeof(ExtendedPanIdTlv::ValueType);
58         break;
59     case Tlv::kPskc:
60         minLength = sizeof(PskcTlv::ValueType);
61         break;
62     case Tlv::kNetworkKey:
63         minLength = sizeof(NetworkKeyTlv::ValueType);
64         break;
65     case Tlv::kMeshLocalPrefix:
66         minLength = sizeof(MeshLocalPrefixTlv::ValueType);
67         break;
68     case Tlv::kChannel:
69         VerifyOrExit(aTlv.GetLength() >= sizeof(ChannelTlvValue), isValid = false);
70         isValid = aTlv.ReadValueAs<ChannelTlv>().IsValid();
71         break;
72     case Tlv::kNetworkName:
73         isValid = As<NetworkNameTlv>(aTlv).IsValid();
74         break;
75 
76     case Tlv::kSecurityPolicy:
77         isValid = As<SecurityPolicyTlv>(aTlv).IsValid();
78         break;
79 
80     case Tlv::kChannelMask:
81         isValid = As<ChannelMaskTlv>(aTlv).IsValid();
82         break;
83 
84     default:
85         break;
86     }
87 
88     if (minLength > 0)
89     {
90         isValid = (aTlv.GetLength() >= minLength);
91     }
92 
93 exit:
94     return isValid;
95 }
96 
GetNetworkName(void) const97 NameData NetworkNameTlv::GetNetworkName(void) const
98 {
99     uint8_t len = GetLength();
100 
101     if (len > sizeof(mNetworkName))
102     {
103         len = sizeof(mNetworkName);
104     }
105 
106     return NameData(mNetworkName, len);
107 }
108 
SetNetworkName(const NameData & aNameData)109 void NetworkNameTlv::SetNetworkName(const NameData &aNameData)
110 {
111     uint8_t len;
112 
113     len = aNameData.CopyTo(mNetworkName, sizeof(mNetworkName));
114     SetLength(len);
115 }
116 
IsValid(void) const117 bool NetworkNameTlv::IsValid(void) const { return IsValidUtf8String(mNetworkName, GetLength()); }
118 
CopyTo(SteeringData & aSteeringData) const119 void SteeringDataTlv::CopyTo(SteeringData &aSteeringData) const
120 {
121     aSteeringData.Init(GetSteeringDataLength());
122     memcpy(aSteeringData.GetData(), mSteeringData, GetSteeringDataLength());
123 }
124 
IsValid(void) const125 bool SecurityPolicyTlv::IsValid(void) const
126 {
127     return GetLength() >= sizeof(mRotationTime) && GetFlagsLength() >= kThread11FlagsLength;
128 }
129 
GetSecurityPolicy(void) const130 SecurityPolicy SecurityPolicyTlv::GetSecurityPolicy(void) const
131 {
132     SecurityPolicy securityPolicy;
133     uint8_t        length = Min(static_cast<uint8_t>(sizeof(mFlags)), GetFlagsLength());
134 
135     securityPolicy.mRotationTime = GetRotationTime();
136     securityPolicy.SetFlags(mFlags, length);
137 
138     return securityPolicy;
139 }
140 
SetSecurityPolicy(const SecurityPolicy & aSecurityPolicy)141 void SecurityPolicyTlv::SetSecurityPolicy(const SecurityPolicy &aSecurityPolicy)
142 {
143     SetRotationTime(aSecurityPolicy.mRotationTime);
144     aSecurityPolicy.GetFlags(mFlags, sizeof(mFlags));
145 }
146 
StateToString(State aState)147 const char *StateTlv::StateToString(State aState)
148 {
149     static const char *const kStateStrings[] = {
150         "Pending", // (0) kPending,
151         "Accept",  // (1) kAccept
152         "Reject",  // (2) kReject,
153     };
154 
155     static_assert(0 == kPending, "kPending value is incorrect");
156     static_assert(1 == kAccept, "kAccept value is incorrect");
157 
158     return aState == kReject ? kStateStrings[2] : kStateStrings[aState];
159 }
160 
IsValid(void) const161 bool ChannelMaskTlv::IsValid(void) const
162 {
163     uint32_t channelMask;
164 
165     return (ReadChannelMask(channelMask) == kErrorNone);
166 }
167 
ReadChannelMask(uint32_t & aChannelMask) const168 Error ChannelMaskTlv::ReadChannelMask(uint32_t &aChannelMask) const
169 {
170     EntriesData entriesData;
171 
172     entriesData.Clear();
173     entriesData.mData   = &mEntriesStart;
174     entriesData.mLength = GetLength();
175 
176     return entriesData.Parse(aChannelMask);
177 }
178 
FindIn(const Message & aMessage,uint32_t & aChannelMask)179 Error ChannelMaskTlv::FindIn(const Message &aMessage, uint32_t &aChannelMask)
180 {
181     Error       error;
182     EntriesData entriesData;
183 
184     entriesData.Clear();
185     entriesData.mMessage = &aMessage;
186 
187     SuccessOrExit(error = FindTlvValueOffset(aMessage, Tlv::kChannelMask, entriesData.mOffset, entriesData.mLength));
188     error = entriesData.Parse(aChannelMask);
189 
190 exit:
191     return error;
192 }
193 
Parse(uint32_t & aChannelMask)194 Error ChannelMaskTlv::EntriesData::Parse(uint32_t &aChannelMask)
195 {
196     // Validates and parses the Channel Mask TLV entries for each
197     // channel page and if successful updates `aChannelMask` to
198     // return the combined mask for all channel pages supported by
199     // radio. The entries can be either contained in `mMessage` from
200     // `mOffset` (when `mMessage` is non-null) or be in a buffer
201     // `mData`. `mLength` gives the number of bytes for all entries.
202 
203     Error        error = kErrorParse;
204     Entry        readEntry;
205     const Entry *entry;
206     uint16_t     size;
207 
208     aChannelMask = 0;
209 
210     VerifyOrExit(mLength > 0); // At least one entry.
211 
212     while (mLength > 0)
213     {
214         VerifyOrExit(mLength > kEntryHeaderSize);
215 
216         if (mMessage != nullptr)
217         {
218             // We first read the entry's header only and after
219             // validating the entry and that the entry's channel page
220             // is supported by radio, we read the full `Entry`.
221 
222             mMessage->ReadBytes(mOffset, &readEntry, kEntryHeaderSize);
223             entry = &readEntry;
224         }
225         else
226         {
227             entry = reinterpret_cast<const Entry *>(mData);
228         }
229 
230         size = kEntryHeaderSize + entry->GetMaskLength();
231 
232         VerifyOrExit(size <= mLength);
233 
234         if (Radio::SupportsChannelPage(entry->GetChannelPage()))
235         {
236             // Currently supported channel pages all use `uint32_t`
237             // channel mask.
238 
239             VerifyOrExit(entry->GetMaskLength() == kMaskLength);
240 
241             if (mMessage != nullptr)
242             {
243                 IgnoreError(mMessage->Read(mOffset, readEntry));
244             }
245 
246             aChannelMask |= (entry->GetMask() & Radio::ChannelMaskForPage(entry->GetChannelPage()));
247         }
248 
249         mLength -= size;
250 
251         if (mMessage != nullptr)
252         {
253             mOffset += size;
254         }
255         else
256         {
257             mData += size;
258         }
259     }
260 
261     error = kErrorNone;
262 
263 exit:
264     return error;
265 }
266 
PrepareValue(Value & aValue,uint32_t aChannelMask)267 void ChannelMaskTlv::PrepareValue(Value &aValue, uint32_t aChannelMask)
268 {
269     Entry *entry = reinterpret_cast<Entry *>(aValue.mData);
270 
271     aValue.mLength = 0;
272 
273     for (uint8_t page : Radio::kSupportedChannelPages)
274     {
275         uint32_t mask = (Radio::ChannelMaskForPage(page) & aChannelMask);
276 
277         if (mask != 0)
278         {
279             entry->SetChannelPage(page);
280             entry->SetMaskLength(kMaskLength);
281             entry->SetMask(mask);
282 
283             aValue.mLength += sizeof(Entry);
284             entry++;
285         }
286     }
287 }
288 
AppendTo(Message & aMessage,uint32_t aChannelMask)289 Error ChannelMaskTlv::AppendTo(Message &aMessage, uint32_t aChannelMask)
290 {
291     Value value;
292 
293     PrepareValue(value, aChannelMask);
294     return Tlv::Append<ChannelMaskTlv>(aMessage, value.mData, value.mLength);
295 }
296 
297 } // namespace MeshCoP
298 } // namespace ot
299