Re: [PATCH] vfio/type1: Empty batch for pfnmap pages

From: Alex Williamson
Date: Thu Mar 25 2021 - 15:26:42 EST


On Wed, 24 Mar 2021 21:05:52 -0400
Daniel Jordan <daniel.m.jordan@xxxxxxxxxx> wrote:

> When vfio_pin_pages_remote() returns with a partial batch consisting of
> a single VM_PFNMAP pfn, a subsequent call will unfortunately try
> restoring it from batch->pages, resulting in vfio mapping the wrong page
> and unbalancing the page refcount.
>
> Prevent the function from returning with this kind of partial batch to
> avoid the issue. There's no explicit check for a VM_PFNMAP pfn because
> it's awkward to do so, so infer it from characteristics of the batch
> instead. This may result in occasional false positives but keeps the
> code simpler.
>
> Fixes: 4d83de6da265 ("vfio/type1: Batch page pinning")
> Link: https://lkml.kernel.org/r/20210323133254.33ed9161@xxxxxxxxxxxxxxxxxxxxx/
> Reported-by: Alex Williamson <alex.williamson@xxxxxxxxxx>
> Suggested-by: Alex Williamson <alex.williamson@xxxxxxxxxx>
> Signed-off-by: Daniel Jordan <daniel.m.jordan@xxxxxxxxxx>
> ---
>
> Alex, I couldn't immediately find a way to trigger this bug, but I can
> run your test case if you like.
>
> This is the minimal fix, but it should still protect all calls of
> vfio_batch_unpin() from this problem.

Thanks, applied to my for-linus branch for v5.12. The attached unit
test triggers the issue, I don't have any real world examples and was
only just experimenting with this for another series earlier this week.
Thanks,

Alex
/*
* Alternate pages of device memory and anonymous memory within a single DMA
* mapping.
*
* Run with argv[1] as a fully specified PCI device already bound to vfio-pci.
* ex. "alternate-pfnmap 0000:01:00.0"
*/
#include <errno.h>
#include <libgen.h>
#include <fcntl.h>
#include <signal.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/eventfd.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/types.h>

#include <linux/ioctl.h>
#include <linux/vfio.h>
#include <linux/pci_regs.h>

void *vaddr = (void *)0x100000000;
size_t map_size = 0;

int get_container(void)
{
int container = open("/dev/vfio/vfio", O_RDWR);

if (container < 0)
fprintf(stderr, "Failed to open /dev/vfio/vfio, %d (%s)\n",
container, strerror(errno));

return container;
}

int get_group(char *name)
{
int seg, bus, slot, func;
int ret, group, groupid;
char path[50], iommu_group_path[50], *group_name;
struct stat st;
ssize_t len;
struct vfio_group_status group_status = {
.argsz = sizeof(group_status)
};

ret = sscanf(name, "%04x:%02x:%02x.%d", &seg, &bus, &slot, &func);
if (ret != 4) {
fprintf(stderr, "Invalid device\n");
return -EINVAL;
}

snprintf(path, sizeof(path),
"/sys/bus/pci/devices/%04x:%02x:%02x.%01x/",
seg, bus, slot, func);

ret = stat(path, &st);
if (ret < 0) {
fprintf(stderr, "No such device\n");
return ret;
}

strncat(path, "iommu_group", sizeof(path) - strlen(path) - 1);

len = readlink(path, iommu_group_path, sizeof(iommu_group_path));
if (len <= 0) {
fprintf(stderr, "No iommu_group for device\n");
return -EINVAL;
}

iommu_group_path[len] = 0;
group_name = basename(iommu_group_path);

if (sscanf(group_name, "%d", &groupid) != 1) {
fprintf(stderr, "Unknown group\n");
return -EINVAL;
}

snprintf(path, sizeof(path), "/dev/vfio/%d", groupid);
group = open(path, O_RDWR);
if (group < 0) {
fprintf(stderr, "Failed to open %s, %d (%s)\n",
path, group, strerror(errno));
return group;
}

ret = ioctl(group, VFIO_GROUP_GET_STATUS, &group_status);
if (ret) {
fprintf(stderr, "ioctl(VFIO_GROUP_GET_STATUS) failed\n");
return ret;
}

if (!(group_status.flags & VFIO_GROUP_FLAGS_VIABLE)) {
fprintf(stderr,
"Group not viable, all devices attached to vfio?\n");
return -1;
}

return group;
}

int group_set_container(int group, int container)
{
int ret = ioctl(group, VFIO_GROUP_SET_CONTAINER, &container);

if (ret)
fprintf(stderr, "Failed to set group container\n");

return ret;
}

int container_set_iommu(int container)
{
int ret = ioctl(container, VFIO_SET_IOMMU, VFIO_TYPE1_IOMMU);

if (ret)
fprintf(stderr, "Failed to set IOMMU\n");

return ret;
}

int group_get_device(int group, char *name)
{
int device = ioctl(group, VFIO_GROUP_GET_DEVICE_FD, name);

if (device < 0)
fprintf(stderr, "Failed to get device\n");

return device;
}

void *mmap_device_page(int device, int prot)
{
struct vfio_region_info config_info = {
.argsz = sizeof(config_info),
.index = VFIO_PCI_CONFIG_REGION_INDEX
};
struct vfio_region_info region_info = {
.argsz = sizeof(region_info)
};
void *map = MAP_FAILED;
unsigned int bar;
int i, ret;

ret = ioctl(device, VFIO_DEVICE_GET_REGION_INFO, &config_info);
if (ret) {
fprintf(stderr, "Failed to get config space region info\n");
return map;
}

for (i = 0; i < 6; i++) {
if (pread(device, &bar, sizeof(bar), config_info.offset +
PCI_BASE_ADDRESS_0 + (4 * i)) != sizeof(bar)) {
fprintf(stderr, "Error reading BAR%d\n", i);
return map;
}

if (!(bar & PCI_BASE_ADDRESS_SPACE)) {
break;
tryagain:
if (bar & PCI_BASE_ADDRESS_MEM_TYPE_64)
i++;
}
}

if (i >= 6) {
fprintf(stderr, "No memory BARs found\n");
return map;
}

region_info.index = VFIO_PCI_BAR0_REGION_INDEX + i;
ret = ioctl(device, VFIO_DEVICE_GET_REGION_INFO, &region_info);
if (ret) {
fprintf(stderr, "Failed to get BAR%d region info\n", i);
return map;
}

if (!(region_info.flags & VFIO_REGION_INFO_FLAG_MMAP)) {
printf("No mmap support, try next\n");
goto tryagain;
}

if (region_info.size < getpagesize()) {
printf("Too small for mmap, try next\n");
goto tryagain;
}

map = mmap(vaddr + map_size, getpagesize(), prot,
MAP_SHARED, device, region_info.offset);
if (map == MAP_FAILED) {
fprintf(stderr, "Error mmap'ing BAR: %m\n");
goto tryagain;
}

fprintf(stderr, "\t\tmmap_device_page @0x%016lx\n",
(unsigned long long)map);
if (!vaddr) {
vaddr = map;
} else if (map != vaddr + map_size) {
fprintf(stderr, "Did not get contiguous mmap\n");
munmap(map, getpagesize());
return MAP_FAILED;
}

map_size += getpagesize();

return map;
}

void *mmap_mem_page(int prot)
{
void *map = mmap(vaddr + map_size, getpagesize(), prot,
MAP_PRIVATE | MAP_ANONYMOUS, 0, 0);

if (map == MAP_FAILED) {
fprintf(stderr, "Map anonymous page failed: %m\n");
return map;
}

fprintf(stderr, "\t\tmmap_mem_page @0x%016lx\n",
(unsigned long long)map);
if (!vaddr) {
vaddr = map;
} else if (map != vaddr + map_size) {
fprintf(stderr, "Did not get contiguous mmap\n");
munmap(map, getpagesize());
return MAP_FAILED;
}

map_size += getpagesize();

return map;
}

int dma_map(int container, void *map, int size, unsigned long iova)
{
struct vfio_iommu_type1_dma_map dma_map = {
.argsz = sizeof(dma_map),
.size = size,
.vaddr = (__u64)map,
.iova = iova,
.flags = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE
};
int ret;

ret = ioctl(container, VFIO_IOMMU_MAP_DMA, &dma_map);
if (ret)
fprintf(stderr, "Failed to DMA map: %m\n");

return ret;
}

int dma_unmap(int container, int size, unsigned long iova)
{
struct vfio_iommu_type1_dma_unmap dma_unmap = {
.argsz = sizeof(dma_unmap),
.iova = iova,
.size = size,
};
int ret;

ret = ioctl(container, VFIO_IOMMU_UNMAP_DMA, &dma_unmap);
if (ret)
fprintf(stderr, "Failed to DMA unmap: %m\n");

return dma_unmap.size;
}

int main(int argc, char **argv)
{
int container1;
int group1;
int device1;
int ret;
void *map, *map_base;

group1 = get_group(argv[1]);
if (group1 < 0) {
fprintf(stderr, "Failed to get group for %s\n", argv[1]);
return group1;
}

fprintf(stderr, "\tGot group for %s\n", argv[1]);

container1 = get_container();

if (container1 < 0) {
fprintf(stderr, "Failed to get container\n");
return -EFAULT;
}

fprintf(stderr, "\tGot container\n");

if (group_set_container(group1, container1)) {
fprintf(stderr, "Failed to set container\n");
return -EFAULT;
}

fprintf(stderr, "\tAttached group to container\n");

if (container_set_iommu(container1)) {
fprintf(stderr, "Failed to set iommu\n");
return -EFAULT;
}

fprintf(stderr, "\tSet IOMMU model for container\n");

device1 = group_get_device(group1, argv[1]);

if (device1 < 0) {
fprintf(stderr, "Failed to get devices\n");
return -EFAULT;
}

fprintf(stderr, "\tGot device file descriptors\n");

map = mmap_device_page(device1, PROT_READ | PROT_WRITE);
if (map == MAP_FAILED) {
fprintf(stderr, "Failed to mmap device page\n");
return -EFAULT;
}

fprintf(stderr, "\tGot mmap to device %s\n", argv[1]);

map_base = map;

map = mmap_mem_page(PROT_READ | PROT_WRITE);
if (map == MAP_FAILED) {
fprintf(stderr, "Failed to mmap memory page\n");
return -EFAULT;
}

fprintf(stderr, "\tGot memory page\n");

map = mmap_device_page(device1, PROT_READ | PROT_WRITE);
if (map == MAP_FAILED) {
fprintf(stderr, "Failed to mmap device page\n");
return -EFAULT;
}

fprintf(stderr, "\tGot mmap to device %s\n", argv[1]);

map = mmap_mem_page(PROT_READ | PROT_WRITE);
if (map == MAP_FAILED) {
fprintf(stderr, "Failed to mmap memory page\n");
return -EFAULT;
}

fprintf(stderr, "\tGot memory page\n");

if (dma_map(container1, map_base, getpagesize() * 4,
1024 * 1024 * 1024)) {
fprintf(stderr, "Failed to DMA map pages\n");
return -EFAULT;
}

fprintf(stderr, "\tDMA mapped pages into container for device %s\n",
argv[1]);

return 0;
}