/*************************************************************************
***	Authentication, authorization, accounting + firewalling package
***	Copyright 1998-2002 Anton Vinokurov <anton@netams.com>
***	Copyright 2002-2008 NeTAMS Development Team
***	This code is GPL v3
***	For latest version and more info, visit this project web page
***	located at http://www.netams.com
***
*************************************************************************/
/* $Id: ng_netams.c,v 1.24 2008-02-23 08:35:03 anton Exp $ */

/* NetGraph-related stuff was taken from:
 * 	http://www.daemonnews.org/200003/netgraph.html 
 * 	http://cell.sick.ru/~glebius/ng_netflow/
 * 	ftp://ftp.wuppy.net.ru/pub/FreeBSD/local/kernel/ng_ipacct
 * 	man 4 netgraph
 *  /usr/include/netgraph/ and /sys/netgraph/
 */

#include <sys/param.h>
#include <sys/systm.h>
#include <sys/kernel.h>
#include <sys/mbuf.h>
#include <sys/malloc.h>
#include <sys/ctype.h>
#include <sys/errno.h>
#include <sys/time.h>
#include <sys/socket.h>
#include <sys/socketvar.h>
#include <sys/ucred.h>
#include <sys/sysctl.h>	
#include <net/bpf.h>
#include <net/ethernet.h>
#include <net/route.h>
#include <sys/lock.h>
#include <sys/mutex.h>

#include <netinet/in_systm.h>
#include <netinet/in.h>
#include <netinet/in_pcb.h>
#include <netinet/ip.h>
#include <netinet/ip_icmp.h>
#include <netinet/ip_var.h>
#include <netinet/tcp.h>
#include <netinet/tcp_var.h>
#include <netinet/udp.h>
#include <netinet/udp_var.h>

#include <netgraph/ng_message.h>
#include <netgraph/ng_parse.h>
#include <netgraph/netgraph.h>
 
MALLOC_DECLARE(M_NETAMS);
MALLOC_DEFINE(M_NETAMS, "netams", "NeTAMS Kernel Module");

#include "ng_netams.h"

/////////////////////////////////////////////////////////////////////////////////////
struct netams_module {
	node_p	node;	
	u_int32_t	packets_in;
	u_int32_t	packets_out;
	hook_p	if_in, if_out;
	u_int32_t	mode;
	u_int32_t	debug;
	unsigned daemon_cookie;
	ng_ID_t daemon_node;
	struct callout callout;
	ng_entry **table;
	u_int32_t	active_flows;
	u_int32_t	total_flows;
	u_int32_t	queued_flows;
	u_int32_t	queued_bytes;
	u_int32_t	queued_packets;
	u_int32_t	blocked_flows;
	u_int32_t	default_policy;
	ng_entry *active;
	struct mtx mtx_queues;
	struct mtx mtx_active;
};
typedef struct netams_module *netams_module_p;

/////////////////////////////////////////////////////////////////////////////////////
static ng_constructor_t ng_netams_constructor;
static ng_rcvmsg_t      ng_netams_rcvmsg;
static ng_shutdown_t    ng_netams_shutdown;
static ng_newhook_t     ng_netams_newhook;
static ng_rcvdata_t     ng_netams_rcvdata;
static ng_disconnect_t  ng_netams_disconnect;
void ng_netams_callout(void*);
ng_entry* ng_netams_processpacket(struct ip *ip, netams_module_p data);
void ng_netams_flushqueue(ng_entry *e, netams_module_p data);
void ng_netams_set_info(struct ng_mesg *msg, netams_module_p data);
/////////////////////////////////////////////////////////////////////////////////////
static const struct ng_parse_struct_field ng_netams_info_type_fields[] = NG_NETAMS_INFO_TYPE;
static const struct ng_parse_type ng_netams_info_type = {
	&ng_parse_struct_type,
	&ng_netams_info_type_fields
};

static const struct ng_cmdlist ng_netams_cmds[] = {
	{
	NG_NETAMS_COOKIE,
	NG_NETAMS_INFO,
	"info",
	NULL,
	&ng_netams_info_type 
	},
	{
	NG_NETAMS_COOKIE,
	NG_NETAMS_SETMODE,
	"mode",
	&ng_parse_int32_type,
	NULL 
	},
	{
	NG_NETAMS_COOKIE,
	NG_NETAMS_SETDEFAULT,
	"setdefault",
	&ng_parse_int32_type,
	NULL 
	},
	{
	NG_NETAMS_COOKIE,
	NG_NETAMS_DEBUG,
	"debug",
	&ng_parse_int32_type,
	NULL 
	},
	{
	NG_NETAMS_COOKIE,
	NG_NETAMS_REGISTER,
	"register",
	&ng_parse_int32_type,
	NULL 
	},
    { 0 }
};

/////////////////////////////////////////////////////////////////////////////////////
static struct ng_type ng_netams_typestruct = {
	.version =	NG_ABI_VERSION,
	.name =		NG_NETAMS_NODE_TYPE,
	.constructor =	ng_netams_constructor,
	.rcvmsg =	ng_netams_rcvmsg,
	.shutdown =	ng_netams_shutdown,
	.newhook =	ng_netams_newhook,
	.rcvdata =	ng_netams_rcvdata,
	.disconnect =	ng_netams_disconnect,
	.cmdlist =	ng_netams_cmds,
};
NETGRAPH_INIT(netams, &ng_netams_typestruct);

/////////////////////////////////////////////////////////////////////////////////////
static int ng_netams_constructor(node_p nodep) {
	netams_module_p data;
	MALLOC(data, netams_module_p, sizeof(*data), M_NETAMS, M_NOWAIT|M_ZERO);
	if (data==NULL)	return ENOMEM;

	NG_NODE_SET_PRIVATE(nodep, data);
	data->node = nodep;
	
	data->packets_in=data->packets_out=0;
	data->if_in=NULL;	data->if_out=NULL;
	data->mode=NG_NETAMS_MODE_TEE;
	data->debug=1;
	data->default_policy=NG_NETAMS_DEFAULT_PASS;
	data->active_flows=data->total_flows=data->blocked_flows=0;
	data->queued_flows=data->queued_bytes=data->queued_packets=0;
	data->daemon_cookie=data->daemon_node=0;
	data->active=NULL;
	callout_init(&data->callout, 1);
	callout_reset(&data->callout, hz, &ng_netams_callout, (void*)data);
	MALLOC(data->table, ng_entry**, NG_IPV4_HASH_SIZE*sizeof(ng_entry*), M_NETAMS, M_NOWAIT|M_ZERO);

	mtx_init(&data->mtx_queues,  "ng_netams mbuf queues mutex", NULL, MTX_DEF);
	mtx_init(&data->mtx_active, "ng_netams active entries list mutex", NULL, MTX_DEF);
	
	if (data->debug) printf("Starting \"%s\" kernel module\n", NG_NETAMS_NODE_TYPE);
	return 0;
}

/////////////////////////////////////////////////////////////////////////////////////
static int ng_netams_newhook(node_p node, hook_p hook, const char *name) {
	netams_module_p data;
	data=NG_NODE_PRIVATE(node);
	if (strncmp(name, NG_NETAMS_HOOK_IF_IN, strlen(NG_NETAMS_HOOK_IF_IN)) == 0) data->if_in=hook;
	else if (strncmp(name, NG_NETAMS_HOOK_IF_OUT, strlen(NG_NETAMS_HOOK_IF_OUT)) == 0) data->if_out=hook;
	else return EINVAL;
	return 0;
}

/////////////////////////////////////////////////////////////////////////////////////
static int ng_netams_rcvmsg(node_p node, item_p item, hook_p lasthook) {

	struct ng_mesg *resp = NULL;
	int error = 0;
	struct ng_mesg *msg;
	const netams_module_p data = NG_NODE_PRIVATE(node);

	NGI_GET_MSG(item, msg);
	 
	switch (msg->header.typecookie) {
	case NG_NETAMS_COOKIE: 
		switch (msg->header.cmd) {
		case NG_NETAMS_INFO:  
		{
			NG_MKRESPONSE(resp, msg, sizeof(struct ng_netams_info), M_NOWAIT);
			ng_netams_set_info(resp, data);
			break;
		}
		case NG_NETAMS_SETMODE:  
		{
			u_int32_t *m;
			m = (u_int32_t*)(msg->data);
			if (*m==NG_NETAMS_MODE_TEE || *m==NG_NETAMS_MODE_DIVERT) data->mode=*m;
			else error = EINVAL;
			break;
		}
		case NG_NETAMS_SETDEFAULT:  
		{
			u_int32_t *m;
			m = (u_int32_t*)(msg->data);
			if (*m==NG_NETAMS_DEFAULT_DROP || *m==NG_NETAMS_DEFAULT_PASS || *m==NG_NETAMS_DEFAULT_PASS_SSH_LOCAL) data->mode=*m;
			else error = EINVAL;
			break;
		}
		case NG_NETAMS_DEBUG:  
		{
			u_int32_t *m;
			m = (u_int32_t*)(msg->data);
			if (*m==1 || *m==0) data->debug=*m;
			else error = EINVAL;
			break;
		}
		case NG_NETAMS_REGISTER:  
		{
			u_int32_t *m;
			m = (u_int32_t*)(msg->data);
			data->daemon_cookie=*m;
			data->daemon_node = NGI_RETADDR(item);
			if (data->debug) printf("%s will send stats to node %i with cookie %u\n", NG_NETAMS_NODE_TYPE, data->daemon_node, data->daemon_cookie);
			break;
		}
		case NG_NETAMS_FWREPLY: 
		{
			ng_entry *e_rcvd = (ng_entry*)msg->data;
			ng_entry *e;
			mtx_lock(&data->mtx_active);
			for(e=data->table[e_rcvd->hash]; e!=NULL; e=e->next) if (e->id==e_rcvd->id) break;
			mtx_unlock(&data->mtx_active);			
			if (e==NULL) { error = EINVAL; break; } // we got some reply for nonexisting entry
			if (data->debug) printf("%s fwreply for entry id=%u, flags=%u, queue %u/%u\n", NG_NETAMS_NODE_TYPE, e->id, e_rcvd->flags, e->queue_packets, e->queue_bytes);

			mtx_lock(&data->mtx_queues);
			if (e_rcvd->flags==2) { e->flags=NG_NETAMS_FW_DROP; data->blocked_flows++; }
			else if (e_rcvd->flags==0) e->flags=NG_NETAMS_FW_PASS;
			ng_netams_flushqueue(e, data);
			mtx_unlock(&data->mtx_queues);
			break;
		}
		default:
			error = EINVAL;		/* unknown command */
			break;
		}
		break;

	default:
		error = EINVAL;			/* unknown cookie type */
		break;
	}

	NG_RESPOND_MSG(error, node, item, resp);
	NG_FREE_MSG(msg);

	return error;
}

/////////////////////////////////////////////////////////////////////////////////////
static int ng_netams_rcvdata(hook_p hook, item_p item) {
	const node_p node = NG_HOOK_NODE(hook);
	netams_module_p data = NG_NODE_PRIVATE(node);
	struct mbuf *m;
	hook_p hk=NULL; // where we will forward (if will!) that packet
	int error=0;
	
	if (hook==data->if_in) { data->packets_in++; hk=data->if_out; }
	else if (hook==data->if_out) { data->packets_out++; hk=data->if_in; }

	m=NGI_M(item);
	if (CHECK_MLEN(m, (sizeof(struct ether_header)+sizeof(struct ip)))) { 
		NG_FWD_ITEM_HOOK(error, item, hk);
		return (error);
	}
	
	if (CHECK_PULLUP(m, (sizeof(struct ether_header)+sizeof(struct ip)))) { 
		NG_FWD_ITEM_HOOK(error, item, hk);
		return (error);
	}
	
	register struct ether_header *eh = mtod(m, struct ether_header *);
	struct ip *ip =(struct ip*)(m->m_data+sizeof (struct ether_header));

	if (ntohs(eh->ether_type)!=ETHERTYPE_IP || ip->ip_v!=4) {
		NG_FWD_ITEM_HOOK(error, item, hk);
		return (error);
	}

	ng_entry *e=ng_netams_processpacket(ip, data);

 	if (e==NULL || (e && e->flags==NG_NETAMS_FW_DROP)) {
		NG_FREE_ITEM(item);
	}
	else if (e->flags==NG_NETAMS_FW_PASS) NG_FWD_ITEM_HOOK(error, item, hk);
	else if (e->flags==NG_NETAMS_FW_QUEUE) {
		// DoS protection
		if (e->queue_bytes>NG_MAX_QUEUE_BYTES
			|| e->queue_packets>NG_MAX_QUEUE_PACKETS
			|| data->queued_bytes>NG_MAX_TOTAL_QUEUED_BYTES
			|| data->queued_packets>NG_MAX_TOTAL_QUEUED_PACKETS
			|| data->queued_flows>NG_MAX_TOTAL_QUEUED_FLOWS) NG_FREE_ITEM(item);
		
		else {
			mtx_lock(&data->mtx_queues);
			ng_queue_item  *q;
			MALLOC(q, ng_queue_item*, sizeof(ng_queue_item), M_NETAMS, M_NOWAIT|M_ZERO);
			q->i=item;
			q->next=NULL;
			if (e->queue==NULL) { 
				e->sendhook=(void*)hk;
				e->queue=q; e->queue_last=q; 
				data->queued_flows++;
				if (data->debug) printf("netams: created queue %p for id=%u, hash=%05u\n", q, e->id, e->hash);
			}
			else {
				e->queue_last->next=q;
				e->queue_last=q;
				}
			e->queue_packets++; e->queue_bytes+=m->m_len;
			data->queued_packets++; data->queued_bytes+=m->m_len;
			mtx_unlock(&data->mtx_queues);
		}
	}

	return (error);
}
/////////////////////////////////////////////////////////////////////////////////////
static int ng_netams_shutdown(node_p node) {
	const netams_module_p data = NG_NODE_PRIVATE(node);
	callout_drain(&data->callout);
	NG_NODE_SET_PRIVATE(node, NULL);
	NG_NODE_UNREF(node);

	ng_entry *ptr, *ptr1;
	for (int i=0; i<NG_IPV4_HASH_SIZE; i++)
		if ((ptr=data->table[i]))
			while(ptr) {
				ptr1=ptr;
				ptr=ptr->next;
				FREE(ptr1, M_NETAMS); }
			
	FREE(data->table, M_NETAMS);
	mtx_destroy(&data->mtx_queues);
	mtx_destroy(&data->mtx_active);
	FREE(data, M_NETAMS);
	if (data->debug) printf("Shutting down \"%s\" kernel module\n", NG_NETAMS_NODE_TYPE);
	return 0;
}

/////////////////////////////////////////////////////////////////////////////////////
static int ng_netams_disconnect(hook_p hook) {

	if ((NG_NODE_NUMHOOKS(NG_HOOK_NODE(hook)) == 0) && (NG_NODE_IS_VALID(NG_HOOK_NODE(hook)))) {
		ng_rmnode_self(NG_HOOK_NODE(hook));
	}
	return 0;
}

/////////////////////////////////////////////////////////////////////////////////////
void ng_netams_callout(void* arg){
	int error=0;
	netams_module_p data = (netams_module_p)arg;
	struct ng_mesg *msg;
	ng_entry *e;
	mtx_lock(&data->mtx_active);
	ng_entry *e2, *prev=NULL;
	unsigned t_active=0, f_active=0, f_queued=0;

	for (e=data->active; e!=NULL; e=e2) {
		t_active++;
		e2=e->next_active;
		if (e->queue) {
			if (time_second - e->First > NG_NETAMS_DEFAULT_TIMEOUT) { // entry queue stalled!
				// here we should flush a queue or do something similar
				mtx_lock(&data->mtx_queues);
				switch (data->default_policy) {
					case NG_NETAMS_DEFAULT_DROP:
						e->flags=NG_NETAMS_FW_DROP; 
						break;
					case NG_NETAMS_DEFAULT_PASS:
						e->flags=NG_NETAMS_FW_PASS; 
						break;
					case NG_NETAMS_DEFAULT_PASS_SSH_LOCAL:
						if (ntohs(e->srcport)==22 || ntohs(e->dstport)==22) e->flags=NG_NETAMS_FW_PASS;
						else e->flags=NG_NETAMS_FW_DROP; 
						break;
				}
				ng_netams_flushqueue(e, data);
				mtx_unlock(&data->mtx_queues);
			}
			else f_queued++;
		}
		if (( (e->Last - e->First) > NG_ACTIVE_TIMEOUT || (time_second - e->Last) > NG_INACTIVE_TIMEOUT ) && data->daemon_cookie && data->daemon_node ) {	
			if (data->debug) printf("callout: expired ng_entry hash=%u [%lu bytes], e_nextactive=%p\n", e->hash, e->dOctets, e->next_active);
			// detach entry
			data->active_flows--;
			unsigned hash=e->hash;
			if (e==data->table[hash])
				data->table[hash]=(ng_entry *)e->next;
			else {
				ng_entry *p;
				for (p=data->table[hash]; p->next!=e; p=p->next);
				p->next=e->next;
			}
			if (e==data->active)
				data->active=e->next_active;
			else {
				prev->next_active=e->next_active;
			}
			// prepare and send control message	
			NG_MKMESSAGE(msg, NG_NETAMS_COOKIE, NG_NETAMS_DATA, sizeof (struct ng_entry), M_NOWAIT);
			msg->header.token=data->daemon_cookie;
			memcpy(msg->data, e, sizeof (struct ng_entry));
			FREE(e, M_NETAMS);
			NG_SEND_MSG_ID(error, data->node, msg, data->daemon_node, NG_NODE_ID(data->node));
			if (error) { data->daemon_cookie=data->daemon_node=0; } // something went wrong, reset userspace destination to prevent further loss
			f_active++;
		} else 
			prev=e;
	}
	if (data->debug)
		printf("callout/%lu%c active %u, checked %u, queued=%u, flushed %u\n",
			(u_long)time_second, data->daemon_node?'+':'-',
			data->active_flows, t_active, f_queued, f_active);
	mtx_unlock(&data->mtx_active);

	if (data->daemon_node!=0 && time_second%10==0) { // time to send info to daemon
		NG_MKMESSAGE(msg, NG_NETAMS_COOKIE, NG_NETAMS_INFO, sizeof (struct ng_netams_info), M_NOWAIT);
		msg->header.token=data->daemon_cookie;
		ng_netams_set_info(msg, data);
		NG_SEND_MSG_ID(error, data->node, msg, data->daemon_node, NG_NODE_ID(data->node));
		if (data->debug)
			printf("info/%lu: sent to daemon [%u] with error=%u\n",
				(u_long)time_second, data->daemon_node, error);
		if (error) { data->daemon_cookie=data->daemon_node=0; } // something went wrong, reset userspace destination to prevent further loss
	}

	callout_reset(&data->callout, hz, &ng_netams_callout, (void*)data);
}
/////////////////////////////////////////////////////////////////////////////////////
ng_entry* ng_netams_processpacket(struct ip *ip, netams_module_p data){
   	unsigned hash;
	u_short sp=0, dp=0;
	in_addr_t src, dst;
	u_char proto;
	u_char tos;
	u_char tcp_flags=0;
	ng_entry *e;
	u_short chunk=0;
	
	dst = ip->ip_dst.s_addr;
	src = ip->ip_src.s_addr;
	proto = ip->ip_p;
	tos = ip->ip_tos;
	
	if (ntohs(ip->ip_off) & (IP_OFFMASK | IP_MF)) {
		//packet is fragmented. check obtained from /usr/src/sys/netinet/ip_fastfwd.c
		sp=dp=0;
		hash=NG_IPV4_ADDR_HASH(src, dst);
	} else if (proto==IPPROTO_TCP || proto==IPPROTO_UDP) {
		struct tcphdr *th;
		th=(struct tcphdr *)((unsigned char *)ip + ip->ip_hl*4);
		sp=th->th_sport;
		dp=th->th_dport;
		if(proto==IPPROTO_TCP) tcp_flags=th->th_flags;
		hash=NG_IPV4_FULL_HASH(src, dst, sp, dp);
	} else {
		hash=NG_IPV4_ADDR_HASH(src, dst);
	}
    	
	mtx_lock(&data->mtx_active);
	for(e=data->table[hash]; e!=NULL; e=e->next){
        if (e->dstaddr.s_addr==dst && e->srcaddr.s_addr==src &&
			e->prot==proto && e->tos==tos &&
			e->srcport==sp && e->dstport==dp) {
			//if(BWCheck(e, ntohs(ip->ip_len), tv)) return 0; //drop packet
            	e->dOctets+=ntohs(ip->ip_len);
                e->dPkts++;
                e->Last=time_second;
                break;
        	}
		chunk++;
    	}
	mtx_unlock(&data->mtx_active);
	
	if (e==NULL) {
		if(chunk>NG_MAX_CHUNK)  {
			if (data->debug) printf("processpacket: max chunk reached for hash=%u\n", hash);
			return NULL; //protection against DoS
		}
		MALLOC(e, ng_entry*, sizeof(ng_entry), M_NETAMS, M_NOWAIT|M_ZERO);
		
		e->dstaddr.s_addr=dst;
		e->srcaddr.s_addr=src;
		e->dstport=dp;
		e->srcport=sp;
		e->dOctets=ntohs(ip->ip_len);
		e->prot=proto;
		e->tos=tos;
		e->First=e->Last=time_second;
		//if (data->mode==NG_NETAMS_MODE_TEE) e->flags=NG_NETAMS_FW_DROP; else e->flags=NG_NETAMS_FW_PASS;
		if (data->mode==NG_NETAMS_MODE_TEE) e->flags=NG_NETAMS_FW_DROP; else e->flags=NG_NETAMS_FW_QUEUE;
		e->id=(data->total_flows++);
		e->dPkts=1;
		e->hash=hash;
        e->queue=NULL;
        e->queue_bytes=e->queue_packets=0;
        
		//if((flags&FLOW_FW_CHECK) && !FW(e) e->flags|=ENTRY_BLOCKED;
		mtx_lock(&data->mtx_active);
		e->next_active=data->active;
		data->active=e;
		e->next=data->table[hash];
		data->table[hash]=e;
		mtx_unlock(&data->mtx_active);
		data->active_flows++;

		if (e->flags==NG_NETAMS_FW_QUEUE) { // create a FW request
			struct ng_mesg *msg;
			int error;
			NG_MKMESSAGE(msg, NG_NETAMS_COOKIE, NG_NETAMS_FWREQUEST, sizeof (struct ng_entry), M_NOWAIT);
			msg->header.token=data->daemon_cookie;
			memcpy(msg->data, e, sizeof (struct ng_entry));
			NG_SEND_MSG_ID(error, data->node, msg, data->daemon_node, NG_NODE_ID(data->node));
			if (error) { data->daemon_cookie=data->daemon_node=0; } // something went wrong, reset userspace destination to prevent further loss
			}
			
		if (data->debug) printf("netams: created flow record id=%u, hash=%05u, time=%lu, proto=%u\n", e->id, hash, e->First, e->prot);
    	}
	
	return e;
}
/////////////////////////////////////////////////////////////////////////////////////
void ng_netams_flushqueue(ng_entry *e, netams_module_p data){

	ng_queue_item  *q, *p=NULL;
	int error;
	hook_p h=(hook_p)e->sendhook;

	if (data->debug) printf("netams: flush queue for entry id=%u, hash=%u, size=%u, action=%u\n", e->id, e->hash, e->queue_packets, e->flags);

	for (q=e->queue; q!=NULL; q=q->next) {
		if (p) FREE(p, M_NETAMS);
		item_p i = q->i;
		if (e->flags==NG_NETAMS_FW_DROP) NG_FREE_ITEM(i);
		else if (e->flags==NG_NETAMS_FW_PASS) NG_FWD_ITEM_HOOK(error, i, h);
		p=q;
		}
	if (p) FREE(p, M_NETAMS);
		
	data->queued_flows--;
	data->queued_packets-=e->queue_packets; data->queued_bytes-=e->queue_bytes;
	e->queue=e->queue_last=NULL; e->queue_packets=e->queue_bytes=0;	
}
/////////////////////////////////////////////////////////////////////////////////////
void ng_netams_set_info(struct ng_mesg *msg, netams_module_p data){
	struct ng_netams_info *i = (struct ng_netams_info*)msg->data;
	i->packets_in = data->packets_in;
	i->packets_out = data->packets_out;
	i->mode = data->mode;
	i->debug = data->debug;
	i->active_flows = data->active_flows;
	i->total_flows = data->total_flows;
	i->queued_flows = data->queued_flows;
	i->queued_bytes = data->queued_bytes;
	i->queued_packets = data->queued_packets;
	i->blocked_flows = data->blocked_flows;
	i->default_policy = data->default_policy;
}
/////////////////////////////////////////////////////////////////////////////////////
	

