1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements checksum calculation.
32  */
33 
34 #include "checksum.hpp"
35 
36 #include "common/code_utils.hpp"
37 #include "common/log.hpp"
38 #include "common/message.hpp"
39 #include "net/icmp6.hpp"
40 #include "net/ip4_types.hpp"
41 #include "net/ip6.hpp"
42 #include "net/tcp6.hpp"
43 #include "net/udp6.hpp"
44 
45 namespace ot {
46 
47 RegisterLogModule("Ip6");
48 
AddUint8(uint8_t aUint8)49 void Checksum::AddUint8(uint8_t aUint8)
50 {
51     uint16_t newValue = mValue;
52 
53     // BigEndian encoding: Even index is MSB and odd index is LSB.
54 
55     newValue += mAtOddIndex ? aUint8 : (static_cast<uint16_t>(aUint8) << 8);
56 
57     // Calculate one's complement sum.
58 
59     if (newValue < mValue)
60     {
61         newValue++;
62     }
63 
64     mValue      = newValue;
65     mAtOddIndex = !mAtOddIndex;
66 }
67 
AddUint16(uint16_t aUint16)68 void Checksum::AddUint16(uint16_t aUint16)
69 {
70     // BigEndian encoding
71     AddUint8(static_cast<uint8_t>(aUint16 >> 8));
72     AddUint8(static_cast<uint8_t>(aUint16 & 0xff));
73 }
74 
AddData(const uint8_t * aBuffer,uint16_t aLength)75 void Checksum::AddData(const uint8_t *aBuffer, uint16_t aLength)
76 {
77     for (uint16_t i = 0; i < aLength; i++)
78     {
79         AddUint8(aBuffer[i]);
80     }
81 }
82 
WriteToMessage(uint16_t aOffset,Message & aMessage) const83 void Checksum::WriteToMessage(uint16_t aOffset, Message &aMessage) const
84 {
85     uint16_t checksum = GetValue();
86 
87     if (checksum != 0xffff)
88     {
89         checksum = ~checksum;
90     }
91 
92     checksum = BigEndian::HostSwap16(checksum);
93 
94     aMessage.Write(aOffset, checksum);
95 }
96 
Calculate(const Ip6::Address & aSource,const Ip6::Address & aDestination,uint8_t aIpProto,const Message & aMessage)97 void Checksum::Calculate(const Ip6::Address &aSource,
98                          const Ip6::Address &aDestination,
99                          uint8_t             aIpProto,
100                          const Message      &aMessage)
101 {
102     Message::Chunk chunk;
103     uint16_t       length = aMessage.GetLength() - aMessage.GetOffset();
104 
105     // Pseudo-header for checksum calculation (RFC-2460).
106 
107     AddData(aSource.GetBytes(), sizeof(Ip6::Address));
108     AddData(aDestination.GetBytes(), sizeof(Ip6::Address));
109     AddUint16(length);
110     AddUint16(static_cast<uint16_t>(aIpProto));
111 
112     // Add message content (from offset to the end) to checksum.
113 
114     aMessage.GetFirstChunk(aMessage.GetOffset(), length, chunk);
115 
116     while (chunk.GetLength() > 0)
117     {
118         AddData(chunk.GetBytes(), chunk.GetLength());
119         aMessage.GetNextChunk(length, chunk);
120     }
121 }
122 
Calculate(const Ip4::Address & aSource,const Ip4::Address & aDestination,uint8_t aIpProto,const Message & aMessage)123 void Checksum::Calculate(const Ip4::Address &aSource,
124                          const Ip4::Address &aDestination,
125                          uint8_t             aIpProto,
126                          const Message      &aMessage)
127 {
128     Message::Chunk chunk;
129     uint16_t       length = aMessage.GetLength() - aMessage.GetOffset();
130 
131     // Pseudo-header for checksum calculation (RFC-768/792/793).
132     // Note: ICMP checksum won't count the pseudo header like TCP and UDP.
133     if (aIpProto != Ip4::kProtoIcmp)
134     {
135         AddData(aSource.GetBytes(), sizeof(Ip4::Address));
136         AddData(aDestination.GetBytes(), sizeof(Ip4::Address));
137         AddUint16(static_cast<uint16_t>(aIpProto));
138         AddUint16(length);
139     }
140 
141     // Add message content (from offset to the end) to checksum.
142 
143     aMessage.GetFirstChunk(aMessage.GetOffset(), length, chunk);
144 
145     while (chunk.GetLength() > 0)
146     {
147         AddData(chunk.GetBytes(), chunk.GetLength());
148         aMessage.GetNextChunk(length, chunk);
149     }
150 }
151 
VerifyMessageChecksum(const Message & aMessage,const Ip6::MessageInfo & aMessageInfo,uint8_t aIpProto)152 Error Checksum::VerifyMessageChecksum(const Message &aMessage, const Ip6::MessageInfo &aMessageInfo, uint8_t aIpProto)
153 {
154     Error    error = kErrorNone;
155     Checksum checksum;
156 
157     checksum.Calculate(aMessageInfo.GetPeerAddr(), aMessageInfo.GetSockAddr(), aIpProto, aMessage);
158 
159     if (checksum.GetValue() != kValidRxChecksum)
160     {
161         LogNote("Bad %s checksum", Ip6::Ip6::IpProtoToString(aIpProto));
162         error = kErrorDrop;
163     }
164 
165     return error;
166 }
167 
UpdateMessageChecksum(Message & aMessage,const Ip6::Address & aSource,const Ip6::Address & aDestination,uint8_t aIpProto)168 void Checksum::UpdateMessageChecksum(Message            &aMessage,
169                                      const Ip6::Address &aSource,
170                                      const Ip6::Address &aDestination,
171                                      uint8_t             aIpProto)
172 {
173     uint16_t headerOffset;
174     Checksum checksum;
175 
176     switch (aIpProto)
177     {
178     case Ip6::kProtoTcp:
179         headerOffset = Ip6::Tcp::Header::kChecksumFieldOffset;
180         break;
181 
182     case Ip6::kProtoUdp:
183         headerOffset = Ip6::Udp::Header::kChecksumFieldOffset;
184         break;
185 
186     case Ip6::kProtoIcmp6:
187         headerOffset = Ip6::Icmp::Header::kChecksumFieldOffset;
188         break;
189 
190     default:
191         ExitNow();
192     }
193 
194     // Clear the checksum before calculating it.
195     aMessage.Write<uint16_t>(aMessage.GetOffset() + headerOffset, 0);
196     checksum.Calculate(aSource, aDestination, aIpProto, aMessage);
197     checksum.WriteToMessage(aMessage.GetOffset() + headerOffset, aMessage);
198 
199 exit:
200     return;
201 }
202 
UpdateMessageChecksum(Message & aMessage,const Ip4::Address & aSource,const Ip4::Address & aDestination,uint8_t aIpProto)203 void Checksum::UpdateMessageChecksum(Message            &aMessage,
204                                      const Ip4::Address &aSource,
205                                      const Ip4::Address &aDestination,
206                                      uint8_t             aIpProto)
207 {
208     uint16_t headerOffset;
209     Checksum checksum;
210 
211     switch (aIpProto)
212     {
213     case Ip4::kProtoTcp:
214         headerOffset = Ip4::Tcp::Header::kChecksumFieldOffset;
215         break;
216 
217     case Ip4::kProtoUdp:
218         headerOffset = Ip4::Udp::Header::kChecksumFieldOffset;
219         break;
220 
221     case Ip4::kProtoIcmp:
222         headerOffset = Ip4::Icmp::Header::kChecksumFieldOffset;
223         break;
224 
225     default:
226         ExitNow();
227     }
228 
229     // Clear the checksum before calculating it.
230     aMessage.Write<uint16_t>(aMessage.GetOffset() + headerOffset, 0);
231     checksum.Calculate(aSource, aDestination, aIpProto, aMessage);
232     checksum.WriteToMessage(aMessage.GetOffset() + headerOffset, aMessage);
233 
234 exit:
235     return;
236 }
237 
UpdateIp4HeaderChecksum(Ip4::Header & aHeader)238 void Checksum::UpdateIp4HeaderChecksum(Ip4::Header &aHeader)
239 {
240     Checksum checksum;
241 
242     aHeader.SetChecksum(0);
243     checksum.AddData(reinterpret_cast<const uint8_t *>(&aHeader), sizeof(aHeader));
244     aHeader.SetChecksum(~checksum.GetValue());
245 }
246 
247 } // namespace ot
248