#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module contains functions and classes converting 
from UFL to internal SyFi representations of elements.
"""


__author__ = "Martin Sandve Alnes"
__date__   = "2008-08-13 -- 2008-12-16"
__copyright__ = "(C) 2008 Martin Sandve Alnes and Simula Resarch Laboratory"
__license__  = "GNU GPL Version 2, or (at your option) any later version"


import operator
import ufl
import swiginac
import SyFi

from ufl.common import component_to_index, index_to_component

from sfc.symbolic_utils import grad, ddx, symbols, symbolic_matrix
from sfc.geometry import UFCCell
from sfc.common.names import finite_element_classname, dof_map_classname
from sfc.common import sfc_assert, sfc_info, sfc_warning, sfc_error, sfc_debug
from sfc.common.options import default_options

def product(sequence):
    return reduce(operator.__mul__, sequence, 1)

# TODO: Replace nsd with topological and geometric dimensions
# everywhere for clarity and future flexibility


def test_polygon(polygon):
    print type(polygon)
    print dir(polygon)
    print polygon.no_space_dim()
    print polygon.str()

def create_syfi_polygon(cell):
    sfc_debug("Entering create_syfi_polygon")
    if   cell == "interval":       p = SyFi.ReferenceLine()
    elif cell == "triangle":       p = SyFi.ReferenceTriangle()
    elif cell == "tetrahedron":    p = SyFi.ReferenceTetrahedron()
    elif cell == "quadrilateral":  p = SyFi.ReferenceRectangle()
    elif cell == "hexahedron":     p = SyFi.ReferenceBox()
    else: raise "Unknown element cell '%s'." % cell
    sfc_debug("Leaving create_syfi_polygon")
    return p

def create_syfi_element(e, polygon):
    "Create a basic element with SyFi."
    sfc_debug("Entering create_syfi_element")
    
    sfc_assert(not isinstance(e, ufl.MixedElement), "Only creating SyFi elements for basic elements.")
    f = e.family()
    d = e.degree()
    if   f in ("Lagrange", "CG"):
        fe = SyFi.Lagrange(polygon, d)
    elif f in ("Discontinuous Lagrange", "DG"):
        if d == 0: fe = SyFi.P0(polygon, 0)
        else:      fe = SyFi.DiscontinuousLagrange(polygon, d)
    elif f in ("Crouzeix-Raviart", "CR"):
        fe = SyFi.CrouzeixRaviart(polygon, d)
    elif f in ("Bubble", "B"):
        fe = SyFi.Bubble(polygon, d)
    elif f in ("Brezzi-Douglas-Marini", "BDM"):
        raise NotImplementedError, "Not implemented element family '%s'." % f
    elif f in ("Brezzi-Douglas-Fortin-Marini", "BDFM"):
        raise NotImplementedError, "Not implemented element family '%s'." % f
    elif f in ("Raviart-Thomas", "RT"):
        raise NotImplementedError, "Not implemented element family '%s'." % f
    elif f in ("Nedelec 1st kind H(div)", "N1div"):
        raise NotImplementedError, "Not implemented element family '%s'." % f
    elif f in ("Nedelec 2nd kind H(div)", "N2div"):
        raise NotImplementedError, "Not implemented element family '%s'." % f
    elif f in ("Nedelec 1st kind H(curl)", "N1curl"):
        raise NotImplementedError, "Not implemented element family '%s'." % f
    elif f in ("Nedelec 2nd kind H(curl)", "N2curl"):
        raise NotImplementedError, "Not implemented element family '%s'." % f
    elif f in ("Quadrature", "Q"):
        raise NotImplementedError, "Not implemented element family '%s'." % f
    elif f in ("Boundary Quadrature", "BQ"):
        raise NotImplementedError, "Not implemented element family '%s'." % f
    else:
        raise NotImplementedError, "Unknown element family '%s'." % f
    
    sfc_debug("Leaving create_syfi_element")
    return fe



#===============================================================================
# # swiginac.matrix(nsd, 1, dof_xi(fe,i))
# def dof_xi(fe, i):
#    """Return a swiginac column vector with the reference coordinates of dof i in fe,
#       assuming elements with point evaluation dofs."""
#    dofi = fe.dof(i)
#    # check if the element is a scalar or vector element
#    if isinstance(dofi[0], swiginac.numeric):
#        dofi0 = dofi
#    elif isinstance(dofi[0], list):
#        dofi0 = dofi[0]
#    return dofi0
#===============================================================================


class ElementRepresentation(object):
    __slots__ = (#Administrative data:
                 "options", "signature",
                 "dof_map_classname", "finite_element_classname",
                 # Element in other representations:
                 "ufl_element", "syfi_element",
                 # Cell data:
                 "quad_rule", "ufl_cell", "polygon", "cell",
                 # Dimensions:
                 "local_dimension", "geometric_dimension", "topological_dimension", 
                 # Value shape info:
                 "value_shape", "value_rank", "value_size", "value_components",
                 # Subelement data
                 "sub_elements", "sub_element_dof_offsets", "sub_element_value_offsets",
                 # Caches for basis functions
                 "_basis_function_cache", "_basis_function_derivative_cache",
                 # Coordinate information:
                 "dof_xi", "dof_x",
                 # Topology information:
                 "entity_dofs", "dof_entity", "num_entity_dofs", "facet_dofs", "num_facet_dofs",
                 # Geometry symbols:
                 "p0", "p", "G", "GinvT",
                 )
    def __init__(self, ufl_element, quad_rule=None, options=None, cache=None):
        sfc_debug("Entering ElementRepresentation.__init__")
        
        # Handle input and default values
        assert isinstance(ufl_element, ufl.FiniteElementBase)
        self.ufl_element = ufl_element
        self.quad_rule = quad_rule
        if options is None:
            self.options = default_options() 
        else:
            self.options = options
        if cache is None:
            cache = {}
        
        # Some derived strings
        self.signature = repr(self.ufl_element)
        self.dof_map_classname = dof_map_classname(self.ufl_element)
        self.finite_element_classname = finite_element_classname(self.ufl_element)
        
        # Geometry information
        self.ufl_cell = self.ufl_element.cell()
        self.polygon = create_syfi_polygon(self.ufl_cell.domain())
        self.cell = UFCCell(self.polygon)
        
        self.geometric_dimension = self.cell.nsd 
        self.topological_dimension = self.cell.nsd    
        
        # Handy information about value shape
        self.value_shape = self.ufl_element.value_shape()
        self.value_rank  = len(self.value_shape)
        self.value_size  = product(self.value_shape)
        self.value_components = ufl.permutation.compute_indices(self.value_shape)
        
        # Representations of subelements
        self.sub_elements = []
        self.sub_element_dof_offsets = []
        self.sub_element_value_offsets = []
        
        if isinstance(self.ufl_element, ufl.MixedElement):
            dof_offset = 0
            value_size_offset = 0
            
            # Create ElementRepresentation objects for subelements, reuse if possible
            for se in ufl_element.sub_elements():
                rep = cache.get(se, None)
                if rep is None:
                    rep = ElementRepresentation(se, self.quad_rule, self.options, cache)
                    cache[se] = rep
                self.sub_elements.append(rep)
                
                # Determine numbering offsets of subelements for dofs and values
                self.sub_element_dof_offsets.append(dof_offset)
                self.sub_element_value_offsets.append(value_size_offset)
                
                dof_offset += rep.local_dimension
                value_size_offset += rep.value_size
            
            # Appending final sizes makes some algorithms more elegant
            self.sub_element_dof_offsets.append(dof_offset)
            self.sub_element_value_offsets.append(value_size_offset)
            
            # No SyFi element, local dimension is the sum of subelement dimensions
            self.syfi_element = None
            self.local_dimension = dof_offset
        
        elif self.ufl_element.family() == "Quadrature":
            # No SyFi element, local dimension is the number of quadrature points
            self.syfi_element = None
            self.local_dimension = self.quad_rule.num_points
        
        elif self.ufl_element.family() == "Boundary Quadrature":
            # No SyFi element, local dimension is the number of quadrature points (TODO: On one facet or on all facets?)
            sfc_error("Boundary Quadrature elements not implemented!")
            self.syfi_element = None
            self.local_dimension = self.facet_quad_rule.num_points # TODO: *num_facets?
        
        else:
            # Make SyFi element
            self.syfi_element = create_syfi_element(self.ufl_element, self.polygon)
            self.local_dimension = self.syfi_element.nbf()
        
        # utility symbols
        self._def_symbols()
        
        # compute dof coordinates
        self._precomp_coords()
        
        # compute dof entity relations
        self._build_entity_dofs()
        self._build_facet_dofs()
        
        # initialize cache structures
        self._basis_function_cache = {}
        self._basis_function_derivative_cache = {}
        
        sfc_debug("Leaving ElementRepresentation.__init__")
    
    def _def_symbols(self):
        nsd = self.cell.nsd
        
        # ... x,y,z symbols for convenience
        self.p0 = swiginac.matrix(nsd, 1, symbols(["x0", "y0", "z0"][:nsd]))
        self.p  = swiginac.matrix(nsd, 1, symbols(["x", "y", "z"][:nsd]))
        #self.x  = swiginac.matrix(nsd, 1, symbols(["x0", "x1", "x2"][:nsd])) # TODO: Use these everywhere! self.x are global coordinates, self.xi are locals.
        #self.xi = swiginac.matrix(nsd, 1, symbols(["xi0", "xi1", "xi2"][:nsd]))
        
        # ... affine mapping symbols for convenience # TODO: Use J instead?
        self.G     = symbolic_matrix(nsd, nsd, "G")
        self.GinvT = symbolic_matrix(nsd, nsd, "GinvT")

    def _dof_xi(self, i):
        "Compute local dof coordinate of dof i."
        nsd = self.cell.nsd
        
        if self.sub_elements:
            sub_element_index = self.dof_to_sub_element_index(i)            
            sub_dof = i - self.sub_element_dof_offsets[sub_element_index]
            xi = self.sub_elements[sub_element_index]._dof_xi(sub_dof)
        
        elif self.ufl_element.family() == "Quadrature":
            xi = swiginac.matrix(nsd, 1, self.quad_rule.points[i])
            
        else:
            # check if the element is a scalar or vector element
            dofi = self.syfi_element.dof(i)
            if isinstance(dofi[0], swiginac.numeric):
                # scalar element
                dof_xi_list = dofi
            elif isinstance(dofi[0], list):
                # vector element
                if isinstance(dofi[0][0], list):
                    # compute midpoints
                    midpoint = [0 for i in range(len(dofi[0][0]))]
                    for d in dofi[0][0:]:
                        for p, dp in enumerate(d):
                            midpoint[p] += dp
                    for p in range(len(d)):
                        midpoint[p] /= len(dofi[0])
                    dof_xi_list = midpoint
                else:
                    # use coordinate directly
                    dof_xi_list = dofi[0]
            xi = swiginac.matrix(nsd, 1, dof_xi_list)
        
        return xi
    
    def _precomp_coords(self):
        "Precompute dof coordinates."
        self.dof_xi = []
        self.dof_x  = []
        
        for i in range(self.local_dimension):
            # point coordinates for this dof in reference coordinates
            dof_xi = self._dof_xi(i)
            self.dof_xi.append(dof_xi)
            
            # apply geometry mapping to get global coordinates
            dof_x  = (self.G.mul(dof_xi).add(self.p0)).evalm() # TODO: Assumes affine mapping!
            self.dof_x.append(dof_x)
    
    def _build_entity_dofs(self):
        "Build dof vs mesh entity relations."

        # TODO: This may be optimized if necessary for mixed elements, but maybe we don't care.
        
        # The basic structure we're building here is a list of lists of lists
        self.entity_dofs = []
        for i in range(self.topological_dimension+1):
            lists = [[] for j in range(self.cell.num_entities[i])]
            self.entity_dofs.append(lists)
        
        if self.ufl_element.family() == "Discontinuous Lagrange" \
            or self.ufl_element.family() == "Quadrature":
            # associate all dofs with the cell
            self.entity_dofs[self.topological_dimension][0] = list(range(self.local_dimension))
        
        elif self.ufl_element.family() == "Boundary Quadrature":
            sfc_error("Boundary Quadrature not handled.")
        
        else:
            # NB! Assuming location dof_xi coordinates match topological entity!
            # build dof list for each cell entity
            for k in range(self.local_dimension):
                (d, i) = self.cell.find_entity(self.dof_xi[k])
                self.entity_dofs[d][i].append(k)
        
        # Build the inverse mapping: idof -> ((d, i), j) (Not currently used for anything)
        self.dof_entity = [None]*self.local_dimension
        for d in range(self.topological_dimension+1):
            for i in range(self.cell.num_entities[d]):
                for j, k in enumerate(self.entity_dofs[d][i]):
                    sfc_assert(self.dof_entity[k] is None, "Expected to set each dof only once.")
                    self.dof_entity[k] = (d, i, j)
        
        # count number of dofs per entity
        self.num_entity_dofs = tuple(len(self.entity_dofs[d][0]) for d in range(self.topological_dimension+1))

        # assert that all entities have the same number of associated dofs
        # (there's a theoretical risk of floating point comparisons messing up the above logic)
        for d in range(self.topological_dimension+1):
            for doflist in self.entity_dofs[d]:
                # each doflist is a list of local dofs associated
                # with a particular mesh entity of dimension d
                assert len(doflist) == self.num_entity_dofs[d]
    
    def _build_facet_dofs(self):
        "Build facet vs dof relations."
        # ... build a list of dofs for each facet:
        self.facet_dofs = [[] for j in range(self.cell.num_facets)]
        
        if self.ufl_element.family() == "Discontinuous Lagrange" \
            or self.ufl_element.family() == "Quadrature":
            pass # no dofs on facets
        
        elif self.ufl_element.family() == "Boundary Quadrature":
            sfc_error("Boundary Quadrature not handled.")
        
        else:
            # for each facet j, loop over the reference coordinates 
            # for all dofs i and check if the dof is on the facet
            for j in range(self.cell.num_facets):
                for (i,p) in enumerate(self.dof_xi):
                    if self.cell.facet_check(j, p):
                        self.facet_dofs[j].append(i)
        
        # ... count number of dofs for each facet (assuming this is constant!)
        self.num_facet_dofs = len(self.facet_dofs[0])
        
        # verify that this number is constant for all facets
        sfc_assert(all(len(fdofs) == self.num_facet_dofs for fdofs in self.facet_dofs),
            "Not the same number of dofs on each facet. This breaks an assumption in UFC.")
    
    # --- subelement access
    
    def dof_to_sub_element_index(self, dof):
        "Return the index of the sub element the given dof is part of."
        n = len(self.sub_elements)
        sfc_assert(n, "Only mixed elements have sub elements.")
        for i in range(n+1):
            if dof < self.sub_element_dof_offsets[i]:
                return i-1
        sfc_error("Invalid dof value!")
    
    def sub_element_to_dofs(self, i):
        "Return a list of all dof indices for sub element with index i."
        sfc_assert(self.sub_elements, "Only mixed elements have sub elements.")
        a = self.sub_element_dof_offsets[i]
        b = self.sub_element_dof_offsets[i+1] 
        return range(a, b)
    
    # --- function space
    
    def basis_function(self, i, component):
        # hit cache?
        N = self._basis_function_cache.get((i,component), None)
        if N is not None:
            return N
        
        if self.sub_elements:
            # get sub element representation corresponding to this dof
            sub_element_index = self.dof_to_sub_element_index(i)
            sub_element = self.sub_elements[sub_element_index]
            
            # dof in sub element numbering
            sub_dof = i - self.sub_element_dof_offsets[sub_element_index]
            
            # component in flattened sub element value index space
            comp_index = component_to_index(component, self.value_shape) 
            value_offset = self.sub_element_value_offsets[sub_element_index]
            
            # check that the component is in the value range of the subelement
            is_nonzero = (comp_index >= value_offset) and (comp_index < (value_offset + sub_element.value_size))
            if is_nonzero:
                # component in unflattened sub element value index space
                sub_comp_index = comp_index - value_offset
                sub_component = index_to_component(sub_comp_index, sub_element.value_shape)
                
                # basis_function from computed component of subelement
                N = sub_element.basis_function(sub_dof, sub_component)
            
            else:
                N = swiginac.numeric(0.0)
        
        elif "Quadrature" in self.ufl_element.family():
            sfc_error("Cannot compute basis functions for quadrature element.")
            N = swiginac.numeric(1.0)
        
        else:
            # Basic element, get basis function from SyFi
            N = self.syfi_element.N(i)
            if isinstance(N, swiginac.matrix):
                sfc_assert((int(N.rows()), int(N.cols())) == self.value_shape, "Shape mismatch")
                return N[component]
            sfc_assert(component == (), "Found scalar basic element, expecting no component, got %s." % repr(component))
        
        # put in cache
        self._basis_function_cache[(i,component)] = N
        return N
    
    def basis_function_derivative(self, i, component, directions):
        # d/dx and d/dy commute so we sort the derivative variables:
        directions = tuple(sorted(directions))
        
        # cache hit?
        DN = self._basis_function_derivative_cache.get((i,component,directions), None)
        if DN is not None:
            return DN
        
        # compute derivative
        DN = self.basis_function(i, component)
        for j in directions:
            DN = ddx(DN, j, self.GinvT)
        
        # put in cache
        self._basis_function_derivative_cache[(i,component,directions)] = DN
        return DN


#    def component_to_sub_element_index(self, component):
#        sfc_assert(self.sub_elements, "Only mixed elements have sub elements.")
#        
#        # FIXME!
#        sfc_error("FIXME: In component_to_sub_element_index: not sure where this will be used?")
#        
#        component_value = flatten_component(component, self.value_shape)
#        
#        n = len(self.sub_elements)
#        for i in range(1,n+1):
#            if component_value < self.sub_element_value_offsets[i]:
#                sub_component_value = self.sub_element_value_offsets[i] - ccomponent_value
#                sub_element_index = i-1
#                sub_element = self.sub_elements[sub_element_index]
#                sub_component = unflatten_component(sub_component_value, sub_element)
#                return sub_component
#        sfc_error("Invalid component value!")
#        
#        #sub_element_component, sub_element =\
#        #    self.ufl_element.extract_component(component)
#        #return sub_element_component, sub_element

