diff --git a/src/SqlClient.Samples/WebApi.Controllers/DataAccess.fs b/src/SqlClient.Samples/WebApi.Controllers/DataAccess.fs index 611fc2db..58caab2f 100644 --- a/src/SqlClient.Samples/WebApi.Controllers/DataAccess.fs +++ b/src/SqlClient.Samples/WebApi.Controllers/DataAccess.fs @@ -5,7 +5,7 @@ open FSharp.Data [] let AdventureWorks2012 = "name=AdventureWorks2012" -type QueryProducts = SqlCommandProvider<"T-SQL\Products.sql", AdventureWorks2012, DataDirectory = "App_Data"> +type QueryProducts = SqlCommandProvider<"T-SQL/Products.sql", AdventureWorks2012, DataDirectory = "App_Data"> type AdventureWorks = SqlProgrammabilityProvider diff --git a/src/SqlClient.Tests/ConfigurationTest.fs b/src/SqlClient.Tests/ConfigurationTest.fs index 19da0e38..0187b08c 100644 --- a/src/SqlClient.Tests/ConfigurationTest.fs +++ b/src/SqlClient.Tests/ConfigurationTest.fs @@ -26,6 +26,15 @@ let RuntimeConfig () = Configuration.GetConnectionStringAtRunTime name |> should equal ConfigurationManager.ConnectionStrings.[name].ConnectionString +[] +let CheckValidFileName() = + let expected = Some "c:\\mysqlfiles\\test.sql" + Configuration.GetValidFileName("test.sql", "c:\\mysqlfiles") |> should equal expected + Configuration.GetValidFileName("../test.sql", "c:\\mysqlfiles\\subfolder") |> should equal expected + Configuration.GetValidFileName("c:\\mysqlfiles/test.sql", "d:\\otherdrive") |> should equal expected + Configuration.GetValidFileName("../mysqlfiles/test.sql", "c:\\otherfolder") |> should equal expected + Configuration.GetValidFileName("a/b/c/../../../test.sql", "c:\\mysqlfiles") |> should equal expected + type Get42RelativePath = SqlCommandProvider<"sampleCommand.sql", "name=AdventureWorks2012", ResolutionFolder="MySqlFolder"> type Get42 = SqlCommandProvider<"SELECT 42", "name=AdventureWorks2012", ConfigFile = "appWithInclude.config"> \ No newline at end of file diff --git a/src/SqlClient/Configuration.fs b/src/SqlClient/Configuration.fs index 50be6669..9a129e16 100644 --- a/src/SqlClient/Configuration.fs +++ b/src/SqlClient/Configuration.fs @@ -20,20 +20,26 @@ open System.Threading.Tasks open System.Collections.Generic type Configuration() = - static let isInvalidPathChars = HashSet(Path.GetInvalidPathChars()) + static let invalidPathChars = HashSet(Path.GetInvalidPathChars()) + static let invalidFileChars = HashSet(Path.GetInvalidFileNameChars()) + + static member GetValidFileName (file:string, resolutionFolder:string) = + if (file.Contains "\n") || (resolutionFolder.Contains "\n") then None else + let f = Path.Combine(resolutionFolder, file) + if invalidPathChars.Overlaps (Path.GetDirectoryName f) || + invalidFileChars.Overlaps (Path.GetFileName f) then None + else + // Canonicalizing the path may throw on bad input, the check above does not cover every error. + try Some (Path.GetFullPath f) with | _ -> None static member ParseTextAtDesignTime(commandTextOrPath : string, resolutionFolder, invalidateCallback) = - if isInvalidPathChars.Overlaps( commandTextOrPath) - then commandTextOrPath, None - else - let path = Path.Combine(resolutionFolder, commandTextOrPath) - if File.Exists(path) |> not - then commandTextOrPath, None - else - if Path.GetExtension(commandTextOrPath) <> ".sql" then failwith "Only files with .sql extension are supported" - let watcher = new FileSystemWatcher(Filter = commandTextOrPath, Path = resolutionFolder) + match Configuration.GetValidFileName (commandTextOrPath, resolutionFolder) with + | Some path when File.Exists path -> + if Path.GetExtension(path) <> ".sql" then failwith "Only files with .sql extension are supported" + let watcher = new FileSystemWatcher(Filter = Path.GetFileName path, Path = Path.GetDirectoryName path) watcher.Changed.Add(fun _ -> invalidateCallback()) watcher.Renamed.Add(fun _ -> invalidateCallback()) + watcher.Deleted.Add(fun _ -> invalidateCallback()) watcher.EnableRaisingEvents <- true let task = Task.Factory.StartNew(fun () -> use stream = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.ReadWrite) @@ -41,6 +47,7 @@ type Configuration() = reader.ReadToEnd()) if not (task.Wait(TimeSpan.FromSeconds(1.))) then failwithf "Couldn't read command from file %s" path task.Result, Some watcher + | _ -> commandTextOrPath, None static member ParseConnectionStringName(s: string) = match s.Trim().Split([|'='|], 2, StringSplitOptions.RemoveEmptyEntries) with