1 // SPDX-License-Identifier: GPL-2.0
2 /*
3 * csum_partial_copy - do IP checksumming and copy
4 *
5 * (C) Copyright 1996 Linus Torvalds
6 * accelerated versions (and 21264 assembly versions ) contributed by
7 * Rick Gorton <rick.gorton@alpha-processor.com>
8 *
9 * Don't look at this too closely - you'll go mad. The things
10 * we do for performance..
11 */
12
13 #include <linux/types.h>
14 #include <linux/string.h>
15 #include <linux/uaccess.h>
16
17
18 #define ldq_u(x,y) \
19 __asm__ __volatile__("ldq_u %0,%1":"=r" (x):"m" (*(const unsigned long *)(y)))
20
21 #define stq_u(x,y) \
22 __asm__ __volatile__("stq_u %1,%0":"=m" (*(unsigned long *)(y)):"r" (x))
23
24 #define extql(x,y,z) \
25 __asm__ __volatile__("extql %1,%2,%0":"=r" (z):"r" (x),"r" (y))
26
27 #define extqh(x,y,z) \
28 __asm__ __volatile__("extqh %1,%2,%0":"=r" (z):"r" (x),"r" (y))
29
30 #define mskql(x,y,z) \
31 __asm__ __volatile__("mskql %1,%2,%0":"=r" (z):"r" (x),"r" (y))
32
33 #define mskqh(x,y,z) \
34 __asm__ __volatile__("mskqh %1,%2,%0":"=r" (z):"r" (x),"r" (y))
35
36 #define insql(x,y,z) \
37 __asm__ __volatile__("insql %1,%2,%0":"=r" (z):"r" (x),"r" (y))
38
39 #define insqh(x,y,z) \
40 __asm__ __volatile__("insqh %1,%2,%0":"=r" (z):"r" (x),"r" (y))
41
42
43 #define __get_user_u(x,ptr) \
44 ({ \
45 long __guu_err; \
46 __asm__ __volatile__( \
47 "1: ldq_u %0,%2\n" \
48 "2:\n" \
49 EXC(1b,2b,%0,%1) \
50 : "=r"(x), "=r"(__guu_err) \
51 : "m"(__m(ptr)), "1"(0)); \
52 __guu_err; \
53 })
54
55 #define __put_user_u(x,ptr) \
56 ({ \
57 long __puu_err; \
58 __asm__ __volatile__( \
59 "1: stq_u %2,%1\n" \
60 "2:\n" \
61 EXC(1b,2b,$31,%0) \
62 : "=r"(__puu_err) \
63 : "m"(__m(addr)), "rJ"(x), "0"(0)); \
64 __puu_err; \
65 })
66
67
from64to16(unsigned long x)68 static inline unsigned short from64to16(unsigned long x)
69 {
70 /* Using extract instructions is a bit more efficient
71 than the original shift/bitmask version. */
72
73 union {
74 unsigned long ul;
75 unsigned int ui[2];
76 unsigned short us[4];
77 } in_v, tmp_v, out_v;
78
79 in_v.ul = x;
80 tmp_v.ul = (unsigned long) in_v.ui[0] + (unsigned long) in_v.ui[1];
81
82 /* Since the bits of tmp_v.sh[3] are going to always be zero,
83 we don't have to bother to add that in. */
84 out_v.ul = (unsigned long) tmp_v.us[0] + (unsigned long) tmp_v.us[1]
85 + (unsigned long) tmp_v.us[2];
86
87 /* Similarly, out_v.us[2] is always zero for the final add. */
88 return out_v.us[0] + out_v.us[1];
89 }
90
91
92
93 /*
94 * Ok. This isn't fun, but this is the EASY case.
95 */
96 static inline unsigned long
csum_partial_cfu_aligned(const unsigned long __user * src,unsigned long * dst,long len,unsigned long checksum,int * errp)97 csum_partial_cfu_aligned(const unsigned long __user *src, unsigned long *dst,
98 long len, unsigned long checksum,
99 int *errp)
100 {
101 unsigned long carry = 0;
102 int err = 0;
103
104 while (len >= 0) {
105 unsigned long word;
106 err |= __get_user(word, src);
107 checksum += carry;
108 src++;
109 checksum += word;
110 len -= 8;
111 carry = checksum < word;
112 *dst = word;
113 dst++;
114 }
115 len += 8;
116 checksum += carry;
117 if (len) {
118 unsigned long word, tmp;
119 err |= __get_user(word, src);
120 tmp = *dst;
121 mskql(word, len, word);
122 checksum += word;
123 mskqh(tmp, len, tmp);
124 carry = checksum < word;
125 *dst = word | tmp;
126 checksum += carry;
127 }
128 if (err && errp) *errp = err;
129 return checksum;
130 }
131
132 /*
133 * This is even less fun, but this is still reasonably
134 * easy.
135 */
136 static inline unsigned long
csum_partial_cfu_dest_aligned(const unsigned long __user * src,unsigned long * dst,unsigned long soff,long len,unsigned long checksum,int * errp)137 csum_partial_cfu_dest_aligned(const unsigned long __user *src,
138 unsigned long *dst,
139 unsigned long soff,
140 long len, unsigned long checksum,
141 int *errp)
142 {
143 unsigned long first;
144 unsigned long word, carry;
145 unsigned long lastsrc = 7+len+(unsigned long)src;
146 int err = 0;
147
148 err |= __get_user_u(first,src);
149 carry = 0;
150 while (len >= 0) {
151 unsigned long second;
152
153 err |= __get_user_u(second, src+1);
154 extql(first, soff, word);
155 len -= 8;
156 src++;
157 extqh(second, soff, first);
158 checksum += carry;
159 word |= first;
160 first = second;
161 checksum += word;
162 *dst = word;
163 dst++;
164 carry = checksum < word;
165 }
166 len += 8;
167 checksum += carry;
168 if (len) {
169 unsigned long tmp;
170 unsigned long second;
171 err |= __get_user_u(second, lastsrc);
172 tmp = *dst;
173 extql(first, soff, word);
174 extqh(second, soff, first);
175 word |= first;
176 mskql(word, len, word);
177 checksum += word;
178 mskqh(tmp, len, tmp);
179 carry = checksum < word;
180 *dst = word | tmp;
181 checksum += carry;
182 }
183 if (err && errp) *errp = err;
184 return checksum;
185 }
186
187 /*
188 * This is slightly less fun than the above..
189 */
190 static inline unsigned long
csum_partial_cfu_src_aligned(const unsigned long __user * src,unsigned long * dst,unsigned long doff,long len,unsigned long checksum,unsigned long partial_dest,int * errp)191 csum_partial_cfu_src_aligned(const unsigned long __user *src,
192 unsigned long *dst,
193 unsigned long doff,
194 long len, unsigned long checksum,
195 unsigned long partial_dest,
196 int *errp)
197 {
198 unsigned long carry = 0;
199 unsigned long word;
200 unsigned long second_dest;
201 int err = 0;
202
203 mskql(partial_dest, doff, partial_dest);
204 while (len >= 0) {
205 err |= __get_user(word, src);
206 len -= 8;
207 insql(word, doff, second_dest);
208 checksum += carry;
209 stq_u(partial_dest | second_dest, dst);
210 src++;
211 checksum += word;
212 insqh(word, doff, partial_dest);
213 carry = checksum < word;
214 dst++;
215 }
216 len += 8;
217 if (len) {
218 checksum += carry;
219 err |= __get_user(word, src);
220 mskql(word, len, word);
221 len -= 8;
222 checksum += word;
223 insql(word, doff, second_dest);
224 len += doff;
225 carry = checksum < word;
226 partial_dest |= second_dest;
227 if (len >= 0) {
228 stq_u(partial_dest, dst);
229 if (!len) goto out;
230 dst++;
231 insqh(word, doff, partial_dest);
232 }
233 doff = len;
234 }
235 ldq_u(second_dest, dst);
236 mskqh(second_dest, doff, second_dest);
237 stq_u(partial_dest | second_dest, dst);
238 out:
239 checksum += carry;
240 if (err && errp) *errp = err;
241 return checksum;
242 }
243
244 /*
245 * This is so totally un-fun that it's frightening. Don't
246 * look at this too closely, you'll go blind.
247 */
248 static inline unsigned long
csum_partial_cfu_unaligned(const unsigned long __user * src,unsigned long * dst,unsigned long soff,unsigned long doff,long len,unsigned long checksum,unsigned long partial_dest,int * errp)249 csum_partial_cfu_unaligned(const unsigned long __user * src,
250 unsigned long * dst,
251 unsigned long soff, unsigned long doff,
252 long len, unsigned long checksum,
253 unsigned long partial_dest,
254 int *errp)
255 {
256 unsigned long carry = 0;
257 unsigned long first;
258 unsigned long lastsrc;
259 int err = 0;
260
261 err |= __get_user_u(first, src);
262 lastsrc = 7+len+(unsigned long)src;
263 mskql(partial_dest, doff, partial_dest);
264 while (len >= 0) {
265 unsigned long second, word;
266 unsigned long second_dest;
267
268 err |= __get_user_u(second, src+1);
269 extql(first, soff, word);
270 checksum += carry;
271 len -= 8;
272 extqh(second, soff, first);
273 src++;
274 word |= first;
275 first = second;
276 insql(word, doff, second_dest);
277 checksum += word;
278 stq_u(partial_dest | second_dest, dst);
279 carry = checksum < word;
280 insqh(word, doff, partial_dest);
281 dst++;
282 }
283 len += doff;
284 checksum += carry;
285 if (len >= 0) {
286 unsigned long second, word;
287 unsigned long second_dest;
288
289 err |= __get_user_u(second, lastsrc);
290 extql(first, soff, word);
291 extqh(second, soff, first);
292 word |= first;
293 first = second;
294 mskql(word, len-doff, word);
295 checksum += word;
296 insql(word, doff, second_dest);
297 carry = checksum < word;
298 stq_u(partial_dest | second_dest, dst);
299 if (len) {
300 ldq_u(second_dest, dst+1);
301 insqh(word, doff, partial_dest);
302 mskqh(second_dest, len, second_dest);
303 stq_u(partial_dest | second_dest, dst+1);
304 }
305 checksum += carry;
306 } else {
307 unsigned long second, word;
308 unsigned long second_dest;
309
310 err |= __get_user_u(second, lastsrc);
311 extql(first, soff, word);
312 extqh(second, soff, first);
313 word |= first;
314 ldq_u(second_dest, dst);
315 mskql(word, len-doff, word);
316 checksum += word;
317 mskqh(second_dest, len, second_dest);
318 carry = checksum < word;
319 insql(word, doff, word);
320 stq_u(partial_dest | word | second_dest, dst);
321 checksum += carry;
322 }
323 if (err && errp) *errp = err;
324 return checksum;
325 }
326
327 __wsum
csum_partial_copy_from_user(const void __user * src,void * dst,int len,__wsum sum,int * errp)328 csum_partial_copy_from_user(const void __user *src, void *dst, int len,
329 __wsum sum, int *errp)
330 {
331 unsigned long checksum = (__force u32) sum;
332 unsigned long soff = 7 & (unsigned long) src;
333 unsigned long doff = 7 & (unsigned long) dst;
334
335 if (len) {
336 if (!access_ok(src, len)) {
337 if (errp) *errp = -EFAULT;
338 memset(dst, 0, len);
339 return sum;
340 }
341 if (!doff) {
342 if (!soff)
343 checksum = csum_partial_cfu_aligned(
344 (const unsigned long __user *) src,
345 (unsigned long *) dst,
346 len-8, checksum, errp);
347 else
348 checksum = csum_partial_cfu_dest_aligned(
349 (const unsigned long __user *) src,
350 (unsigned long *) dst,
351 soff, len-8, checksum, errp);
352 } else {
353 unsigned long partial_dest;
354 ldq_u(partial_dest, dst);
355 if (!soff)
356 checksum = csum_partial_cfu_src_aligned(
357 (const unsigned long __user *) src,
358 (unsigned long *) dst,
359 doff, len-8, checksum,
360 partial_dest, errp);
361 else
362 checksum = csum_partial_cfu_unaligned(
363 (const unsigned long __user *) src,
364 (unsigned long *) dst,
365 soff, doff, len-8, checksum,
366 partial_dest, errp);
367 }
368 checksum = from64to16 (checksum);
369 }
370 return (__force __wsum)checksum;
371 }
372 EXPORT_SYMBOL(csum_partial_copy_from_user);
373
374 __wsum
csum_partial_copy_nocheck(const void * src,void * dst,int len,__wsum sum)375 csum_partial_copy_nocheck(const void *src, void *dst, int len, __wsum sum)
376 {
377 __wsum checksum;
378 mm_segment_t oldfs = get_fs();
379 set_fs(KERNEL_DS);
380 checksum = csum_partial_copy_from_user((__force const void __user *)src,
381 dst, len, sum, NULL);
382 set_fs(oldfs);
383 return checksum;
384 }
385 EXPORT_SYMBOL(csum_partial_copy_nocheck);
386