1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements common methods for manipulating MLE TLVs.
32  */
33 
34 #include "tlvs.hpp"
35 
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/message.hpp"
39 
40 namespace ot {
41 
GetSize(void) const42 uint32_t Tlv::GetSize(void) const
43 {
44     return IsExtended() ? sizeof(ExtendedTlv) + As<ExtendedTlv>(this)->GetLength() : sizeof(Tlv) + GetLength();
45 }
46 
GetValue(void)47 uint8_t *Tlv::GetValue(void)
48 {
49     return reinterpret_cast<uint8_t *>(this) + (IsExtended() ? sizeof(ExtendedTlv) : sizeof(Tlv));
50 }
51 
GetValue(void) const52 const uint8_t *Tlv::GetValue(void) const
53 {
54     return reinterpret_cast<const uint8_t *>(this) + (IsExtended() ? sizeof(ExtendedTlv) : sizeof(Tlv));
55 }
56 
AppendTo(Message & aMessage) const57 Error Tlv::AppendTo(Message &aMessage) const { return aMessage.AppendBytes(this, static_cast<uint16_t>(GetSize())); }
58 
FindTlv(const Message & aMessage,uint8_t aType,uint16_t aMaxSize,Tlv & aTlv)59 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, uint16_t aMaxSize, Tlv &aTlv)
60 {
61     uint16_t offset;
62 
63     return FindTlv(aMessage, aType, aMaxSize, aTlv, offset);
64 }
65 
FindTlv(const Message & aMessage,uint8_t aType,uint16_t aMaxSize,Tlv & aTlv,uint16_t & aOffset)66 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, uint16_t aMaxSize, Tlv &aTlv, uint16_t &aOffset)
67 {
68     Error      error;
69     ParsedInfo info;
70 
71     SuccessOrExit(error = info.FindIn(aMessage, aType));
72 
73     info.mTlvOffsetRange.ShrinkLength(aMaxSize);
74     aMessage.ReadBytes(info.mTlvOffsetRange, &aTlv);
75     aOffset = info.mTlvOffsetRange.GetOffset();
76 
77 exit:
78     return error;
79 }
80 
FindTlvValueOffsetRange(const Message & aMessage,uint8_t aType,OffsetRange & aOffsetRange)81 Error Tlv::FindTlvValueOffsetRange(const Message &aMessage, uint8_t aType, OffsetRange &aOffsetRange)
82 {
83     Error      error;
84     ParsedInfo info;
85 
86     SuccessOrExit(error = info.FindIn(aMessage, aType));
87     aOffsetRange = info.mValueOffsetRange;
88 
89 exit:
90     return error;
91 }
92 
ParseFrom(const Message & aMessage,uint16_t aOffset)93 Error Tlv::ParsedInfo::ParseFrom(const Message &aMessage, uint16_t aOffset)
94 {
95     OffsetRange offsetRange;
96 
97     offsetRange.InitFromRange(aOffset, aMessage.GetLength());
98     return ParseFrom(aMessage, offsetRange);
99 }
100 
ParseFrom(const Message & aMessage,const OffsetRange & aOffsetRange)101 Error Tlv::ParsedInfo::ParseFrom(const Message &aMessage, const OffsetRange &aOffsetRange)
102 {
103     Error       error;
104     Tlv         tlv;
105     ExtendedTlv extTlv;
106     uint32_t    headerSize;
107     uint32_t    size;
108 
109     SuccessOrExit(error = aMessage.Read(aOffsetRange, tlv));
110 
111     mType = tlv.GetType();
112 
113     if (!tlv.IsExtended())
114     {
115         mIsExtended = false;
116         headerSize  = sizeof(Tlv);
117         size        = headerSize + tlv.GetLength();
118     }
119     else
120     {
121         SuccessOrExit(error = aMessage.Read(aOffsetRange, extTlv));
122 
123         mIsExtended = true;
124         headerSize  = sizeof(ExtendedTlv);
125         size        = headerSize + extTlv.GetLength();
126     }
127 
128     mTlvOffsetRange = aOffsetRange;
129     VerifyOrExit(mTlvOffsetRange.Contains(size), error = kErrorParse);
130     mTlvOffsetRange.ShrinkLength(static_cast<uint16_t>(size));
131 
132     VerifyOrExit(mTlvOffsetRange.GetEndOffset() <= aMessage.GetLength(), error = kErrorParse);
133 
134     mValueOffsetRange = mTlvOffsetRange;
135     mValueOffsetRange.AdvanceOffset(headerSize);
136 
137 exit:
138     return error;
139 }
140 
FindIn(const Message & aMessage,uint8_t aType)141 Error Tlv::ParsedInfo::FindIn(const Message &aMessage, uint8_t aType)
142 {
143     Error       error = kErrorNotFound;
144     OffsetRange offsetRange;
145 
146     offsetRange.InitFromMessageOffsetToEnd(aMessage);
147 
148     while (true)
149     {
150         SuccessOrExit(ParseFrom(aMessage, offsetRange));
151 
152         if (mType == aType)
153         {
154             error = kErrorNone;
155             ExitNow();
156         }
157 
158         offsetRange.AdvanceOffset(mTlvOffsetRange.GetLength());
159     }
160 
161 exit:
162     return error;
163 }
164 
ReadStringTlv(const Message & aMessage,uint16_t aOffset,uint8_t aMaxStringLength,char * aValue)165 Error Tlv::ReadStringTlv(const Message &aMessage, uint16_t aOffset, uint8_t aMaxStringLength, char *aValue)
166 {
167     Error      error = kErrorNone;
168     ParsedInfo info;
169 
170     SuccessOrExit(error = info.ParseFrom(aMessage, aOffset));
171 
172     info.mValueOffsetRange.ShrinkLength(aMaxStringLength);
173     aMessage.ReadBytes(info.mValueOffsetRange, aValue);
174     aValue[info.mValueOffsetRange.GetLength()] = kNullChar;
175 
176 exit:
177     return error;
178 }
179 
ReadUintTlv(const Message & aMessage,uint16_t aOffset,UintType & aValue)180 template <typename UintType> Error Tlv::ReadUintTlv(const Message &aMessage, uint16_t aOffset, UintType &aValue)
181 {
182     Error error;
183 
184     SuccessOrExit(error = ReadTlvValue(aMessage, aOffset, &aValue, sizeof(aValue)));
185     aValue = BigEndian::HostSwap<UintType>(aValue);
186 
187 exit:
188     return error;
189 }
190 
191 // Explicit instantiations of `ReadUintTlv<>()`
192 template Error Tlv::ReadUintTlv<uint8_t>(const Message &aMessage, uint16_t aOffset, uint8_t &aValue);
193 template Error Tlv::ReadUintTlv<uint16_t>(const Message &aMessage, uint16_t aOffset, uint16_t &aValue);
194 template Error Tlv::ReadUintTlv<uint32_t>(const Message &aMessage, uint16_t aOffset, uint32_t &aValue);
195 
ReadTlvValue(const Message & aMessage,uint16_t aOffset,void * aValue,uint8_t aMinLength)196 Error Tlv::ReadTlvValue(const Message &aMessage, uint16_t aOffset, void *aValue, uint8_t aMinLength)
197 {
198     Error      error;
199     ParsedInfo info;
200 
201     SuccessOrExit(error = info.ParseFrom(aMessage, aOffset));
202 
203     VerifyOrExit(info.mValueOffsetRange.Contains(aMinLength), error = kErrorParse);
204     info.mValueOffsetRange.ShrinkLength(aMinLength);
205 
206     aMessage.ReadBytes(info.mValueOffsetRange, aValue);
207 
208 exit:
209     return error;
210 }
211 
FindStringTlv(const Message & aMessage,uint8_t aType,uint8_t aMaxStringLength,char * aValue)212 Error Tlv::FindStringTlv(const Message &aMessage, uint8_t aType, uint8_t aMaxStringLength, char *aValue)
213 {
214     Error      error;
215     ParsedInfo info;
216 
217     SuccessOrExit(error = info.FindIn(aMessage, aType));
218     error = ReadStringTlv(aMessage, info.mTlvOffsetRange.GetOffset(), aMaxStringLength, aValue);
219 
220 exit:
221     return error;
222 }
223 
FindUintTlv(const Message & aMessage,uint8_t aType,UintType & aValue)224 template <typename UintType> Error Tlv::FindUintTlv(const Message &aMessage, uint8_t aType, UintType &aValue)
225 {
226     Error      error;
227     ParsedInfo info;
228 
229     SuccessOrExit(error = info.FindIn(aMessage, aType));
230     error = ReadUintTlv<UintType>(aMessage, info.mTlvOffsetRange.GetOffset(), aValue);
231 
232 exit:
233     return error;
234 }
235 
236 // Explicit instantiations of `FindUintTlv<>()`
237 template Error Tlv::FindUintTlv<uint8_t>(const Message &aMessage, uint8_t aType, uint8_t &aValue);
238 template Error Tlv::FindUintTlv<uint16_t>(const Message &aMessage, uint8_t aType, uint16_t &aValue);
239 template Error Tlv::FindUintTlv<uint32_t>(const Message &aMessage, uint8_t aType, uint32_t &aValue);
240 
FindTlv(const Message & aMessage,uint8_t aType,void * aValue,uint16_t aLength)241 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, void *aValue, uint16_t aLength)
242 {
243     Error       error;
244     OffsetRange offsetRange;
245 
246     SuccessOrExit(error = FindTlvValueOffsetRange(aMessage, aType, offsetRange));
247     error = aMessage.Read(offsetRange, aValue, aLength);
248 
249 exit:
250     return error;
251 }
252 
AppendStringTlv(Message & aMessage,uint8_t aType,uint8_t aMaxStringLength,const char * aValue)253 Error Tlv::AppendStringTlv(Message &aMessage, uint8_t aType, uint8_t aMaxStringLength, const char *aValue)
254 {
255     uint16_t length = (aValue == nullptr) ? 0 : StringLength(aValue, aMaxStringLength);
256 
257     return AppendTlv(aMessage, aType, aValue, static_cast<uint8_t>(length));
258 }
259 
AppendUintTlv(Message & aMessage,uint8_t aType,UintType aValue)260 template <typename UintType> Error Tlv::AppendUintTlv(Message &aMessage, uint8_t aType, UintType aValue)
261 {
262     UintType value = BigEndian::HostSwap<UintType>(aValue);
263 
264     return AppendTlv(aMessage, aType, &value, sizeof(UintType));
265 }
266 
267 // Explicit instantiations of `AppendUintTlv<>()`
268 template Error Tlv::AppendUintTlv<uint8_t>(Message &aMessage, uint8_t aType, uint8_t aValue);
269 template Error Tlv::AppendUintTlv<uint16_t>(Message &aMessage, uint8_t aType, uint16_t aValue);
270 template Error Tlv::AppendUintTlv<uint32_t>(Message &aMessage, uint8_t aType, uint32_t aValue);
271 
AppendTlv(Message & aMessage,uint8_t aType,const void * aValue,uint16_t aLength)272 Error Tlv::AppendTlv(Message &aMessage, uint8_t aType, const void *aValue, uint16_t aLength)
273 {
274     Error       error = kErrorNone;
275     ExtendedTlv extTlv;
276     Tlv         tlv;
277 
278     if (aLength > kBaseTlvMaxLength)
279     {
280         extTlv.SetType(aType);
281         extTlv.SetLength(aLength);
282         SuccessOrExit(error = aMessage.Append(extTlv));
283     }
284     else
285     {
286         tlv.SetType(aType);
287         tlv.SetLength(static_cast<uint8_t>(aLength));
288         SuccessOrExit(error = aMessage.Append(tlv));
289     }
290 
291     VerifyOrExit(aLength > 0);
292     error = aMessage.AppendBytes(aValue, aLength);
293 
294 exit:
295     return error;
296 }
297 
FindTlv(const void * aTlvsStart,uint16_t aTlvsLength,uint8_t aType)298 const Tlv *Tlv::FindTlv(const void *aTlvsStart, uint16_t aTlvsLength, uint8_t aType)
299 {
300     const Tlv *tlv;
301     const Tlv *end = reinterpret_cast<const Tlv *>(reinterpret_cast<const uint8_t *>(aTlvsStart) + aTlvsLength);
302 
303     for (tlv = reinterpret_cast<const Tlv *>(aTlvsStart); tlv < end; tlv = tlv->GetNext())
304     {
305         VerifyOrExit((tlv + 1) <= end, tlv = nullptr);
306 
307         if (tlv->IsExtended())
308         {
309             VerifyOrExit((As<ExtendedTlv>(tlv) + 1) <= As<ExtendedTlv>(end), tlv = nullptr);
310         }
311 
312         VerifyOrExit(tlv->GetNext() <= end, tlv = nullptr);
313 
314         if (tlv->GetType() == aType)
315         {
316             ExitNow();
317         }
318     }
319 
320     tlv = nullptr;
321 
322 exit:
323     return tlv;
324 }
325 
326 } // namespace ot
327