module TypeInfer where
import MonadTransformers
import StateMTRefs
import InstsST
import System(getArgs)
data Type s = TyVar (TyVarName s) | Type s :-> Type s
data Term = Var Id | Abs Id Term | App Term Term
type Id = String
type TyVarName s = STRef s (Either Id (Type s))
type TyEnv s = [(Id,Type s)]
type TypeInfM s
= WithEnv (TyEnv s) -- typing environment
( WithState [Id] -- identifier supply
( WithExcept String -- error detection
( ST s
)))
newTyVar :: TypeInfM s (Type s)
newTyVar = do
x:_ <- updSt tail
n <- newRef (Left x)
return (TyVar n)
-- lookupVar :: Id -> TypeInfM s (Type s)
lookupVar x = do
bindings <- getEnv
case lookup x bindings of
Nothing -> raise $ "Free variable: " ++ x
Just t -> return t
hasType :: Id -> Type s -> TyEnv s -> TyEnv s
(x `hasType` t) bindings = (x,t) : bindings
typeOf :: Term -> TypeInfM s (Type s)
typeOf (Var x) = lookupVar x
typeOf (Abs x e) = do
t <- newTyVar
t' <- inModEnv (x `hasType` t) (typeOf e)
return (t :-> t')
typeOf (App e1 e2) = do
t <- typeOf e1
t' <- typeOf e2
x <- newTyVar
unify t (t' :-> x)
return x
unify :: Type s -> Type s -> TypeInfM s ()
unify (TyVar x) t = x |-> t
unify t (TyVar x) = x |-> t
unify (s :-> s') (t :-> t') = unify s t >> unify s' t'
(|->) :: TyVarName s -> Type s -> TypeInfM s ()
x |-> t = follow t >>= bindTo
where
bindTo (TyVar y)
| x == y = return ()
bindTo t = do
occurs <- x `occursIn` t
if occurs then raise "Occurs check!"
else writeRef x (Right t)
follow it@(TyVar x) = readRef x >>= either (const (return it)) follow
follow it = return it
occursIn :: TyVarName s -> Type s -> TypeInfM s Bool
x `occursIn` t = do
t' <- follow t
case t' of
TyVar y -> return (x == y)
a :-> b -> do
inA <- x `occursIn` a
inB <- x `occursIn` b
return (inA || inB)
test :: Term -> Either String String
test t = runST
( removeExcept
$ removeState_ names
$ removeEnv []
$ typeOf t >>= showType
)
showType it = do
t <- follow it
case t of
TyVar x -> do Left x <- readRef x; return x
a :-> b -> do
a' <- showType a
b' <- showType b
return $ "(" ++ a' ++ ") -> (" ++ b' ++ ")"
names :: [Id]
names = concat [ map (:d) ['a'..'z'] | d <- ["", "'"] ++ map show [1..] ]
tId = Abs "x" $ Var "x"
tK = Abs "x" $ Abs "y" $ Var "x"
tApp = Abs "f" $ Abs "x" $ App (Var "f") (Var "x")
tS = Abs "f" $ Abs "g" $ Abs "x" $ App (App (Var "f") (Var "x"))
(App (Var "g") (Var "x"))
tSApp = Abs "x" $ App (Var "x" ) (Var "x")
tO = App tSApp tSApp
tests = [tId, tK, tApp, tS, tSApp, tO]
main = getArgs >>= putStrLn . show . test . (tests !!) . read . head