/*	$NetBSD: viomb.c,v 1.17 2023/03/25 11:04:34 mlelstv Exp $	*/

/*
 * Copyright (c) 2010 Minoura Makoto.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <sys/cdefs.h>
__KERNEL_RCSID(0, "$NetBSD: viomb.c,v 1.17 2023/03/25 11:04:34 mlelstv Exp $");

#include <sys/param.h>
#include <sys/systm.h>
#include <sys/kernel.h>
#include <sys/bus.h>
#include <sys/condvar.h>
#include <sys/device.h>
#include <sys/kthread.h>
#include <sys/mutex.h>
#include <sys/sysctl.h>
#include <uvm/uvm_page.h>
#include <sys/module.h>

#include <dev/pci/virtioreg.h>
#include <dev/pci/virtiovar.h>

#include "ioconf.h"

/* Configuration registers */
#define VIRTIO_BALLOON_CONFIG_NUM_PAGES	0 /* 32bit */
#define VIRTIO_BALLOON_CONFIG_ACTUAL	4 /* 32bit */

/* Feature bits */
#define VIRTIO_BALLOON_F_MUST_TELL_HOST (1<<0)
#define VIRTIO_BALLOON_F_STATS_VQ	(1<<1)

#define VIRTIO_BALLOON_FLAG_BITS		\
	VIRTIO_COMMON_FLAG_BITS			\
	"b\x01" "STATS_VQ\0"			\
	"b\x00" "MUST_TELL_HOST\0"

#define PGS_PER_REQ		(256) /* 1MB, 4KB/page */
#define VQ_INFLATE	0
#define VQ_DEFLATE	1


CTASSERT((PAGE_SIZE) == (VIRTIO_PAGE_SIZE)); /* XXX */

struct balloon_req {
	bus_dmamap_t			bl_dmamap;
	struct pglist			bl_pglist;
	int				bl_nentries;
	uint32_t			bl_pages[PGS_PER_REQ];
};

struct viomb_softc {
	device_t		sc_dev;

	struct virtio_softc	*sc_virtio;
	struct virtqueue	sc_vq[2];

	unsigned int		sc_npages;
	unsigned int		sc_actual;
	int			sc_inflight;
	struct balloon_req	sc_req;
	struct pglist		sc_balloon_pages;

	int			sc_inflate_done;
	int			sc_deflate_done;

	kcondvar_t		sc_wait;
	kmutex_t		sc_waitlock;
};

static int	balloon_initialized = 0; /* multiple balloon is not allowed */

static int	viomb_match(device_t, cfdata_t, void *);
static void	viomb_attach(device_t, device_t, void *);
static void	viomb_read_config(struct viomb_softc *);
static int	viomb_config_change(struct virtio_softc *);
static int	inflate(struct viomb_softc *);
static int	inflateq_done(struct virtqueue *);
static int	inflate_done(struct viomb_softc *);
static int	deflate(struct viomb_softc *);
static int	deflateq_done(struct virtqueue *);
static int	deflate_done(struct viomb_softc *);
static void	viomb_thread(void *);

CFATTACH_DECL_NEW(viomb, sizeof(struct viomb_softc),
    viomb_match, viomb_attach, NULL, NULL);

static int
viomb_match(device_t parent, cfdata_t match, void *aux)
{
	struct virtio_attach_args *va = aux;

	if (va->sc_childdevid == VIRTIO_DEVICE_ID_BALLOON)
		return 1;

	return 0;
}

static void
viomb_attach(device_t parent, device_t self, void *aux)
{
	struct viomb_softc *sc = device_private(self);
	struct virtio_softc *vsc = device_private(parent);
	const struct sysctlnode *node;
	uint64_t features;

	if (virtio_child(vsc) != NULL) {
		aprint_normal(": child already attached for %s; "
			      "something wrong...\n", device_xname(parent));
		return;
	}

	if (balloon_initialized++) {
		aprint_normal(": balloon already exists; something wrong...\n");
		return;
	}

	/* fail on non-4K page size archs */
	if (VIRTIO_PAGE_SIZE != PAGE_SIZE){
		aprint_normal("non-4K page size arch found, needs %d, got %d\n",
		    VIRTIO_PAGE_SIZE, PAGE_SIZE);
		return;
	}

	sc->sc_dev = self;
	sc->sc_virtio = vsc;

	virtio_child_attach_start(vsc, self, IPL_VM,
	    VIRTIO_BALLOON_F_MUST_TELL_HOST, VIRTIO_BALLOON_FLAG_BITS);

	features = virtio_features(vsc);
	if (features == 0)
		goto err_none;

	viomb_read_config(sc);
	sc->sc_inflight = 0;
	TAILQ_INIT(&sc->sc_balloon_pages);

	sc->sc_inflate_done = sc->sc_deflate_done = 0;
	mutex_init(&sc->sc_waitlock, MUTEX_DEFAULT, IPL_VM); /* spin */
	cv_init(&sc->sc_wait, "balloon");

	virtio_init_vq_vqdone(vsc, &sc->sc_vq[VQ_INFLATE], VQ_INFLATE,
	    inflateq_done);
	virtio_init_vq_vqdone(vsc, &sc->sc_vq[VQ_DEFLATE], VQ_DEFLATE,
	    deflateq_done);

	if (virtio_alloc_vq(vsc, &sc->sc_vq[VQ_INFLATE],
			     sizeof(uint32_t)*PGS_PER_REQ, 1,
			     "inflate") != 0)
		goto err_mutex;
	if (virtio_alloc_vq(vsc, &sc->sc_vq[VQ_DEFLATE],
			     sizeof(uint32_t)*PGS_PER_REQ, 1,
			     "deflate") != 0)
		goto err_vq0;

	if (bus_dmamap_create(virtio_dmat(vsc), sizeof(uint32_t)*PGS_PER_REQ,
			      1, sizeof(uint32_t)*PGS_PER_REQ, 0,
			      BUS_DMA_NOWAIT, &sc->sc_req.bl_dmamap)) {
		aprint_error_dev(sc->sc_dev, "dmamap creation failed.\n");
		goto err_vq;
	}
	if (bus_dmamap_load(virtio_dmat(vsc), sc->sc_req.bl_dmamap,
			    &sc->sc_req.bl_pages[0],
			    sizeof(uint32_t) * PGS_PER_REQ,
			    NULL, BUS_DMA_NOWAIT)) {
		aprint_error_dev(sc->sc_dev, "dmamap load failed.\n");
		goto err_dmamap;
	}

	if (virtio_child_attach_finish(vsc, sc->sc_vq, __arraycount(sc->sc_vq),
	    viomb_config_change, VIRTIO_F_INTR_MPSAFE) != 0)
		goto err_out;

	if (kthread_create(PRI_IDLE, KTHREAD_MPSAFE, NULL,
			   viomb_thread, sc, NULL, "viomb")) {
		aprint_error_dev(sc->sc_dev, "cannot create kthread.\n");
		goto err_out;
	}

	sysctl_createv(NULL, 0, NULL, &node, 0, CTLTYPE_NODE,
		       "viomb", SYSCTL_DESCR("VirtIO Balloon status"),
		       NULL, 0, NULL, 0,
		       CTL_HW, CTL_CREATE, CTL_EOL);
	sysctl_createv(NULL, 0, NULL, NULL, 0, CTLTYPE_INT,
		       "npages", SYSCTL_DESCR("VirtIO Balloon npages value"),
		       NULL, 0, &sc->sc_npages, 0,
		       CTL_HW, node->sysctl_num, CTL_CREATE, CTL_EOL);
	sysctl_createv(NULL, 0, NULL, NULL, 0, CTLTYPE_INT,
		       "actual", SYSCTL_DESCR("VirtIO Balloon actual value"),
		       NULL, 0, &sc->sc_actual, 0,
		       CTL_HW, node->sysctl_num, CTL_CREATE, CTL_EOL);
	return;

err_out:
err_dmamap:
	bus_dmamap_destroy(virtio_dmat(vsc), sc->sc_req.bl_dmamap);
err_vq:
	virtio_free_vq(vsc, &sc->sc_vq[VQ_DEFLATE]);
err_vq0:
	virtio_free_vq(vsc, &sc->sc_vq[VQ_INFLATE]);
err_mutex:
	cv_destroy(&sc->sc_wait);
	mutex_destroy(&sc->sc_waitlock);
err_none:
	virtio_child_attach_failed(vsc);
	return;
}

static void
viomb_read_config(struct viomb_softc *sc)
{
	/* these values are explicitly specified as little-endian */
	sc->sc_npages = virtio_read_device_config_le_4(sc->sc_virtio,
		  VIRTIO_BALLOON_CONFIG_NUM_PAGES);

	sc->sc_actual = virtio_read_device_config_le_4(sc->sc_virtio,
		  VIRTIO_BALLOON_CONFIG_ACTUAL);
}

/*
 * Config change callback: wakeup the kthread.
 */
static int
viomb_config_change(struct virtio_softc *vsc)
{
	struct viomb_softc *sc = device_private(virtio_child(vsc));
	unsigned int old;

	old = sc->sc_npages;
	viomb_read_config(sc);
	mutex_enter(&sc->sc_waitlock);
	cv_signal(&sc->sc_wait);
	mutex_exit(&sc->sc_waitlock);
	if (sc->sc_npages > old)
		printf("%s: inflating balloon from %u to %u.\n",
		       device_xname(sc->sc_dev), old, sc->sc_npages);
	else if  (sc->sc_npages < old)
		printf("%s: deflating balloon from %u to %u.\n",
		       device_xname(sc->sc_dev), old, sc->sc_npages);

	return 1;
}

/*
 * Inflate: consume some amount of physical memory.
 */
static int
inflate(struct viomb_softc *sc)
{
	struct virtio_softc *vsc = sc->sc_virtio;
	int i, slot;
	uint64_t nvpages, nhpages;
	struct balloon_req *b;
	struct vm_page *p;
	struct virtqueue *vq = &sc->sc_vq[VQ_INFLATE];

	if (sc->sc_inflight)
		return 0;
	nvpages = sc->sc_npages - sc->sc_actual;
	if (nvpages > PGS_PER_REQ)
		nvpages = PGS_PER_REQ;
	nhpages = nvpages * VIRTIO_PAGE_SIZE / PAGE_SIZE;

	b = &sc->sc_req;
	if (uvm_pglistalloc(nhpages*PAGE_SIZE, 0, UINT32_MAX*(paddr_t)PAGE_SIZE,
			    0, 0, &b->bl_pglist, nhpages, 0)) {
		printf("%s: %" PRIu64 " pages of physical memory "
		       "could not be allocated, retrying...\n",
		       device_xname(sc->sc_dev), nhpages);
		return 1;	/* sleep longer */
	}

	b->bl_nentries = nvpages;
	i = 0;
	TAILQ_FOREACH(p, &b->bl_pglist, pageq.queue) {
		b->bl_pages[i++] =
			htole32(VM_PAGE_TO_PHYS(p) / VIRTIO_PAGE_SIZE);
	}
	KASSERT(i == nvpages);

	if (virtio_enqueue_prep(vsc, vq, &slot) != 0) {
		printf("%s: inflate enqueue failed.\n",
		       device_xname(sc->sc_dev));
		uvm_pglistfree(&b->bl_pglist);
		return 0;
	}
	if (virtio_enqueue_reserve(vsc, vq, slot, 1)) {
		printf("%s: inflate enqueue failed.\n",
		       device_xname(sc->sc_dev));
		uvm_pglistfree(&b->bl_pglist);
		return 0;
	}
	bus_dmamap_sync(virtio_dmat(vsc), b->bl_dmamap, 0,
	    sizeof(uint32_t)*nvpages, BUS_DMASYNC_PREWRITE);
	virtio_enqueue(vsc, vq, slot, b->bl_dmamap, true);
	virtio_enqueue_commit(vsc, vq, slot, true);
	sc->sc_inflight += nvpages;

	return 0;
}

static int
inflateq_done(struct virtqueue *vq)
{
	struct virtio_softc *vsc = vq->vq_owner;
	struct viomb_softc *sc = device_private(virtio_child(vsc));

	mutex_enter(&sc->sc_waitlock);
	sc->sc_inflate_done = 1;
	cv_signal(&sc->sc_wait);
	mutex_exit(&sc->sc_waitlock);

	return 1;
}

static int
inflate_done(struct viomb_softc *sc)
{
	struct virtio_softc *vsc = sc->sc_virtio;
	struct virtqueue *vq = &sc->sc_vq[VQ_INFLATE];
	struct balloon_req *b;
	int r, slot;
	uint64_t nvpages;
	struct vm_page *p;

	r = virtio_dequeue(vsc, vq, &slot, NULL);
	if (r != 0) {
		printf("%s: inflate dequeue failed, errno %d.\n",
		       device_xname(sc->sc_dev), r);
		return 1;
	}
	virtio_dequeue_commit(vsc, vq, slot);

	b = &sc->sc_req;
	nvpages = b->bl_nentries;
	bus_dmamap_sync(virtio_dmat(vsc), b->bl_dmamap,
			0,
			sizeof(uint32_t)*nvpages,
			BUS_DMASYNC_POSTWRITE);
	while (!TAILQ_EMPTY(&b->bl_pglist)) {
		p = TAILQ_FIRST(&b->bl_pglist);
		TAILQ_REMOVE(&b->bl_pglist, p, pageq.queue);
		TAILQ_INSERT_TAIL(&sc->sc_balloon_pages, p, pageq.queue);
	}

	sc->sc_inflight -= nvpages;
	virtio_write_device_config_le_4(vsc,
		     VIRTIO_BALLOON_CONFIG_ACTUAL,
		     sc->sc_actual + nvpages);
	viomb_read_config(sc);

	return 1;
}
	
/*
 * Deflate: free previously allocated memory.
 */
static int
deflate(struct viomb_softc *sc)
{
	struct virtio_softc *vsc = sc->sc_virtio;
	int i, slot;
	uint64_t nvpages, nhpages;
	struct balloon_req *b;
	struct vm_page *p;
	struct virtqueue *vq = &sc->sc_vq[VQ_DEFLATE];

	nvpages = (sc->sc_actual + sc->sc_inflight) - sc->sc_npages;
	if (nvpages > PGS_PER_REQ)
		nvpages = PGS_PER_REQ;
	nhpages = nvpages * VIRTIO_PAGE_SIZE / PAGE_SIZE;

	b = &sc->sc_req;

	b->bl_nentries = nvpages;
	TAILQ_INIT(&b->bl_pglist);
	for (i = 0; i < nhpages; i++) {
		p = TAILQ_FIRST(&sc->sc_balloon_pages);
		if (p == NULL)
			break;
		TAILQ_REMOVE(&sc->sc_balloon_pages, p, pageq.queue);
		TAILQ_INSERT_TAIL(&b->bl_pglist, p, pageq.queue);
		b->bl_pages[i] =
			htole32(VM_PAGE_TO_PHYS(p) / VIRTIO_PAGE_SIZE);
	}

	if (virtio_enqueue_prep(vsc, vq, &slot) != 0) {
		printf("%s: deflate enqueue failed.\n",
		       device_xname(sc->sc_dev));
		TAILQ_FOREACH_REVERSE(p, &b->bl_pglist, pglist, pageq.queue) {
			TAILQ_REMOVE(&b->bl_pglist, p, pageq.queue);
			TAILQ_INSERT_HEAD(&sc->sc_balloon_pages, p,
			    pageq.queue);
		}
		return 0;
	}
	if (virtio_enqueue_reserve(vsc, vq, slot, 1) != 0) {
		printf("%s: deflate enqueue failed.\n",
		       device_xname(sc->sc_dev));
		TAILQ_FOREACH_REVERSE(p, &b->bl_pglist, pglist, pageq.queue) {
			TAILQ_REMOVE(&b->bl_pglist, p, pageq.queue);
			TAILQ_INSERT_HEAD(&sc->sc_balloon_pages, p,
			    pageq.queue);
		}
		return 0;
	}
	bus_dmamap_sync(virtio_dmat(vsc), b->bl_dmamap, 0,
	    sizeof(uint32_t)*nvpages, BUS_DMASYNC_PREWRITE);
	virtio_enqueue(vsc, vq, slot, b->bl_dmamap, true);
	virtio_enqueue_commit(vsc, vq, slot, true);
	sc->sc_inflight -= nvpages;

	if (!(virtio_features(vsc) & VIRTIO_BALLOON_F_MUST_TELL_HOST))
		uvm_pglistfree(&b->bl_pglist);

	return 0;
}

static int
deflateq_done(struct virtqueue *vq)
{
	struct virtio_softc *vsc = vq->vq_owner;
	struct viomb_softc *sc = device_private(virtio_child(vsc));

	mutex_enter(&sc->sc_waitlock);
	sc->sc_deflate_done = 1;
	cv_signal(&sc->sc_wait);
	mutex_exit(&sc->sc_waitlock);

	return 1;
}
	
static int
deflate_done(struct viomb_softc *sc)
{
	struct virtio_softc *vsc = sc->sc_virtio;
	struct virtqueue *vq = &sc->sc_vq[VQ_DEFLATE];
	struct balloon_req *b;
	int r, slot;
	uint64_t nvpages;

	r = virtio_dequeue(vsc, vq, &slot, NULL);
	if (r != 0) {
		printf("%s: deflate dequeue failed, errno %d\n",
		       device_xname(sc->sc_dev), r);
		return 1;
	}
	virtio_dequeue_commit(vsc, vq, slot);

	b = &sc->sc_req;
	nvpages = b->bl_nentries;
	bus_dmamap_sync(virtio_dmat(vsc), b->bl_dmamap,
			0,
			sizeof(uint32_t)*nvpages,
			BUS_DMASYNC_POSTWRITE);

	if (virtio_features(vsc) & VIRTIO_BALLOON_F_MUST_TELL_HOST)
		uvm_pglistfree(&b->bl_pglist);

	sc->sc_inflight += nvpages;
	virtio_write_device_config_le_4(vsc,
		     VIRTIO_BALLOON_CONFIG_ACTUAL,
		     sc->sc_actual - nvpages);
	viomb_read_config(sc);

	return 1;
}

/*
 * Kthread: sleeps, eventually inflate and deflate.
 */
static void
viomb_thread(void *arg)
{
	struct viomb_softc *sc = arg;
	int sleeptime, r;

	for ( ; ; ) {
		sleeptime = 30000;
		if (sc->sc_npages > sc->sc_actual + sc->sc_inflight) {
			if (sc->sc_inflight == 0) {
				r = inflate(sc);
				if (r != 0)
					sleeptime = 10000;
				else
					sleeptime = 100;
			} else
				sleeptime = 20;
		} else if (sc->sc_npages < sc->sc_actual + sc->sc_inflight) {
			if (sc->sc_inflight == 0)
				r = deflate(sc);
			sleeptime = 100;
		}

	again:
		mutex_enter(&sc->sc_waitlock);
		if (sc->sc_inflate_done) {
			sc->sc_inflate_done = 0;
			mutex_exit(&sc->sc_waitlock);
			inflate_done(sc);
			goto again;
		}
		if (sc->sc_deflate_done) {
			sc->sc_deflate_done = 0;
			mutex_exit(&sc->sc_waitlock);
			deflate_done(sc);
			goto again;
		}
		cv_timedwait(&sc->sc_wait, &sc->sc_waitlock,
			     mstohz(sleeptime));
		mutex_exit(&sc->sc_waitlock);
	}
}

MODULE(MODULE_CLASS_DRIVER, viomb, "virtio");
 
#ifdef _MODULE
#include "ioconf.c"
#endif
 
static int 
viomb_modcmd(modcmd_t cmd, void *opaque)
{
	int error = 0;
 
#ifdef _MODULE
	switch (cmd) {
	case MODULE_CMD_INIT:
		error = config_init_component(cfdriver_ioconf_viomb, 
		    cfattach_ioconf_viomb, cfdata_ioconf_viomb); 
		break;
	case MODULE_CMD_FINI:
		error = config_fini_component(cfdriver_ioconf_viomb,
		    cfattach_ioconf_viomb, cfdata_ioconf_viomb);
		break;
	default:
		error = ENOTTY;
		break; 
	}
#endif
   
	return error;
}
