-- | Type checking and type inference
module L3.Core.Infer (weakInferType, inferType, inferType0, inferType1, wellTyped, wellTyped0) where

import Data.Char (isDigit)
import Data.List (intercalate)
import L3.Core.Eq
import L3.Core.Expr
import L3.Core.Normal
import L3.Log
import L3.Util

trace :: String -> a -> a
trace = String -> String -> a -> a
forall a. String -> String -> a -> a
traceU String
"Core::Infer"

-- | Type-check an expression and return the expression's type if type-checking
-- | succeeds or an error message if type-checking fails
-- | `inferType'` does not necessarily normalize the type since full normalization
-- | is not necessary for just type-checking.  If you actually care about the
-- | returned type then you may want to `normalize` it afterwards.
-- | Type inference is within a type context (list of global names and their types)
-- |
-- | 'Weak' type infernce here refers to the lack of partial evaluation for contextual
-- | beta-equivalence. For some X, a by-value and by-reference of X should be legal:
-- |   (λ (T : *) . λ (f : π (X : *) . X) . λ (x : T) . f x) X
-- | In fact, resolving T := X as beta-equivalent to X will fail for weakInferType.
weakInferType :: (Eq a, Enum a, Show a) => Context a -> Expr a -> Result (Expr a)
weakInferType :: Context a -> Expr a -> Result (Expr a)
weakInferType (Ctx [(a, Expr a)]
τ) Expr a
e = String -> Result (Expr a) -> Result (Expr a)
forall a. String -> a -> a
trace (String
"weakInferType " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)] -> String
forall a. Show a => a -> String
show [(a, Expr a)]
τ String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Expr a -> String
forall a. Show a => a -> String
show Expr a
e) (Context a -> Expr a -> Result (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
weakInferType' ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [(a, Expr a)]
τ) Expr a
e)

weakInferType' :: Context a -> Expr a -> Either Error (Expr a)
weakInferType' Context a
_ Expr a
Star = Expr a -> Either Error (Expr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr a
forall a. Expr a
Box
weakInferType' (Ctx [(a, Expr a)]
τ) Expr a
Box =
  Error -> Either Error (Expr a)
forall a b. a -> Either a b
Left (Error -> Either Error (Expr a)) -> Error -> Either Error (Expr a)
forall a b. (a -> b) -> a -> b
$
    [String] -> Error -> Error
rethrowError
      (String
"in context:" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: ((a, Expr a) -> String) -> [(a, Expr a)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (a, Expr a) -> String
forall a. Show a => a -> String
showIndent [(a, Expr a)]
τ)
      ( [String] -> Error
throwError
          [ String
"absurd box"
          ]
      )
weakInferType' (Ctx [(a, Expr a)]
τ) (Var a
v) = case a -> [(a, Expr a)] -> Maybe (Expr a)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup a
v [(a, Expr a)]
τ of
  Maybe (Expr a)
Nothing ->
    Error -> Either Error (Expr a)
forall a b. a -> Either a b
Left (Error -> Either Error (Expr a)) -> Error -> Either Error (Expr a)
forall a b. (a -> b) -> a -> b
$
      [String] -> Error -> Error
rethrowError
        (String
"in context:" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: ((a, Expr a) -> String) -> [(a, Expr a)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (a, Expr a) -> String
forall a. Show a => a -> String
showIndent [(a, Expr a)]
τ)
        ( [String] -> Error
throwError
            [ String
"unbound variable:",
              a -> String
forall a. Show a => a -> String
showIndent a
v
            ]
        )
  Just Expr a
e -> Expr a -> Either Error (Expr a)
forall a b. b -> Either a b
Right Expr a
e
weakInferType' (Ctx [(a, Expr a)]
τ) (Lam a
v Expr a
ta Expr a
b) = do
  Expr a
tb <- Context a -> Expr a -> Either Error (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
weakInferType ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx ((a
v, Expr a
ta) (a, Expr a) -> [(a, Expr a)] -> [(a, Expr a)]
forall a. a -> [a] -> [a]
: [(a, Expr a)]
τ)) Expr a
b
  let tf :: Expr a
tf = a -> Expr a -> Expr a -> Expr a
forall a. a -> Expr a -> Expr a -> Expr a
Pi a
v Expr a
ta Expr a
tb
  -- Types may themselves be well-typed, since they are expressions
  Expr a
_ <- Context a -> Expr a -> Either Error (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
weakInferType ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [(a, Expr a)]
τ) Expr a
tf
  Expr a -> Either Error (Expr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr a
tf
weakInferType' (Ctx [(a, Expr a)]
τ) (Pi a
v Expr a
ta Expr a
tb) = do
  Expr a
tta <- Context a -> Expr a -> Either Error (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
weakInferType ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [(a, Expr a)]
τ) Expr a
ta
  Expr a
ttb <- Context a -> Expr a -> Either Error (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
weakInferType ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx ((a
v, Expr a
ta) (a, Expr a) -> [(a, Expr a)] -> [(a, Expr a)]
forall a. a -> [a] -> [a]
: [(a, Expr a)]
τ)) Expr a
tb
  case (Expr a
tta, Expr a
ttb) of
    (Expr a
Star, Expr a
Star) -> Expr a -> Either Error (Expr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr a
forall a. Expr a
Star
    (Expr a
Box, Expr a
Star) -> Expr a -> Either Error (Expr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr a
forall a. Expr a
Star
    (Expr a
Star, Expr a
Box) -> Expr a -> Either Error (Expr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr a
forall a. Expr a
Box
    (Expr a
Box, Expr a
Box) -> Expr a -> Either Error (Expr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr a
forall a. Expr a
Box
    (Expr a
l, Expr a
r) ->
      Error -> Either Error (Expr a)
forall a b. a -> Either a b
Left (Error -> Either Error (Expr a)) -> Error -> Either Error (Expr a)
forall a b. (a -> b) -> a -> b
$
        [String] -> Error -> Error
rethrowError
          (String
"in context:" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: ((a, Expr a) -> String) -> [(a, Expr a)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (a, Expr a) -> String
forall a. Show a => a -> String
showIndent [(a, Expr a)]
τ)
          ( [String] -> Error
throwError
              [ String
"invalid type:",
                Expr a -> String
forall a. Show a => a -> String
showIndent (a -> Expr a -> Expr a -> Expr a
forall a. a -> Expr a -> Expr a -> Expr a
Pi a
v Expr a
ta Expr a
tb),
                String
"had left kind:",
                Expr a -> String
forall a. Show a => a -> String
showIndent Expr a
l,
                String
"had right kind:",
                Expr a -> String
forall a. Show a => a -> String
showIndent Expr a
r
              ]
          )
weakInferType' (Ctx [(a, Expr a)]
τ) (App Expr a
f Expr a
a) = do
  (a
v, Expr a
ta, Expr a
tb) <- case Context a -> Expr a -> Either Error (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
weakInferType ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [(a, Expr a)]
τ) Expr a
f of
    Right (Pi a
v Expr a
ta Expr a
tb) -> (a, Expr a, Expr a) -> Either Error (a, Expr a, Expr a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
v, Expr a
ta, Expr a
tb)
    Right Expr a
expr ->
      Error -> Either Error (a, Expr a, Expr a)
forall a b. a -> Either a b
Left (Error -> Either Error (a, Expr a, Expr a))
-> Error -> Either Error (a, Expr a, Expr a)
forall a b. (a -> b) -> a -> b
$
        [String] -> Error -> Error
rethrowError
          (String
"in context:" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: ((a, Expr a) -> String) -> [(a, Expr a)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (a, Expr a) -> String
forall a. Show a => a -> String
showIndent [(a, Expr a)]
τ)
          ( [String] -> Error
throwError
              [ String
"cannot apply to non-function:",
                Expr a -> String
forall a. Show a => a -> String
showIndent Expr a
f,
                String
"had type: ",
                Expr a -> String
forall a. Show a => a -> String
showIndent Expr a
expr,
                String
"had application:",
                Expr a -> String
forall a. Show a => a -> String
showIndent Expr a
a
              ]
          )
    Left Error
err -> Error -> Either Error (a, Expr a, Expr a)
forall a b. a -> Either a b
Left (Error -> Either Error (a, Expr a, Expr a))
-> Error -> Either Error (a, Expr a, Expr a)
forall a b. (a -> b) -> a -> b
$ Expr a -> Error
matchBind Expr a
f
      where
        matchBind :: Expr a -> Error
matchBind (Lam a
v Expr a
ta Expr a
b) = [String] -> Error -> Error
rethrowError [String
"with binding:", (a, Expr a) -> String
forall a. Show a => a -> String
showIndent (a
v, Expr a
a)] Error
err
        matchBind Expr a
f = [String] -> Error -> Error
rethrowError [String
"in expression:", Expr a -> String
forall a. Show a => a -> String
showIndent (Expr a -> String) -> Expr a -> String
forall a b. (a -> b) -> a -> b
$ Expr a -> Expr a -> Expr a
forall a. Expr a -> Expr a -> Expr a
App Expr a
f Expr a
a] Error
err
  Expr a
ta' <- Context a -> Expr a -> Either Error (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
weakInferType ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [(a, Expr a)]
τ) Expr a
a
  if Expr a
ta Expr a -> Expr a -> Bool
forall a. (Eq a, Enum a, Show a) => Expr a -> Expr a -> Bool
`betaEq` Expr a
ta'
    then Expr a -> Either Error (Expr a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr a -> Either Error (Expr a))
-> Expr a -> Either Error (Expr a)
forall a b. (a -> b) -> a -> b
$ a -> Expr a -> Expr a -> Expr a
forall a. (Eq a, Enum a, Show a) => a -> Expr a -> Expr a -> Expr a
substitute a
v Expr a
a Expr a
tb
    else
      Error -> Either Error (Expr a)
forall a b. a -> Either a b
Left (Error -> Either Error (Expr a)) -> Error -> Either Error (Expr a)
forall a b. (a -> b) -> a -> b
$
        [String] -> Error -> Error
rethrowError
          (String
"in context:" String -> [String] -> [String]
forall a. a -> [a] -> [a]
: ((a, Expr a) -> String) -> [(a, Expr a)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (a, Expr a) -> String
forall a. Show a => a -> String
showIndent [(a, Expr a)]
τ)
          ( [String] -> Error
throwError
              [ String
"type mismatch for function:",
                Expr a -> String
forall a. Show a => a -> String
showIndent Expr a
f,
                String
"expected type:",
                Expr a -> String
forall a. Show a => a -> String
showIndent Expr a
ta,
                String
"but given arg:",
                Expr a -> String
forall a. Show a => a -> String
showIndent Expr a
a,
                String
"and given type:",
                Expr a -> String
forall a. Show a => a -> String
showIndent Expr a
ta'
              ]
          )

-- | Type-check an expression and return the expression's normalized type if
-- | type-checking succeeds or an error message if type-checking fails
-- | Perform partial evaluation by substitution of lambda-applications to types
-- | to ensure the problem-case for `weakInferType` does not fail here.
inferType1 :: (Eq a, Enum a, Show a) => Context a -> Expr a -> Result (Expr a)
inferType1 :: Context a -> Expr a -> Result (Expr a)
inferType1 (Ctx [(a, Expr a)]
τ) Expr a
e = String -> Result (Expr a) -> Result (Expr a)
forall a. String -> a -> a
trace (String
"inferType1 " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)] -> String
forall a. Show a => a -> String
show [(a, Expr a)]
τ String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Expr a -> String
forall a. Show a => a -> String
show Expr a
e) (Context a -> Expr a -> Result (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
inferType1' ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [(a, Expr a)]
τ) Expr a
e)

inferType1' :: Context a -> Expr a -> Result (Expr a)
inferType1' Context a
τ Expr a
e = Context a -> Expr a -> Result (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
weakInferType Context a
τ Expr a
e

-- | Type-check an expression and return the expression's normalized type if
-- | type-checking succeeds or an error message if type-checking fails
inferType :: (Eq a, Enum a, Show a) => Context a -> Expr a -> Result (Expr a)
inferType :: Context a -> Expr a -> Result (Expr a)
inferType (Ctx [(a, Expr a)]
τ) Expr a
e = String -> Result (Expr a) -> Result (Expr a)
forall a. String -> a -> a
trace (String
"inferType " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)] -> String
forall a. Show a => a -> String
show [(a, Expr a)]
τ String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Expr a -> String
forall a. Show a => a -> String
show Expr a
e) (Context a -> Expr a -> Result (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
inferType' ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [(a, Expr a)]
τ) Expr a
e)

inferType' :: Context a -> Expr a -> Result (Expr a)
inferType' Context a
τ Expr a
e = (Expr a -> Expr a) -> Result (Expr a) -> Result (Expr a)
forall a b. (a -> b) -> Result a -> Result b
mapR Expr a -> Expr a
forall a. (Eq a, Enum a, Show a) => Expr a -> Expr a
normalize (Result (Expr a) -> Result (Expr a))
-> Result (Expr a) -> Result (Expr a)
forall a b. (a -> b) -> a -> b
$ Context a -> Expr a -> Result (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
inferType1 Context a
τ Expr a
e

-- | `inferType0` is the same as `inferType` with an empty context, meaning that
-- | the expression must be closed (i.e. no free variables), otherwise type-checking
-- | will fail.
inferType0 :: (Eq a, Enum a, Show a) => Expr a -> Result (Expr a)
inferType0 :: Expr a -> Result (Expr a)
inferType0 Expr a
e = String -> Result (Expr a) -> Result (Expr a)
forall a. String -> a -> a
trace (String
"inferType0 " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Expr a -> String
forall a. Show a => a -> String
show Expr a
e) (Expr a -> Result (Expr a)
forall a. (Eq a, Enum a, Show a) => Expr a -> Result (Expr a)
inferType0' Expr a
e)

inferType0' :: (Eq a, Enum a, Show a) => Expr a -> Result (Expr a)
inferType0' :: Expr a -> Result (Expr a)
inferType0' = Context a -> Expr a -> Result (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
inferType ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [])

-- | Deduce if an expression e is well-typed - i.e. its type can be inferred.
wellTyped :: (Eq a, Enum a, Show a) => Context a -> Expr a -> Bool
wellTyped :: Context a -> Expr a -> Bool
wellTyped (Ctx [(a, Expr a)]
τ) Expr a
e = String -> Bool -> Bool
forall a. String -> a -> a
trace (String
"wellTyped " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)] -> String
forall a. Show a => a -> String
show [(a, Expr a)]
τ String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Expr a -> String
forall a. Show a => a -> String
show Expr a
e) (Context a -> Expr a -> Bool
forall a. (Eq a, Enum a, Show a) => Context a -> Expr a -> Bool
wellTyped' ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [(a, Expr a)]
τ) Expr a
e)

wellTyped' :: Context a -> Expr a -> Bool
wellTyped' Context a
τ Expr a
e = case Context a -> Expr a -> Result (Expr a)
forall a.
(Eq a, Enum a, Show a) =>
Context a -> Expr a -> Either Error (Expr a)
inferType Context a
τ Expr a
e of
  Left Error
_ -> Bool
False
  Right Expr a
_ -> Bool
True

-- | Deduce if an expression is well-typed context-free - i.e. it is additionally
-- | closed and therefore well-typed without additional context.
wellTyped0 :: (Eq a, Enum a, Show a) => Expr a -> Bool
wellTyped0 :: Expr a -> Bool
wellTyped0 Expr a
e = String -> Bool -> Bool
forall a. String -> a -> a
trace (String
"wellTyped0 " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Expr a -> String
forall a. Show a => a -> String
show Expr a
e) (Expr a -> Bool
forall a. (Eq a, Enum a, Show a) => Expr a -> Bool
wellTyped0' Expr a
e)

wellTyped0' :: (Eq a, Enum a, Show a) => Expr a -> Bool
wellTyped0' :: Expr a -> Bool
wellTyped0' = Context a -> Expr a -> Bool
forall a. (Eq a, Enum a, Show a) => Context a -> Expr a -> Bool
wellTyped ([(a, Expr a)] -> Context a
forall a. [(a, Expr a)] -> Context a
Ctx [])