1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 #include "common/encoding.hpp"
30 #include "common/instance.hpp"
31 #include "common/message.hpp"
32 #include "common/random.hpp"
33 #include "net/checksum.hpp"
34 #include "net/icmp6.hpp"
35 #include "net/udp6.hpp"
36 
37 #include "test_platform.h"
38 #include "test_util.hpp"
39 
40 namespace ot {
41 
CalculateChecksum(const void * aBuffer,uint16_t aLength)42 uint16_t CalculateChecksum(const void *aBuffer, uint16_t aLength)
43 {
44     // Calculates checksum over a given buffer data. This implementation
45     // is inspired by the algorithm from RFC-1071.
46 
47     uint32_t       sum   = 0;
48     const uint8_t *bytes = reinterpret_cast<const uint8_t *>(aBuffer);
49 
50     while (aLength >= sizeof(uint16_t))
51     {
52         sum += Encoding::BigEndian::ReadUint16(bytes);
53         bytes += sizeof(uint16_t);
54         aLength -= sizeof(uint16_t);
55     }
56 
57     if (aLength > 0)
58     {
59         sum += (static_cast<uint32_t>(bytes[0]) << 8);
60     }
61 
62     // Fold 32-bit sum to 16 bits.
63 
64     while (sum >> 16)
65     {
66         sum = (sum & 0xffff) + (sum >> 16);
67     }
68 
69     return static_cast<uint16_t>(sum & 0xffff);
70 }
71 
CalculateChecksum(const Ip6::Address & aSource,const Ip6::Address & aDestination,uint8_t aIpProto,const Message & aMessage)72 uint16_t CalculateChecksum(const Ip6::Address &aSource,
73                            const Ip6::Address &aDestination,
74                            uint8_t             aIpProto,
75                            const Message &     aMessage)
76 {
77     // This method calculates the checksum over an IPv6 message.
78 
79     enum : uint16_t
80     {
81         kMaxPayload = 1024,
82     };
83 
84     OT_TOOL_PACKED_BEGIN
85     struct PseudoHeader
86     {
87         Ip6::Address mSource;
88         Ip6::Address mDestination;
89         uint32_t     mPayloadLength;
90         uint32_t     mProtocol;
91     } OT_TOOL_PACKED_END;
92 
93     OT_TOOL_PACKED_BEGIN
94     struct ChecksumData
95     {
96         PseudoHeader mPseudoHeader;
97         uint8_t      mPayload[kMaxPayload];
98     } OT_TOOL_PACKED_END;
99 
100     ChecksumData data;
101     uint16_t     payloadLength;
102 
103     payloadLength = aMessage.GetLength() - aMessage.GetOffset();
104 
105     data.mPseudoHeader.mSource        = aSource;
106     data.mPseudoHeader.mDestination   = aDestination;
107     data.mPseudoHeader.mProtocol      = Encoding::BigEndian::HostSwap32(aIpProto);
108     data.mPseudoHeader.mPayloadLength = Encoding::BigEndian::HostSwap32(payloadLength);
109 
110     SuccessOrQuit(aMessage.Read(aMessage.GetOffset(), data.mPayload, payloadLength));
111 
112     return CalculateChecksum(&data, sizeof(PseudoHeader) + payloadLength);
113 }
114 
CorruptMessage(Message & aMessage)115 void CorruptMessage(Message &aMessage)
116 {
117     // Change a random bit in the message.
118 
119     uint16_t byteOffset;
120     uint8_t  bitOffset;
121     uint8_t  byte;
122 
123     byteOffset = Random::NonCrypto::GetUint16InRange(0, aMessage.GetLength());
124 
125     SuccessOrQuit(aMessage.Read(byteOffset, byte));
126 
127     bitOffset = Random::NonCrypto::GetUint8InRange(0, CHAR_BIT);
128 
129     byte ^= (1 << bitOffset);
130 
131     aMessage.Write(byteOffset, byte);
132 }
133 
TestUdpMessageChecksum(void)134 void TestUdpMessageChecksum(void)
135 {
136     enum : uint16_t
137     {
138         kMinSize = sizeof(Ip6::Udp::Header),
139         kMaxSize = kBufferSize * 3 + 24,
140     };
141 
142     const char *kSourceAddress = "fd00:1122:3344:5566:7788:99aa:bbcc:ddee";
143     const char *kDestAddress   = "fd01:2345:6789:abcd:ef01:2345:6789:abcd";
144 
145     Instance *instance = static_cast<Instance *>(testInitInstance());
146 
147     VerifyOrQuit(instance != nullptr);
148 
149     for (uint16_t size = kMinSize; size <= kMaxSize; size++)
150     {
151         Message *        message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Udp::Header));
152         Ip6::Udp::Header udpHeader;
153         Ip6::MessageInfo messageInfo;
154 
155         VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
156         SuccessOrQuit(message->SetLength(size));
157 
158         // Write UDP header with a random payload.
159 
160         Random::NonCrypto::FillBuffer(reinterpret_cast<uint8_t *>(&udpHeader), sizeof(udpHeader));
161         udpHeader.SetChecksum(0);
162         message->Write(0, udpHeader);
163 
164         if (size > sizeof(udpHeader))
165         {
166             uint8_t  buffer[kMaxSize];
167             uint16_t payloadSize = size - sizeof(udpHeader);
168 
169             Random::NonCrypto::FillBuffer(buffer, payloadSize);
170             message->WriteBytes(sizeof(udpHeader), &buffer[0], payloadSize);
171         }
172 
173         SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
174         SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
175 
176         // Verify that the `Checksum::UpdateMessageChecksum` correctly
177         // updates the checksum field in the UDP header on the message.
178 
179         Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp);
180 
181         SuccessOrQuit(message->Read(message->GetOffset(), udpHeader));
182         VerifyOrQuit(udpHeader.GetChecksum() != 0);
183 
184         // Verify that the calculated UDP checksum is valid.
185 
186         VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp,
187                                        *message) == 0xffff);
188 
189         // Verify that `Checksum::VerifyMessageChecksum()` accepts the
190         // message and its calculated checksum.
191 
192         SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp));
193 
194         // Corrupt the message and verify that checksum is no longer accepted.
195 
196         CorruptMessage(*message);
197 
198         VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp) != kErrorNone,
199                      "Checksum passed on corrupted message");
200 
201         message->Free();
202     }
203 }
204 
TestIcmp6MessageChecksum(void)205 void TestIcmp6MessageChecksum(void)
206 {
207     enum : uint16_t
208     {
209         kMinSize = sizeof(Ip6::Icmp::Header),
210         kMaxSize = kBufferSize * 3 + 24,
211     };
212 
213     const char *kSourceAddress = "fd00:feef:dccd:baab:9889:7667:5444:3223";
214     const char *kDestAddress   = "fd01:abab:beef:cafe:1234:5678:9abc:0";
215 
216     Instance *instance = static_cast<Instance *>(testInitInstance());
217 
218     VerifyOrQuit(instance != nullptr, "Null OpenThread instance\n");
219 
220     for (uint16_t size = kMinSize; size <= kMaxSize; size++)
221     {
222         Message *         message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Icmp::Header));
223         Ip6::Icmp::Header icmp6Header;
224         Ip6::MessageInfo  messageInfo;
225 
226         VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
227         SuccessOrQuit(message->SetLength(size));
228 
229         // Write ICMP6 header with a random payload.
230 
231         Random::NonCrypto::FillBuffer(reinterpret_cast<uint8_t *>(&icmp6Header), sizeof(icmp6Header));
232         icmp6Header.SetChecksum(0);
233         message->Write(0, icmp6Header);
234 
235         if (size > sizeof(icmp6Header))
236         {
237             uint8_t  buffer[kMaxSize];
238             uint16_t payloadSize = size - sizeof(icmp6Header);
239 
240             Random::NonCrypto::FillBuffer(buffer, payloadSize);
241             message->WriteBytes(sizeof(icmp6Header), &buffer[0], payloadSize);
242         }
243 
244         SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
245         SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
246 
247         // Verify that the `Checksum::UpdateMessageChecksum` correctly
248         // updates the checksum field in the ICMP6 header on the message.
249 
250         Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(),
251                                         Ip6::kProtoIcmp6);
252 
253         SuccessOrQuit(message->Read(message->GetOffset(), icmp6Header));
254         VerifyOrQuit(icmp6Header.GetChecksum() != 0, "Failed to update checksum");
255 
256         // Verify that the calculated ICMP6 checksum is valid.
257 
258         VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoIcmp6,
259                                        *message) == 0xffff);
260 
261         // Verify that `Checksum::VerifyMessageChecksum()` accepts the
262         // message and its calculated checksum.
263 
264         SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6));
265 
266         // Corrupt the message and verify that checksum is no longer accepted.
267 
268         CorruptMessage(*message);
269 
270         VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6) != kErrorNone,
271                      "Checksum passed on corrupted message");
272 
273         message->Free();
274     }
275 }
276 
277 class ChecksumTester
278 {
279 public:
TestExampleVector(void)280     static void TestExampleVector(void)
281     {
282         // Example from RFC 1071
283         const uint8_t  kTestVector[]       = {0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7};
284         const uint16_t kTestVectorChecksum = 0xddf2;
285 
286         Checksum checksum;
287 
288         VerifyOrQuit(checksum.GetValue() == 0, "Incorrect initial checksum value");
289 
290         checksum.AddData(kTestVector, sizeof(kTestVector));
291         VerifyOrQuit(checksum.GetValue() == kTestVectorChecksum);
292         VerifyOrQuit(checksum.GetValue() == CalculateChecksum(kTestVector, sizeof(kTestVector)), );
293     }
294 };
295 
296 } // namespace ot
297 
main(void)298 int main(void)
299 {
300     ot::ChecksumTester::TestExampleVector();
301     ot::TestUdpMessageChecksum();
302     ot::TestIcmp6MessageChecksum();
303     printf("All tests passed\n");
304     return 0;
305 }
306