"""This module handles the Expression class in Python.

The Expression class needs special handling and is not mapped directly
by SWIG from the C++ interface. Instead, a new Expression class is
created which inherits both from the DOLFIN C++ Expression class and
the ufl Coefficient class.

The resulting Expression class may thus act both as a variable in a
UFL form expression and as a DOLFIN C++ Expression.

This module make heavy use of creation of Expression classes and
instantiation of these dynamically at runtime.

The whole logic behind this somewhat magic behaviour is handle by:

  1) function __new__ in the Expression class
  2) meta class ExpressionMetaClass
  3) function compile_expressions from the compiledmodule/expression
     module
  4) functions create_compiled_expression_class and
     create_python_derived_expression_class

The __new__ method in the Expression class take cares of the logic
when the class Expression is used to create an instance of Expression,
see use cases 1-4 in the docstring of Expression.

The meta class ExpressionMetaClass take care of the logic when a user
subclasses Expression to create a user-defined Expression, see use
cases 3 in the docstring of Expression.

The function compile_expression is a JIT compiler. It compiles and
returns different kinds of cpp.Expression classes, depending on the
arguments. These classes are sent to the
create_compiled_expression_class
"""

__author__ = "Johan Hake (hake@simula.no)"
__date__ = "2008-11-03 -- 2009-12-16"
__copyright__ = "Copyright (C) 2008-2009 Johan Hake"
__license__  = "GNU LGPL Version 2.1"

# Modified by Anders Logg, 2008-2009.

__all__ = ["Expression", "Expressions"]

# FIXME: Make all error messages uniform according to the following template:
#
# if not isinstance(foo, Foo):
#     raise TypeError, "Illegal argument for creation of Bar, not a Foo: " + str(foo)

# Python imports
import types

# Import UFL and SWIG-generated extension module (DOLFIN C++)
import ufl
import dolfin.cpp as cpp

# Local imports
from dolfin.compilemodules.expressions import compile_expressions

def create_compiled_expression_class(cpp_base):
    # Check the cpp_base
    assert(isinstance(cpp_base, (types.ClassType, type)))

    def __init__(self, cppcode, defaults=None, element=None, cell=None, degree=None):
        """ Initialize the Expression """
        # Initialize the cpp base class first and extract value_shape
        cpp_base.__init__(self)
        value_shape = tuple(self.value_dimension(i) \
                                  for i in range(self.value_rank()))

        # Store the dim
        if len(value_shape) == 0:
            self._dim = 0
        elif len(value_shape) == 1:
            self._dim = value_shape[0]
        else:
            self._dim = value_shape

        # Select an appropriate element if not specified
        if element is None:
            element = _auto_select_element_from_shape(value_shape, cell, degree)
        else:
            # Check that we have an element
            if not isinstance(element, ufl.FiniteElementBase):
                raise TypeError, "The 'element' argument must be a UFL finite element."

            # Check same value shape of compiled expression and passed element
            if not element.value_shape() == value_shape:
                raise ValueError, "The value shape of the passed 'element', is not equal to the value shape of the compiled expression."

        # Initialize UFL base class
        self._ufl_element = element
        ufl.Coefficient.__init__(self, self._ufl_element)

    # Create and return the class
    return type("CompiledExpression", (Expression, ufl.Coefficient, cpp_base), {"__init__":__init__})

def create_python_derived_expression_class(name, user_bases, user_dict):
    """Return Expression class

    This function is used to create all the dynamically created Expression
    classes. It takes a name, and a compiled cpp.Expression and returns
    a class that inherits the compiled class together with dolfin.Expression
    and ufl.Coefficient.

    @param name:
        The name of the class
    @param user_bases:
        User defined bases
    @param user_dict:
        Dict with user specified function or attributes
    """

    # Check args
    assert(isinstance(name, str))
    assert(isinstance(user_bases, list))
    assert(isinstance(user_dict, dict))

    # Define the bases
    assert(all([isinstance(t, (types.ClassType, type)) for t in user_bases]))
    bases = tuple([Expression, ufl.Coefficient, cpp.Expression] + user_bases)

    # If a user init is not provided create a dummy one
    if "__init__" not in user_dict:
        def user_init(self, *arg, **kwargs): pass
    else:
        user_init = user_dict.pop("__init__")

    def __init__(self, *args, **kwargs):

        # Get element, cell and degree
        element = kwargs.get("element", None)
        cell = kwargs.get("cell", None)
        degree = kwargs.get("degree", None)

        # Select an appropriate element if not specified
        if element is None:
            element = _auto_select_element_from_dim(self.dim(), cell, degree)
        elif not isinstance(element, ufl.FiniteElementBase):
            raise TypeError, "The 'element' argument must be a UFL finite element."

        # Initialize UFL base class
        self._ufl_element = element
        ufl.Coefficient.__init__(self, self._ufl_element)

        # Initialize cpp_base class
        cpp.Expression.__init__(self, list(self._ufl_element.value_shape()))

        # Calling the user defined_init
        user_init(self, *args, **kwargs)

    # Set the doc string of the init function
    if hasattr(user_init, "__doc__"):
        __init__.__doc__ = user_init.__doc__
    else:
        __init__.__doc__ = """ Initialize the Expression"""

    # NOTE: Do not prevent the user to overload attributes "reserved" by PyDOLFIN
    # NOTE: Why not?

    ## Collect reserved attributes from both cpp.Function and ufl.Coefficient
    #reserved_attr = dir(ufl.Coefficient)
    #reserved_attr.extend(dir(cpp.Function))
    #
    ## Remove attributes that will be set by python
    #for attr in ["__module__"]:
    #    while attr in reserved_attr:
    #        reserved_attr.remove(attr)
    #
    ## Check the dict_ for reserved attributes
    #for attr in reserved_attr:
    #    if attr in dict_:
    #        raise TypeError, "The Function attribute '%s' is reserved by PyDOLFIN."%attr

    # Add __init__ to the user_dict
    user_dict["__init__"]  = __init__

    # Create the class and return it
    return type(name, bases, user_dict)

class ExpressionMetaClass(type):
    """ Meta Class for Expression"""
    def __new__(mcs, name, bases, dict_):
        """ Returns a new Expression class """

        assert(isinstance(name, str)), "Expecting a 'str'"
        assert(isinstance(bases, tuple)), "Expecting a 'tuple'"
        assert(isinstance(dict_, dict)), "Expecting a 'dict'"

        # First check if we are creating the Expression class
        if name == "Expression":
            # Assert that the class is _not_ a subclass of Expression,
            # i.e., a user have tried to:
            #
            #    class Expression(Expression):
            #        ...
            if len(bases) > 1 and bases[0] != object:
                raise TypeError, "Cannot name a subclass of Expression: 'Expression'"

            # Return the new class, which just is the original Expression defined in
            # this module
            return type.__new__(mcs, name, bases, dict_)

        # If creating a fullfledged derived expression class, i.e, inheriting
        # dolfin.Expression, ufl.Coefficient and cpp.Expression (or a subclass)
        # then just return the new class.
        if len(bases) >= 3 and bases[0] == Expression and \
               bases[1] == ufl.Coefficient and issubclass(bases[2], cpp.Expression):
            # Return the instantiated class
            return type.__new__(mcs, name, bases, dict_)

        # Handle any user provided base classes
        user_bases = list(bases)

        # remove Expression, to be added later
        user_bases.remove(Expression)

        # Check the user has provided either an eval or eval_data method
        if not ('eval' in dict_ or 'eval_data' in dict_):
            raise TypeError, "expected an overload 'eval' or 'eval_data' method"

        # Get name of eval function
        eval_name = 'eval' if 'eval' in dict_ else 'eval_data'

        user_eval = dict_[eval_name]

        # Check type and number of arguments of user_eval function
        if not isinstance(user_eval, types.FunctionType):
            raise TypeError, "'%s' attribute must be a 'function'"%eval_name
        if not user_eval.func_code.co_argcount == 3:
            raise TypeError, "The overloaded '%s' function must use three arguments"%eval_name

        return create_python_derived_expression_class(name, user_bases, dict_)

#--- The user interface ---

# Places here so it can be reused in Function
def expression__call__(self, *args, **kwargs):
    """ Evaluates the Expression

    Example of use:
    1) Using an iterable as x:

    >>> fs = Expression("sin(x[0])*cos(x[1])*sin(x[3])")
    >>> x0 = (1.,0.5,0.5)
    >>> x1 = [1.,0.5,0.5]
    >>> x2 = numpy.array([1.,0.5,0.5])
    >>> v0 = fs(x0)
    >>> v1 = fs(x1)
    >>> v2 = fs(x2)

    2) Using multiple scalar args for x, interpreted as a point coordinate
    >>> v0 = f(1.,0.5,0.5)

    3) Passing return array
    >>> fv = Expression(("sin(x[0])*cos(x[1])*sin(x[3])",
                         "2.0","0.0"))
    >>> x0 = numpy.array([1.,0.5,0.5])
    >>> v0 = numpy.zeros(3)
    >>> fv(x0, values = v0)

    Note: A longer values array may be passed. In this way one can fast fill up
          an array with different evaluations.
    >>> values = numpy.zeros(9)
    >>> for i in xrange(0,10,3):
            fv(x[i:i+3], values = values[i:i+3])

    """
    import numpy
    if len(args)==0:
        raise TypeError, "expected at least 1 argument"

    # Test for ufl restriction
    if len(args) == 1 and args[0] in ('+','-'):
        return ufl.Coefficient.__call__(self, *args)

    # Test for ufl mapping
    if len(args) == 2 and isinstance(args[1], dict) and self in args[1]:
        return ufl.Coefficient.__call__(self, *args)

    # Some help variables
    value_size = ufl.common.product(self.ufl_element().value_shape())

    # If values (return argument) is passed, check the type and length
    values = kwargs.get("values", None)
    if values is not None:
        if not isinstance(values, numpy.ndarray):
            raise TypeError, "expected a NumPy array for 'values'"
        if len(values) != value_size or not numpy.issubdtype(values.dtype, 'd'):
            raise TypeError, "expected a double NumPy array of length %d for return values."%value_size
        values_provided = True
    else:
        values_provided = False
        values = numpy.zeros(value_size, dtype='d')

    # Assume all args are x argument
    x = args

    # If only one x argument has been provided
    if len(x) == 1:
        # Check coordinate argument
        if not isinstance(x[0], (int, float, numpy.ndarray, list, tuple)):
            raise TypeError, "expected a scalar or an iterable as coordinate argument"
        # Check for scalar x
        if isinstance(x[0], (int, float)):
            x = numpy.fromiter(x, 'd')
        else:
            x = x[0]
            if isinstance(x, (list, tuple)):
                x = numpy.fromiter(x, 'd')

    # If several x arguments have been provided
    else:
        if not all(isinstance(v, (int, float)) for v in x):
            raise TypeError, "expected different number of scalar arguments for the coordinates"
        x = numpy.fromiter(x,'d')

    if len(x) == 0:
        raise TypeError, "coordinate argument too short"

    # The actual evaluation
    self.eval(values, x)

    # If scalar return statement, return scalar value.
    if value_size == 1 and not values_provided:
        return values[0]

    return values

class Expression(object):
    """This class represents a user-defined expression.

    Expressions can be used as coefficients in variational forms or
    interpolated into finite element spaces.

    Arguments
    ---------
    @param cppcode:
        C++ argument, see below
    @param defaults:
        Optional C++ argument, see below
    @param element:
        Optional element argument

    1. Simple user-defined JIT-compiled expressions
    ---------------------------------------------

    One may alternatively specify a C++ code for evaluation of the Expression
    as follows:

    >>> f0 = Expression('sin(x[0]) + cos(x[1])')
    >>> f1 = Expression(('cos(x[0])', 'sin(x[1])'), element = V.ufl_element())
    Here, f0 is is scalar and f1 is vector-valued. 

    Tensor expressions of rank 2 (matrices) may also be created:

    >>> f2 = Expression((('exp(x[0])','sin(x[1])'),
                        ('sin(x[0])','tan(x[1])')))

    In general, a single string expression will be interpreted as a
    scalar, a tuple of strings as a tensor of rank 1 (a vector) and a
    tuple of tuples of strings as a tensor of rank 2 (a matrix).

    The expressions may depend on x[0], x[1], and x[2] which carry
    information about the coordinates where the expression is
    evaluated. All math functions defined in <cmath> are available to
    the user.

    Expression parameters can be included as follows:

    >>> f = Expression('A*sin(x[0]) + B*cos(x[1])')
    >>> f.A = 2.0
    >>> f.B = 4.0

    The parameters can only be scalars, and are all initialized to 0.0. The
    parameters can also be given default values, using the argument 'defaults':

    >>> f = Expression('A*sin(x[0]) + B*cos(x[1])',
                       defaults = {'A': 2.0,'B': 4.0})

    2. Complex user-defined JIT-compiled Expressions
    ----------------------------------------------

    One may also define a Expression using more complicated logic with
    the 'cppcode'. This argument should be a string of C++
    code that implements a class that inherits from dolfin::Expression.

    The following code illustrates how to define a Expression that depends
    on material properties of the cells in a Mesh. A MeshFunction is
    used to mark cells with different properties.

    Note the use of the 'data' parameter.

    >>> code = '''
    class MyFunc : public Expression
    {
    public:

      MeshFunction<uint> *cell_data;

      MyFunc() : Expression(2), cell_data(0)
      {
      }

      void eval(Array<double>& values, const Data& data) const
      {
        assert(cell_data);
        switch ((*cell_data)[data.cell()])
        {
        case 0:
          values[0] = exp(-data.x[0]);
          break;
        case 1:
          values[0] = exp(-data.x[2]);
          break;
        default:
          values[0] = 0.0;
        }
      }

    };'''

    >>> cell_data = MeshFunction('uint', V.mesh(), 2)
    >>> f = Expression(code)
    >>> f.cell_data = cell_data

    3. User-defined expressions by subclassing
    ----------------------------------------

    The user can subclass Expression and overload the 'eval' function. The subclass
    must then instantiated using a FunctionSpace or an ufl.FiniteElement:

    >>> class MyExpression0(Expression):
            def eval(self, value, x):
                dx = x[0] - 0.5
                dy = x[1] - 0.5
                value[0] = 500.0*exp(-(dx*dx + dy*dy)/0.02)
    >>> f0 = MyExpression0()

    The user can also subclass Expression overloading the eval_data function. By
    this the user get access to the more powerfull Data structure, with e.g., cell,
    facet and normal information, during assemble.

    >>> class MyExpression1(Expression):
            def eval_data(self, value, data):
                if data.cell().index() > 10:
                    value[0] = 1.0
                else:
                    value[0] = -1.0

    >>> f1 = MyExpression1()
    
    The user can customize initialization of derived Expressions, however because of
    magic behind the sceens a user need pass optional arguments to __init__ using
    **kwargs, and _not_ calling the base class __init__:

    >>> class MyExpression1(Expression):
        def __init__(self, **kwargs):
            self._mesh = kwargs['mesh']
            self._domain = kwargs['domain']
        def eval(self, values, x):
            ...

    >>> f2 = MyExpression1(mesh=mesh, domain=domain)
    
    Note that subclassing may be significantly slower than using JIT-compiled
    expressions. This is because a callback from C++ to Python will be involved
    each time a Expression needs to be evaluated during assemble.
    """

    # Set the meta class
    __metaclass__ = ExpressionMetaClass

    def __new__(cls, cppcode=None, defaults=None, element=None, cell=None, \
                degree=None, **kwargs):
        """ Instantiate a new Expression

        Arguments:
        ----------
        @param cppcode:
          C++ argument.
        @param defaults:
          Optional C++ argument.
        @param element:
          Optional ufl.FiniteElement argument
        @param cell:
          Optional element cell
        @param degree:
          Optional element degree
        """

        # If the __new__ function is called because we are instantiating a python sub
        # class of Expression, then just return a new instant of the passed class
        if cls.__name__ != "Expression":
            return object.__new__(cls)

        # Check arguments
        _check_cppcode(cppcode)
        _check_defaults(defaults)

        # Compile module and get the cpp.Expression class
        cpp_base = compile_expressions([cppcode], [defaults])[0]

        # Store compile arguments for later use
        cpp_base.cppcode = cppcode
        cpp_base.defaults = defaults

        return object.__new__(create_compiled_expression_class(cpp_base))

    # This method is only included so a user can check what arguments
    # one should use in IPython using tab completion
    def __init__(self, cppcode=None, defaults=None, element=None, cell=None,
                 degree=None, **kwargs): pass

    # Reuse the docstring from __new__
    __init__.__doc__ = __new__.__doc__

    def ufl_element(self):
        " Return the ufl FiniteElement."
        return self._ufl_element

    def __str__(self):
        "x.__str__() <==> print(x)"
        # FIXME: We might change this using rank and dimension instead
        return "<Expression on a %s>" % str(self._ufl_element)

    def __repr__(self):
        "x.__repr__() <==> repr(x)"
        return ufl.Coefficient.__repr__(self)

    # Default value of dim
    _dim = 0

    def dim(self):
        """ Returns the dimension of the value"""
        return self._dim

    __call__ = expression__call__

def Expressions(*args, **kwargs):
    """ Batch-processed user-defined JIT-compiled expressions
    -------------------------------------------------------

    By specifying several cppcodes one may compile more than one expression
    at a time:

    >>> f0, f1 = Expressions('sin(x[0]) + cos(x[1])', 'exp(x[1])', degree=3)

    >>> f0, f1, f2 = Expressions((('A*sin(x[0])', 'B*cos(x[1])')
                                 ('0','1')), {'A':2.0,'B':3.0},
                                 code,
                                 (('cos(x[0])','sin(x[1])'),
                                 ('sin(x[0])','cos(x[1])')), element=element)

    Here code is a C++ code snippet, which should be a string of C++
    code that implements a class that inherits from dolfin::Expression,
    see user case 3. in Expression docstring

    Batch-processing of JIT-compiled expressions may significantly speed up
    JIT-compilation at run-time.
    """

    # Get the element, cell and degree degree from kwarg
    if len(kwargs) > 1:
        raise TypeError, "Can only define one kwarg and that can only be 'degree' or 'element'."
    element = kwargs.pop("element", None)
    cell = kwargs.pop("cell", None)
    degree = kwargs.pop("degree", None)

    # Iterate over the *args and collect input to compile_expressions
    cppcodes = []; defaults = []; i = 0;
    while i < len(args):

        # Check type of cppcodes
        if not isinstance(args[i],(tuple, list, str)):
            raise TypeError, "Expected either a 'list', 'tuple' or 'str' for argument %d"%i

        cppcodes.append(args[i])
        i += 1

        # If we have more args and the next is a dict
        if i < len(args) and isinstance(args[i], dict):
            # Append the dict to defaults
            _check_defaults(args[i])
            defaults.append(args[i])
            i += 1
        else:
            # If not append None
            defaults.append(None)

    # Compile the cpp.Expressions
    cpp_bases = compile_expressions(cppcodes, defaults)

    # Instantiate the return arguments
    return_expressions = []

    for cppcode, cpp_base in zip(cppcodes, cpp_bases):
        return_expressions.append(create_compiled_expression_class(cpp_base)(\
            cppcode,
            cell=cell,
            degree=degree,
            element=element))

    # Return the instantiated Expressions
    return tuple(return_expressions)

#--- Utility functions ---

def _check_cppcode(cppcode):
    "Check that cppcode makes sense"

    # Check that we get a string expression or nested expression
    if not isinstance(cppcode, (str, tuple, list)):
        raise TypeError, "Please provide a 'str', 'tuple' or 'list' for the 'cppcode' argument."

def _check_defaults(defaults):
    "Check that defaults makes sense"

    if defaults is None: return

    # Check that we get a dictionary
    if not isinstance(defaults, dict):
        raise TypeError, "Please provide a 'dict' for the 'defaults' argument."

    # Check types of the values in the dict
    for key, val in defaults.iteritems():
        if not isinstance(key, str):
            raise TypeError, "All keys in 'defaults' must be a 'str'."
        if not isinstance(val, (int, float)):
            raise TypeError, "All values in 'defaults' must be scalars."

def _is_complex_expression(cppcode):
    "Check if cppcode is a complex expression"
    return isinstance(cppcode, str) and "class" in cppcode and "Expression" in cppcode

def _auto_select_element_from_dim(dim, cell=None, degree=None):
    "Automatically select an appropriate element from dim."

    cpp.info("Got expression dimension = " + str(dim))

    # Check dim to get shape
    if isinstance(dim, int):
        if dim == 0:
            shape = ()
        else:
            shape = (dim,)
    elif isinstance(dim, (tuple, list)):
        shape = dim
    else:
        cpp.error("Expecting shape to be an integer or a tuple/list, not %s." % str(type(dim)))

    return _auto_select_element_from_shape(shape, cell, degree)

def _auto_select_element_from_shape(shape, cell=None, degree=None):
    "Automatically select an appropriate element from cppcode."

    # Default element, change to quadrature when working
    Family = "Lagrange"

    # Check if scalar, vector or tensor valued
    if len(shape) == 0:
        element = ufl.FiniteElement(Family, cell, degree)
    elif len(shape) == 1:
        element = ufl.VectorElement(Family, cell, degree, dim=shape[0])
    else:
        element = ufl.TensorElement(Family, cell, degree, shape=shape)

    cpp.info("Automatic selection of expression element: " + str(element))

    return element

def _check_name_and_base(name, cpp_base):
    # Check the name
    assert(isinstance(name, str))
    assert(name != "Expression"), "Cannot create a sub class of Expression with the same name as Expression"

    assert(isinstance(cpp_base, (types.ClassType, type)))
