1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements common methods for manipulating MLE TLVs.
32  */
33 
34 #include "tlvs.hpp"
35 
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/message.hpp"
39 
40 namespace ot {
41 
GetSize(void) const42 uint32_t Tlv::GetSize(void) const
43 {
44     return IsExtended() ? sizeof(ExtendedTlv) + As<ExtendedTlv>(this)->GetLength() : sizeof(Tlv) + GetLength();
45 }
46 
GetValue(void)47 uint8_t *Tlv::GetValue(void)
48 {
49     return reinterpret_cast<uint8_t *>(this) + (IsExtended() ? sizeof(ExtendedTlv) : sizeof(Tlv));
50 }
51 
GetValue(void) const52 const uint8_t *Tlv::GetValue(void) const
53 {
54     return reinterpret_cast<const uint8_t *>(this) + (IsExtended() ? sizeof(ExtendedTlv) : sizeof(Tlv));
55 }
56 
AppendTo(Message & aMessage) const57 Error Tlv::AppendTo(Message &aMessage) const { return aMessage.AppendBytes(this, static_cast<uint16_t>(GetSize())); }
58 
FindTlv(const Message & aMessage,uint8_t aType,uint16_t aMaxSize,Tlv & aTlv)59 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, uint16_t aMaxSize, Tlv &aTlv)
60 {
61     uint16_t offset;
62 
63     return FindTlv(aMessage, aType, aMaxSize, aTlv, offset);
64 }
65 
FindTlv(const Message & aMessage,uint8_t aType,uint16_t aMaxSize,Tlv & aTlv,uint16_t & aOffset)66 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, uint16_t aMaxSize, Tlv &aTlv, uint16_t &aOffset)
67 {
68     Error      error;
69     ParsedInfo info;
70 
71     SuccessOrExit(error = info.FindIn(aMessage, aType));
72 
73     if (aMaxSize > info.mSize)
74     {
75         aMaxSize = info.mSize;
76     }
77 
78     aMessage.ReadBytes(info.mOffset, &aTlv, aMaxSize);
79     aOffset = info.mOffset;
80 
81 exit:
82     return error;
83 }
FindTlvValueOffset(const Message & aMessage,uint8_t aType,uint16_t & aValueOffset,uint16_t & aLength)84 Error Tlv::FindTlvValueOffset(const Message &aMessage, uint8_t aType, uint16_t &aValueOffset, uint16_t &aLength)
85 {
86     Error      error;
87     ParsedInfo info;
88 
89     SuccessOrExit(error = info.FindIn(aMessage, aType));
90 
91     aValueOffset = info.mValueOffset;
92     aLength      = info.mLength;
93 
94 exit:
95     return error;
96 }
97 
FindTlvValueStartEndOffsets(const Message & aMessage,uint8_t aType,uint16_t & aValueStartOffset,uint16_t & aValueEndOffset)98 Error Tlv::FindTlvValueStartEndOffsets(const Message &aMessage,
99                                        uint8_t        aType,
100                                        uint16_t      &aValueStartOffset,
101                                        uint16_t      &aValueEndOffset)
102 {
103     Error      error;
104     ParsedInfo info;
105 
106     SuccessOrExit(error = info.FindIn(aMessage, aType));
107 
108     aValueStartOffset = info.mValueOffset;
109     aValueEndOffset   = info.mValueOffset + info.mLength;
110 
111 exit:
112     return error;
113 }
114 
ParseFrom(const Message & aMessage,uint16_t aOffset)115 Error Tlv::ParsedInfo::ParseFrom(const Message &aMessage, uint16_t aOffset)
116 {
117     // This method reads and parses the TLV info from `aMessage` at
118     // `aOffset`. This can be used independent of whether the TLV is
119     // extended or not. It validates that the entire TLV can be read
120     // from `aMessage`.  Returns `kErrorNone` when successfully parsed,
121     // otherwise `kErrorParse`.
122 
123     Error       error;
124     Tlv         tlv;
125     ExtendedTlv extTlv;
126     uint16_t    headerSize;
127 
128     SuccessOrExit(error = aMessage.Read(aOffset, tlv));
129 
130     if (!tlv.IsExtended())
131     {
132         mType      = tlv.GetType();
133         mLength    = tlv.GetLength();
134         headerSize = sizeof(Tlv);
135     }
136     else
137     {
138         SuccessOrExit(error = aMessage.Read(aOffset, extTlv));
139 
140         mType      = extTlv.GetType();
141         mLength    = extTlv.GetLength();
142         headerSize = sizeof(ExtendedTlv);
143     }
144 
145     // We know that we could successfully read `tlv` or `extTlv`
146     // (`headerSize` bytes) from the message, so the calculation of the
147     // remaining length as `aMessage.GetLength() - aOffset - headerSize`
148     // cannot underflow.
149 
150     VerifyOrExit(mLength <= aMessage.GetLength() - aOffset - headerSize, error = kErrorParse);
151 
152     // Now that we know the entire TLV is contained within the
153     // `aMessage`, we can safely calculate `mValueOffset` and `mSize`
154     // as `uint16_t` and know that there will be no overflow.
155 
156     mType        = tlv.GetType();
157     mOffset      = aOffset;
158     mValueOffset = aOffset + headerSize;
159     mSize        = mLength + headerSize;
160 
161 exit:
162     return error;
163 }
164 
FindIn(const Message & aMessage,uint8_t aType)165 Error Tlv::ParsedInfo::FindIn(const Message &aMessage, uint8_t aType)
166 {
167     // This  method searches within `aMessage` starting from
168     // `aMessage.GetOffset()` for a TLV type `aType` and parsed its
169     // info and validates that the entire TLV can be read from
170     // `aMessage`. Returns `kErrorNone` when found, otherwise
171     // `kErrorNotFound`.
172 
173     Error    error  = kErrorNotFound;
174     uint16_t offset = aMessage.GetOffset();
175 
176     while (true)
177     {
178         SuccessOrExit(ParseFrom(aMessage, offset));
179 
180         if (mType == aType)
181         {
182             error = kErrorNone;
183             ExitNow();
184         }
185 
186         // `ParseFrom()` already validated that `offset + mSize` is
187         // less than `aMessage.GetLength()` and therefore we can not
188         // have an overflow here.
189 
190         offset += mSize;
191     }
192 
193 exit:
194     return error;
195 }
196 
ReadStringTlv(const Message & aMessage,uint16_t aOffset,uint8_t aMaxStringLength,char * aValue)197 Error Tlv::ReadStringTlv(const Message &aMessage, uint16_t aOffset, uint8_t aMaxStringLength, char *aValue)
198 {
199     Error      error = kErrorNone;
200     ParsedInfo info;
201     uint16_t   length;
202 
203     SuccessOrExit(error = info.ParseFrom(aMessage, aOffset));
204 
205     length = Min(info.mLength, static_cast<uint16_t>(aMaxStringLength));
206 
207     aMessage.ReadBytes(info.mValueOffset, aValue, length);
208     aValue[length] = '\0';
209 
210 exit:
211     return error;
212 }
213 
ReadUintTlv(const Message & aMessage,uint16_t aOffset,UintType & aValue)214 template <typename UintType> Error Tlv::ReadUintTlv(const Message &aMessage, uint16_t aOffset, UintType &aValue)
215 {
216     Error error;
217 
218     SuccessOrExit(error = ReadTlvValue(aMessage, aOffset, &aValue, sizeof(aValue)));
219     aValue = BigEndian::HostSwap<UintType>(aValue);
220 
221 exit:
222     return error;
223 }
224 
225 // Explicit instantiations of `ReadUintTlv<>()`
226 template Error Tlv::ReadUintTlv<uint8_t>(const Message &aMessage, uint16_t aOffset, uint8_t &aValue);
227 template Error Tlv::ReadUintTlv<uint16_t>(const Message &aMessage, uint16_t aOffset, uint16_t &aValue);
228 template Error Tlv::ReadUintTlv<uint32_t>(const Message &aMessage, uint16_t aOffset, uint32_t &aValue);
229 
ReadTlvValue(const Message & aMessage,uint16_t aOffset,void * aValue,uint8_t aMinLength)230 Error Tlv::ReadTlvValue(const Message &aMessage, uint16_t aOffset, void *aValue, uint8_t aMinLength)
231 {
232     Error      error;
233     ParsedInfo info;
234 
235     SuccessOrExit(error = info.ParseFrom(aMessage, aOffset));
236 
237     VerifyOrExit(info.mLength >= aMinLength, error = kErrorParse);
238 
239     aMessage.ReadBytes(info.mValueOffset, aValue, aMinLength);
240 
241 exit:
242     return error;
243 }
244 
FindStringTlv(const Message & aMessage,uint8_t aType,uint8_t aMaxStringLength,char * aValue)245 Error Tlv::FindStringTlv(const Message &aMessage, uint8_t aType, uint8_t aMaxStringLength, char *aValue)
246 {
247     Error      error;
248     ParsedInfo info;
249 
250     SuccessOrExit(error = info.FindIn(aMessage, aType));
251     error = ReadStringTlv(aMessage, info.mOffset, aMaxStringLength, aValue);
252 
253 exit:
254     return error;
255 }
256 
FindUintTlv(const Message & aMessage,uint8_t aType,UintType & aValue)257 template <typename UintType> Error Tlv::FindUintTlv(const Message &aMessage, uint8_t aType, UintType &aValue)
258 {
259     Error      error;
260     ParsedInfo info;
261 
262     SuccessOrExit(error = info.FindIn(aMessage, aType));
263     error = ReadUintTlv<UintType>(aMessage, info.mOffset, aValue);
264 
265 exit:
266     return error;
267 }
268 
269 // Explicit instantiations of `FindUintTlv<>()`
270 template Error Tlv::FindUintTlv<uint8_t>(const Message &aMessage, uint8_t aType, uint8_t &aValue);
271 template Error Tlv::FindUintTlv<uint16_t>(const Message &aMessage, uint8_t aType, uint16_t &aValue);
272 template Error Tlv::FindUintTlv<uint32_t>(const Message &aMessage, uint8_t aType, uint32_t &aValue);
273 
FindTlv(const Message & aMessage,uint8_t aType,void * aValue,uint16_t aLength)274 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, void *aValue, uint16_t aLength)
275 {
276     Error    error;
277     uint16_t offset;
278     uint16_t length;
279 
280     SuccessOrExit(error = FindTlvValueOffset(aMessage, aType, offset, length));
281     VerifyOrExit(length >= aLength, error = kErrorParse);
282     aMessage.ReadBytes(offset, aValue, aLength);
283 
284 exit:
285     return error;
286 }
287 
AppendStringTlv(Message & aMessage,uint8_t aType,uint8_t aMaxStringLength,const char * aValue)288 Error Tlv::AppendStringTlv(Message &aMessage, uint8_t aType, uint8_t aMaxStringLength, const char *aValue)
289 {
290     uint16_t length = (aValue == nullptr) ? 0 : StringLength(aValue, aMaxStringLength);
291 
292     return AppendTlv(aMessage, aType, aValue, static_cast<uint8_t>(length));
293 }
294 
AppendUintTlv(Message & aMessage,uint8_t aType,UintType aValue)295 template <typename UintType> Error Tlv::AppendUintTlv(Message &aMessage, uint8_t aType, UintType aValue)
296 {
297     UintType value = BigEndian::HostSwap<UintType>(aValue);
298 
299     return AppendTlv(aMessage, aType, &value, sizeof(UintType));
300 }
301 
302 // Explicit instantiations of `AppendUintTlv<>()`
303 template Error Tlv::AppendUintTlv<uint8_t>(Message &aMessage, uint8_t aType, uint8_t aValue);
304 template Error Tlv::AppendUintTlv<uint16_t>(Message &aMessage, uint8_t aType, uint16_t aValue);
305 template Error Tlv::AppendUintTlv<uint32_t>(Message &aMessage, uint8_t aType, uint32_t aValue);
306 
AppendTlv(Message & aMessage,uint8_t aType,const void * aValue,uint8_t aLength)307 Error Tlv::AppendTlv(Message &aMessage, uint8_t aType, const void *aValue, uint8_t aLength)
308 {
309     Error error = kErrorNone;
310     Tlv   tlv;
311 
312     OT_ASSERT(aLength <= Tlv::kBaseTlvMaxLength);
313 
314     tlv.SetType(aType);
315     tlv.SetLength(aLength);
316     SuccessOrExit(error = aMessage.Append(tlv));
317 
318     VerifyOrExit(aLength > 0);
319     error = aMessage.AppendBytes(aValue, aLength);
320 
321 exit:
322     return error;
323 }
324 
FindTlv(const void * aTlvsStart,uint16_t aTlvsLength,uint8_t aType)325 const Tlv *Tlv::FindTlv(const void *aTlvsStart, uint16_t aTlvsLength, uint8_t aType)
326 {
327     const Tlv *tlv;
328     const Tlv *end = reinterpret_cast<const Tlv *>(reinterpret_cast<const uint8_t *>(aTlvsStart) + aTlvsLength);
329 
330     for (tlv = reinterpret_cast<const Tlv *>(aTlvsStart); tlv < end; tlv = tlv->GetNext())
331     {
332         VerifyOrExit((tlv + 1) <= end, tlv = nullptr);
333 
334         if (tlv->IsExtended())
335         {
336             VerifyOrExit((As<ExtendedTlv>(tlv) + 1) <= As<ExtendedTlv>(end), tlv = nullptr);
337         }
338 
339         VerifyOrExit(tlv->GetNext() <= end, tlv = nullptr);
340 
341         if (tlv->GetType() == aType)
342         {
343             ExitNow();
344         }
345     }
346 
347     tlv = nullptr;
348 
349 exit:
350     return tlv;
351 }
352 
353 } // namespace ot
354