{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE IncoherentInstances #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE UndecidableInstances #-} module Language.Fiddle.Compiler.ConsistencyCheck (consistencyCheckPhase) where import Control.Monad (forM_, when) import Control.Monad.RWS (MonadWriter (tell)) import Control.Monad.Trans.Writer (Writer, execWriter) import Data.Foldable (foldlM, toList) import Data.Functor.Identity import qualified Data.List.NonEmpty as NonEmpty import Data.Typeable import Data.Word (Word32) import GHC.TypeError as TypeError import Language.Fiddle.Ast import Language.Fiddle.Compiler import Language.Fiddle.Internal.UnitInterface as UnitInterface import Language.Fiddle.Types import Text.Printf (printf) import Prelude hiding (unzip) type S = Qualified type S' = Checked type F = Identity type A = Commented SourceSpan type M = Compile () instance CompilationStage Checked where type StageAfter Checked = TypeError (TypeError.Text "No stage after Checked") type StageMonad Checked = M type StageState Checked = () type StageFunctor Checked = Identity type StageAnnotation Checked = A instance CompilationStage S where type StageAfter S = S' type StageMonad S = M type StageState S = () type StageFunctor S = F type StageAnnotation S = A consistencyCheckPhase :: CompilationPhase S S' consistencyCheckPhase = pureCompilationPhase $ advanceStage () instance AdvanceStage S ObjTypeBody where advanceStage () objTypeBody = snd <$> advanceObjTypeBody objTypeBody 0 deriving instance AdvanceStage S DeferredRegisterBody deriving instance AdvanceStage S RegisterBody deriving instance AdvanceStage S AnonymousBitsType deriving instance AdvanceStage S ImportStatement deriving instance AdvanceStage S BitType deriving instance AdvanceStage S EnumBody deriving instance AdvanceStage S EnumConstantDecl deriving instance AdvanceStage S RegisterBitsDecl deriving instance AdvanceStage S PackageBody deriving instance AdvanceStage S FiddleDecl instance AdvanceStage S FiddleUnit where advanceStage () fu@(FiddleUnit _ decls a) = FiddleUnit (getUnitInterface fu) <$> mapM (advanceStage ()) decls <*> pure a where getUnitInterface = execWriter . walk_ doWalk doWalk :: forall t'. (Walk t', Typeable t') => t' F A -> Writer UnitInterface () doWalk t = case () of () | (Just (PackageDecl {packageQualificationMetadata = (Identity d)})) <- castTS t -> tell (UnitInterface.singleton d) | (Just (LocationDecl {locationQualificationMetadata = (Identity d)})) <- castTS t -> tell (UnitInterface.singleton d) | (Just (BitsDecl {bitsQualificationMetadata = (Identity d)})) <- castTS t -> tell (UnitInterface.singleton d) | (Just (ObjTypeDecl {objTypeQualificationMetadata = (Identity d)})) <- castTS t -> tell (UnitInterface.singleton d) | (Just (ObjectDecl {objectQualificationMetadata = (Identity d)})) <- castTS t -> tell (UnitInterface.singleton d) | (Just (ImportStatement {importInterface = ii})) <- castTS t -> tell (UnitInterface mempty (dependencies ii)) _ -> return () castTS :: (Typeable t', Typeable t, Typeable f, Typeable a) => t' f a -> Maybe (t S f a) castTS = cast deriving instance AdvanceStage S Expression deriving instance AdvanceStage S RegisterBitsTypeRef deriving instance AdvanceStage S ObjType deriving instance (AdvanceStage S t) => AdvanceStage S (Directed t) advanceObjTypeBody :: ObjTypeBody S F A -> Word32 -> M (Word32, ObjTypeBody S' F A) advanceObjTypeBody (ObjTypeBody us decls a) startOffset = do (decls', _) <- advanceDecls calcSize <- case us of Union {} -> do checkJagged decls' return $ maximum (map fst decls') Struct {} -> return $ sum (map fst decls') return (calcSize, ObjTypeBody us (reverse $ map snd decls') a) where advanceDecls :: M ([(Word32, Directed ObjTypeDecl S' F A)], Word32) advanceDecls = do foldlM ( \(ret, offset) d -> let advanceOffset = case us of Union {} -> const Struct {} -> (+) doReturn x size = return ((size, mapDirected (const x) d) : ret, advanceOffset offset size) in case undirected d of e@AssertPosStatement {assertExpr = expr} -> do assertedPos <- expressionToIntM expr checkPositionAssertion (annot e) assertedPos offset return (ret, offset) (RegisterDecl mod ident size Nothing a) -> do (sizeExpr, reifiedSize) <- advanceAndGetSize size doReturn (RegisterDecl mod ident sizeExpr Nothing a) =<< checkBitsSizeMod8 a reifiedSize (RegisterDecl mod ident size (Just body) a) -> do declaredSize <- expressionToIntM size (actualSize, body') <- advanceRegisterBody body checkSizeMismatch a declaredSize actualSize (sizeExpr, reifiedSize) <- advanceAndGetSize size doReturn (RegisterDecl mod ident sizeExpr (Just body') a) =<< checkBitsSizeMod8 a reifiedSize (ReservedDecl size a) -> do (sizeExpr, reifiedSize) <- advanceAndGetSize size doReturn (ReservedDecl sizeExpr a) reifiedSize (TypeSubStructure (Identity body) name a) -> do (size, body') <- advanceObjTypeBody body offset doReturn (TypeSubStructure (Identity body') name a) size ) (([], startOffset) :: ([(Word32, Directed ObjTypeDecl S' F A)], Word32)) decls advanceAndGetSize e = (,) <$> advanceStage () e <*> expressionToIntM e pattern RegisterBodyPattern :: BodyType F A -> [Directed RegisterBitsDecl s F A] -> A -> A -> RegisterBody s F A pattern RegisterBodyPattern u decls a b = RegisterBody u (Identity (DeferredRegisterBody decls b)) a -- registerBodyPattern u decls a b = RegisterBody u (Identity (DeferredRegisterBody decls a)) a advanceRegisterBody :: RegisterBody S F A -> M (Word32, RegisterBody S' F A) -- Handle the case where it's a union. advanceRegisterBody (RegisterBodyPattern us (NonEmpty.nonEmpty -> Just decls) a b) = do decls' <- mapM ( \d -> do (sz, t) <- advanceDecl (undirected d) return (sz, mapDirected (const t) d) ) decls calcSize <- case us of Union {} -> do checkJagged (toList decls') return $ maximum (map fst (toList decls')) Struct {} -> do return $ sum (map fst (toList decls')) return (calcSize, RegisterBodyPattern us (map snd $ toList decls') a b) -- Handle the case where there's no decls. advanceRegisterBody (RegisterBodyPattern u _ a b) = return (0, RegisterBodyPattern u [] a b) advanceRegisterBody RegisterBody {} = error "GHC not smart enuf" checkJagged :: (Annotated t) => [(Word32, t f A)] -> Compile s () checkJagged decls = do let expectedSize = maximum (fmap fst decls) forM_ decls $ \(sz, annot -> a) -> when (sz /= expectedSize) $ emitDiagnosticWarning ( printf "[JaggedUnion] - All elements of a union should be the same size. \ \ this element is size %d, expected size %d. Maybe bundle this with \ \ reserved(%d)?" sz expectedSize (expectedSize - sz) ) a advanceDecl :: RegisterBitsDecl S F A -> M (Word32, RegisterBitsDecl S' F A) advanceDecl = \case ReservedBits expr an -> do sz <- expressionToIntM expr (sz,) <$> ( ReservedBits <$> advanceStage () expr <*> pure an ) DefinedBits mod ident typ annot -> do size <- bitsTypeSize typ (size,) <$> (DefinedBits mod ident <$> advanceStage () typ <*> pure annot) BitsSubStructure subBody subName ann -> do (sz, body') <- advanceRegisterBody subBody return (sz, BitsSubStructure body' subName ann) bitsTypeSize :: RegisterBitsTypeRef S F A -> M Word32 bitsTypeSize (RegisterBitsArray tr nExpr _) = do sz <- bitsTypeSize tr n <- expressionToIntM nExpr return (sz * n) bitsTypeSize RegisterBitsReference { bitsRefQualificationMetadata = Identity (ExportedBitsDecl {exportedBitsDeclSizeBits = sz}) } = return sz bitsTypeSize (RegisterBitsJustBits expr _) = expressionToIntM expr checkSizeMismatch :: A -> Word32 -> Word32 -> Compile s () checkSizeMismatch _ a b | a == b = return () checkSizeMismatch pos declaredSize calculatedSize = emitDiagnosticError ( printf "Size assertion failed. Declared size %d, calculated %d" declaredSize calculatedSize ) pos checkPositionAssertion :: A -> Word32 -> Word32 -> Compile s () checkPositionAssertion _ a b | a == b = return () checkPositionAssertion pos declaredPosition calculatedPostion = emitDiagnosticError ( printf "Position assertion failed. Asserted 0x%x, calculated 0x%x" declaredPosition calculatedPostion ) pos expressionToIntM :: (Integral i, Integral (NumberType stage)) => Expression stage f A -> Compile s i expressionToIntM expr = resolveOrFail $ either ( \reason -> Left [Diagnostic Error reason (unCommented $ annot expr)] ) return (expressionToInt expr) checkBitsSizeMod8 :: A -> Word32 -> M Word32 checkBitsSizeMod8 _ w | w `mod` 8 == 0 = return (w `div` 8) checkBitsSizeMod8 a w = do emitDiagnosticWarning (printf "Register size %d is not a multiple of 8. Please add padding to this register." w) a return ((w `div` 8) + 1)