RemovePatBinds.hs

module RemovePatBinds where

--import SrcLoc
import HasBaseStruct
import HsDeclStruct
import HsPatStruct
import HsPatUtil(isPVar)
import HsExpStruct
import HsGuardsStruct
import HsLiteral 
--import HsModule
import HsModuleMaps(mapDecls)
import DefinedNames
import MapDeclM
import HsIdent
import TiNames(ValueId,localVal)

import Maybe(isNothing)
import Monad(liftM)
import StateM
import PrettyPrint(pp,(<>))

-- should add a class here...

remPats pErr = remPatBinds' pErr names
  where names = [ localVal ("patbind__" ++ show n) | n <- [0..] ]

remPatBinds' pErr names = {-mapDecls-} (removePatBinds pErr names)

-- a perhaps more convinient version...
removePatBinds pErr names = withSt names . removePatBinds' err
  where
    -- the error we give if a pattern binding fails
    err pos =
      hsApp (hsId pErr) 
            (hsLit pos $ HsString $ pp $ pos<>": pattern binding failed")

-- removes all pattern bindings from a list of declarations (recursively)
{-
removePatBinds' :: 
    ( GetBaseStruct d (DI i e p [d] t c tp)
    , GetBaseStruct p (PI i p)
    , HasBaseStruct d (DI i e p [d] t c tp)
    , HasBaseStruct p (PI i p)
    , HasBaseStruct e (EI i e p [d] t c)
    , DefinedNames i p
    , ValueId i
    , MapDeclM d [d] )
    => [d] -> StateM [i] [d]
-}
removePatBinds' err ds = liftM concat $ mapM remPB ds
    where
    remPB d = do
        d' <- mapDeclM (removePatBinds' err) d
        case basestruct d' of
            Just d''@(HsPatBind _ pat _ _) 
                | isNothing (isPVar pat) -> do
                    n:_ <- updSt tail
                    return (remPatBind err n d'')
            _ -> return [d']

-- remove pattern bindings in a single declaration, non-recursive 
{-
remPatBind :: 
    ( HasBaseStruct d (DI i e p [d] t c tp)
    , HasBaseStruct p (PI i p)
    , HasBaseStruct e (EI i e p [d] t c)
    , ValueId i
    , DefinedNames i p
    )
    => i -> DI i e p [d] t c tp -> [d]
-}
remPatBind err n (HsPatBind pos pat rhs ds) 
    | nVars == 1 = [hsPatBind pos (hsPId (head vars)) tple ds]
    | otherwise = hsPatBind pos (hsPVar n) tple ds
                    : zipWith (mkDecl pos n nVars) [1..] vars
    where 
    vars    = definedVars pat 
    nVars   = length vars
    caseBody
        | nVars == 1    = hsId (head vars)
        | otherwise     = hsTuple (map hsId vars)

    toCase e = hsCase e 
        [ HsAlt pos pat (HsBody caseBody) ds
        , HsAlt pos hsPWildCard (HsBody (err pos)) []
        ]

    mkDecl pos n nVars x v = hsPatBind pos (hsPId v) (HsBody body) []
        where
        body = hsCase (hsEVar n) [ HsAlt pos pat (HsBody (hsEVar i)) [] ]
        i    = localVal "x"
        wcs  = repeat hsPWildCard 
        pat  = hsPTuple pos $ take (x - 1) wcs ++ [hsPVar i] ++ take (nVars - x) wcs


    tple = HsBody $ case rhs of
        HsBody e -> toCase e
        HsGuard gs -> foldr (\(_,g,e) -> hsIf g (toCase e)) (err pos) gs

Plain-text version of RemovePatBinds.hs | Valid HTML?