From b3d5aabe3a33e2f4f8d7dad02bd3d6102ced395d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Hr=C4=8Dek?= Date: Sat, 13 Jul 2024 09:52:26 +0200 Subject: [PATCH] Add refs.is_generated to distinguish references from source/generated by ghc --- src/HieDb/Compat.hs | 29 +++++++++++++++++++++++++++++ src/HieDb/Create.hs | 12 ++++++------ src/HieDb/Run.hs | 4 ++-- src/HieDb/Types.hs | 28 +++++++++++++++------------- src/HieDb/Utils.hs | 19 +++++++++---------- 5 files changed, 61 insertions(+), 31 deletions(-) diff --git a/src/HieDb/Compat.hs b/src/HieDb/Compat.hs index 4710eec..74d682a 100644 --- a/src/HieDb/Compat.hs +++ b/src/HieDb/Compat.hs @@ -88,6 +88,8 @@ module HieDb.Compat ( , fieldNameSpace_maybe , fieldName , mkFastStringByteString + , generateReferencesMap2 + , RefMap2 ) where import Compat.HieTypes @@ -216,3 +218,30 @@ fieldNameSpace_maybe _ = Nothing fieldName :: FastString -> NameSpace fieldName _ = varName #endif + +-- We want to distinguish between references from source (NodeOrigin is SourceInfo) +-- vs. generated by compiler (NodeOrigin is GeneratedInfo). +-- +-- But GHC's generateReferencesMap throws away the info about NodeOrigin. +-- Compare: https://hackage.haskell.org/package/ghc-9.10.1/docs/GHC-Iface-Ext-Utils.html#t:RefMap +-- RefMap a = M.Map Identifier [( Span, IdentifierDetails a)] +type RefMap2 a = M.Map Identifier [(NodeOrigin, Span, IdentifierDetails a)] + +generateReferencesMap2 + :: Foldable f + => f (HieAST a) + -> RefMap2 a +generateReferencesMap2 = foldr (\ast m -> M.unionWith (++) (go ast) m) M.empty + where + go :: HieAST a -> RefMap2 a + go ast = M.unionsWith (++) (this : map go (nodeChildren ast)) + where + span = nodeSpan ast + this = + M.unionsWith (++) + $ M.mapWithKey + (\nodeOrigin nodeInfo -> + (\identDetails -> [(nodeOrigin, span, identDetails)]) <$> nodeIdentifiers nodeInfo + ) + $ getSourcedNodeInfo + $ sourcedNodeInfo ast diff --git a/src/HieDb/Create.hs b/src/HieDb/Create.hs index 8572734..3869fc0 100644 --- a/src/HieDb/Create.hs +++ b/src/HieDb/Create.hs @@ -37,7 +37,7 @@ import HieDb.Types import HieDb.Utils sCHEMA_VERSION :: Integer -sCHEMA_VERSION = 8 +sCHEMA_VERSION = 9 dB_VERSION :: Integer dB_VERSION = read (show sCHEMA_VERSION ++ "999" ++ show hieVersion) @@ -117,6 +117,7 @@ initConn (getConn -> conn) = do \, sc INTEGER NOT NULL \ \, el INTEGER NOT NULL \ \, ec INTEGER NOT NULL \ + \, is_generated BOOLEAN NOT NULL \ \, FOREIGN KEY(hieFile) REFERENCES mods(hieFile) ON UPDATE CASCADE ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED \ \)" execute_ conn "CREATE INDEX IF NOT EXISTS refs_mod ON refs(hieFile)" @@ -331,7 +332,7 @@ addRefsFromLoaded_unsafe mod = moduleName smod uid = moduleUnit smod smod = hie_module hf - refmap = generateReferencesMap $ getAsts $ hie_asts hf + refmap = generateReferencesMap2 $ getAsts $ hie_asts hf (srcFile, isReal) = case sourceFile of RealFile f -> (Just f, True) FakeFile mf -> (mf, False) @@ -339,10 +340,10 @@ addRefsFromLoaded_unsafe execute conn "INSERT INTO mods VALUES (?,?,?,?,?,?,?)" modrow - let AstInfo rows decls imports = genAstInfo path smod refmap + let AstInfo refs decls imports = genAstInfo path smod refmap unless (skipRefs skipOptions) $ - executeMany conn "INSERT INTO refs VALUES (?,?,?,?,?,?,?,?)" rows + executeMany conn "INSERT INTO refs VALUES (?,?,?,?,?,?,?,?,?)" refs unless (skipDecls skipOptions) $ executeMany conn "INSERT INTO decls VALUES (?,?,?,?,?,?,?)" decls unless (skipImports skipOptions) $ @@ -350,8 +351,7 @@ addRefsFromLoaded_unsafe let defs = genDefRow path smod refmap unless (skipDefs skipOptions) $ - forM_ defs $ \def -> - execute conn "INSERT INTO defs VALUES (?,?,?,?,?,?)" def + executeMany conn "INSERT INTO defs VALUES (?,?,?,?,?,?)" defs let exports = generateExports path $ hie_exports hf unless (skipExports skipOptions) $ diff --git a/src/HieDb/Run.hs b/src/HieDb/Run.hs index 1d6e396..d91ec87 100644 --- a/src/HieDb/Run.hs +++ b/src/HieDb/Run.hs @@ -372,8 +372,8 @@ runCommand libdir opts cmd = withHieDbAndFlags libdir (database opts) $ \dynFlag Just mod -> do reportRefs opts =<< findReferences conn False (nameOccName name) (Just $ moduleName mod) (Just $ moduleUnit mod) [] Nothing -> do - let refmap = generateReferencesMap (getAsts $ hie_asts hf) - refs = map (toRef . fst) $ M.findWithDefault [] (Right name) refmap + let refmap = generateReferencesMap2 $ getAsts $ hie_asts hf + refs = map (\(_, spn, _) -> toRef spn) $ M.findWithDefault [] (Right name) refmap toRef spn = (hie_module hf ,(srcSpanStartLine spn , srcSpanStartCol spn) ,(srcSpanEndLine spn , srcSpanEndCol spn) diff --git a/src/HieDb/Types.hs b/src/HieDb/Types.hs index 1dc32e6..529ed5a 100644 --- a/src/HieDb/Types.hs +++ b/src/HieDb/Types.hs @@ -150,15 +150,17 @@ data RefRow , refSCol :: Int , refELine :: Int , refECol :: Int + , refIsGenerated :: Bool -- ^ True if the reference to this name is generated by GHC (NodeOrigin is GeneratedInfo) + -- False if it comes from the source code (NodeOrigin is SourceInfo) } instance ToRow RefRow where - toRow (RefRow a b c d e f g h) = toRow ((a,b,c):.(d,e,f):.(g,h)) + toRow (RefRow a b c d e f g h i) = toRow ((a,b,c):.(d,e,f):.(g,h,i)) instance FromRow RefRow where fromRow = RefRow <$> field <*> field <*> field <*> field <*> field <*> field - <*> field <*> field + <*> field <*> field <*> field data DeclRow = DeclRow @@ -178,23 +180,23 @@ instance FromRow DeclRow where fromRow = DeclRow <$> field <*> field <*> field <*> field <*> field <*> field <*> field -data ImportRow - = ImportRow +data ImportRow + = ImportRow { importSrc :: FilePath , importModuleName :: ModuleName - , importSLine :: Int - , importSCol :: Int - , importELine :: Int - , importECol :: Int + , importSLine :: Int + , importSCol :: Int + , importELine :: Int + , importECol :: Int } -instance FromRow ImportRow where - fromRow = - ImportRow - <$> field <*> field <*> field <*> field +instance FromRow ImportRow where + fromRow = + ImportRow + <$> field <*> field <*> field <*> field <*> field <*> field -instance ToRow ImportRow where +instance ToRow ImportRow where toRow (ImportRow a b c d e f) = toRow ((a,b,c,d):.(e,f)) data TypeName = TypeName diff --git a/src/HieDb/Utils.hs b/src/HieDb/Utils.hs index b1c8339..3d41d45 100644 --- a/src/HieDb/Utils.hs +++ b/src/HieDb/Utils.hs @@ -179,18 +179,17 @@ instance Semigroup AstInfo where instance Monoid AstInfo where mempty = AstInfo [] [] [] -genAstInfo :: FilePath -> Module -> M.Map Identifier [(Span, IdentifierDetails a)] -> AstInfo +genAstInfo :: FilePath -> Module -> RefMap2 a -> AstInfo genAstInfo path smdl refmap = genRows $ flat $ M.toList refmap where flat = concatMap (\(a,xs) -> map (a,) xs) - genRows = foldMap go - go = mkAstInfo + genRows = foldMap mkAstInfo mkAstInfo x = AstInfo (maybeToList $ goRef x) (maybeToList $ goDec x) (maybeToList $ goImport x) - goRef (Right name, (sp,_)) + goRef (Right name, (nodeOrigin, sp, _)) | Just mod <- nameModule_maybe name = Just $ - RefRow path occ (moduleName mod) (moduleUnit mod) sl sc el ec + RefRow path occ (moduleName mod) (moduleUnit mod) sl sc el ec (nodeOrigin == GeneratedInfo) where occ = nameOccName name sl = srcSpanStartLine sp @@ -199,7 +198,7 @@ genAstInfo path smdl refmap = genRows $ flat $ M.toList refmap ec = srcSpanEndCol sp goRef _ = Nothing - goImport (Left modName, (sp, IdentifierDetails _ contextInfos)) = do + goImport (Left modName, (_, sp, IdentifierDetails _ contextInfos)) = do _ <- guard $ not $ S.disjoint contextInfos $ S.fromList [IEThing Import, IEThing ImportAs, IEThing ImportHiding] let sl = srcSpanStartLine sp @@ -209,7 +208,7 @@ genAstInfo path smdl refmap = genRows $ flat $ M.toList refmap Just $ ImportRow path modName sl sc el ec goImport _ = Nothing - goDec (Right name,(_,dets)) + goDec (Right name,(_,_,dets)) | Just mod <- nameModule_maybe name , mod == smdl , occ <- nameOccName name @@ -235,17 +234,17 @@ genAstInfo path smdl refmap = genRows $ flat $ M.toList refmap goDecl (RecField _ sp) = sp goDecl _ = Nothing -genDefRow :: FilePath -> Module -> M.Map Identifier [(Span, IdentifierDetails a)] -> [DefRow] +genDefRow :: FilePath -> Module -> RefMap2 a -> [DefRow] genDefRow path smod refmap = genRows $ M.toList refmap where genRows = mapMaybe go getSpan name dets | RealSrcSpan sp _ <- nameSrcSpan name = Just sp | otherwise = do - (sp, _dets) <- find defSpan dets + (_, sp, _dets) <- find defSpan dets pure sp - defSpan = any isDef . identInfo . snd + defSpan (_, _, dets)= any isDef $ identInfo dets isDef (ValBind RegularBind _ _) = True isDef PatternBind{} = True isDef Decl{} = True