{-
    Kaya - My favourite toy language.
    Copyright (C) 2004, 2005 Edwin Brady

    This file is distributed under the terms of the GNU General
    Public Licence. See COPYING for licence.
-}

module InfGadgets where

-- Helper functions for type inference

import Language
import Debug.Trace

-- Substitutions and unification (from SPJ87)

type Subst = Name -> Type

-- Apply a substitution to a type
subst :: Subst -> Type -> Type
subst s (Prim t) = Prim t
subst s (Fn ns ts t) = Fn ns (map (subst s) ts) (subst s t)
subst s (Array t) = Array (subst s t)
subst s (TyApp t ts) = TyApp (subst s t) (map (subst s) ts)
--subst s (Syn n) = Syn n
subst s (TyVar n) = s n
subst s x = x

-- Substitution composition
scomp :: Subst -> Subst -> Subst
scomp s2 s1 tn = subst s2 (s1 tn)

id_subst :: Subst
id_subst tn = TyVar tn

delta :: Name -> Type -> Subst
delta tn t tn' | tn == tn' = t
	       | otherwise = TyVar tn'

-- Extend a substitution with a new one, or fail if there's an error
extend :: Monad m => String -> Int -> Subst -> Name -> Type -> m Subst
extend file line phi tvn (TyVar n) | tvn == n = return phi
extend file line phi tvn t | tvn `elem` (getVars t) = fail $ file++":"++show line++":Type error - possible infinite type"
			   | otherwise = return {- $ trace ("Extending with " ++ show tvn ++ " -> " ++ show t)-} $ 
					 (scomp $! (delta tvn t)) $! phi

unify :: Monad m => Subst -> (Type, Type, String, Int, String) -> m Subst
unify phi e@(t1,t2,f,l,ctxt) = {- trace ("Unifying " ++ show t1 ++ " & " ++ show t2 ++ " in " ++ ctxt ++ " at " ++ f ++ ":" ++ show l) $ -}
			  unify' phi e
unify' phi ((TyVar tvn),t,f,l,ctxt) 
    | phitvn == (TyVar tvn) = extend f l phi tvn phit
    | otherwise = unify phi (phitvn,phit,f,l,ctxt)
   where phitvn = phi tvn
	 phit = subst phi t
unify' phi ((Array t),(Array t'),f,l,ctxt) = unify phi (t,t',f,l,ctxt)
unify' phi (t1@(Fn ns ts t),t2@(Fn ns' ts' t'),f,l,ctxt) 
    = do zls <- (zipfl (t:ts) (t':ts') f l ctxt err)
	 unifyl phi zls t1 t2 -- pass top level types through for any error message
  where err = f ++ ":" ++ show l ++ ":" ++ cantunify t1 t2 ctxt
unify' phi (t1@(TyApp t ts),t2@(TyApp t' ts'),f,l,ctxt) 
    | (length ts == length ts') = do zl <- (zipfl (t:ts) (t':ts') f l ctxt err)
                                     unifyl phi zl t1 t2
    | otherwise = fail $ err
  where err = f ++ ":"++ show l ++ ":" ++ cantunify t1 t2 ctxt

-- Try it the other way...
unify' phi (t,(TyVar tvn),f,l,ctxt) = unify phi ((TyVar tvn),t,f,l,ctxt)
-- And now we must have something primitive
unify' phi (t,t',f,l,ctxt) | t == t' = return phi
		           | otherwise = fail $ f ++ ":" ++ 
                                         show l ++ ":" ++ cantunify t t' ctxt
zipfl :: Monad m => [a] -> [b] -> c -> d -> e -> String -> m [(a,b,c,d,e)]
zipfl [] [] _ _ _ err = return []
zipfl (x:xs) (y:ys) z w c err = do zl <- zipfl xs ys z w c err
				   return $ (x,y,z,w,c):zl
zipfl _ _ _ _ _ err = fail err

unifyl :: Monad m => Subst -> [(Type,Type,String,Int,String)] -> 
          Type -> Type ->  -- Top level types being unified
          m Subst
unifyl phi [] t1 t2 = return phi
unifyl phi (x@(_,_,f,l,ctxt):xs) t1 t2 = do 
    phi' <- case unify phi x of
       Nothing -> fail $ f ++ ":" ++ show l ++ ":" ++ cantunify t1 t2 ctxt
       Just uph -> return uph
    unifyl phi' xs t1 t2

-- Convert the global names (the Ps) to local variable indexes (Vs)
-- (The name is a reference to McKinna-Pollack '91. Apologies...)

pToV :: Locals -> Expr Name -> Expr Name
pToV cs (Global n m ar) | getpos n cs >= 0 = (Loc (getpos n cs))
		        | otherwise = (Global n m ar)
pToV cs (Loc l) = Loc l
pToV cs (GVar x) = GVar x
pToV cs (GConst c) = GConst c
pToV cs (Lambda iv ns sc) = Lambda iv ns (pToV cs sc)
pToV cs (Closure ns rt sc) = Closure ns rt (pToV cs sc)
pToV cs (Bind n t v sc) = Bind n t (pToV cs v) (pToV (cs++[(n,(t,[Public]))]) sc)
pToV cs (Declare f l (n,loc) t sc) = Declare f l (n,loc) t (pToV (cs++[(n,(t,[Public]))]) sc)
pToV cs (Return r) = Return (pToV cs r)
pToV cs (Assign l e) = Assign (pToVlval l) (pToV cs e)
  where pToVlval (AName i) = AName i
	pToVlval (AGlob i) = AGlob i
	pToVlval (AIndex l r) = AIndex (pToVlval l) (pToV cs r)
	pToVlval (AField l n a t) = AField (pToVlval l) n a t
pToV cs (AssignOp op l e) = AssignOp op (pToVlval l) (pToV cs e)
  where pToVlval (AName i) = AName i
	pToVlval (AGlob i) = AGlob i
	pToVlval (AIndex l r) = AIndex (pToVlval l) (pToV cs r)
	pToVlval (AField l n a t) = AField (pToVlval l) n a t
pToV cs (Seq a b) = Seq (pToV cs a) (pToV cs b)
pToV cs (Apply f as) = Apply (pToV cs f) (fmap (pToV cs) as)
pToV cs (ConApply f as) = ConApply (pToV cs f) (fmap (pToV cs) as)
pToV cs (Partial f as i) = Partial (pToV cs f) (fmap (pToV cs) as) i
pToV cs (Foreign ty f as) = Foreign ty f 
			    (fmap (\ (x,y) -> ((pToV cs x),y)) as)
pToV cs (While t e) = While (pToV cs t) (pToV cs e)
pToV cs (DoWhile e t) = DoWhile (pToV cs e) (pToV cs t)
pToV cs (For x nm y l ar e) = For x nm y (pToVlval l) (pToV cs ar) (pToV cs e)
  where pToVlval (AName i) = AName i
	pToVlval (AIndex l r) = AIndex (pToVlval l) (pToV cs r)
pToV cs (TryCatch e1 e2 n f) = TryCatch (pToV cs e1) (pToV cs e2) 
			                (pToV cs n) (pToV cs f)
pToV cs (NewTryCatch e ctchs) = NewTryCatch (pToV cs e) (pvCatch ctchs)
   where pvCatch [] = []
         pvCatch ((Catch (Left (n,as)) h):xs) =
             ((Catch (Left (n, fmap (pToV cs) as)) (pToV cs h)):pvCatch xs)
         pvCatch ((Catch (Right e) h):xs) =
             ((Catch (Right (pToV cs e)) (pToV cs h)):pvCatch xs)
pToV cs (Throw e) = Throw (pToV cs e)
pToV cs (Except e1 e2) = Except (pToV cs e1) (pToV cs e2)
pToV cs (InferPrint e t f l) = InferPrint (pToV cs e) t f l
pToV cs (PrintStr e) = PrintStr (pToV cs e)
pToV cs (PrintNum e) = PrintNum (pToV cs e)
pToV cs (PrintExc e) = PrintExc (pToV cs e)
pToV cs (Infix op a b) = Infix op (pToV cs a) (pToV cs b)
pToV cs (InferInfix op a b ts f l) = InferInfix op (pToV cs a) (pToV cs b) ts f l
pToV cs (Append a b) = Append (pToV cs a) (pToV cs b)
pToV cs (Unary op a) = Unary op (pToV cs a)
pToV cs (InferUnary op a ts f l) = InferUnary op (pToV cs a) ts f l
pToV cs (Coerce t1 t2 v) = Coerce t1 t2 (pToV cs v)
pToV cs (InferCoerce t1 t2 v f l) = InferCoerce t1 t2 (pToV cs v) f l
pToV cs (Case t e) = Case (pToV cs t) (pvAlt e)
  where pvAlt [] = []
        pvAlt ((Default ex):xs) = ((Default (pToV cs ex)):pvAlt xs)
        pvAlt ((ConstAlt pt c ex):xs) = ((ConstAlt pt c (pToV cs ex)):pvAlt xs)
	pvAlt ((ArrayAlt exs ex):xs) 
            = (ArrayAlt (map (pToV cs) exs) (pToV cs ex)):
	      (pvAlt xs)
	pvAlt ((Alt n t exs ex):xs) 
            = (Alt n t (map (pToV cs) exs) (pToV cs ex)):
              (pvAlt xs)
pToV cs (ArrayInit xs) = ArrayInit (map (pToV cs) xs)
pToV cs (If a t e) = If (pToV cs a) (pToV cs t) (pToV cs e)
pToV cs (Index l es) = Index (pToV cs l) (pToV cs es)
pToV cs (Field v n a t) = Field (pToV cs v) n a t
pToV cs Noop = Noop
pToV cs VMPtr = VMPtr
pToV cs (Length s) = Length (pToV cs s)
pToV cs (Break f l) = Break f l
pToV cs VoidReturn = VoidReturn
pToV cs (Metavar f l i) = Metavar f l i
pToV cs (Annotation a e) = Annotation a (pToV cs e)

-- Check whether the two types are equal (up to alpha conversion of type vars)
checkEq :: Monad m => String -> Int -> Type -> Type -> m ()
checkEq file line t1 t2 = do foo <- cg t1 t2 []
			     return ()
  where
     cg (TyVar x) (TyVar y) tvm = 
	 case (lookup x tvm) of
	   (Just z) -> if y==z then return tvm
		        else fail $ file ++ ":" ++ show line ++ ":" ++
			       "Inferred type less general than given type"
			       ++ " - Inferred " ++ show t1 ++ ", given " 
			       ++ show t2
	   Nothing -> return $ (x,y):tvm
     cg t (TyVar y) tvm = fail $ file ++ ":" ++ show line ++ ":" ++
			    "Inferred type less general than given type"
			    ++ " - Inferred " ++ show t1 ++ ", given " 
			    ++ show t2
     cg (Array x) (Array y) tvm = cg x y tvm
     cg (Fn ns ts t) (Fn ns' ts' t') tvm = do
          tvm' <- cg t t' tvm
	  cgl ts ts' tvm'
     cg (TyApp t ts) (TyApp t' ts') tvm = cgl (t:ts) (t':ts') tvm
     cg _ _ tvm = return tvm

     cgl [] [] tvm = return tvm
     cgl (x:xs) (y:ys) tvm = do tvm' <- cg x y tvm
				cgl xs ys tvm'


-- Return whether an expression returns a value
containsReturn :: Expr Name -> Bool
containsReturn (Return _) = True
containsReturn (Throw _) = True -- kind of the same thing!
containsReturn (Lambda _ _ e) = containsReturn e
containsReturn (Bind n t e1 e2) = containsReturn e2
containsReturn (Declare _ _ _ _ e) = containsReturn e
containsReturn (Seq e1 e2) = containsReturn e1 || containsReturn e2
containsReturn (While e1 e2) = containsReturn e2
containsReturn (DoWhile e1 e2) = containsReturn e1
containsReturn (For _ _ _ _ _ e) = containsReturn e
containsReturn (Case _ alts) = acr alts
   where acr [] = False
         acr [(Default r)] = containsReturn r
         acr [(ConstAlt _ _ r)] = containsReturn r
         acr [(Alt _ _ ts r)] = containsReturn r
         acr [(ArrayAlt ts r)] = containsReturn r
	 acr ((Alt _ _ ts r):rs) = containsReturn r && acr rs
	 acr ((ArrayAlt ts r):rs) = containsReturn r && acr rs
         acr ((Default r):rs) = containsReturn r && acr rs
         acr ((ConstAlt _ c r):rs) = containsReturn r && acr rs
containsReturn (If a t e) = containsReturn t && containsReturn e
containsReturn (TryCatch tr ca _ f) = (containsReturn tr && containsReturn ca)
				      || containsReturn f
containsReturn (NewTryCatch tr cs) = (containsReturn tr && acr cs)
   where acr [] = False
         acr [Catch _ h] = containsReturn h
         acr ((Catch _ h):xs) = containsReturn h && acr xs
containsReturn (Annotation _ e) = containsReturn e
containsReturn _ = False

-- Return whether a function type needs a runtime check on its result (True) 
-- or has been suitably checked at compile time (False)
-- Functions which return a type variable which does not appear in the 
-- arguments (e.g. unmarshal or subvert, which this function really exists for)
-- need a runtime check

needsCheck :: Type -> Bool
-- All type variables in r need to occur in args
needsCheck (Fn _ args r) 
    = let rvars = getTyVars r
          argvars = concat (map getTyVars args) in
       not $ length rvars == 0 || and (map (`elem` argvars) rvars)
needsCheck _ = False -- Not a function type

-- See if we can work out why two types didn't unify. Plausible reasons are:
-- * Two function types, wrong number of arguments. Maybe a function has
--   arguments missing.
-- * any more?

guessReason :: Type -> Type -> String
guessReason (Fn _ tys t) (Fn _ tys2 t2)
   | length tys > length tys2 = " (possible reason: too few arguments in function call)"
   | length tys < length tys2 = " (possible reason: too many arguments in function call)"
guessReason _ _ = ""

cantunify t1 t2 ctxt = 
    "Type error in " ++ ctxt ++ 
    "; has type " ++ show t1 ++ ", expected " ++
    show t2 ++ guessReason t1 t2

-- Show the name of the function we're applying, or nothing if it's a complex
-- expression

showraw (RVar _ _ n) = " of '" ++ showuser n ++ "'"
showraw (RQVar _ _ n) = " of '" ++ showuser n ++ "'"
showraw _ = ""

