/*	$NetBSD: vio9p.c,v 1.12 2025/04/22 05:56:25 ozaki-r Exp $	*/

/*
 * Copyright (c) 2019 Internet Initiative Japan, Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <sys/cdefs.h>
__KERNEL_RCSID(0, "$NetBSD: vio9p.c,v 1.12 2025/04/22 05:56:25 ozaki-r Exp $");

#include <sys/param.h>
#include <sys/systm.h>
#include <sys/kernel.h>
#include <sys/bus.h>
#include <sys/conf.h>
#include <sys/condvar.h>
#include <sys/device.h>
#include <sys/mutex.h>
#include <sys/sysctl.h>
#include <sys/module.h>
#include <sys/syslog.h>
#include <sys/select.h>
#include <sys/kmem.h>

#include <sys/file.h>
#include <sys/filedesc.h>
#include <sys/uio.h>

#include <dev/pci/virtioreg.h>
#include <dev/pci/virtiovar.h>

#include "ioconf.h"

//#define VIO9P_DEBUG	1
//#define VIO9P_DUMP	1
#ifdef VIO9P_DEBUG
#define DLOG(fmt, args...) \
	do { log(LOG_DEBUG, "%s: " fmt "\n", __func__, ##args); } while (0)
#else
#define DLOG(fmt, args...) __nothing
#endif

/* Device-specific feature bits */
#define VIO9P_F_MOUNT_TAG	(UINT64_C(1) << 0) /* mount tag specified */

/* Configuration registers */
#define VIO9P_CONFIG_TAG_LEN	0 /* 16bit */
#define VIO9P_CONFIG_TAG	2

#define VIO9P_FLAG_BITS				\
	VIRTIO_COMMON_FLAG_BITS			\
	"b\x00" "MOUNT_TAG\0"


// Must be the same as P9P_DEFREQLEN of usr.sbin/puffs/mount_9p/ninepuffs.h
#define VIO9P_MAX_REQLEN	(16 * 1024)
#define VIO9P_SEGSIZE		PAGE_SIZE
#define VIO9P_N_SEGMENTS	(VIO9P_MAX_REQLEN / VIO9P_SEGSIZE)

/*
 * QEMU defines this as 32 but includes the final zero byte into the
 * limit.  The code below counts the final zero byte separately, so
 * adjust this define to match.
 */
#define P9_MAX_TAG_LEN		31

CTASSERT((PAGE_SIZE) == (VIRTIO_PAGE_SIZE)); /* XXX */

struct vio9p_softc {
	device_t		sc_dev;

	struct virtio_softc	*sc_virtio;
	struct virtqueue	sc_vq[1];

	uint16_t		sc_taglen;
	uint8_t			sc_tag[P9_MAX_TAG_LEN + 1];

	int			sc_flags;
#define VIO9P_INUSE		__BIT(0)

	int			sc_state;
#define VIO9P_S_INIT		0
#define VIO9P_S_REQUESTING	1
#define VIO9P_S_REPLIED		2
#define VIO9P_S_CONSUMING	3
	kcondvar_t		sc_wait;
	struct selinfo		sc_sel;
	kmutex_t		sc_lock;

	bus_dmamap_t		sc_dmamap_tx;
	bus_dmamap_t		sc_dmamap_rx;
	char			*sc_buf_tx;
	char			*sc_buf_rx;
	size_t			sc_buf_rx_len;
	off_t			sc_buf_rx_offset;
};

/*
 * Locking notes:
 * - sc_state, sc_wait and sc_sel are protected by sc_lock
 *
 * The state machine (sc_state):
 * - INIT       =(write from client)=> REQUESTING
 * - REQUESTING =(reply from host)=>   REPLIED
 * - REPLIED    =(read from client)=>  CONSUMING
 * - CONSUMING  =(read completed(*))=> INIT
 *
 * (*) read may not finish by one read(2) request, then
 *     the state remains CONSUMING.
 */

static int	vio9p_match(device_t, cfdata_t, void *);
static void	vio9p_attach(device_t, device_t, void *);
static void	vio9p_read_config(struct vio9p_softc *);
static int	vio9p_request_done(struct virtqueue *);

static int	vio9p_read(struct file *, off_t *, struct uio *, kauth_cred_t,
		    int);
static int	vio9p_write(struct file *, off_t *, struct uio *,
		    kauth_cred_t, int);
static int	vio9p_ioctl(struct file *, u_long, void *);
static int	vio9p_close(struct file *);
static int	vio9p_kqfilter(struct file *, struct knote *);

static const struct fileops vio9p_fileops = {
	.fo_name = "vio9p",
	.fo_read = vio9p_read,
	.fo_write = vio9p_write,
	.fo_ioctl = vio9p_ioctl,
	.fo_fcntl = fnullop_fcntl,
	.fo_poll = fnullop_poll,
	.fo_stat = fbadop_stat,
	.fo_close = vio9p_close,
	.fo_kqfilter = vio9p_kqfilter,
	.fo_restart = fnullop_restart,
};

static dev_type_open(vio9p_dev_open);

const struct cdevsw vio9p_cdevsw = {
	.d_open = vio9p_dev_open,
	.d_read = noread,
	.d_write = nowrite,
	.d_ioctl = noioctl,
	.d_stop = nostop,
	.d_tty = notty,
	.d_poll = nopoll,
	.d_mmap = nommap,
	.d_kqfilter = nokqfilter,
	.d_discard = nodiscard,
	.d_flag = D_OTHER | D_MPSAFE,
};

static int
vio9p_dev_open(dev_t dev, int flag, int mode, struct lwp *l)
{
	struct vio9p_softc *sc;
	struct file *fp;
	int error, fd;

	sc = device_lookup_private(&vio9p_cd, minor(dev));
	if (sc == NULL)
		return ENXIO;

	/* FIXME TOCTOU */
	if (ISSET(sc->sc_flags, VIO9P_INUSE))
		return EBUSY;

	/* falloc() will fill in the descriptor for us. */
	error = fd_allocfile(&fp, &fd);
	if (error != 0)
		return error;

	sc->sc_flags |= VIO9P_INUSE;

	return fd_clone(fp, fd, flag, &vio9p_fileops, sc);
}

static int
vio9p_ioctl(struct file *fp, u_long cmd, void *addr)
{
	int error = 0;

	switch (cmd) {
	case FIONBIO:
		break;
	default:
		error = EINVAL;
		break;
	}

	return error;
}

static int
vio9p_read(struct file *fp, off_t *offp, struct uio *uio,
    kauth_cred_t cred, int flags)
{
	struct vio9p_softc *sc = fp->f_data;
	struct virtio_softc *vsc = sc->sc_virtio;
	struct virtqueue *vq = &sc->sc_vq[0];
	int error, slot, len;

	DLOG("enter");

	mutex_enter(&sc->sc_lock);

	if (sc->sc_state == VIO9P_S_INIT) {
		DLOG("%s: not requested", device_xname(sc->sc_dev));
		error = EAGAIN;
		goto out;
	}

	if (sc->sc_state == VIO9P_S_CONSUMING) {
		KASSERT(sc->sc_buf_rx_len > 0);
		/* We already have some remaining, consume it. */
		len = sc->sc_buf_rx_len - sc->sc_buf_rx_offset;
		goto consume;
	}

#if 0
	if (uio->uio_resid != VIO9P_MAX_REQLEN)
		return EINVAL;
#else
	if (uio->uio_resid > VIO9P_MAX_REQLEN) {
		error = EINVAL;
		goto out;
	}
#endif

	error = 0;
	while (sc->sc_state == VIO9P_S_REQUESTING) {
		error = cv_timedwait_sig(&sc->sc_wait, &sc->sc_lock, hz);
		if (error != 0)
			break;
	}
	if (sc->sc_state == VIO9P_S_REPLIED)
		sc->sc_state = VIO9P_S_CONSUMING;

	if (error != 0)
		goto out;

	error = virtio_dequeue(vsc, vq, &slot, &len);
	if (error != 0) {
		log(LOG_ERR, "%s: virtio_dequeue failed: %d\n",
		       device_xname(sc->sc_dev), error);
		goto out;
	}
	DLOG("len=%d", len);
	sc->sc_buf_rx_len = len;
	sc->sc_buf_rx_offset = 0;
	bus_dmamap_sync(virtio_dmat(vsc), sc->sc_dmamap_tx, 0, VIO9P_MAX_REQLEN,
	    BUS_DMASYNC_POSTWRITE);
	bus_dmamap_sync(virtio_dmat(vsc), sc->sc_dmamap_rx, 0, VIO9P_MAX_REQLEN,
	    BUS_DMASYNC_POSTREAD);
	virtio_dequeue_commit(vsc, vq, slot);
#ifdef VIO9P_DUMP
	int i;
	log(LOG_DEBUG, "%s: buf: ", __func__);
	for (i = 0; i < len; i++) {
		log(LOG_DEBUG, "%c", (char)sc->sc_buf_rx[i]);
	}
	log(LOG_DEBUG, "\n");
#endif

consume:
	DLOG("uio_resid=%lu", uio->uio_resid);
	if (len < uio->uio_resid) {
		error = EINVAL;
		goto out;
	}
	len = uio->uio_resid;
	error = uiomove(sc->sc_buf_rx + sc->sc_buf_rx_offset, len, uio);
	if (error != 0)
		goto out;

	sc->sc_buf_rx_offset += len;
	if (sc->sc_buf_rx_offset == sc->sc_buf_rx_len) {
		sc->sc_buf_rx_len = 0;
		sc->sc_buf_rx_offset = 0;

		sc->sc_state = VIO9P_S_INIT;
		selnotify(&sc->sc_sel, 0, 1);
	}

out:
	mutex_exit(&sc->sc_lock);
	return error;
}

static int
vio9p_write(struct file *fp, off_t *offp, struct uio *uio,
    kauth_cred_t cred, int flags)
{
	struct vio9p_softc *sc = fp->f_data;
	struct virtio_softc *vsc = sc->sc_virtio;
	struct virtqueue *vq = &sc->sc_vq[0];
	int error, slot;
	size_t len;

	DLOG("enter");

	mutex_enter(&sc->sc_lock);

	if (sc->sc_state != VIO9P_S_INIT) {
		DLOG("already requesting");
		error = EAGAIN;
		goto out;
	}

	if (uio->uio_resid == 0) {
		error = 0;
		goto out;
	}

	if (uio->uio_resid > VIO9P_MAX_REQLEN) {
		error = EINVAL;
		goto out;
	}

	len = uio->uio_resid;
	error = uiomove(sc->sc_buf_tx, len, uio);
	if (error != 0)
		goto out;

	DLOG("len=%lu", len);
#ifdef VIO9P_DUMP
	int i;
	log(LOG_DEBUG, "%s: buf: ", __func__);
	for (i = 0; i < len; i++) {
		log(LOG_DEBUG, "%c", (char)sc->sc_buf_tx[i]);
	}
	log(LOG_DEBUG, "\n");
#endif

	error = virtio_enqueue_prep(vsc, vq, &slot);
	if (error != 0) {
		log(LOG_ERR, "%s: virtio_enqueue_prep failed\n",
		       device_xname(sc->sc_dev));
		goto out;
	}
	DLOG("slot=%d", slot);
	error = virtio_enqueue_reserve(vsc, vq, slot,
	    sc->sc_dmamap_tx->dm_nsegs + sc->sc_dmamap_rx->dm_nsegs);
	if (error != 0) {
		log(LOG_ERR, "%s: virtio_enqueue_reserve failed\n",
		       device_xname(sc->sc_dev));
		goto out;
	}

	/* Tx */
	bus_dmamap_sync(virtio_dmat(vsc), sc->sc_dmamap_tx, 0,
	    len, BUS_DMASYNC_PREWRITE);
	virtio_enqueue(vsc, vq, slot, sc->sc_dmamap_tx, true);
	/* Rx */
	bus_dmamap_sync(virtio_dmat(vsc), sc->sc_dmamap_rx, 0,
	    VIO9P_MAX_REQLEN, BUS_DMASYNC_PREREAD);
	virtio_enqueue(vsc, vq, slot, sc->sc_dmamap_rx, false);
	virtio_enqueue_commit(vsc, vq, slot, true);

	sc->sc_state = VIO9P_S_REQUESTING;
out:
	mutex_exit(&sc->sc_lock);
	return error;
}

static int
vio9p_close(struct file *fp)
{
	struct vio9p_softc *sc = fp->f_data;

	KASSERT(ISSET(sc->sc_flags, VIO9P_INUSE));
	sc->sc_flags &= ~VIO9P_INUSE;

	return 0;
}

static void
filt_vio9p_detach(struct knote *kn)
{
	struct vio9p_softc *sc = kn->kn_hook;

	mutex_enter(&sc->sc_lock);
	selremove_knote(&sc->sc_sel, kn);
	mutex_exit(&sc->sc_lock);
}

static int
filt_vio9p_read(struct knote *kn, long hint)
{
	struct vio9p_softc *sc = kn->kn_hook;
	int rv;

	kn->kn_data = sc->sc_buf_rx_len;
	/* XXX need sc_lock? */
	rv = (kn->kn_data > 0) || sc->sc_state != VIO9P_S_INIT;

	return rv;
}

static const struct filterops vio9p_read_filtops = {
	.f_flags = FILTEROP_ISFD,
	.f_attach = NULL,
	.f_detach = filt_vio9p_detach,
	.f_event = filt_vio9p_read,
};

static int
filt_vio9p_write(struct knote *kn, long hint)
{
	struct vio9p_softc *sc = kn->kn_hook;

	/* XXX need sc_lock? */
	return sc->sc_state == VIO9P_S_INIT;
}

static const struct filterops vio9p_write_filtops = {
	.f_flags = FILTEROP_ISFD,
	.f_attach = NULL,
	.f_detach = filt_vio9p_detach,
	.f_event = filt_vio9p_write,
};

static int
vio9p_kqfilter(struct file *fp, struct knote *kn)
{
	struct vio9p_softc *sc = fp->f_data;

	switch (kn->kn_filter) {
	case EVFILT_READ:
		kn->kn_fop = &vio9p_read_filtops;
		break;

	case EVFILT_WRITE:
		kn->kn_fop = &vio9p_write_filtops;
		break;

	default:
		log(LOG_ERR, "%s: kn_filter=%u\n", __func__, kn->kn_filter);
		return EINVAL;
	}

	kn->kn_hook = sc;

	mutex_enter(&sc->sc_lock);
	selrecord_knote(&sc->sc_sel, kn);
	mutex_exit(&sc->sc_lock);

	return 0;
}

CFATTACH_DECL_NEW(vio9p, sizeof(struct vio9p_softc),
    vio9p_match, vio9p_attach, NULL, NULL);

static int
vio9p_match(device_t parent, cfdata_t match, void *aux)
{
	struct virtio_attach_args *va = aux;

	if (va->sc_childdevid == VIRTIO_DEVICE_ID_9P)
		return 1;

	return 0;
}

static void
vio9p_attach(device_t parent, device_t self, void *aux)
{
	struct vio9p_softc *sc = device_private(self);
	struct virtio_softc *vsc = device_private(parent);
	uint64_t features;
	int error;
	const struct sysctlnode *node;

	if (virtio_child(vsc) != NULL) {
		aprint_normal(": child already attached for %s; "
			      "something wrong...\n", device_xname(parent));
		return;
	}

	sc->sc_dev = self;
	sc->sc_virtio = vsc;

	virtio_child_attach_start(vsc, self, IPL_VM,
	    VIO9P_F_MOUNT_TAG, VIO9P_FLAG_BITS);

	features = virtio_features(vsc);
	if ((features & VIO9P_F_MOUNT_TAG) == 0)
		goto err_none;

	virtio_init_vq_vqdone(vsc, &sc->sc_vq[0], 0, vio9p_request_done);
	error = virtio_alloc_vq(vsc, &sc->sc_vq[0], VIO9P_MAX_REQLEN,
	    VIO9P_N_SEGMENTS * 2, "vio9p");
	if (error != 0)
		goto err_none;

	sc->sc_buf_tx = kmem_alloc(VIO9P_MAX_REQLEN, KM_SLEEP);
	sc->sc_buf_rx = kmem_alloc(VIO9P_MAX_REQLEN, KM_SLEEP);

	error = bus_dmamap_create(virtio_dmat(vsc), VIO9P_MAX_REQLEN,
	    VIO9P_N_SEGMENTS, VIO9P_SEGSIZE, 0, BUS_DMA_WAITOK, &sc->sc_dmamap_tx);
	if (error != 0) {
		aprint_error_dev(sc->sc_dev, "bus_dmamap_create failed: %d\n",
		    error);
		goto err_vq;
	}
	error = bus_dmamap_create(virtio_dmat(vsc), VIO9P_MAX_REQLEN,
	    VIO9P_N_SEGMENTS, VIO9P_SEGSIZE, 0, BUS_DMA_WAITOK, &sc->sc_dmamap_rx);
	if (error != 0) {
		aprint_error_dev(sc->sc_dev, "bus_dmamap_create failed: %d\n",
		    error);
		goto err_vq;
	}

	error = bus_dmamap_load(virtio_dmat(vsc), sc->sc_dmamap_tx,
	    sc->sc_buf_tx, VIO9P_MAX_REQLEN, NULL, BUS_DMA_WAITOK | BUS_DMA_WRITE);
	if (error != 0) {
		aprint_error_dev(sc->sc_dev, "bus_dmamap_load failed: %d\n",
		    error);
		goto err_dmamap;
	}
	error = bus_dmamap_load(virtio_dmat(vsc), sc->sc_dmamap_rx,
	    sc->sc_buf_rx, VIO9P_MAX_REQLEN, NULL, BUS_DMA_WAITOK | BUS_DMA_READ);
	if (error != 0) {
		aprint_error_dev(sc->sc_dev, "bus_dmamap_load failed: %d\n",
		    error);
		goto err_dmamap;
	}

	sc->sc_state = VIO9P_S_INIT;
	mutex_init(&sc->sc_lock, MUTEX_DEFAULT, IPL_NONE);
	cv_init(&sc->sc_wait, "vio9p");

	vio9p_read_config(sc);
	aprint_normal_dev(self, "tagged as %s\n", sc->sc_tag);

	sysctl_createv(NULL, 0, NULL, &node, 0, CTLTYPE_NODE,
	    "vio9p", SYSCTL_DESCR("VirtIO 9p toplevel"),
	    NULL, 0, NULL, 0,
	    CTL_HW, CTL_CREATE, CTL_EOL);
	sysctl_createv(NULL, 0, &node, &node, 0, CTLTYPE_NODE,
	    device_xname(self), SYSCTL_DESCR("VirtIO 9p device"),
	    NULL, 0, NULL, 0,
	    CTL_CREATE, CTL_EOL);
	sysctl_createv(NULL, 0, &node, NULL, 0, CTLTYPE_STRING,
	    "tag", SYSCTL_DESCR("VirtIO 9p tag value"),
	    NULL, 0, sc->sc_tag, 0,
	    CTL_CREATE, CTL_EOL);

	error = virtio_child_attach_finish(vsc, sc->sc_vq,
	    __arraycount(sc->sc_vq), NULL,
	    VIRTIO_F_INTR_MPSAFE | VIRTIO_F_INTR_SOFTINT);
	if (error != 0)
		goto err_mutex;

	return;

err_mutex:
	cv_destroy(&sc->sc_wait);
	mutex_destroy(&sc->sc_lock);
err_dmamap:
	bus_dmamap_destroy(virtio_dmat(vsc), sc->sc_dmamap_tx);
	bus_dmamap_destroy(virtio_dmat(vsc), sc->sc_dmamap_rx);
err_vq:
	virtio_free_vq(vsc, &sc->sc_vq[0]);
err_none:
	virtio_child_attach_failed(vsc);
	return;
}

static void
vio9p_read_config(struct vio9p_softc *sc)
{
	device_t dev = sc->sc_dev;
	uint8_t reg;
	int i;

	/* these values are explicitly specified as little-endian */
	sc->sc_taglen = virtio_read_device_config_le_2(sc->sc_virtio,
		VIO9P_CONFIG_TAG_LEN);

	if (sc->sc_taglen > P9_MAX_TAG_LEN) {
		aprint_error_dev(dev, "warning: tag is trimmed from %u to %u\n",
		    sc->sc_taglen, P9_MAX_TAG_LEN);
		sc->sc_taglen = P9_MAX_TAG_LEN;
	}

	for (i = 0; i < sc->sc_taglen; i++) {
		reg = virtio_read_device_config_1(sc->sc_virtio,
		    VIO9P_CONFIG_TAG + i);
		sc->sc_tag[i] = reg;
	}
	sc->sc_tag[i] = '\0';
}

static int
vio9p_request_done(struct virtqueue *vq)
{
	struct virtio_softc *vsc = vq->vq_owner;
	struct vio9p_softc *sc = device_private(virtio_child(vsc));

	DLOG("enter");

	mutex_enter(&sc->sc_lock);
	sc->sc_state = VIO9P_S_REPLIED;
	cv_broadcast(&sc->sc_wait);
	selnotify(&sc->sc_sel, 0, 1);
	mutex_exit(&sc->sc_lock);

	return 1;
}

MODULE(MODULE_CLASS_DRIVER, vio9p, "virtio");

#ifdef _MODULE
#include "ioconf.c"
#endif

static int
vio9p_modcmd(modcmd_t cmd, void *opaque)
{
#ifdef _MODULE
	devmajor_t bmajor = NODEVMAJOR, cmajor = NODEVMAJOR;
#endif
	int error = 0;

#ifdef _MODULE
	switch (cmd) {
	case MODULE_CMD_INIT:
		devsw_attach(vio9p_cd.cd_name, NULL, &bmajor,
		    &vio9p_cdevsw, &cmajor);
		error = config_init_component(cfdriver_ioconf_vio9p,
		    cfattach_ioconf_vio9p, cfdata_ioconf_vio9p);
		break;
	case MODULE_CMD_FINI:
		error = config_fini_component(cfdriver_ioconf_vio9p,
		    cfattach_ioconf_vio9p, cfdata_ioconf_vio9p);
		devsw_detach(NULL, &vio9p_cdevsw);
		break;
	default:
		error = ENOTTY;
		break;
	}
#endif

	return error;
}
