diff --git a/src/fsharp/build.fs b/src/fsharp/build.fs index 276f833e1dd..cda122e9a8d 100644 --- a/src/fsharp/build.fs +++ b/src/fsharp/build.fs @@ -3875,7 +3875,9 @@ type TcImports(tcConfigP:TcConfigProvider, initialResolutions:TcAssemblyResoluti // Add the invalidation signal handlers to each provider for provider in providers do - provider.PUntaint((fun tp -> tp.Invalidate.Add(fun _ -> invalidateCcu.Trigger ("The provider '" + fileNameOfRuntimeAssembly + "' reported a change"))), m) + provider.PUntaint((fun tp -> + let handler = tp.Invalidate.Subscribe(fun _ -> invalidateCcu.Trigger ("The provider '" + fileNameOfRuntimeAssembly + "' reported a change")) + tcImports.AttachDisposeAction(fun () -> try handler.Dispose() with _ -> ())), m) match providers with | [] -> diff --git a/vsintegration/src/unittests/Resources.MockTypeProviders/DummyProviderForLanguageServiceTesting/TypeProviderEmit.fs b/vsintegration/src/unittests/Resources.MockTypeProviders/DummyProviderForLanguageServiceTesting/TypeProviderEmit.fs index 82c59297daf..ecdced32403 100644 --- a/vsintegration/src/unittests/Resources.MockTypeProviders/DummyProviderForLanguageServiceTesting/TypeProviderEmit.fs +++ b/vsintegration/src/unittests/Resources.MockTypeProviders/DummyProviderForLanguageServiceTesting/TypeProviderEmit.fs @@ -1429,6 +1429,13 @@ module Local = failwith (sprintf "Unknown type '%s' in namespace '%s' (contains %s)" typeName namespaceName typenames) } +// Used by unit testing to check that invalidation handlers are being disconnected +module GlobalCountersForInvalidation = + let mutable invalidationHandlersAdded = 0 + let mutable invalidationHandlersRemoved = 0 + let GetInvalidationHandlersAdded() = invalidationHandlersAdded + let GetInvalidationHandlersRemoved() = invalidationHandlersRemoved + type TypeProviderForNamespaces(namespacesAndTypes : list<(string * list)>) = let otherNamespaces = ResizeArray>() @@ -1440,6 +1447,26 @@ type TypeProviderForNamespaces(namespacesAndTypes : list<(string * list() + let invalidateP = invalidateE.Publish + let invalidatePCounting = + { new obj() with + member x.ToString() = "" + interface IEvent + interface IDelegateEvent with + member e.AddHandler(d) = + GlobalCountersForInvalidation.invalidationHandlersAdded <- GlobalCountersForInvalidation.invalidationHandlersAdded + 1 + invalidateP.AddHandler(d) + member e.RemoveHandler(d) = + GlobalCountersForInvalidation.invalidationHandlersRemoved <- GlobalCountersForInvalidation.invalidationHandlersRemoved + 1 + invalidateP.RemoveHandler(d) + interface System.IObservable with + member e.Subscribe(observer) = + GlobalCountersForInvalidation.invalidationHandlersAdded <- GlobalCountersForInvalidation.invalidationHandlersAdded + 1 + let d = invalidateP.Subscribe(observer) + { new System.IDisposable with + member x.Dispose() = + GlobalCountersForInvalidation.invalidationHandlersRemoved <- GlobalCountersForInvalidation.invalidationHandlersRemoved + 1 + d.Dispose() } } new (namespaceName:string,types:list) = new TypeProviderForNamespaces([(namespaceName,types)]) new () = new TypeProviderForNamespaces([]) @@ -1450,7 +1477,7 @@ type TypeProviderForNamespaces(namespacesAndTypes : list<(string * list] - override this.Invalidate = invalidateE.Publish + override this.Invalidate = invalidatePCounting override this.GetNamespaces() = Array.copy providedNamespaces.Value member __.GetInvokerExpression(methodBase, parameters) = match methodBase with diff --git a/vsintegration/src/unittests/Tests.LanguageService.Script.fs b/vsintegration/src/unittests/Tests.LanguageService.Script.fs index 270428f94d2..32997a153ba 100644 --- a/vsintegration/src/unittests/Tests.LanguageService.Script.fs +++ b/vsintegration/src/unittests/Tests.LanguageService.Script.fs @@ -1629,16 +1629,30 @@ type ScriptTests() as this = let totalDisposalsMeth = providerCounters.GetMethod("GetTotalDisposals") Assert.IsNotNull(totalDisposalsMeth, "totalDisposalsMeth should not be null") + let providerCounters2 = providerAssembly.GetType("Microsoft.FSharp.TypeProvider.Emit.GlobalCountersForInvalidation") + Assert.IsNotNull(providerCounters2, "provider counters #2 module should not be null") + let totalInvaldiationHandlersAddedMeth = providerCounters2.GetMethod("GetInvalidationHandlersAdded") + Assert.IsNotNull(totalInvaldiationHandlersAddedMeth, "totalInvaldiationHandlersAddedMeth should not be null") + let totalInvaldiationHandlersRemovedMeth = providerCounters2.GetMethod("GetInvalidationHandlersRemoved") + Assert.IsNotNull(totalInvaldiationHandlersRemovedMeth, "totalInvaldiationHandlersRemovedMeth should not be null") + let totalCreations() = totalCreationsMeth.Invoke(null, [| |]) :?> int let totalDisposals() = totalDisposalsMeth.Invoke(null, [| |]) :?> int + let totalInvaldiationHandlersAdded() = totalInvaldiationHandlersAddedMeth.Invoke(null, [| |]) :?> int + let totalInvaldiationHandlersRemoved() = totalInvaldiationHandlersRemovedMeth.Invoke(null, [| |]) :?> int let startCreations = totalCreations() let startDisposals = totalDisposals() + let startInvaldiationHandlersAdded = totalInvaldiationHandlersAdded() + let startInvaldiationHandlersRemoved = totalInvaldiationHandlersRemoved() let countCreations() = totalCreations() - startCreations let countDisposals() = totalDisposals() - startDisposals + let countInvaldiationHandlersAdded() = totalInvaldiationHandlersAdded() - startInvaldiationHandlersAdded + let countInvaldiationHandlersRemoved() = totalInvaldiationHandlersRemoved() - startInvaldiationHandlersRemoved Assert.IsTrue(startCreations >= startDisposals, "Check0") + Assert.IsTrue(startInvaldiationHandlersAdded >= startInvaldiationHandlersRemoved, "Check0") for i in 1 .. 50 do let solution = this.CreateSolution() let project = CreateProject(solution,"testproject" + string (i % 20)) @@ -1662,20 +1676,25 @@ type ScriptTests() as this = Assert.IsTrue(countDisposals() < i, "Check1, countDisposals() < i, iteration " + string i) Assert.IsTrue(countCreations() >= countDisposals(), "Check2, countCreations() >= countDisposals(), iteration " + string i) Assert.IsTrue(countCreations() = i, "Check3, countCreations() = i, iteration " + string i) + Assert.IsTrue(countInvaldiationHandlersAdded() = i, "Check3b, countInvaldiationHandlersAdded() = i, iteration " + string i) if not clearing then // By default we hold 3 build incrementalBuilderCache entries and 5 typeCheckInfo entries, so if we're not clearing // there should be some roots to project builds still present if i >= 3 then Assert.IsTrue(i >= countDisposals() + 3, "Check4a, i >= countDisposals() + 3, iteration " + string i + ", i = " + string i + ", countDisposals() = " + string (countDisposals())) + Assert.IsTrue(i >= countInvaldiationHandlersRemoved() + 3, "Check4a2, i >= countInvaldiationHandlersRemoved() + 3, iteration " + string i + ", i = " + string i + ", countDisposals() = " + string (countDisposals())) // If we forcefully clear out caches and force a collection, then we can say much stronger things... if clearing then ClearLanguageServiceRootCachesAndCollectAndFinalizeAllTransients(this.VS) Assert.IsTrue((i = countDisposals()), "Check4b, countCreations() = countDisposals(), iteration " + string i) + Assert.IsTrue((i = countInvaldiationHandlersRemoved()), "Check4b2, countCreations() = countInvaldiationHandlersRemoved(), iteration " + string i) Assert.IsTrue(countCreations() = 50, "Check5, at end, countCreations() = 50") + Assert.IsTrue(countInvaldiationHandlersAdded() = 50, "Check5, at end, countCreations() = 50") ClearLanguageServiceRootCachesAndCollectAndFinalizeAllTransients(this.VS) - Assert.IsTrue(countDisposals() = 50, "Check6b, at end, countDisposals() = 50 when clearing") + Assert.IsTrue(countDisposals() = 50, "Check6b, at end, countDisposals() = 50 after explicit clearing") + Assert.IsTrue(countInvaldiationHandlersRemoved() = 50, "Check5, at end, countInvaldiationHandlersRemoved() = 50 after explicit cleraring") [] []