1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3 *
4 * The iopt_pages is the center of the storage and motion of PFNs. Each
5 * iopt_pages represents a logical linear array of full PFNs. The array is 0
6 * based and has npages in it. Accessors use 'index' to refer to the entry in
7 * this logical array, regardless of its storage location.
8 *
9 * PFNs are stored in a tiered scheme:
10 * 1) iopt_pages::pinned_pfns xarray
11 * 2) An iommu_domain
12 * 3) The origin of the PFNs, i.e. the userspace pointer
13 *
14 * PFN have to be copied between all combinations of tiers, depending on the
15 * configuration.
16 *
17 * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18 * The storage locations of the PFN's index are tracked in the two interval
19 * trees. If no interval includes the index then it is not pinned.
20 *
21 * If access_itree includes the PFN's index then an in-kernel access has
22 * requested the page. The PFN is stored in the xarray so other requestors can
23 * continue to find it.
24 *
25 * If the domains_itree includes the PFN's index then an iommu_domain is storing
26 * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27 * duplicating storage the xarray is not used if only iommu_domains are using
28 * the PFN's index.
29 *
30 * As a general principle this is designed so that destroy never fails. This
31 * means removing an iommu_domain or releasing a in-kernel access will not fail
32 * due to insufficient memory. In practice this means some cases have to hold
33 * PFNs in the xarray even though they are also being stored in an iommu_domain.
34 *
35 * While the iopt_pages can use an iommu_domain as storage, it does not have an
36 * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37 * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38 * and reference their own slice of the PFN array, with sub page granularity.
39 *
40 * In this file the term 'last' indicates an inclusive and closed interval, eg
41 * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42 * no PFNs.
43 *
44 * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45 * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46 * ULONG_MAX so last_index + 1 cannot overflow.
47 */
48 #include <linux/overflow.h>
49 #include <linux/slab.h>
50 #include <linux/iommu.h>
51 #include <linux/sched/mm.h>
52 #include <linux/highmem.h>
53 #include <linux/kthread.h>
54 #include <linux/iommufd.h>
55
56 #include "io_pagetable.h"
57 #include "double_span.h"
58
59 #ifndef CONFIG_IOMMUFD_TEST
60 #define TEMP_MEMORY_LIMIT 65536
61 #else
62 #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
63 #endif
64 #define BATCH_BACKUP_SIZE 32
65
66 /*
67 * More memory makes pin_user_pages() and the batching more efficient, but as
68 * this is only a performance optimization don't try too hard to get it. A 64k
69 * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
70 * pfn_batch. Various destroy paths cannot fail and provide a small amount of
71 * stack memory as a backup contingency. If backup_len is given this cannot
72 * fail.
73 */
temp_kmalloc(size_t * size,void * backup,size_t backup_len)74 static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
75 {
76 void *res;
77
78 if (WARN_ON(*size == 0))
79 return NULL;
80
81 if (*size < backup_len)
82 return backup;
83
84 if (!backup && iommufd_should_fail())
85 return NULL;
86
87 *size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
88 res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
89 if (res)
90 return res;
91 *size = PAGE_SIZE;
92 if (backup_len) {
93 res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
94 if (res)
95 return res;
96 *size = backup_len;
97 return backup;
98 }
99 return kmalloc(*size, GFP_KERNEL);
100 }
101
interval_tree_double_span_iter_update(struct interval_tree_double_span_iter * iter)102 void interval_tree_double_span_iter_update(
103 struct interval_tree_double_span_iter *iter)
104 {
105 unsigned long last_hole = ULONG_MAX;
106 unsigned int i;
107
108 for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
109 if (interval_tree_span_iter_done(&iter->spans[i])) {
110 iter->is_used = -1;
111 return;
112 }
113
114 if (iter->spans[i].is_hole) {
115 last_hole = min(last_hole, iter->spans[i].last_hole);
116 continue;
117 }
118
119 iter->is_used = i + 1;
120 iter->start_used = iter->spans[i].start_used;
121 iter->last_used = min(iter->spans[i].last_used, last_hole);
122 return;
123 }
124
125 iter->is_used = 0;
126 iter->start_hole = iter->spans[0].start_hole;
127 iter->last_hole =
128 min(iter->spans[0].last_hole, iter->spans[1].last_hole);
129 }
130
interval_tree_double_span_iter_first(struct interval_tree_double_span_iter * iter,struct rb_root_cached * itree1,struct rb_root_cached * itree2,unsigned long first_index,unsigned long last_index)131 void interval_tree_double_span_iter_first(
132 struct interval_tree_double_span_iter *iter,
133 struct rb_root_cached *itree1, struct rb_root_cached *itree2,
134 unsigned long first_index, unsigned long last_index)
135 {
136 unsigned int i;
137
138 iter->itrees[0] = itree1;
139 iter->itrees[1] = itree2;
140 for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
141 interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
142 first_index, last_index);
143 interval_tree_double_span_iter_update(iter);
144 }
145
interval_tree_double_span_iter_next(struct interval_tree_double_span_iter * iter)146 void interval_tree_double_span_iter_next(
147 struct interval_tree_double_span_iter *iter)
148 {
149 unsigned int i;
150
151 if (iter->is_used == -1 ||
152 iter->last_hole == iter->spans[0].last_index) {
153 iter->is_used = -1;
154 return;
155 }
156
157 for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
158 interval_tree_span_iter_advance(
159 &iter->spans[i], iter->itrees[i], iter->last_hole + 1);
160 interval_tree_double_span_iter_update(iter);
161 }
162
iopt_pages_add_npinned(struct iopt_pages * pages,size_t npages)163 static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
164 {
165 int rc;
166
167 rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
168 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
169 WARN_ON(rc || pages->npinned > pages->npages);
170 }
171
iopt_pages_sub_npinned(struct iopt_pages * pages,size_t npages)172 static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
173 {
174 int rc;
175
176 rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
177 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
178 WARN_ON(rc || pages->npinned > pages->npages);
179 }
180
iopt_pages_err_unpin(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** page_list)181 static void iopt_pages_err_unpin(struct iopt_pages *pages,
182 unsigned long start_index,
183 unsigned long last_index,
184 struct page **page_list)
185 {
186 unsigned long npages = last_index - start_index + 1;
187
188 unpin_user_pages(page_list, npages);
189 iopt_pages_sub_npinned(pages, npages);
190 }
191
192 /*
193 * index is the number of PAGE_SIZE units from the start of the area's
194 * iopt_pages. If the iova is sub page-size then the area has an iova that
195 * covers a portion of the first and last pages in the range.
196 */
iopt_area_index_to_iova(struct iopt_area * area,unsigned long index)197 static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
198 unsigned long index)
199 {
200 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
201 WARN_ON(index < iopt_area_index(area) ||
202 index > iopt_area_last_index(area));
203 index -= iopt_area_index(area);
204 if (index == 0)
205 return iopt_area_iova(area);
206 return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
207 }
208
iopt_area_index_to_iova_last(struct iopt_area * area,unsigned long index)209 static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
210 unsigned long index)
211 {
212 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
213 WARN_ON(index < iopt_area_index(area) ||
214 index > iopt_area_last_index(area));
215 if (index == iopt_area_last_index(area))
216 return iopt_area_last_iova(area);
217 return iopt_area_iova(area) - area->page_offset +
218 (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
219 }
220
iommu_unmap_nofail(struct iommu_domain * domain,unsigned long iova,size_t size)221 static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
222 size_t size)
223 {
224 size_t ret;
225
226 ret = iommu_unmap(domain, iova, size);
227 /*
228 * It is a logic error in this code or a driver bug if the IOMMU unmaps
229 * something other than exactly as requested. This implies that the
230 * iommu driver may not fail unmap for reasons beyond bad agruments.
231 * Particularly, the iommu driver may not do a memory allocation on the
232 * unmap path.
233 */
234 WARN_ON(ret != size);
235 }
236
iopt_area_unmap_domain_range(struct iopt_area * area,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index)237 static void iopt_area_unmap_domain_range(struct iopt_area *area,
238 struct iommu_domain *domain,
239 unsigned long start_index,
240 unsigned long last_index)
241 {
242 unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
243
244 iommu_unmap_nofail(domain, start_iova,
245 iopt_area_index_to_iova_last(area, last_index) -
246 start_iova + 1);
247 }
248
iopt_pages_find_domain_area(struct iopt_pages * pages,unsigned long index)249 static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
250 unsigned long index)
251 {
252 struct interval_tree_node *node;
253
254 node = interval_tree_iter_first(&pages->domains_itree, index, index);
255 if (!node)
256 return NULL;
257 return container_of(node, struct iopt_area, pages_node);
258 }
259
260 /*
261 * A simple datastructure to hold a vector of PFNs, optimized for contiguous
262 * PFNs. This is used as a temporary holding memory for shuttling pfns from one
263 * place to another. Generally everything is made more efficient if operations
264 * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
265 * better cache locality, etc
266 */
267 struct pfn_batch {
268 unsigned long *pfns;
269 u32 *npfns;
270 unsigned int array_size;
271 unsigned int end;
272 unsigned int total_pfns;
273 };
274
batch_clear(struct pfn_batch * batch)275 static void batch_clear(struct pfn_batch *batch)
276 {
277 batch->total_pfns = 0;
278 batch->end = 0;
279 batch->pfns[0] = 0;
280 batch->npfns[0] = 0;
281 }
282
283 /*
284 * Carry means we carry a portion of the final hugepage over to the front of the
285 * batch
286 */
batch_clear_carry(struct pfn_batch * batch,unsigned int keep_pfns)287 static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
288 {
289 if (!keep_pfns)
290 return batch_clear(batch);
291
292 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
293 WARN_ON(!batch->end ||
294 batch->npfns[batch->end - 1] < keep_pfns);
295
296 batch->total_pfns = keep_pfns;
297 batch->pfns[0] = batch->pfns[batch->end - 1] +
298 (batch->npfns[batch->end - 1] - keep_pfns);
299 batch->npfns[0] = keep_pfns;
300 batch->end = 1;
301 }
302
batch_skip_carry(struct pfn_batch * batch,unsigned int skip_pfns)303 static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
304 {
305 if (!batch->total_pfns)
306 return;
307 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
308 WARN_ON(batch->total_pfns != batch->npfns[0]);
309 skip_pfns = min(batch->total_pfns, skip_pfns);
310 batch->pfns[0] += skip_pfns;
311 batch->npfns[0] -= skip_pfns;
312 batch->total_pfns -= skip_pfns;
313 }
314
__batch_init(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)315 static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
316 size_t backup_len)
317 {
318 const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
319 size_t size = max_pages * elmsz;
320
321 batch->pfns = temp_kmalloc(&size, backup, backup_len);
322 if (!batch->pfns)
323 return -ENOMEM;
324 if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
325 return -EINVAL;
326 batch->array_size = size / elmsz;
327 batch->npfns = (u32 *)(batch->pfns + batch->array_size);
328 batch_clear(batch);
329 return 0;
330 }
331
batch_init(struct pfn_batch * batch,size_t max_pages)332 static int batch_init(struct pfn_batch *batch, size_t max_pages)
333 {
334 return __batch_init(batch, max_pages, NULL, 0);
335 }
336
batch_init_backup(struct pfn_batch * batch,size_t max_pages,void * backup,size_t backup_len)337 static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
338 void *backup, size_t backup_len)
339 {
340 __batch_init(batch, max_pages, backup, backup_len);
341 }
342
batch_destroy(struct pfn_batch * batch,void * backup)343 static void batch_destroy(struct pfn_batch *batch, void *backup)
344 {
345 if (batch->pfns != backup)
346 kfree(batch->pfns);
347 }
348
349 /* true if the pfn was added, false otherwise */
batch_add_pfn(struct pfn_batch * batch,unsigned long pfn)350 static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
351 {
352 const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
353
354 if (batch->end &&
355 pfn == batch->pfns[batch->end - 1] + batch->npfns[batch->end - 1] &&
356 batch->npfns[batch->end - 1] != MAX_NPFNS) {
357 batch->npfns[batch->end - 1]++;
358 batch->total_pfns++;
359 return true;
360 }
361 if (batch->end == batch->array_size)
362 return false;
363 batch->total_pfns++;
364 batch->pfns[batch->end] = pfn;
365 batch->npfns[batch->end] = 1;
366 batch->end++;
367 return true;
368 }
369
370 /*
371 * Fill the batch with pfns from the domain. When the batch is full, or it
372 * reaches last_index, the function will return. The caller should use
373 * batch->total_pfns to determine the starting point for the next iteration.
374 */
batch_from_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)375 static void batch_from_domain(struct pfn_batch *batch,
376 struct iommu_domain *domain,
377 struct iopt_area *area, unsigned long start_index,
378 unsigned long last_index)
379 {
380 unsigned int page_offset = 0;
381 unsigned long iova;
382 phys_addr_t phys;
383
384 iova = iopt_area_index_to_iova(area, start_index);
385 if (start_index == iopt_area_index(area))
386 page_offset = area->page_offset;
387 while (start_index <= last_index) {
388 /*
389 * This is pretty slow, it would be nice to get the page size
390 * back from the driver, or have the driver directly fill the
391 * batch.
392 */
393 phys = iommu_iova_to_phys(domain, iova) - page_offset;
394 if (!batch_add_pfn(batch, PHYS_PFN(phys)))
395 return;
396 iova += PAGE_SIZE - page_offset;
397 page_offset = 0;
398 start_index++;
399 }
400 }
401
raw_pages_from_domain(struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages)402 static struct page **raw_pages_from_domain(struct iommu_domain *domain,
403 struct iopt_area *area,
404 unsigned long start_index,
405 unsigned long last_index,
406 struct page **out_pages)
407 {
408 unsigned int page_offset = 0;
409 unsigned long iova;
410 phys_addr_t phys;
411
412 iova = iopt_area_index_to_iova(area, start_index);
413 if (start_index == iopt_area_index(area))
414 page_offset = area->page_offset;
415 while (start_index <= last_index) {
416 phys = iommu_iova_to_phys(domain, iova) - page_offset;
417 *(out_pages++) = pfn_to_page(PHYS_PFN(phys));
418 iova += PAGE_SIZE - page_offset;
419 page_offset = 0;
420 start_index++;
421 }
422 return out_pages;
423 }
424
425 /* Continues reading a domain until we reach a discontinuity in the pfns. */
batch_from_domain_continue(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index,unsigned long last_index)426 static void batch_from_domain_continue(struct pfn_batch *batch,
427 struct iommu_domain *domain,
428 struct iopt_area *area,
429 unsigned long start_index,
430 unsigned long last_index)
431 {
432 unsigned int array_size = batch->array_size;
433
434 batch->array_size = batch->end;
435 batch_from_domain(batch, domain, area, start_index, last_index);
436 batch->array_size = array_size;
437 }
438
439 /*
440 * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
441 * mode permits splitting a mapped area up, and then one of the splits is
442 * unmapped. Doing this normally would cause us to violate our invariant of
443 * pairing map/unmap. Thus, to support old VFIO compatibility disable support
444 * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
445 * PAGE_SIZE units, not larger or smaller.
446 */
batch_iommu_map_small(struct iommu_domain * domain,unsigned long iova,phys_addr_t paddr,size_t size,int prot)447 static int batch_iommu_map_small(struct iommu_domain *domain,
448 unsigned long iova, phys_addr_t paddr,
449 size_t size, int prot)
450 {
451 unsigned long start_iova = iova;
452 int rc;
453
454 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
455 WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
456 size % PAGE_SIZE);
457
458 while (size) {
459 rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
460 GFP_KERNEL_ACCOUNT);
461 if (rc)
462 goto err_unmap;
463 iova += PAGE_SIZE;
464 paddr += PAGE_SIZE;
465 size -= PAGE_SIZE;
466 }
467 return 0;
468
469 err_unmap:
470 if (start_iova != iova)
471 iommu_unmap_nofail(domain, start_iova, iova - start_iova);
472 return rc;
473 }
474
batch_to_domain(struct pfn_batch * batch,struct iommu_domain * domain,struct iopt_area * area,unsigned long start_index)475 static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
476 struct iopt_area *area, unsigned long start_index)
477 {
478 bool disable_large_pages = area->iopt->disable_large_pages;
479 unsigned long last_iova = iopt_area_last_iova(area);
480 unsigned int page_offset = 0;
481 unsigned long start_iova;
482 unsigned long next_iova;
483 unsigned int cur = 0;
484 unsigned long iova;
485 int rc;
486
487 /* The first index might be a partial page */
488 if (start_index == iopt_area_index(area))
489 page_offset = area->page_offset;
490 next_iova = iova = start_iova =
491 iopt_area_index_to_iova(area, start_index);
492 while (cur < batch->end) {
493 next_iova = min(last_iova + 1,
494 next_iova + batch->npfns[cur] * PAGE_SIZE -
495 page_offset);
496 if (disable_large_pages)
497 rc = batch_iommu_map_small(
498 domain, iova,
499 PFN_PHYS(batch->pfns[cur]) + page_offset,
500 next_iova - iova, area->iommu_prot);
501 else
502 rc = iommu_map(domain, iova,
503 PFN_PHYS(batch->pfns[cur]) + page_offset,
504 next_iova - iova, area->iommu_prot,
505 GFP_KERNEL_ACCOUNT);
506 if (rc)
507 goto err_unmap;
508 iova = next_iova;
509 page_offset = 0;
510 cur++;
511 }
512 return 0;
513 err_unmap:
514 if (start_iova != iova)
515 iommu_unmap_nofail(domain, start_iova, iova - start_iova);
516 return rc;
517 }
518
batch_from_xarray(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)519 static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
520 unsigned long start_index,
521 unsigned long last_index)
522 {
523 XA_STATE(xas, xa, start_index);
524 void *entry;
525
526 rcu_read_lock();
527 while (true) {
528 entry = xas_next(&xas);
529 if (xas_retry(&xas, entry))
530 continue;
531 WARN_ON(!xa_is_value(entry));
532 if (!batch_add_pfn(batch, xa_to_value(entry)) ||
533 start_index == last_index)
534 break;
535 start_index++;
536 }
537 rcu_read_unlock();
538 }
539
batch_from_xarray_clear(struct pfn_batch * batch,struct xarray * xa,unsigned long start_index,unsigned long last_index)540 static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
541 unsigned long start_index,
542 unsigned long last_index)
543 {
544 XA_STATE(xas, xa, start_index);
545 void *entry;
546
547 xas_lock(&xas);
548 while (true) {
549 entry = xas_next(&xas);
550 if (xas_retry(&xas, entry))
551 continue;
552 WARN_ON(!xa_is_value(entry));
553 if (!batch_add_pfn(batch, xa_to_value(entry)))
554 break;
555 xas_store(&xas, NULL);
556 if (start_index == last_index)
557 break;
558 start_index++;
559 }
560 xas_unlock(&xas);
561 }
562
clear_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index)563 static void clear_xarray(struct xarray *xa, unsigned long start_index,
564 unsigned long last_index)
565 {
566 XA_STATE(xas, xa, start_index);
567 void *entry;
568
569 xas_lock(&xas);
570 xas_for_each(&xas, entry, last_index)
571 xas_store(&xas, NULL);
572 xas_unlock(&xas);
573 }
574
pages_to_xarray(struct xarray * xa,unsigned long start_index,unsigned long last_index,struct page ** pages)575 static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
576 unsigned long last_index, struct page **pages)
577 {
578 struct page **end_pages = pages + (last_index - start_index) + 1;
579 struct page **half_pages = pages + (end_pages - pages) / 2;
580 XA_STATE(xas, xa, start_index);
581
582 do {
583 void *old;
584
585 xas_lock(&xas);
586 while (pages != end_pages) {
587 /* xarray does not participate in fault injection */
588 if (pages == half_pages && iommufd_should_fail()) {
589 xas_set_err(&xas, -EINVAL);
590 xas_unlock(&xas);
591 /* aka xas_destroy() */
592 xas_nomem(&xas, GFP_KERNEL);
593 goto err_clear;
594 }
595
596 old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
597 if (xas_error(&xas))
598 break;
599 WARN_ON(old);
600 pages++;
601 xas_next(&xas);
602 }
603 xas_unlock(&xas);
604 } while (xas_nomem(&xas, GFP_KERNEL));
605
606 err_clear:
607 if (xas_error(&xas)) {
608 if (xas.xa_index != start_index)
609 clear_xarray(xa, start_index, xas.xa_index - 1);
610 return xas_error(&xas);
611 }
612 return 0;
613 }
614
batch_from_pages(struct pfn_batch * batch,struct page ** pages,size_t npages)615 static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
616 size_t npages)
617 {
618 struct page **end = pages + npages;
619
620 for (; pages != end; pages++)
621 if (!batch_add_pfn(batch, page_to_pfn(*pages)))
622 break;
623 }
624
batch_unpin(struct pfn_batch * batch,struct iopt_pages * pages,unsigned int first_page_off,size_t npages)625 static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
626 unsigned int first_page_off, size_t npages)
627 {
628 unsigned int cur = 0;
629
630 while (first_page_off) {
631 if (batch->npfns[cur] > first_page_off)
632 break;
633 first_page_off -= batch->npfns[cur];
634 cur++;
635 }
636
637 while (npages) {
638 size_t to_unpin = min_t(size_t, npages,
639 batch->npfns[cur] - first_page_off);
640
641 unpin_user_page_range_dirty_lock(
642 pfn_to_page(batch->pfns[cur] + first_page_off),
643 to_unpin, pages->writable);
644 iopt_pages_sub_npinned(pages, to_unpin);
645 cur++;
646 first_page_off = 0;
647 npages -= to_unpin;
648 }
649 }
650
copy_data_page(struct page * page,void * data,unsigned long offset,size_t length,unsigned int flags)651 static void copy_data_page(struct page *page, void *data, unsigned long offset,
652 size_t length, unsigned int flags)
653 {
654 void *mem;
655
656 mem = kmap_local_page(page);
657 if (flags & IOMMUFD_ACCESS_RW_WRITE) {
658 memcpy(mem + offset, data, length);
659 set_page_dirty_lock(page);
660 } else {
661 memcpy(data, mem + offset, length);
662 }
663 kunmap_local(mem);
664 }
665
batch_rw(struct pfn_batch * batch,void * data,unsigned long offset,unsigned long length,unsigned int flags)666 static unsigned long batch_rw(struct pfn_batch *batch, void *data,
667 unsigned long offset, unsigned long length,
668 unsigned int flags)
669 {
670 unsigned long copied = 0;
671 unsigned int npage = 0;
672 unsigned int cur = 0;
673
674 while (cur < batch->end) {
675 unsigned long bytes = min(length, PAGE_SIZE - offset);
676
677 copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
678 offset, bytes, flags);
679 offset = 0;
680 length -= bytes;
681 data += bytes;
682 copied += bytes;
683 npage++;
684 if (npage == batch->npfns[cur]) {
685 npage = 0;
686 cur++;
687 }
688 if (!length)
689 break;
690 }
691 return copied;
692 }
693
694 /* pfn_reader_user is just the pin_user_pages() path */
695 struct pfn_reader_user {
696 struct page **upages;
697 size_t upages_len;
698 unsigned long upages_start;
699 unsigned long upages_end;
700 unsigned int gup_flags;
701 /*
702 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
703 * neither
704 */
705 int locked;
706 };
707
pfn_reader_user_init(struct pfn_reader_user * user,struct iopt_pages * pages)708 static void pfn_reader_user_init(struct pfn_reader_user *user,
709 struct iopt_pages *pages)
710 {
711 user->upages = NULL;
712 user->upages_start = 0;
713 user->upages_end = 0;
714 user->locked = -1;
715
716 user->gup_flags = FOLL_LONGTERM;
717 if (pages->writable)
718 user->gup_flags |= FOLL_WRITE;
719 }
720
pfn_reader_user_destroy(struct pfn_reader_user * user,struct iopt_pages * pages)721 static void pfn_reader_user_destroy(struct pfn_reader_user *user,
722 struct iopt_pages *pages)
723 {
724 if (user->locked != -1) {
725 if (user->locked)
726 mmap_read_unlock(pages->source_mm);
727 if (pages->source_mm != current->mm)
728 mmput(pages->source_mm);
729 user->locked = -1;
730 }
731
732 kfree(user->upages);
733 user->upages = NULL;
734 }
735
pfn_reader_user_pin(struct pfn_reader_user * user,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)736 static int pfn_reader_user_pin(struct pfn_reader_user *user,
737 struct iopt_pages *pages,
738 unsigned long start_index,
739 unsigned long last_index)
740 {
741 bool remote_mm = pages->source_mm != current->mm;
742 unsigned long npages;
743 uintptr_t uptr;
744 long rc;
745
746 if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
747 WARN_ON(last_index < start_index))
748 return -EINVAL;
749
750 if (!user->upages) {
751 /* All undone in pfn_reader_destroy() */
752 user->upages_len =
753 (last_index - start_index + 1) * sizeof(*user->upages);
754 user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
755 if (!user->upages)
756 return -ENOMEM;
757 }
758
759 if (user->locked == -1) {
760 /*
761 * The majority of usages will run the map task within the mm
762 * providing the pages, so we can optimize into
763 * get_user_pages_fast()
764 */
765 if (remote_mm) {
766 if (!mmget_not_zero(pages->source_mm))
767 return -EFAULT;
768 }
769 user->locked = 0;
770 }
771
772 npages = min_t(unsigned long, last_index - start_index + 1,
773 user->upages_len / sizeof(*user->upages));
774
775
776 if (iommufd_should_fail())
777 return -EFAULT;
778
779 uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
780 if (!remote_mm)
781 rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
782 user->upages);
783 else {
784 if (!user->locked) {
785 mmap_read_lock(pages->source_mm);
786 user->locked = 1;
787 }
788 rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
789 user->gup_flags, user->upages,
790 &user->locked);
791 }
792 if (rc <= 0) {
793 if (WARN_ON(!rc))
794 return -EFAULT;
795 return rc;
796 }
797 iopt_pages_add_npinned(pages, rc);
798 user->upages_start = start_index;
799 user->upages_end = start_index + rc;
800 return 0;
801 }
802
803 /* This is the "modern" and faster accounting method used by io_uring */
incr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)804 static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
805 {
806 unsigned long lock_limit;
807 unsigned long cur_pages;
808 unsigned long new_pages;
809
810 lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
811 PAGE_SHIFT;
812 do {
813 cur_pages = atomic_long_read(&pages->source_user->locked_vm);
814 new_pages = cur_pages + npages;
815 if (new_pages > lock_limit)
816 return -ENOMEM;
817 } while (atomic_long_cmpxchg(&pages->source_user->locked_vm, cur_pages,
818 new_pages) != cur_pages);
819 return 0;
820 }
821
decr_user_locked_vm(struct iopt_pages * pages,unsigned long npages)822 static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
823 {
824 if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
825 return;
826 atomic_long_sub(npages, &pages->source_user->locked_vm);
827 }
828
829 /* This is the accounting method used for compatibility with VFIO */
update_mm_locked_vm(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)830 static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
831 bool inc, struct pfn_reader_user *user)
832 {
833 bool do_put = false;
834 int rc;
835
836 if (user && user->locked) {
837 mmap_read_unlock(pages->source_mm);
838 user->locked = 0;
839 /* If we had the lock then we also have a get */
840 } else if ((!user || !user->upages) &&
841 pages->source_mm != current->mm) {
842 if (!mmget_not_zero(pages->source_mm))
843 return -EINVAL;
844 do_put = true;
845 }
846
847 mmap_write_lock(pages->source_mm);
848 rc = __account_locked_vm(pages->source_mm, npages, inc,
849 pages->source_task, false);
850 mmap_write_unlock(pages->source_mm);
851
852 if (do_put)
853 mmput(pages->source_mm);
854 return rc;
855 }
856
do_update_pinned(struct iopt_pages * pages,unsigned long npages,bool inc,struct pfn_reader_user * user)857 static int do_update_pinned(struct iopt_pages *pages, unsigned long npages,
858 bool inc, struct pfn_reader_user *user)
859 {
860 int rc = 0;
861
862 switch (pages->account_mode) {
863 case IOPT_PAGES_ACCOUNT_NONE:
864 break;
865 case IOPT_PAGES_ACCOUNT_USER:
866 if (inc)
867 rc = incr_user_locked_vm(pages, npages);
868 else
869 decr_user_locked_vm(pages, npages);
870 break;
871 case IOPT_PAGES_ACCOUNT_MM:
872 rc = update_mm_locked_vm(pages, npages, inc, user);
873 break;
874 }
875 if (rc)
876 return rc;
877
878 pages->last_npinned = pages->npinned;
879 if (inc)
880 atomic64_add(npages, &pages->source_mm->pinned_vm);
881 else
882 atomic64_sub(npages, &pages->source_mm->pinned_vm);
883 return 0;
884 }
885
update_unpinned(struct iopt_pages * pages)886 static void update_unpinned(struct iopt_pages *pages)
887 {
888 if (WARN_ON(pages->npinned > pages->last_npinned))
889 return;
890 if (pages->npinned == pages->last_npinned)
891 return;
892 do_update_pinned(pages, pages->last_npinned - pages->npinned, false,
893 NULL);
894 }
895
896 /*
897 * Changes in the number of pages pinned is done after the pages have been read
898 * and processed. If the user lacked the limit then the error unwind will unpin
899 * everything that was just pinned. This is because it is expensive to calculate
900 * how many pages we have already pinned within a range to generate an accurate
901 * prediction in advance of doing the work to actually pin them.
902 */
pfn_reader_user_update_pinned(struct pfn_reader_user * user,struct iopt_pages * pages)903 static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
904 struct iopt_pages *pages)
905 {
906 unsigned long npages;
907 bool inc;
908
909 lockdep_assert_held(&pages->mutex);
910
911 if (pages->npinned == pages->last_npinned)
912 return 0;
913
914 if (pages->npinned < pages->last_npinned) {
915 npages = pages->last_npinned - pages->npinned;
916 inc = false;
917 } else {
918 if (iommufd_should_fail())
919 return -ENOMEM;
920 npages = pages->npinned - pages->last_npinned;
921 inc = true;
922 }
923 return do_update_pinned(pages, npages, inc, user);
924 }
925
926 /*
927 * PFNs are stored in three places, in order of preference:
928 * - The iopt_pages xarray. This is only populated if there is a
929 * iopt_pages_access
930 * - The iommu_domain under an area
931 * - The original PFN source, ie pages->source_mm
932 *
933 * This iterator reads the pfns optimizing to load according to the
934 * above order.
935 */
936 struct pfn_reader {
937 struct iopt_pages *pages;
938 struct interval_tree_double_span_iter span;
939 struct pfn_batch batch;
940 unsigned long batch_start_index;
941 unsigned long batch_end_index;
942 unsigned long last_index;
943
944 struct pfn_reader_user user;
945 };
946
pfn_reader_update_pinned(struct pfn_reader * pfns)947 static int pfn_reader_update_pinned(struct pfn_reader *pfns)
948 {
949 return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
950 }
951
952 /*
953 * The batch can contain a mixture of pages that are still in use and pages that
954 * need to be unpinned. Unpin only pages that are not held anywhere else.
955 */
pfn_reader_unpin(struct pfn_reader * pfns)956 static void pfn_reader_unpin(struct pfn_reader *pfns)
957 {
958 unsigned long last = pfns->batch_end_index - 1;
959 unsigned long start = pfns->batch_start_index;
960 struct interval_tree_double_span_iter span;
961 struct iopt_pages *pages = pfns->pages;
962
963 lockdep_assert_held(&pages->mutex);
964
965 interval_tree_for_each_double_span(&span, &pages->access_itree,
966 &pages->domains_itree, start, last) {
967 if (span.is_used)
968 continue;
969
970 batch_unpin(&pfns->batch, pages, span.start_hole - start,
971 span.last_hole - span.start_hole + 1);
972 }
973 }
974
975 /* Process a single span to load it from the proper storage */
pfn_reader_fill_span(struct pfn_reader * pfns)976 static int pfn_reader_fill_span(struct pfn_reader *pfns)
977 {
978 struct interval_tree_double_span_iter *span = &pfns->span;
979 unsigned long start_index = pfns->batch_end_index;
980 struct iopt_area *area;
981 int rc;
982
983 if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
984 WARN_ON(span->last_used < start_index))
985 return -EINVAL;
986
987 if (span->is_used == 1) {
988 batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
989 start_index, span->last_used);
990 return 0;
991 }
992
993 if (span->is_used == 2) {
994 /*
995 * Pull as many pages from the first domain we find in the
996 * target span. If it is too small then we will be called again
997 * and we'll find another area.
998 */
999 area = iopt_pages_find_domain_area(pfns->pages, start_index);
1000 if (WARN_ON(!area))
1001 return -EINVAL;
1002
1003 /* The storage_domain cannot change without the pages mutex */
1004 batch_from_domain(
1005 &pfns->batch, area->storage_domain, area, start_index,
1006 min(iopt_area_last_index(area), span->last_used));
1007 return 0;
1008 }
1009
1010 if (start_index >= pfns->user.upages_end) {
1011 rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
1012 span->last_hole);
1013 if (rc)
1014 return rc;
1015 }
1016
1017 batch_from_pages(&pfns->batch,
1018 pfns->user.upages +
1019 (start_index - pfns->user.upages_start),
1020 pfns->user.upages_end - start_index);
1021 return 0;
1022 }
1023
pfn_reader_done(struct pfn_reader * pfns)1024 static bool pfn_reader_done(struct pfn_reader *pfns)
1025 {
1026 return pfns->batch_start_index == pfns->last_index + 1;
1027 }
1028
pfn_reader_next(struct pfn_reader * pfns)1029 static int pfn_reader_next(struct pfn_reader *pfns)
1030 {
1031 int rc;
1032
1033 batch_clear(&pfns->batch);
1034 pfns->batch_start_index = pfns->batch_end_index;
1035
1036 while (pfns->batch_end_index != pfns->last_index + 1) {
1037 unsigned int npfns = pfns->batch.total_pfns;
1038
1039 if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1040 WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1041 return -EINVAL;
1042
1043 rc = pfn_reader_fill_span(pfns);
1044 if (rc)
1045 return rc;
1046
1047 if (WARN_ON(!pfns->batch.total_pfns))
1048 return -EINVAL;
1049
1050 pfns->batch_end_index =
1051 pfns->batch_start_index + pfns->batch.total_pfns;
1052 if (pfns->batch_end_index == pfns->span.last_used + 1)
1053 interval_tree_double_span_iter_next(&pfns->span);
1054
1055 /* Batch is full */
1056 if (npfns == pfns->batch.total_pfns)
1057 return 0;
1058 }
1059 return 0;
1060 }
1061
pfn_reader_init(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1062 static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1063 unsigned long start_index, unsigned long last_index)
1064 {
1065 int rc;
1066
1067 lockdep_assert_held(&pages->mutex);
1068
1069 pfns->pages = pages;
1070 pfns->batch_start_index = start_index;
1071 pfns->batch_end_index = start_index;
1072 pfns->last_index = last_index;
1073 pfn_reader_user_init(&pfns->user, pages);
1074 rc = batch_init(&pfns->batch, last_index - start_index + 1);
1075 if (rc)
1076 return rc;
1077 interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1078 &pages->domains_itree, start_index,
1079 last_index);
1080 return 0;
1081 }
1082
1083 /*
1084 * There are many assertions regarding the state of pages->npinned vs
1085 * pages->last_pinned, for instance something like unmapping a domain must only
1086 * decrement the npinned, and pfn_reader_destroy() must be called only after all
1087 * the pins are updated. This is fine for success flows, but error flows
1088 * sometimes need to release the pins held inside the pfn_reader before going on
1089 * to complete unmapping and releasing pins held in domains.
1090 */
pfn_reader_release_pins(struct pfn_reader * pfns)1091 static void pfn_reader_release_pins(struct pfn_reader *pfns)
1092 {
1093 struct iopt_pages *pages = pfns->pages;
1094
1095 if (pfns->user.upages_end > pfns->batch_end_index) {
1096 size_t npages = pfns->user.upages_end - pfns->batch_end_index;
1097
1098 /* Any pages not transferred to the batch are just unpinned */
1099 unpin_user_pages(pfns->user.upages + (pfns->batch_end_index -
1100 pfns->user.upages_start),
1101 npages);
1102 iopt_pages_sub_npinned(pages, npages);
1103 pfns->user.upages_end = pfns->batch_end_index;
1104 }
1105 if (pfns->batch_start_index != pfns->batch_end_index) {
1106 pfn_reader_unpin(pfns);
1107 pfns->batch_start_index = pfns->batch_end_index;
1108 }
1109 }
1110
pfn_reader_destroy(struct pfn_reader * pfns)1111 static void pfn_reader_destroy(struct pfn_reader *pfns)
1112 {
1113 struct iopt_pages *pages = pfns->pages;
1114
1115 pfn_reader_release_pins(pfns);
1116 pfn_reader_user_destroy(&pfns->user, pfns->pages);
1117 batch_destroy(&pfns->batch, NULL);
1118 WARN_ON(pages->last_npinned != pages->npinned);
1119 }
1120
pfn_reader_first(struct pfn_reader * pfns,struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1121 static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1122 unsigned long start_index, unsigned long last_index)
1123 {
1124 int rc;
1125
1126 if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1127 WARN_ON(last_index < start_index))
1128 return -EINVAL;
1129
1130 rc = pfn_reader_init(pfns, pages, start_index, last_index);
1131 if (rc)
1132 return rc;
1133 rc = pfn_reader_next(pfns);
1134 if (rc) {
1135 pfn_reader_destroy(pfns);
1136 return rc;
1137 }
1138 return 0;
1139 }
1140
iopt_alloc_pages(void __user * uptr,unsigned long length,bool writable)1141 struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
1142 bool writable)
1143 {
1144 struct iopt_pages *pages;
1145 unsigned long end;
1146
1147 /*
1148 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1149 * below from overflow
1150 */
1151 if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1152 return ERR_PTR(-EINVAL);
1153
1154 if (check_add_overflow((unsigned long)uptr, length, &end))
1155 return ERR_PTR(-EOVERFLOW);
1156
1157 pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1158 if (!pages)
1159 return ERR_PTR(-ENOMEM);
1160
1161 kref_init(&pages->kref);
1162 xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1163 mutex_init(&pages->mutex);
1164 pages->source_mm = current->mm;
1165 mmgrab(pages->source_mm);
1166 pages->uptr = (void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1167 pages->npages = DIV_ROUND_UP(length + (uptr - pages->uptr), PAGE_SIZE);
1168 pages->access_itree = RB_ROOT_CACHED;
1169 pages->domains_itree = RB_ROOT_CACHED;
1170 pages->writable = writable;
1171 if (capable(CAP_IPC_LOCK))
1172 pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1173 else
1174 pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1175 pages->source_task = current->group_leader;
1176 get_task_struct(current->group_leader);
1177 pages->source_user = get_uid(current_user());
1178 return pages;
1179 }
1180
iopt_release_pages(struct kref * kref)1181 void iopt_release_pages(struct kref *kref)
1182 {
1183 struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1184
1185 WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1186 WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1187 WARN_ON(pages->npinned);
1188 WARN_ON(!xa_empty(&pages->pinned_pfns));
1189 mmdrop(pages->source_mm);
1190 mutex_destroy(&pages->mutex);
1191 put_task_struct(pages->source_task);
1192 free_uid(pages->source_user);
1193 kfree(pages);
1194 }
1195
1196 static void
iopt_area_unpin_domain(struct pfn_batch * batch,struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long start_index,unsigned long last_index,unsigned long * unmapped_end_index,unsigned long real_last_index)1197 iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1198 struct iopt_pages *pages, struct iommu_domain *domain,
1199 unsigned long start_index, unsigned long last_index,
1200 unsigned long *unmapped_end_index,
1201 unsigned long real_last_index)
1202 {
1203 while (start_index <= last_index) {
1204 unsigned long batch_last_index;
1205
1206 if (*unmapped_end_index <= last_index) {
1207 unsigned long start =
1208 max(start_index, *unmapped_end_index);
1209
1210 if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1211 batch->total_pfns)
1212 WARN_ON(*unmapped_end_index -
1213 batch->total_pfns !=
1214 start_index);
1215 batch_from_domain(batch, domain, area, start,
1216 last_index);
1217 batch_last_index = start_index + batch->total_pfns - 1;
1218 } else {
1219 batch_last_index = last_index;
1220 }
1221
1222 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1223 WARN_ON(batch_last_index > real_last_index);
1224
1225 /*
1226 * unmaps must always 'cut' at a place where the pfns are not
1227 * contiguous to pair with the maps that always install
1228 * contiguous pages. Thus, if we have to stop unpinning in the
1229 * middle of the domains we need to keep reading pfns until we
1230 * find a cut point to do the unmap. The pfns we read are
1231 * carried over and either skipped or integrated into the next
1232 * batch.
1233 */
1234 if (batch_last_index == last_index &&
1235 last_index != real_last_index)
1236 batch_from_domain_continue(batch, domain, area,
1237 last_index + 1,
1238 real_last_index);
1239
1240 if (*unmapped_end_index <= batch_last_index) {
1241 iopt_area_unmap_domain_range(
1242 area, domain, *unmapped_end_index,
1243 start_index + batch->total_pfns - 1);
1244 *unmapped_end_index = start_index + batch->total_pfns;
1245 }
1246
1247 /* unpin must follow unmap */
1248 batch_unpin(batch, pages, 0,
1249 batch_last_index - start_index + 1);
1250 start_index = batch_last_index + 1;
1251
1252 batch_clear_carry(batch,
1253 *unmapped_end_index - batch_last_index - 1);
1254 }
1255 }
1256
__iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long last_index)1257 static void __iopt_area_unfill_domain(struct iopt_area *area,
1258 struct iopt_pages *pages,
1259 struct iommu_domain *domain,
1260 unsigned long last_index)
1261 {
1262 struct interval_tree_double_span_iter span;
1263 unsigned long start_index = iopt_area_index(area);
1264 unsigned long unmapped_end_index = start_index;
1265 u64 backup[BATCH_BACKUP_SIZE];
1266 struct pfn_batch batch;
1267
1268 lockdep_assert_held(&pages->mutex);
1269
1270 /*
1271 * For security we must not unpin something that is still DMA mapped,
1272 * so this must unmap any IOVA before we go ahead and unpin the pages.
1273 * This creates a complexity where we need to skip over unpinning pages
1274 * held in the xarray, but continue to unmap from the domain.
1275 *
1276 * The domain unmap cannot stop in the middle of a contiguous range of
1277 * PFNs. To solve this problem the unpinning step will read ahead to the
1278 * end of any contiguous span, unmap that whole span, and then only
1279 * unpin the leading part that does not have any accesses. The residual
1280 * PFNs that were unmapped but not unpinned are called a "carry" in the
1281 * batch as they are moved to the front of the PFN list and continue on
1282 * to the next iteration(s).
1283 */
1284 batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1285 interval_tree_for_each_double_span(&span, &pages->domains_itree,
1286 &pages->access_itree, start_index,
1287 last_index) {
1288 if (span.is_used) {
1289 batch_skip_carry(&batch,
1290 span.last_used - span.start_used + 1);
1291 continue;
1292 }
1293 iopt_area_unpin_domain(&batch, area, pages, domain,
1294 span.start_hole, span.last_hole,
1295 &unmapped_end_index, last_index);
1296 }
1297 /*
1298 * If the range ends in a access then we do the residual unmap without
1299 * any unpins.
1300 */
1301 if (unmapped_end_index != last_index + 1)
1302 iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1303 last_index);
1304 WARN_ON(batch.total_pfns);
1305 batch_destroy(&batch, backup);
1306 update_unpinned(pages);
1307 }
1308
iopt_area_unfill_partial_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain,unsigned long end_index)1309 static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1310 struct iopt_pages *pages,
1311 struct iommu_domain *domain,
1312 unsigned long end_index)
1313 {
1314 if (end_index != iopt_area_index(area))
1315 __iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1316 }
1317
1318 /**
1319 * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1320 * @area: The IOVA range to unmap
1321 * @domain: The domain to unmap
1322 *
1323 * The caller must know that unpinning is not required, usually because there
1324 * are other domains in the iopt.
1325 */
iopt_area_unmap_domain(struct iopt_area * area,struct iommu_domain * domain)1326 void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1327 {
1328 iommu_unmap_nofail(domain, iopt_area_iova(area),
1329 iopt_area_length(area));
1330 }
1331
1332 /**
1333 * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1334 * @area: IOVA area to use
1335 * @pages: page supplier for the area (area->pages is NULL)
1336 * @domain: Domain to unmap from
1337 *
1338 * The domain should be removed from the domains_itree before calling. The
1339 * domain will always be unmapped, but the PFNs may not be unpinned if there are
1340 * still accesses.
1341 */
iopt_area_unfill_domain(struct iopt_area * area,struct iopt_pages * pages,struct iommu_domain * domain)1342 void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1343 struct iommu_domain *domain)
1344 {
1345 __iopt_area_unfill_domain(area, pages, domain,
1346 iopt_area_last_index(area));
1347 }
1348
1349 /**
1350 * iopt_area_fill_domain() - Map PFNs from the area into a domain
1351 * @area: IOVA area to use
1352 * @domain: Domain to load PFNs into
1353 *
1354 * Read the pfns from the area's underlying iopt_pages and map them into the
1355 * given domain. Called when attaching a new domain to an io_pagetable.
1356 */
iopt_area_fill_domain(struct iopt_area * area,struct iommu_domain * domain)1357 int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1358 {
1359 unsigned long done_end_index;
1360 struct pfn_reader pfns;
1361 int rc;
1362
1363 lockdep_assert_held(&area->pages->mutex);
1364
1365 rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1366 iopt_area_last_index(area));
1367 if (rc)
1368 return rc;
1369
1370 while (!pfn_reader_done(&pfns)) {
1371 done_end_index = pfns.batch_start_index;
1372 rc = batch_to_domain(&pfns.batch, domain, area,
1373 pfns.batch_start_index);
1374 if (rc)
1375 goto out_unmap;
1376 done_end_index = pfns.batch_end_index;
1377
1378 rc = pfn_reader_next(&pfns);
1379 if (rc)
1380 goto out_unmap;
1381 }
1382
1383 rc = pfn_reader_update_pinned(&pfns);
1384 if (rc)
1385 goto out_unmap;
1386 goto out_destroy;
1387
1388 out_unmap:
1389 pfn_reader_release_pins(&pfns);
1390 iopt_area_unfill_partial_domain(area, area->pages, domain,
1391 done_end_index);
1392 out_destroy:
1393 pfn_reader_destroy(&pfns);
1394 return rc;
1395 }
1396
1397 /**
1398 * iopt_area_fill_domains() - Install PFNs into the area's domains
1399 * @area: The area to act on
1400 * @pages: The pages associated with the area (area->pages is NULL)
1401 *
1402 * Called during area creation. The area is freshly created and not inserted in
1403 * the domains_itree yet. PFNs are read and loaded into every domain held in the
1404 * area's io_pagetable and the area is installed in the domains_itree.
1405 *
1406 * On failure all domains are left unchanged.
1407 */
iopt_area_fill_domains(struct iopt_area * area,struct iopt_pages * pages)1408 int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1409 {
1410 unsigned long done_first_end_index;
1411 unsigned long done_all_end_index;
1412 struct iommu_domain *domain;
1413 unsigned long unmap_index;
1414 struct pfn_reader pfns;
1415 unsigned long index;
1416 int rc;
1417
1418 lockdep_assert_held(&area->iopt->domains_rwsem);
1419
1420 if (xa_empty(&area->iopt->domains))
1421 return 0;
1422
1423 mutex_lock(&pages->mutex);
1424 rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1425 iopt_area_last_index(area));
1426 if (rc)
1427 goto out_unlock;
1428
1429 while (!pfn_reader_done(&pfns)) {
1430 done_first_end_index = pfns.batch_end_index;
1431 done_all_end_index = pfns.batch_start_index;
1432 xa_for_each(&area->iopt->domains, index, domain) {
1433 rc = batch_to_domain(&pfns.batch, domain, area,
1434 pfns.batch_start_index);
1435 if (rc)
1436 goto out_unmap;
1437 }
1438 done_all_end_index = done_first_end_index;
1439
1440 rc = pfn_reader_next(&pfns);
1441 if (rc)
1442 goto out_unmap;
1443 }
1444 rc = pfn_reader_update_pinned(&pfns);
1445 if (rc)
1446 goto out_unmap;
1447
1448 area->storage_domain = xa_load(&area->iopt->domains, 0);
1449 interval_tree_insert(&area->pages_node, &pages->domains_itree);
1450 goto out_destroy;
1451
1452 out_unmap:
1453 pfn_reader_release_pins(&pfns);
1454 xa_for_each(&area->iopt->domains, unmap_index, domain) {
1455 unsigned long end_index;
1456
1457 if (unmap_index < index)
1458 end_index = done_first_end_index;
1459 else
1460 end_index = done_all_end_index;
1461
1462 /*
1463 * The area is not yet part of the domains_itree so we have to
1464 * manage the unpinning specially. The last domain does the
1465 * unpin, every other domain is just unmapped.
1466 */
1467 if (unmap_index != area->iopt->next_domain_id - 1) {
1468 if (end_index != iopt_area_index(area))
1469 iopt_area_unmap_domain_range(
1470 area, domain, iopt_area_index(area),
1471 end_index - 1);
1472 } else {
1473 iopt_area_unfill_partial_domain(area, pages, domain,
1474 end_index);
1475 }
1476 }
1477 out_destroy:
1478 pfn_reader_destroy(&pfns);
1479 out_unlock:
1480 mutex_unlock(&pages->mutex);
1481 return rc;
1482 }
1483
1484 /**
1485 * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1486 * @area: The area to act on
1487 * @pages: The pages associated with the area (area->pages is NULL)
1488 *
1489 * Called during area destruction. This unmaps the iova's covered by all the
1490 * area's domains and releases the PFNs.
1491 */
iopt_area_unfill_domains(struct iopt_area * area,struct iopt_pages * pages)1492 void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1493 {
1494 struct io_pagetable *iopt = area->iopt;
1495 struct iommu_domain *domain;
1496 unsigned long index;
1497
1498 lockdep_assert_held(&iopt->domains_rwsem);
1499
1500 mutex_lock(&pages->mutex);
1501 if (!area->storage_domain)
1502 goto out_unlock;
1503
1504 xa_for_each(&iopt->domains, index, domain)
1505 if (domain != area->storage_domain)
1506 iopt_area_unmap_domain_range(
1507 area, domain, iopt_area_index(area),
1508 iopt_area_last_index(area));
1509
1510 interval_tree_remove(&area->pages_node, &pages->domains_itree);
1511 iopt_area_unfill_domain(area, pages, area->storage_domain);
1512 area->storage_domain = NULL;
1513 out_unlock:
1514 mutex_unlock(&pages->mutex);
1515 }
1516
iopt_pages_unpin_xarray(struct pfn_batch * batch,struct iopt_pages * pages,unsigned long start_index,unsigned long end_index)1517 static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1518 struct iopt_pages *pages,
1519 unsigned long start_index,
1520 unsigned long end_index)
1521 {
1522 while (start_index <= end_index) {
1523 batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1524 end_index);
1525 batch_unpin(batch, pages, 0, batch->total_pfns);
1526 start_index += batch->total_pfns;
1527 batch_clear(batch);
1528 }
1529 }
1530
1531 /**
1532 * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1533 * @pages: The pages to act on
1534 * @start_index: Starting PFN index
1535 * @last_index: Last PFN index
1536 *
1537 * Called when an iopt_pages_access is removed, removes pages from the itree.
1538 * The access should already be removed from the access_itree.
1539 */
iopt_pages_unfill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index)1540 void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1541 unsigned long start_index,
1542 unsigned long last_index)
1543 {
1544 struct interval_tree_double_span_iter span;
1545 u64 backup[BATCH_BACKUP_SIZE];
1546 struct pfn_batch batch;
1547 bool batch_inited = false;
1548
1549 lockdep_assert_held(&pages->mutex);
1550
1551 interval_tree_for_each_double_span(&span, &pages->access_itree,
1552 &pages->domains_itree, start_index,
1553 last_index) {
1554 if (!span.is_used) {
1555 if (!batch_inited) {
1556 batch_init_backup(&batch,
1557 last_index - start_index + 1,
1558 backup, sizeof(backup));
1559 batch_inited = true;
1560 }
1561 iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1562 span.last_hole);
1563 } else if (span.is_used == 2) {
1564 /* Covered by a domain */
1565 clear_xarray(&pages->pinned_pfns, span.start_used,
1566 span.last_used);
1567 }
1568 /* Otherwise covered by an existing access */
1569 }
1570 if (batch_inited)
1571 batch_destroy(&batch, backup);
1572 update_unpinned(pages);
1573 }
1574
1575 /**
1576 * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1577 * @pages: The pages to act on
1578 * @start_index: The first page index in the range
1579 * @last_index: The last page index in the range
1580 * @out_pages: The output array to return the pages
1581 *
1582 * This can be called if the caller is holding a refcount on an
1583 * iopt_pages_access that is known to have already been filled. It quickly reads
1584 * the pages directly from the xarray.
1585 *
1586 * This is part of the SW iommu interface to read pages for in-kernel use.
1587 */
iopt_pages_fill_from_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1588 void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1589 unsigned long start_index,
1590 unsigned long last_index,
1591 struct page **out_pages)
1592 {
1593 XA_STATE(xas, &pages->pinned_pfns, start_index);
1594 void *entry;
1595
1596 rcu_read_lock();
1597 while (start_index <= last_index) {
1598 entry = xas_next(&xas);
1599 if (xas_retry(&xas, entry))
1600 continue;
1601 WARN_ON(!xa_is_value(entry));
1602 *(out_pages++) = pfn_to_page(xa_to_value(entry));
1603 start_index++;
1604 }
1605 rcu_read_unlock();
1606 }
1607
iopt_pages_fill_from_domain(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1608 static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1609 unsigned long start_index,
1610 unsigned long last_index,
1611 struct page **out_pages)
1612 {
1613 while (start_index != last_index + 1) {
1614 unsigned long domain_last;
1615 struct iopt_area *area;
1616
1617 area = iopt_pages_find_domain_area(pages, start_index);
1618 if (WARN_ON(!area))
1619 return -EINVAL;
1620
1621 domain_last = min(iopt_area_last_index(area), last_index);
1622 out_pages = raw_pages_from_domain(area->storage_domain, area,
1623 start_index, domain_last,
1624 out_pages);
1625 start_index = domain_last + 1;
1626 }
1627 return 0;
1628 }
1629
iopt_pages_fill_from_mm(struct iopt_pages * pages,struct pfn_reader_user * user,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1630 static int iopt_pages_fill_from_mm(struct iopt_pages *pages,
1631 struct pfn_reader_user *user,
1632 unsigned long start_index,
1633 unsigned long last_index,
1634 struct page **out_pages)
1635 {
1636 unsigned long cur_index = start_index;
1637 int rc;
1638
1639 while (cur_index != last_index + 1) {
1640 user->upages = out_pages + (cur_index - start_index);
1641 rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1642 if (rc)
1643 goto out_unpin;
1644 cur_index = user->upages_end;
1645 }
1646 return 0;
1647
1648 out_unpin:
1649 if (start_index != cur_index)
1650 iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1651 out_pages);
1652 return rc;
1653 }
1654
1655 /**
1656 * iopt_pages_fill_xarray() - Read PFNs
1657 * @pages: The pages to act on
1658 * @start_index: The first page index in the range
1659 * @last_index: The last page index in the range
1660 * @out_pages: The output array to return the pages, may be NULL
1661 *
1662 * This populates the xarray and returns the pages in out_pages. As the slow
1663 * path this is able to copy pages from other storage tiers into the xarray.
1664 *
1665 * On failure the xarray is left unchanged.
1666 *
1667 * This is part of the SW iommu interface to read pages for in-kernel use.
1668 */
iopt_pages_fill_xarray(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,struct page ** out_pages)1669 int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1670 unsigned long last_index, struct page **out_pages)
1671 {
1672 struct interval_tree_double_span_iter span;
1673 unsigned long xa_end = start_index;
1674 struct pfn_reader_user user;
1675 int rc;
1676
1677 lockdep_assert_held(&pages->mutex);
1678
1679 pfn_reader_user_init(&user, pages);
1680 user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1681 interval_tree_for_each_double_span(&span, &pages->access_itree,
1682 &pages->domains_itree, start_index,
1683 last_index) {
1684 struct page **cur_pages;
1685
1686 if (span.is_used == 1) {
1687 cur_pages = out_pages + (span.start_used - start_index);
1688 iopt_pages_fill_from_xarray(pages, span.start_used,
1689 span.last_used, cur_pages);
1690 continue;
1691 }
1692
1693 if (span.is_used == 2) {
1694 cur_pages = out_pages + (span.start_used - start_index);
1695 iopt_pages_fill_from_domain(pages, span.start_used,
1696 span.last_used, cur_pages);
1697 rc = pages_to_xarray(&pages->pinned_pfns,
1698 span.start_used, span.last_used,
1699 cur_pages);
1700 if (rc)
1701 goto out_clean_xa;
1702 xa_end = span.last_used + 1;
1703 continue;
1704 }
1705
1706 /* hole */
1707 cur_pages = out_pages + (span.start_hole - start_index);
1708 rc = iopt_pages_fill_from_mm(pages, &user, span.start_hole,
1709 span.last_hole, cur_pages);
1710 if (rc)
1711 goto out_clean_xa;
1712 rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1713 span.last_hole, cur_pages);
1714 if (rc) {
1715 iopt_pages_err_unpin(pages, span.start_hole,
1716 span.last_hole, cur_pages);
1717 goto out_clean_xa;
1718 }
1719 xa_end = span.last_hole + 1;
1720 }
1721 rc = pfn_reader_user_update_pinned(&user, pages);
1722 if (rc)
1723 goto out_clean_xa;
1724 user.upages = NULL;
1725 pfn_reader_user_destroy(&user, pages);
1726 return 0;
1727
1728 out_clean_xa:
1729 if (start_index != xa_end)
1730 iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1731 user.upages = NULL;
1732 pfn_reader_user_destroy(&user, pages);
1733 return rc;
1734 }
1735
1736 /*
1737 * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1738 * do every scenario and is fully consistent with what an iommu_domain would
1739 * see.
1740 */
iopt_pages_rw_slow(struct iopt_pages * pages,unsigned long start_index,unsigned long last_index,unsigned long offset,void * data,unsigned long length,unsigned int flags)1741 static int iopt_pages_rw_slow(struct iopt_pages *pages,
1742 unsigned long start_index,
1743 unsigned long last_index, unsigned long offset,
1744 void *data, unsigned long length,
1745 unsigned int flags)
1746 {
1747 struct pfn_reader pfns;
1748 int rc;
1749
1750 mutex_lock(&pages->mutex);
1751
1752 rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1753 if (rc)
1754 goto out_unlock;
1755
1756 while (!pfn_reader_done(&pfns)) {
1757 unsigned long done;
1758
1759 done = batch_rw(&pfns.batch, data, offset, length, flags);
1760 data += done;
1761 length -= done;
1762 offset = 0;
1763 pfn_reader_unpin(&pfns);
1764
1765 rc = pfn_reader_next(&pfns);
1766 if (rc)
1767 goto out_destroy;
1768 }
1769 if (WARN_ON(length != 0))
1770 rc = -EINVAL;
1771 out_destroy:
1772 pfn_reader_destroy(&pfns);
1773 out_unlock:
1774 mutex_unlock(&pages->mutex);
1775 return rc;
1776 }
1777
1778 /*
1779 * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1780 * memory allocations or interval tree searches.
1781 */
iopt_pages_rw_page(struct iopt_pages * pages,unsigned long index,unsigned long offset,void * data,unsigned long length,unsigned int flags)1782 static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1783 unsigned long offset, void *data,
1784 unsigned long length, unsigned int flags)
1785 {
1786 struct page *page = NULL;
1787 int rc;
1788
1789 if (!mmget_not_zero(pages->source_mm))
1790 return iopt_pages_rw_slow(pages, index, index, offset, data,
1791 length, flags);
1792
1793 if (iommufd_should_fail()) {
1794 rc = -EINVAL;
1795 goto out_mmput;
1796 }
1797
1798 mmap_read_lock(pages->source_mm);
1799 rc = pin_user_pages_remote(
1800 pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1801 1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1802 NULL);
1803 mmap_read_unlock(pages->source_mm);
1804 if (rc != 1) {
1805 if (WARN_ON(rc >= 0))
1806 rc = -EINVAL;
1807 goto out_mmput;
1808 }
1809 copy_data_page(page, data, offset, length, flags);
1810 unpin_user_page(page);
1811 rc = 0;
1812
1813 out_mmput:
1814 mmput(pages->source_mm);
1815 return rc;
1816 }
1817
1818 /**
1819 * iopt_pages_rw_access - Copy to/from a linear slice of the pages
1820 * @pages: pages to act on
1821 * @start_byte: First byte of pages to copy to/from
1822 * @data: Kernel buffer to get/put the data
1823 * @length: Number of bytes to copy
1824 * @flags: IOMMUFD_ACCESS_RW_* flags
1825 *
1826 * This will find each page in the range, kmap it and then memcpy to/from
1827 * the given kernel buffer.
1828 */
iopt_pages_rw_access(struct iopt_pages * pages,unsigned long start_byte,void * data,unsigned long length,unsigned int flags)1829 int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
1830 void *data, unsigned long length, unsigned int flags)
1831 {
1832 unsigned long start_index = start_byte / PAGE_SIZE;
1833 unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
1834 bool change_mm = current->mm != pages->source_mm;
1835 int rc = 0;
1836
1837 if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1838 (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
1839 change_mm = true;
1840
1841 if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1842 return -EPERM;
1843
1844 if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
1845 if (start_index == last_index)
1846 return iopt_pages_rw_page(pages, start_index,
1847 start_byte % PAGE_SIZE, data,
1848 length, flags);
1849 return iopt_pages_rw_slow(pages, start_index, last_index,
1850 start_byte % PAGE_SIZE, data, length,
1851 flags);
1852 }
1853
1854 /*
1855 * Try to copy using copy_to_user(). We do this as a fast path and
1856 * ignore any pinning inconsistencies, unlike a real DMA path.
1857 */
1858 if (change_mm) {
1859 if (!mmget_not_zero(pages->source_mm))
1860 return iopt_pages_rw_slow(pages, start_index,
1861 last_index,
1862 start_byte % PAGE_SIZE, data,
1863 length, flags);
1864 kthread_use_mm(pages->source_mm);
1865 }
1866
1867 if (flags & IOMMUFD_ACCESS_RW_WRITE) {
1868 if (copy_to_user(pages->uptr + start_byte, data, length))
1869 rc = -EFAULT;
1870 } else {
1871 if (copy_from_user(data, pages->uptr + start_byte, length))
1872 rc = -EFAULT;
1873 }
1874
1875 if (change_mm) {
1876 kthread_unuse_mm(pages->source_mm);
1877 mmput(pages->source_mm);
1878 }
1879
1880 return rc;
1881 }
1882
1883 static struct iopt_pages_access *
iopt_pages_get_exact_access(struct iopt_pages * pages,unsigned long index,unsigned long last)1884 iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
1885 unsigned long last)
1886 {
1887 struct interval_tree_node *node;
1888
1889 lockdep_assert_held(&pages->mutex);
1890
1891 /* There can be overlapping ranges in this interval tree */
1892 for (node = interval_tree_iter_first(&pages->access_itree, index, last);
1893 node; node = interval_tree_iter_next(node, index, last))
1894 if (node->start == index && node->last == last)
1895 return container_of(node, struct iopt_pages_access,
1896 node);
1897 return NULL;
1898 }
1899
1900 /**
1901 * iopt_area_add_access() - Record an in-knerel access for PFNs
1902 * @area: The source of PFNs
1903 * @start_index: First page index
1904 * @last_index: Inclusive last page index
1905 * @out_pages: Output list of struct page's representing the PFNs
1906 * @flags: IOMMUFD_ACCESS_RW_* flags
1907 *
1908 * Record that an in-kernel access will be accessing the pages, ensure they are
1909 * pinned, and return the PFNs as a simple list of 'struct page *'.
1910 *
1911 * This should be undone through a matching call to iopt_area_remove_access()
1912 */
iopt_area_add_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index,struct page ** out_pages,unsigned int flags)1913 int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
1914 unsigned long last_index, struct page **out_pages,
1915 unsigned int flags)
1916 {
1917 struct iopt_pages *pages = area->pages;
1918 struct iopt_pages_access *access;
1919 int rc;
1920
1921 if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1922 return -EPERM;
1923
1924 mutex_lock(&pages->mutex);
1925 access = iopt_pages_get_exact_access(pages, start_index, last_index);
1926 if (access) {
1927 area->num_accesses++;
1928 access->users++;
1929 iopt_pages_fill_from_xarray(pages, start_index, last_index,
1930 out_pages);
1931 mutex_unlock(&pages->mutex);
1932 return 0;
1933 }
1934
1935 access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
1936 if (!access) {
1937 rc = -ENOMEM;
1938 goto err_unlock;
1939 }
1940
1941 rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
1942 if (rc)
1943 goto err_free;
1944
1945 access->node.start = start_index;
1946 access->node.last = last_index;
1947 access->users = 1;
1948 area->num_accesses++;
1949 interval_tree_insert(&access->node, &pages->access_itree);
1950 mutex_unlock(&pages->mutex);
1951 return 0;
1952
1953 err_free:
1954 kfree(access);
1955 err_unlock:
1956 mutex_unlock(&pages->mutex);
1957 return rc;
1958 }
1959
1960 /**
1961 * iopt_area_remove_access() - Release an in-kernel access for PFNs
1962 * @area: The source of PFNs
1963 * @start_index: First page index
1964 * @last_index: Inclusive last page index
1965 *
1966 * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
1967 * must stop using the PFNs before calling this.
1968 */
iopt_area_remove_access(struct iopt_area * area,unsigned long start_index,unsigned long last_index)1969 void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
1970 unsigned long last_index)
1971 {
1972 struct iopt_pages *pages = area->pages;
1973 struct iopt_pages_access *access;
1974
1975 mutex_lock(&pages->mutex);
1976 access = iopt_pages_get_exact_access(pages, start_index, last_index);
1977 if (WARN_ON(!access))
1978 goto out_unlock;
1979
1980 WARN_ON(area->num_accesses == 0 || access->users == 0);
1981 area->num_accesses--;
1982 access->users--;
1983 if (access->users)
1984 goto out_unlock;
1985
1986 interval_tree_remove(&access->node, &pages->access_itree);
1987 iopt_pages_unfill_xarray(pages, start_index, last_index);
1988 kfree(access);
1989 out_unlock:
1990 mutex_unlock(&pages->mutex);
1991 }
1992