1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements common methods for manipulating MLE TLVs.
32  */
33 
34 #include "tlvs.hpp"
35 
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/message.hpp"
39 
40 namespace ot {
41 
GetSize(void) const42 uint32_t Tlv::GetSize(void) const
43 {
44     return IsExtended() ? sizeof(ExtendedTlv) + As<ExtendedTlv>(this)->GetLength() : sizeof(Tlv) + GetLength();
45 }
46 
GetValue(void)47 uint8_t *Tlv::GetValue(void)
48 {
49     return reinterpret_cast<uint8_t *>(this) + (IsExtended() ? sizeof(ExtendedTlv) : sizeof(Tlv));
50 }
51 
GetValue(void) const52 const uint8_t *Tlv::GetValue(void) const
53 {
54     return reinterpret_cast<const uint8_t *>(this) + (IsExtended() ? sizeof(ExtendedTlv) : sizeof(Tlv));
55 }
56 
AppendTo(Message & aMessage) const57 Error Tlv::AppendTo(Message &aMessage) const { return aMessage.AppendBytes(this, static_cast<uint16_t>(GetSize())); }
58 
FindTlv(const Message & aMessage,uint8_t aType,uint16_t aMaxSize,Tlv & aTlv)59 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, uint16_t aMaxSize, Tlv &aTlv)
60 {
61     Error      error;
62     ParsedInfo info;
63 
64     SuccessOrExit(error = info.FindIn(aMessage, aType));
65 
66     if (aMaxSize > info.mSize)
67     {
68         aMaxSize = info.mSize;
69     }
70 
71     aMessage.ReadBytes(info.mOffset, &aTlv, aMaxSize);
72 
73 exit:
74     return error;
75 }
76 
FindTlvOffset(const Message & aMessage,uint8_t aType,uint16_t & aOffset)77 Error Tlv::FindTlvOffset(const Message &aMessage, uint8_t aType, uint16_t &aOffset)
78 {
79     Error      error;
80     ParsedInfo info;
81 
82     SuccessOrExit(error = info.FindIn(aMessage, aType));
83     aOffset = info.mOffset;
84 
85 exit:
86     return error;
87 }
88 
FindTlvValueOffset(const Message & aMessage,uint8_t aType,uint16_t & aValueOffset,uint16_t & aLength)89 Error Tlv::FindTlvValueOffset(const Message &aMessage, uint8_t aType, uint16_t &aValueOffset, uint16_t &aLength)
90 {
91     Error      error;
92     ParsedInfo info;
93 
94     SuccessOrExit(error = info.FindIn(aMessage, aType));
95 
96     aValueOffset = info.mValueOffset;
97     aLength      = info.mLength;
98 
99 exit:
100     return error;
101 }
102 
ParseFrom(const Message & aMessage,uint16_t aOffset)103 Error Tlv::ParsedInfo::ParseFrom(const Message &aMessage, uint16_t aOffset)
104 {
105     // This method reads and parses the TLV info from `aMessage` at
106     // `aOffset`. This can be used independent of whether the TLV is
107     // extended or not. It validates that the entire TLV can be read
108     // from `aMessage`.  Returns `kErrorNone` when successfully parsed,
109     // otherwise `kErrorParse`.
110 
111     Error       error;
112     Tlv         tlv;
113     ExtendedTlv extTlv;
114     uint16_t    headerSize;
115 
116     SuccessOrExit(error = aMessage.Read(aOffset, tlv));
117 
118     if (!tlv.IsExtended())
119     {
120         mType      = tlv.GetType();
121         mLength    = tlv.GetLength();
122         headerSize = sizeof(Tlv);
123     }
124     else
125     {
126         SuccessOrExit(error = aMessage.Read(aOffset, extTlv));
127 
128         mType      = extTlv.GetType();
129         mLength    = extTlv.GetLength();
130         headerSize = sizeof(ExtendedTlv);
131     }
132 
133     // We know that we could successfully read `tlv` or `extTlv`
134     // (`headerSize` bytes) from the message, so the calculation of the
135     // remaining length as `aMessage.GetLength() - aOffset - headerSize`
136     // cannot underflow.
137 
138     VerifyOrExit(mLength <= aMessage.GetLength() - aOffset - headerSize, error = kErrorParse);
139 
140     // Now that we know the entire TLV is contained within the
141     // `aMessage`, we can safely calculate `mValueOffset` and `mSize`
142     // as `uint16_t` and know that there will be no overflow.
143 
144     mType        = tlv.GetType();
145     mOffset      = aOffset;
146     mValueOffset = aOffset + headerSize;
147     mSize        = mLength + headerSize;
148 
149 exit:
150     return error;
151 }
152 
FindIn(const Message & aMessage,uint8_t aType)153 Error Tlv::ParsedInfo::FindIn(const Message &aMessage, uint8_t aType)
154 {
155     // This  method searches within `aMessage` starting from
156     // `aMessage.GetOffset()` for a TLV type `aType` and parsed its
157     // info and validates that the entire TLV can be read from
158     // `aMessage`. Returns `kErrorNone` when found, otherwise
159     // `kErrorNotFound`.
160 
161     Error    error  = kErrorNotFound;
162     uint16_t offset = aMessage.GetOffset();
163 
164     while (true)
165     {
166         SuccessOrExit(ParseFrom(aMessage, offset));
167 
168         if (mType == aType)
169         {
170             error = kErrorNone;
171             ExitNow();
172         }
173 
174         // `ParseFrom()` already validated that `offset + mSize` is
175         // less than `aMessage.GetLength()` and therefore we can not
176         // have an overflow here.
177 
178         offset += mSize;
179     }
180 
181 exit:
182     return error;
183 }
184 
ReadStringTlv(const Message & aMessage,uint16_t aOffset,uint8_t aMaxStringLength,char * aValue)185 Error Tlv::ReadStringTlv(const Message &aMessage, uint16_t aOffset, uint8_t aMaxStringLength, char *aValue)
186 {
187     Error      error = kErrorNone;
188     ParsedInfo info;
189     uint16_t   length;
190 
191     SuccessOrExit(error = info.ParseFrom(aMessage, aOffset));
192 
193     length = Min(info.mLength, static_cast<uint16_t>(aMaxStringLength));
194 
195     aMessage.ReadBytes(info.mValueOffset, aValue, length);
196     aValue[length] = '\0';
197 
198 exit:
199     return error;
200 }
201 
ReadUintTlv(const Message & aMessage,uint16_t aOffset,UintType & aValue)202 template <typename UintType> Error Tlv::ReadUintTlv(const Message &aMessage, uint16_t aOffset, UintType &aValue)
203 {
204     Error error;
205 
206     SuccessOrExit(error = ReadTlvValue(aMessage, aOffset, &aValue, sizeof(aValue)));
207     aValue = Encoding::BigEndian::HostSwap<UintType>(aValue);
208 
209 exit:
210     return error;
211 }
212 
213 // Explicit instantiations of `ReadUintTlv<>()`
214 template Error Tlv::ReadUintTlv<uint8_t>(const Message &aMessage, uint16_t aOffset, uint8_t &aValue);
215 template Error Tlv::ReadUintTlv<uint16_t>(const Message &aMessage, uint16_t aOffset, uint16_t &aValue);
216 template Error Tlv::ReadUintTlv<uint32_t>(const Message &aMessage, uint16_t aOffset, uint32_t &aValue);
217 
ReadTlvValue(const Message & aMessage,uint16_t aOffset,void * aValue,uint8_t aMinLength)218 Error Tlv::ReadTlvValue(const Message &aMessage, uint16_t aOffset, void *aValue, uint8_t aMinLength)
219 {
220     Error      error;
221     ParsedInfo info;
222 
223     SuccessOrExit(error = info.ParseFrom(aMessage, aOffset));
224 
225     VerifyOrExit(info.mLength >= aMinLength, error = kErrorParse);
226 
227     aMessage.ReadBytes(info.mValueOffset, aValue, aMinLength);
228 
229 exit:
230     return error;
231 }
232 
FindStringTlv(const Message & aMessage,uint8_t aType,uint8_t aMaxStringLength,char * aValue)233 Error Tlv::FindStringTlv(const Message &aMessage, uint8_t aType, uint8_t aMaxStringLength, char *aValue)
234 {
235     Error    error = kErrorNone;
236     uint16_t offset;
237 
238     SuccessOrExit(error = FindTlvOffset(aMessage, aType, offset));
239     error = ReadStringTlv(aMessage, offset, aMaxStringLength, aValue);
240 
241 exit:
242     return error;
243 }
244 
FindUintTlv(const Message & aMessage,uint8_t aType,UintType & aValue)245 template <typename UintType> Error Tlv::FindUintTlv(const Message &aMessage, uint8_t aType, UintType &aValue)
246 {
247     Error    error = kErrorNone;
248     uint16_t offset;
249 
250     SuccessOrExit(error = FindTlvOffset(aMessage, aType, offset));
251     error = ReadUintTlv<UintType>(aMessage, offset, aValue);
252 
253 exit:
254     return error;
255 }
256 
257 // Explicit instantiations of `FindUintTlv<>()`
258 template Error Tlv::FindUintTlv<uint8_t>(const Message &aMessage, uint8_t aType, uint8_t &aValue);
259 template Error Tlv::FindUintTlv<uint16_t>(const Message &aMessage, uint8_t aType, uint16_t &aValue);
260 template Error Tlv::FindUintTlv<uint32_t>(const Message &aMessage, uint8_t aType, uint32_t &aValue);
261 
FindTlv(const Message & aMessage,uint8_t aType,void * aValue,uint8_t aLength)262 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, void *aValue, uint8_t aLength)
263 {
264     Error    error;
265     uint16_t offset;
266     uint16_t length;
267 
268     SuccessOrExit(error = FindTlvValueOffset(aMessage, aType, offset, length));
269     VerifyOrExit(length >= aLength, error = kErrorParse);
270     aMessage.ReadBytes(offset, aValue, aLength);
271 
272 exit:
273     return error;
274 }
275 
AppendStringTlv(Message & aMessage,uint8_t aType,uint8_t aMaxStringLength,const char * aValue)276 Error Tlv::AppendStringTlv(Message &aMessage, uint8_t aType, uint8_t aMaxStringLength, const char *aValue)
277 {
278     uint16_t length = (aValue == nullptr) ? 0 : StringLength(aValue, aMaxStringLength);
279 
280     return AppendTlv(aMessage, aType, aValue, static_cast<uint8_t>(length));
281 }
282 
AppendUintTlv(Message & aMessage,uint8_t aType,UintType aValue)283 template <typename UintType> Error Tlv::AppendUintTlv(Message &aMessage, uint8_t aType, UintType aValue)
284 {
285     UintType value = Encoding::BigEndian::HostSwap<UintType>(aValue);
286 
287     return AppendTlv(aMessage, aType, &value, sizeof(UintType));
288 }
289 
290 // Explicit instantiations of `AppendUintTlv<>()`
291 template Error Tlv::AppendUintTlv<uint8_t>(Message &aMessage, uint8_t aType, uint8_t aValue);
292 template Error Tlv::AppendUintTlv<uint16_t>(Message &aMessage, uint8_t aType, uint16_t aValue);
293 template Error Tlv::AppendUintTlv<uint32_t>(Message &aMessage, uint8_t aType, uint32_t aValue);
294 
AppendTlv(Message & aMessage,uint8_t aType,const void * aValue,uint8_t aLength)295 Error Tlv::AppendTlv(Message &aMessage, uint8_t aType, const void *aValue, uint8_t aLength)
296 {
297     Error error = kErrorNone;
298     Tlv   tlv;
299 
300     OT_ASSERT(aLength <= Tlv::kBaseTlvMaxLength);
301 
302     tlv.SetType(aType);
303     tlv.SetLength(aLength);
304     SuccessOrExit(error = aMessage.Append(tlv));
305 
306     VerifyOrExit(aLength > 0);
307     error = aMessage.AppendBytes(aValue, aLength);
308 
309 exit:
310     return error;
311 }
312 
313 } // namespace ot
314