// Copyright (C) 2006-2009 Kent-Andre Mardal and Simula Research Laboratory.
// Licensed under the GNU GPL Version 2, or (at your option) any later version.

#include "symbol_factory.h"
#include "utilities.h"

#include <stdexcept>
#include <sstream>
#include <map>
using namespace std;
using namespace GiNaC;

namespace SyFi
{

	// TODO: make structure for these
	//namespace space
	//{
	unsigned int nsd = 2;
	//GiNaC::symbol p[4];
	GiNaC::symbol x("(x is not initialized since initSyFi has never been called)");
	GiNaC::symbol y("(y is not initialized since initSyFi has never been called)");
	GiNaC::symbol z("(z is not initialized since initSyFi has never been called)");
	GiNaC::symbol t("(t is not initialized since initSyFi has never been called)");
	GiNaC::lst    p;
	//}

	GiNaC::symbol infinity("(infinity is not initialized since initSyFi has never been called)");
	GiNaC::symbol DUMMY("(DUMMY is not initialized since initSyFi has never been called)");

	// Initialize global variables of SyFi
	void initSyFi(unsigned int nsd_)
	{
		// initSyFi uses the global coordinates x      for nsd == 1
		// initSyFi uses the global coordinates x,y    for nsd == 2
		// initSyFi uses the global coordinates x,y,z  for nsd == 3
		// when nsd > 3 the coordinates can be found in the p, which is of type lst

		// FIXME: this whole thing is just a mess, but it's a nontrivial job to fix it all over syfi...

		SyFi::nsd      = nsd_;
		SyFi::t        = get_symbol("t");

		SyFi::infinity = get_symbol("infinity");
		SyFi::DUMMY    = get_symbol("DUMMY");

		SyFi::x        = get_symbol("(SyFi::x is not initialized)");
		SyFi::y        = get_symbol("(SyFi::y is not initialized)");
		SyFi::z        = get_symbol("(SyFi::z is not initialized)");

		/*
		std::cout << "SyFi::p before remove_all:" << std::endl;
		std::cout << SyFi::p << std::endl;
		*/

		SyFi::p.remove_all();

		/*
		std::cout << "SyFi::p after remove_all:" << std::endl;
		std::cout << SyFi::p << std::endl;
		*/

		if ( nsd  > 3 )
		{
			SyFi::x = get_symbol("(SyFi::x is an invalid symbol when nsd>3)");
			SyFi::y = get_symbol("(SyFi::y is an invalid symbol when nsd>3)");
			SyFi::z = get_symbol("(SyFi::z is an invalid symbol when nsd>3)");

			ex tmp = get_symbolic_vector(nsd, "x");
			for (unsigned int i=0; i<tmp.nops(); i++)
			{
				p.append(tmp.op(i));
			}
		}
		else
		{
			if ( nsd  > 0 )
			{
				SyFi::x = get_symbol("x");
				SyFi::p.append(SyFi::x);
			}
			if ( nsd  > 1 )
			{
				SyFi::y = get_symbol("y");
				SyFi::p.append(SyFi::y);
			}
			if ( nsd  > 2 )
			{
				SyFi::z = get_symbol("z");
				SyFi::p.append(SyFi::z);
			}
		}

		/*
		std::cout << "SyFi::p at end of initSyFi:" << std::endl;
		std::cout << SyFi::p << std::endl;
		*/
	}

	// =========== symbol factory implementation from ginac manual page 14-15

	map<string, symbol> symbol_collection;

	bool symbol_exists(const string & name)
	{
		return symbol_collection.find(name) != symbol_collection.end();
	}

	const symbol & get_symbol(const string & name)
	{
		map<string, symbol>::iterator i = symbol_collection.find(name);
		if( i != symbol_collection.end() )
		{
			return i->second;
		}
		return symbol_collection.insert(make_pair(name, symbol(name))).first->second;
	}

	const symbol & isymb(const string & a, int b)
	{
		return get_symbol(istr(a,b));
	}

	const symbol & isymb(const string & a, int b, int c)
	{
		return get_symbol(istr(a,b,c));
	}

	GiNaC::ex get_symbolic_vector(int m, const std::string & basename)
	{
		GiNaC::matrix A(m,1);
		for(int i=0; i<m; i++)
		{
			A.set(i, 0, isymb(basename, i));
		}
		GiNaC::ex e = A;
		return e;
	}

	GiNaC::ex get_symbolic_matrix(int m, int n, const std::string & basename)
	{
		GiNaC::matrix A(m,n);
		for(int i=0; i<m; i++)
		{
			for(int j=0; j<n; j++)
			{
				A.set(i, j, isymb(basename, i,j));
			}
		}
		GiNaC::ex e = A;
		return e;
	}

}								 //namespace SyFi
