virtio: Allocate virtqueue in page-size units

In preparation for explicit bouncing of virtqueue pages for devices
advertising the VIRTIO_F_IOMMU_PLATFORM feature, introduce a couple
of wrappers around virtqueue allocation and freeing operations,
ensuring that buffers are handled in terms of page-size units.

Signed-off-by: Will Deacon <willdeacon@google.com>
[ Paul: pick from the Android tree. Rebase to the upstream ]
Signed-off-by: Ying-Chun Liu (PaulLiu) <paul.liu@linaro.org>
Cc: Bin Meng <bmeng.cn@gmail.com>
Link: https://android.googlesource.com/platform/external/u-boot/+/b4bb5227d4cf4fdfcd8b4e1ff2692d3a54d1482a
Reviewed-by: Simon Glass <sjg@chromium.org>
diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
index f71bab7..5aeb13f 100644
--- a/drivers/virtio/virtio_ring.c
+++ b/drivers/virtio/virtio_ring.c
@@ -15,6 +15,17 @@
 #include <virtio_ring.h>
 #include <linux/bug.h>
 #include <linux/compat.h>
+#include <linux/kernel.h>
+
+static void *virtio_alloc_pages(struct udevice *vdev, u32 npages)
+{
+	return memalign(PAGE_SIZE, npages * PAGE_SIZE);
+}
+
+static void virtio_free_pages(struct udevice *vdev, void *ptr, u32 npages)
+{
+	free(ptr);
+}
 
 static unsigned int virtqueue_attach_desc(struct virtqueue *vq, unsigned int i,
 					  struct virtio_sg *sg, u16 flags)
@@ -271,6 +282,8 @@
 					 unsigned int vring_align,
 					 struct udevice *udev)
 {
+	struct virtio_dev_priv *uc_priv = dev_get_uclass_priv(udev);
+	struct udevice *vdev = uc_priv->vdev;
 	struct virtqueue *vq;
 	void *queue = NULL;
 	struct vring vring;
@@ -283,7 +296,9 @@
 
 	/* TODO: allocate each queue chunk individually */
 	for (; num && vring_size(num, vring_align) > PAGE_SIZE; num /= 2) {
-		queue = memalign(PAGE_SIZE, vring_size(num, vring_align));
+		size_t sz = vring_size(num, vring_align);
+
+		queue = virtio_alloc_pages(vdev, DIV_ROUND_UP(sz, PAGE_SIZE));
 		if (queue)
 			break;
 	}
@@ -293,7 +308,7 @@
 
 	if (!queue) {
 		/* Try to get a single page. You are my only hope! */
-		queue = memalign(PAGE_SIZE, vring_size(num, vring_align));
+		queue = virtio_alloc_pages(vdev, 1);
 	}
 	if (!queue)
 		return NULL;
@@ -303,7 +318,7 @@
 
 	vq = __vring_new_virtqueue(index, vring, udev);
 	if (!vq) {
-		free(queue);
+		virtio_free_pages(vdev, queue, DIV_ROUND_UP(vring.size, PAGE_SIZE));
 		return NULL;
 	}
 	debug("(%s): created vring @ %p for vq @ %p with num %u\n", udev->name,
@@ -314,7 +329,8 @@
 
 void vring_del_virtqueue(struct virtqueue *vq)
 {
-	free(vq->vring.desc);
+	virtio_free_pages(vq->vdev, vq->vring.desc,
+			  DIV_ROUND_UP(vq->vring.size, PAGE_SIZE));
 	free(vq->vring_desc_shadow);
 	list_del(&vq->list);
 	free(vq);