1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "common/encoding.hpp"
30 #include "common/message.hpp"
31 #include "common/numeric_limits.hpp"
32 #include "common/random.hpp"
33 #include "instance/instance.hpp"
34 #include "net/checksum.hpp"
35 #include "net/icmp6.hpp"
36 #include "net/ip4_types.hpp"
37 #include "net/udp6.hpp"
38 
39 #include "test_platform.h"
40 #include "test_util.hpp"
41 
42 namespace ot {
43 
CalculateChecksum(const void * aBuffer,uint16_t aLength)44 uint16_t CalculateChecksum(const void *aBuffer, uint16_t aLength)
45 {
46     // Calculates checksum over a given buffer data. This implementation
47     // is inspired by the algorithm from RFC-1071.
48 
49     uint32_t       sum   = 0;
50     const uint8_t *bytes = reinterpret_cast<const uint8_t *>(aBuffer);
51 
52     while (aLength >= sizeof(uint16_t))
53     {
54         sum += BigEndian::ReadUint16(bytes);
55         bytes += sizeof(uint16_t);
56         aLength -= sizeof(uint16_t);
57     }
58 
59     if (aLength > 0)
60     {
61         sum += (static_cast<uint32_t>(bytes[0]) << 8);
62     }
63 
64     // Fold 32-bit sum to 16 bits.
65 
66     while (sum >> 16)
67     {
68         sum = (sum & 0xffff) + (sum >> 16);
69     }
70 
71     return static_cast<uint16_t>(sum & 0xffff);
72 }
73 
CalculateChecksum(const Ip6::Address & aSource,const Ip6::Address & aDestination,uint8_t aIpProto,const Message & aMessage)74 uint16_t CalculateChecksum(const Ip6::Address &aSource,
75                            const Ip6::Address &aDestination,
76                            uint8_t             aIpProto,
77                            const Message      &aMessage)
78 {
79     // This method calculates the checksum over an IPv6 message.
80     constexpr uint16_t kMaxPayload = 1024;
81 
82     OT_TOOL_PACKED_BEGIN
83     struct PseudoHeader
84     {
85         Ip6::Address mSource;
86         Ip6::Address mDestination;
87         uint32_t     mPayloadLength;
88         uint32_t     mProtocol;
89     } OT_TOOL_PACKED_END;
90 
91     OT_TOOL_PACKED_BEGIN
92     struct ChecksumData
93     {
94         PseudoHeader mPseudoHeader;
95         uint8_t      mPayload[kMaxPayload];
96     } OT_TOOL_PACKED_END;
97 
98     ChecksumData data;
99     uint16_t     payloadLength;
100 
101     payloadLength = aMessage.GetLength() - aMessage.GetOffset();
102 
103     data.mPseudoHeader.mSource        = aSource;
104     data.mPseudoHeader.mDestination   = aDestination;
105     data.mPseudoHeader.mProtocol      = BigEndian::HostSwap32(aIpProto);
106     data.mPseudoHeader.mPayloadLength = BigEndian::HostSwap32(payloadLength);
107 
108     SuccessOrQuit(aMessage.Read(aMessage.GetOffset(), data.mPayload, payloadLength));
109 
110     return CalculateChecksum(&data, sizeof(PseudoHeader) + payloadLength);
111 }
112 
CalculateChecksum(const Ip4::Address & aSource,const Ip4::Address & aDestination,uint8_t aIpProto,const Message & aMessage)113 uint16_t CalculateChecksum(const Ip4::Address &aSource,
114                            const Ip4::Address &aDestination,
115                            uint8_t             aIpProto,
116                            const Message      &aMessage)
117 {
118     // This method calculates the checksum over an IPv4 message.
119     constexpr uint16_t kMaxPayload = 1024;
120 
121     OT_TOOL_PACKED_BEGIN
122     struct PseudoHeader
123     {
124         Ip4::Address mSource;
125         Ip4::Address mDestination;
126         uint16_t     mPayloadLength;
127         uint16_t     mProtocol;
128     } OT_TOOL_PACKED_END;
129 
130     OT_TOOL_PACKED_BEGIN
131     struct ChecksumData
132     {
133         PseudoHeader mPseudoHeader;
134         uint8_t      mPayload[kMaxPayload];
135     } OT_TOOL_PACKED_END;
136 
137     ChecksumData data;
138     uint16_t     payloadLength;
139 
140     payloadLength = aMessage.GetLength() - aMessage.GetOffset();
141 
142     data.mPseudoHeader.mSource        = aSource;
143     data.mPseudoHeader.mDestination   = aDestination;
144     data.mPseudoHeader.mProtocol      = BigEndian::HostSwap16(aIpProto);
145     data.mPseudoHeader.mPayloadLength = BigEndian::HostSwap16(payloadLength);
146 
147     SuccessOrQuit(aMessage.Read(aMessage.GetOffset(), data.mPayload, payloadLength));
148 
149     return CalculateChecksum(&data, sizeof(PseudoHeader) + payloadLength);
150 }
151 
CorruptMessage(Message & aMessage)152 void CorruptMessage(Message &aMessage)
153 {
154     // Change a random bit in the message.
155 
156     uint16_t byteOffset;
157     uint8_t  bitOffset;
158     uint8_t  byte;
159 
160     byteOffset = Random::NonCrypto::GetUint16InRange(0, aMessage.GetLength());
161 
162     SuccessOrQuit(aMessage.Read(byteOffset, byte));
163 
164     bitOffset = Random::NonCrypto::GetUint8InRange(0, kBitsPerByte);
165 
166     byte ^= (1 << bitOffset);
167 
168     aMessage.Write(byteOffset, byte);
169 }
170 
TestUdpMessageChecksum(void)171 void TestUdpMessageChecksum(void)
172 {
173     constexpr uint16_t kMinSize = sizeof(Ip6::Udp::Header);
174     constexpr uint16_t kMaxSize = kBufferSize * 3 + 24;
175 
176     const char *kSourceAddress = "fd00:1122:3344:5566:7788:99aa:bbcc:ddee";
177     const char *kDestAddress   = "fd01:2345:6789:abcd:ef01:2345:6789:abcd";
178 
179     Instance *instance = static_cast<Instance *>(testInitInstance());
180 
181     VerifyOrQuit(instance != nullptr);
182 
183     for (uint16_t size = kMinSize; size <= kMaxSize; size++)
184     {
185         Message         *message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Udp::Header));
186         Ip6::Udp::Header udpHeader;
187         Ip6::MessageInfo messageInfo;
188 
189         VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
190         SuccessOrQuit(message->SetLength(size));
191 
192         // Write UDP header with a random payload.
193 
194         Random::NonCrypto::Fill(udpHeader);
195         udpHeader.SetChecksum(0);
196         message->Write(0, udpHeader);
197 
198         if (size > sizeof(udpHeader))
199         {
200             uint8_t  buffer[kMaxSize];
201             uint16_t payloadSize = size - sizeof(udpHeader);
202 
203             Random::NonCrypto::FillBuffer(buffer, payloadSize);
204             message->WriteBytes(sizeof(udpHeader), &buffer[0], payloadSize);
205         }
206 
207         SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
208         SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
209 
210         // Verify that the `Checksum::UpdateMessageChecksum` correctly
211         // updates the checksum field in the UDP header on the message.
212 
213         Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp);
214 
215         SuccessOrQuit(message->Read(message->GetOffset(), udpHeader));
216         VerifyOrQuit(udpHeader.GetChecksum() != 0);
217 
218         // Verify that the calculated UDP checksum is valid.
219 
220         VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp,
221                                        *message) == 0xffff);
222 
223         // Verify that `Checksum::VerifyMessageChecksum()` accepts the
224         // message and its calculated checksum.
225 
226         SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp));
227 
228         // Corrupt the message and verify that checksum is no longer accepted.
229 
230         CorruptMessage(*message);
231 
232         VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp) != kErrorNone,
233                      "Checksum passed on corrupted message");
234 
235         message->Free();
236     }
237 }
238 
TestIcmp6MessageChecksum(void)239 void TestIcmp6MessageChecksum(void)
240 {
241     constexpr uint16_t kMinSize = sizeof(Ip6::Icmp::Header);
242     constexpr uint16_t kMaxSize = kBufferSize * 3 + 24;
243 
244     const char *kSourceAddress = "fd00:feef:dccd:baab:9889:7667:5444:3223";
245     const char *kDestAddress   = "fd01:abab:beef:cafe:1234:5678:9abc:0";
246 
247     Instance *instance = static_cast<Instance *>(testInitInstance());
248 
249     VerifyOrQuit(instance != nullptr, "Null OpenThread instance\n");
250 
251     for (uint16_t size = kMinSize; size <= kMaxSize; size++)
252     {
253         Message          *message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Icmp::Header));
254         Ip6::Icmp::Header icmp6Header;
255         Ip6::MessageInfo  messageInfo;
256 
257         VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
258         SuccessOrQuit(message->SetLength(size));
259 
260         // Write ICMP6 header with a random payload.
261 
262         Random::NonCrypto::Fill(icmp6Header);
263         icmp6Header.SetChecksum(0);
264         message->Write(0, icmp6Header);
265 
266         if (size > sizeof(icmp6Header))
267         {
268             uint8_t  buffer[kMaxSize];
269             uint16_t payloadSize = size - sizeof(icmp6Header);
270 
271             Random::NonCrypto::FillBuffer(buffer, payloadSize);
272             message->WriteBytes(sizeof(icmp6Header), &buffer[0], payloadSize);
273         }
274 
275         SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
276         SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
277 
278         // Verify that the `Checksum::UpdateMessageChecksum` correctly
279         // updates the checksum field in the ICMP6 header on the message.
280 
281         Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(),
282                                         Ip6::kProtoIcmp6);
283 
284         SuccessOrQuit(message->Read(message->GetOffset(), icmp6Header));
285         VerifyOrQuit(icmp6Header.GetChecksum() != 0, "Failed to update checksum");
286 
287         // Verify that the calculated ICMP6 checksum is valid.
288 
289         VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoIcmp6,
290                                        *message) == 0xffff);
291 
292         // Verify that `Checksum::VerifyMessageChecksum()` accepts the
293         // message and its calculated checksum.
294 
295         SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6));
296 
297         // Corrupt the message and verify that checksum is no longer accepted.
298 
299         CorruptMessage(*message);
300 
301         VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6) != kErrorNone,
302                      "Checksum passed on corrupted message");
303 
304         message->Free();
305     }
306 }
307 
TestTcp4MessageChecksum(void)308 void TestTcp4MessageChecksum(void)
309 {
310     constexpr size_t kMinSize = sizeof(Ip4::Tcp::Header);
311     constexpr size_t kMaxSize = kBufferSize * 3 + 24;
312 
313     const char *kSourceAddress = "12.34.56.78";
314     const char *kDestAddress   = "87.65.43.21";
315 
316     Ip4::Address sourceAddress;
317     Ip4::Address destAddress;
318 
319     Instance *instance = static_cast<Instance *>(testInitInstance());
320 
321     VerifyOrQuit(instance != nullptr);
322 
323     SuccessOrQuit(sourceAddress.FromString(kSourceAddress));
324     SuccessOrQuit(destAddress.FromString(kDestAddress));
325 
326     for (uint16_t size = kMinSize; size <= kMaxSize; size++)
327     {
328         Message         *message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip4::Tcp::Header));
329         Ip4::Tcp::Header tcpHeader;
330 
331         VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
332         SuccessOrQuit(message->SetLength(size));
333 
334         // Write TCP header with a random payload.
335 
336         Random::NonCrypto::Fill(tcpHeader);
337         message->Write(0, tcpHeader);
338 
339         if (size > sizeof(tcpHeader))
340         {
341             uint8_t  buffer[kMaxSize];
342             uint16_t payloadSize = size - sizeof(tcpHeader);
343 
344             Random::NonCrypto::FillBuffer(buffer, payloadSize);
345             message->WriteBytes(sizeof(tcpHeader), &buffer[0], payloadSize);
346         }
347 
348         // Verify that the `Checksum::UpdateMessageChecksum` correctly
349         // updates the checksum field in the UDP header on the message.
350 
351         Checksum::UpdateMessageChecksum(*message, sourceAddress, destAddress, Ip4::kProtoTcp);
352 
353         SuccessOrQuit(message->Read(message->GetOffset(), tcpHeader));
354         VerifyOrQuit(tcpHeader.GetChecksum() != 0);
355 
356         // Verify that the calculated UDP checksum is valid.
357 
358         VerifyOrQuit(CalculateChecksum(sourceAddress, destAddress, Ip4::kProtoTcp, *message) == 0xffff);
359         message->Free();
360     }
361 }
362 
TestUdp4MessageChecksum(void)363 void TestUdp4MessageChecksum(void)
364 {
365     constexpr uint16_t kMinSize = sizeof(Ip4::Udp::Header);
366     constexpr uint16_t kMaxSize = kBufferSize * 3 + 24;
367 
368     const char *kSourceAddress = "12.34.56.78";
369     const char *kDestAddress   = "87.65.43.21";
370 
371     Ip4::Address sourceAddress;
372     Ip4::Address destAddress;
373 
374     Instance *instance = static_cast<Instance *>(testInitInstance());
375 
376     SuccessOrQuit(sourceAddress.FromString(kSourceAddress));
377     SuccessOrQuit(destAddress.FromString(kDestAddress));
378 
379     VerifyOrQuit(instance != nullptr);
380 
381     for (uint16_t size = kMinSize; size <= kMaxSize; size++)
382     {
383         Message         *message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip4::Udp::Header));
384         Ip4::Udp::Header udpHeader;
385 
386         VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
387         SuccessOrQuit(message->SetLength(size));
388 
389         // Write UDP header with a random payload.
390 
391         Random::NonCrypto::Fill(udpHeader);
392         udpHeader.SetChecksum(0);
393         message->Write(0, udpHeader);
394 
395         if (size > sizeof(udpHeader))
396         {
397             uint8_t  buffer[kMaxSize];
398             uint16_t payloadSize = size - sizeof(udpHeader);
399 
400             Random::NonCrypto::FillBuffer(buffer, payloadSize);
401             message->WriteBytes(sizeof(udpHeader), &buffer[0], payloadSize);
402         }
403 
404         // Verify that the `Checksum::UpdateMessageChecksum` correctly
405         // updates the checksum field in the UDP header on the message.
406 
407         Checksum::UpdateMessageChecksum(*message, sourceAddress, destAddress, Ip4::kProtoUdp);
408 
409         SuccessOrQuit(message->Read(message->GetOffset(), udpHeader));
410         VerifyOrQuit(udpHeader.GetChecksum() != 0);
411 
412         // Verify that the calculated UDP checksum is valid.
413 
414         VerifyOrQuit(CalculateChecksum(sourceAddress, destAddress, Ip4::kProtoUdp, *message) == 0xffff);
415         message->Free();
416     }
417 }
418 
TestIcmp4MessageChecksum(void)419 void TestIcmp4MessageChecksum(void)
420 {
421     // A captured ICMP echo request (ping) message. Checksum field is set to zero.
422     const uint8_t kExampleIcmpMessage[]      = "\x08\x00\x00\x00\x67\x2e\x00\x00\x62\xaf\xf1\x61\x00\x04\xfc\x24"
423                                                "\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17"
424                                                "\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21\x22\x23\x24\x25\x26\x27"
425                                                "\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31\x32\x33\x34\x35\x36\x37";
426     uint16_t      kChecksumForExampleMessage = 0x5594;
427     Instance     *instance                   = static_cast<Instance *>(testInitInstance());
428     Message      *message                    = instance->Get<Ip6::Ip6>().NewMessage(sizeof(kExampleIcmpMessage));
429 
430     Ip4::Address source;
431     Ip4::Address dest;
432 
433     uint8_t           mPayload[sizeof(kExampleIcmpMessage)];
434     Ip4::Icmp::Header icmpHeader;
435 
436     SuccessOrQuit(message->AppendBytes(kExampleIcmpMessage, sizeof(kExampleIcmpMessage)));
437 
438     // Random IPv4 address, ICMP message checksum does not include a presudo header like TCP and UDP.
439     source.mFields.m32 = 0x12345678;
440     dest.mFields.m32   = 0x87654321;
441 
442     Checksum::UpdateMessageChecksum(*message, source, dest, Ip4::kProtoIcmp);
443 
444     SuccessOrQuit(message->Read(0, icmpHeader));
445     VerifyOrQuit(icmpHeader.GetChecksum() == kChecksumForExampleMessage);
446 
447     SuccessOrQuit(message->Read(message->GetOffset(), mPayload, sizeof(mPayload)));
448     VerifyOrQuit(CalculateChecksum(mPayload, sizeof(mPayload)) == 0xffff);
449 }
450 
451 class ChecksumTester
452 {
453 public:
TestExampleVector(void)454     static void TestExampleVector(void)
455     {
456         // Example from RFC 1071
457         const uint8_t  kTestVector[]       = {0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7};
458         const uint16_t kTestVectorChecksum = 0xddf2;
459 
460         Checksum checksum;
461 
462         VerifyOrQuit(checksum.GetValue() == 0, "Incorrect initial checksum value");
463 
464         checksum.AddData(kTestVector, sizeof(kTestVector));
465         VerifyOrQuit(checksum.GetValue() == kTestVectorChecksum);
466         VerifyOrQuit(checksum.GetValue() == CalculateChecksum(kTestVector, sizeof(kTestVector)), );
467     }
468 };
469 
470 } // namespace ot
471 
main(void)472 int main(void)
473 {
474     ot::ChecksumTester::TestExampleVector();
475     ot::TestUdpMessageChecksum();
476     ot::TestIcmp6MessageChecksum();
477     ot::TestTcp4MessageChecksum();
478     ot::TestUdp4MessageChecksum();
479     ot::TestIcmp4MessageChecksum();
480     printf("All tests passed\n");
481     return 0;
482 }
483