
/*
 * Copyright (c) Abraham vd Merwe <abz@blio.com>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *	  notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *	  notice, this list of conditions and the following disclaimer in the
 *	  documentation and/or other materials provided with the distribution.
 * 3. Neither the name of the author nor the names of other contributors
 *	  may be used to endorse or promote products derived from this software
 *	  without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <string.h>
#include <sys/types.h>
#include <errno.h>
#include <time.h>

#include <libiptc/libiptc.h>

#include <debug/memory.h>

#include <abz/aton.h>
#include <abz/error.h>
#include <abz/contigmask.h>

#include "clients.h"
#include "stats.h"

#define ADDR(x)					\
	(x) & 0x000000ff,			\
	((x) & 0x0000ff00) >> 8,	\
	((x) & 0x00ff0000) >> 16,	\
	((x) & 0xff000000) >> 24	\

#define ASSERT(x)															\
	if (!(x)) {																\
		abz_set_error ("%s:%d: assert(%s) failed",__FILE__,__LINE__,#x);	\
		return (-1);														\
	}

struct iptc
{
   char *table;
   char *chain;
   unsigned int rule;
   iptc_handle_t handle;
};

static struct iptc priv = { .table = NULL };

static void out_of_memory ()
{
   abz_set_error ("out of memory");
}

/*
 * Initialize client counters. Returns 0 if successful, -1 if some
 * error occurred. You can check which error occurred by calling
 * abz_get_error().
 */
int client_create (struct client **clients,const char *table,const char *chain)
{
   int result = -1;

   abz_clear_error ();

   ASSERT (priv.table == NULL);

   *clients = NULL;

   do
	 {
		priv.rule = 0;

		if ((priv.table = (char *) mem_alloc ((strlen (table) + 1) * sizeof (char))) == NULL ||
			(priv.chain = (char *) mem_alloc ((strlen (chain) + 1) * sizeof (char))) == NULL)
		  {
			 out_of_memory ();
			 break;
		  }

		strcpy (priv.table,table);
		strcpy (priv.chain,chain);

		if ((priv.handle = iptc_init (priv.table)) == NULL)
		  {
			 abz_set_error ("%s",iptc_strerror (errno));
			 break;
		  }

		if (!iptc_is_chain (priv.chain,priv.handle))
		  {
			 abz_set_error ("there is no chain called %s in table %s",chain,table);
			 break;
		  }

		if (!iptc_flush_entries (priv.chain,&priv.handle))
		  {
			 abz_set_error ("%s",iptc_strerror (errno));
			 break;
		  }

		result = 0;
	 }
   while (0);

   if (result < 0)
	 {
		if (priv.table != NULL) mem_free (priv.table);
		if (priv.chain != NULL) mem_free (priv.chain);
	 }

   return (result);
}

static void destroy (struct client **clients)
{
   while (*clients != NULL)
	 {
		struct client *tmp1 = *clients;

		*clients = (*clients)->next;
		while (tmp1->counter != NULL)
		  {
			 struct counter *tmp2 = tmp1->counter;
			 tmp1->counter = tmp1->counter->next;
			 mem_free (tmp2);
		  }
		mem_free (tmp1->name);
		stats_destroy (&tmp1->stats);
		mem_free (tmp1);
	 }
}

static void swap (struct client **clients,int a,int b)
{
   struct client *tmp,*n1,*n2,*n1prev,*n2prev;
   int i;

   for (n1 = *clients, n1prev = NULL, i = 0; n1->next != NULL && i < a; i++)
	 n1prev = n1, n1 = n1->next;

   for (n2 = *clients, n2prev = NULL, i = 0; n2->next != NULL && i < b; i++)
	 n2prev = n2, n2 = n2->next;

   if (n1prev != NULL) n1prev->next = n2;
   if (n2prev != NULL) n2prev->next = n1;

   tmp = n1->next;
   n1->next = n2->next;
   n2->next = tmp;

   if (n1prev == NULL) *clients = n2;
   if (n2prev == NULL) *clients = n1;
}

static const struct client *nth (const struct client *clients,int index)
{
   int i;

   for (i = 0; clients->next != NULL && i < index; i++)
	 clients = clients->next;

   return (clients);
}

static void quicksort (struct client **clients,int left,int right)
{
   int i = left,j = right;
   const struct client *median = nth (*clients,(left + right) / 2);

   do
	 {
		while (strcasecmp (nth (*clients,i)->name,median->name) < 0) i++;
		while (strcasecmp (median->name,nth (*clients,j)->name) < 0) j--;
		if (i <= j) swap (clients,i++,j--);
	 }
   while (i <= j);

   if (left < j) quicksort (clients,left,j);
   if (i < right) quicksort (clients,i,right);
}

static int commit ()
{
   if (!iptc_commit (&priv.handle) || (priv.handle = iptc_init (priv.table)) == NULL)
	 {
		abz_set_error ("%s",iptc_strerror (errno));
		return (-1);
	 }

   return (0);
}

/*
 * Commit changes and sort list of clients alphabetically. Return 0
 * if successful, -1 if some error occurred. You can check which error
 * occurred by calling abz_get_error().
 */
int client_commit (struct client **clients)
{
   struct client *tmp;
   int n = 0;

   abz_clear_error ();

   ASSERT (priv.table != NULL);

   if (commit () < 0) return (-1);
   for (tmp = *clients; tmp != NULL; tmp = tmp->next) n++;
   quicksort (clients,0,n - 1);

   return (0);
}

/*
 * Destroy client counters.
 */
void client_destroy (struct client **clients)
{
   if (priv.table != NULL)
	 {
		iptc_flush_entries (priv.chain,&priv.handle);
		iptc_commit (&priv.handle);
		destroy (clients);
		mem_free (priv.table);
		mem_free (priv.chain);
		priv.table = NULL;
	 }
}

/*
 * Add a client. Return 0 if successful, -1 if some error occurred.
 * You can check which error occurred by calling abz_get_error().
 */
int client_add (struct client **clients,
				const char *name,
				const struct bandwidth *input,
				const struct bandwidth *output,
				const struct network *network,
				size_t n)
{
   struct ipt_entry *entry;
   struct ipt_standard_target *target;
   size_t match_size = 0;
   size_t target_size = IPT_ALIGN (sizeof (struct ipt_standard_target));
   size_t entry_size = IPT_ALIGN (sizeof (struct ipt_entry));
   size_t size = entry_size + match_size + target_size;
   size_t i;
   struct client *client;
   struct counter *tmp;

   abz_clear_error ();

   ASSERT (priv.table != NULL);

   if ((client = mem_alloc (sizeof (struct client))) == NULL)
	 {
		out_of_memory ();
		return (-1);
	 }
   client->next = NULL;

   memset (&client->stats,0L,sizeof (struct stats));
   client->input = *input;
   client->output = *output;

   if ((client->name = (char *) mem_alloc ((strlen (name) + 1) * sizeof (char))) == NULL)
	 {
		out_of_memory ();
		mem_free (client);
		return (-1);
	 }
   strcpy (client->name,name);

   client->counter = tmp = NULL;
   stats_create (&client->stats,61);	/* store 1 minute worth of samples */

   for (i = 0; i < n; i++)
	 {
		if (tmp != NULL)
		  {
			 if ((tmp->next = mem_alloc (sizeof (struct counter))) == NULL)
			   {
				  out_of_memory ();
				  destroy (&client);
				  return (-1);
			   }

			 tmp = tmp->next;
		  }
		else
		  {
			 if ((client->counter = mem_alloc (sizeof (struct counter))) == NULL)
			   {
				  out_of_memory ();
				  destroy (&client);
				  return (-1);
			   }

			 tmp = client->counter;
		  }

		tmp->next = NULL;
	 }
   tmp = client->counter;

   if ((entry = (struct ipt_entry *) mem_alloc (size)) == NULL)
	 {
		out_of_memory ();
		destroy (&client);
		return (-1);
	 }

   memset (entry,0L,size);
   entry->target_offset = size - target_size;
   entry->next_offset = size;

   target = (struct ipt_standard_target *) ((uint8_t *) entry + entry->target_offset);
   target->target.u.target_size = target->target.u.user.target_size = target->target.u.kernel.target_size = target_size;
   memset (target->target.u.user.name,0L,IPT_FUNCTION_MAXNAMELEN);
   strcpy (target->target.u.user.name,IPTC_LABEL_RETURN);

   entry->ip.proto = IPPROTO_IP;
   memset (entry->ip.iniface,0L,IFNAMSIZ);
   memset (entry->ip.outiface,0L,IFNAMSIZ);
   memset (entry->ip.iniface_mask,0L,IFNAMSIZ);
   memset (entry->ip.outiface_mask,0L,IFNAMSIZ);

   for (i = 0; i < n; i++)
	 {
		tmp->network = network[i];

		/* -d ... */
		entry->ip.src.s_addr = entry->ip.smsk.s_addr = 0;
		entry->ip.dst.s_addr = network[i].address;
		entry->ip.dmsk.s_addr = network[i].netmask;
		entry->nfcache = NFC_IP_DST;
		if (!iptc_insert_entry (priv.chain,entry,priv.rule,&priv.handle))
		  {
			 abz_set_error ("failed to add counter: ip dst %u.%u.%u.%u/%u: %s",
							ADDR (network[i].address),contigmask (network[i].netmask),
							iptc_strerror (errno));
			 mem_free (entry);
			 destroy (&client);
			 return (-1);
		  }
		tmp->inrule = priv.rule++;

		/* -s ... */
		entry->ip.dst.s_addr = entry->ip.dmsk.s_addr = 0;
		entry->ip.src.s_addr = network[i].address;
		entry->ip.smsk.s_addr = network[i].netmask;
		entry->nfcache = NFC_IP_SRC;
		if (!iptc_insert_entry (priv.chain,entry,priv.rule,&priv.handle))
		  {
			 abz_set_error ("failed to add counter: ip src %u.%u.%u.%u/%u: %s",
							ADDR (network[i].address),contigmask (network[i].netmask),
							iptc_strerror (errno));
			 mem_free (entry);
			 destroy (&client);
			 return (-1);
		  }
		tmp->outrule = priv.rule++;

		tmp = tmp->next;
	 }

   client->next = *clients, *clients = client;
   mem_free (entry);
   return (0);
}

static int client_read_counters (struct sample *sample,const struct client *client)
{
   struct counter *tmp;
   struct ipt_counters *counter;

   ASSERT (priv.table != NULL);
   ASSERT (client != NULL);

   sample->input = sample->output = 0;

   for (tmp = client->counter; tmp != NULL; tmp = tmp->next)
	 {
		sample->timestamp = time (NULL);

		if ((counter = iptc_read_counter (priv.chain,tmp->inrule,&priv.handle)) == NULL)
		  {
			 abz_set_error ("%s",iptc_strerror (errno));
			 return (-1);
		  }

		sample->input += counter->bcnt;

		if ((counter = iptc_read_counter (priv.chain,tmp->outrule,&priv.handle)) == NULL)
		  {
			 abz_set_error ("%s",iptc_strerror (errno));
			 return (-1);
		  }

		sample->output += counter->bcnt;
	 }

   return (0);
}

/*
 * Update client counters. This function should be called at least
 * once per second. Returns 0 if successful, -1 if some error occurred.
 */
int client_update (struct client *clients)
{
   struct sample sample;

   abz_clear_error ();

   if (commit () < 0)
	 return (-1);

   while (clients != NULL)
	 {
		if (client_read_counters (&sample,clients) < 0 ||
			stats_update (&clients->stats,&sample) < 0)
		  return (-1);

		clients = clients->next;
	 }

   return (0);
}

/*
 * Calculate average rate over the specified period. Return 0 if
 * successful, -1 if some error occurred. You can check which error
 * occurred by calling abz_get_error().
 */
int client_calc_rate (struct rate *avg,const struct client *client,time_t period)
{
   return (stats_calc_rate (avg,&client->stats,period));
}

/*
 * Reset client counters. Returns 0 if successful, -1 if some error
 * occurred by calling abz_get_error().
 */
int client_reset (struct client *clients)
{
   abz_clear_error ();

   while (clients != NULL)
	 {
		stats_destroy (&clients->stats);
		clients = clients->next;
	 }

   if (!iptc_zero_entries (priv.chain,&priv.handle))
	 {
		abz_set_error ("%s",iptc_strerror (errno));
		return (-1);
	 }

   return (commit ());
}

