diff --git a/.vsts-dotnet-ci.yml b/.vsts-dotnet-ci.yml
index c2e0e2c0d4..4e18da7c90 100644
--- a/.vsts-dotnet-ci.yml
+++ b/.vsts-dotnet-ci.yml
@@ -11,5 +11,3 @@ phases:
buildScript: build.cmd
queue:
name: Hosted VS2017
- demands:
- - agent.os -equals Windows_NT
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index da15d0493f..321ad3dcf0 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -16,7 +16,7 @@ If you are new to GitHub [here](https://help.github.com/categories/collaborating
As a first time contributor, you will be invited to sign the Contributor License Agreement (CLA). Please follow the instructions of the dotnet foundation bot reviewer on your PR to sign the agreement indicating that you have appropriate rights to your contribution.
-Your pull request needs to reference a filed issue. Please fill in the template that is populated for the pull request. Only pull requests adressing small typos can have no issues associated with them.
+Your pull request needs to reference a filed issue. Please fill in the template that is populated for the pull request. Only pull requests addressing small typos can have no issues associated with them.
An ML.NET team member will be assigned to your pull request once the continuous integration checks have passed successfully.
diff --git a/Directory.Build.props b/Directory.Build.props
index 73144201c7..bdca231554 100644
--- a/Directory.Build.props
+++ b/Directory.Build.props
@@ -114,4 +114,11 @@
true
+
+ $(Configuration.EndsWith('-Intrinsics'))
+
+
+
+ $(RepoRoot)build\AfterCommonTargets.targets
+
diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index 58e24041f1..18d9d3867e 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -97,6 +97,19 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.PerformanceTests", "test\Microsoft.ML.CpuMath.PerformanceTests\Microsoft.ML.CpuMath.PerformanceTests.csproj", "{7333EDEF-4144-405C-A5EC-6F42201857D8}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTests.netstandard", "test\Microsoft.ML.CpuMath.UnitTests.netstandard\Microsoft.ML.CpuMath.UnitTests.netstandard.csproj", "{A0E562A9-0E6D-470D-B180-6EB44BA84D60}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTests.netcoreapp", "test\Microsoft.ML.CpuMath.UnitTests.netcoreapp\Microsoft.ML.CpuMath.UnitTests.netcoreapp.csproj", "{5F81A2A4-73AD-494C-B387-07D605EC8826}"
+EndProject
+
+Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", "src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -329,6 +342,54 @@ Global
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release|Any CPU.Build.0 = Release|Any CPU
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {7333EDEF-4144-405C-A5EC-6F42201857D8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {7333EDEF-4144-405C-A5EC-6F42201857D8}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {7333EDEF-4144-405C-A5EC-6F42201857D8}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {7333EDEF-4144-405C-A5EC-6F42201857D8}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {7333EDEF-4144-405C-A5EC-6F42201857D8}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {7333EDEF-4144-405C-A5EC-6F42201857D8}.Release|Any CPU.Build.0 = Release|Any CPU
+ {7333EDEF-4144-405C-A5EC-6F42201857D8}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {7333EDEF-4144-405C-A5EC-6F42201857D8}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {A0E562A9-0E6D-470D-B180-6EB44BA84D60}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {A0E562A9-0E6D-470D-B180-6EB44BA84D60}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {A0E562A9-0E6D-470D-B180-6EB44BA84D60}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {A0E562A9-0E6D-470D-B180-6EB44BA84D60}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {A0E562A9-0E6D-470D-B180-6EB44BA84D60}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {A0E562A9-0E6D-470D-B180-6EB44BA84D60}.Release|Any CPU.Build.0 = Release|Any CPU
+ {A0E562A9-0E6D-470D-B180-6EB44BA84D60}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {A0E562A9-0E6D-470D-B180-6EB44BA84D60}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {5F81A2A4-73AD-494C-B387-07D605EC8826}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {5F81A2A4-73AD-494C-B387-07D605EC8826}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {5F81A2A4-73AD-494C-B387-07D605EC8826}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {5F81A2A4-73AD-494C-B387-07D605EC8826}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {5F81A2A4-73AD-494C-B387-07D605EC8826}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {5F81A2A4-73AD-494C-B387-07D605EC8826}.Release|Any CPU.Build.0 = Release|Any CPU
+ {5F81A2A4-73AD-494C-B387-07D605EC8826}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {5F81A2A4-73AD-494C-B387-07D605EC8826}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release|Any CPU.Build.0 = Release|Any CPU
+ {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Release|Any CPU.Build.0 = Release|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.Build.0 = Release|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -365,8 +426,14 @@ Global
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
+ {7333EDEF-4144-405C-A5EC-6F42201857D8} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {A0E562A9-0E6D-470D-B180-6EB44BA84D60} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {5F81A2A4-73AD-494C-B387-07D605EC8826} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
{3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
diff --git a/README.md b/README.md
index 9ccd06c165..e9d256352b 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ Along with these ML capabilities this first release of ML.NET also brings the fi
ML.NET runs on Windows, Linux, and macOS - any platform where 64 bit [.NET Core](https://github.com/dotnet/core) or later is available.
-The current release is 0.3. Check out the [release notes](docs/release-notes/0.3/release-0.3.md).
+The current release is 0.4. Check out the [release notes](docs/release-notes/0.4/release-0.4.md).
First ensure you have installed [.NET Core 2.0](https://www.microsoft.com/net/learn/get-started) or later. ML.NET also works on the .NET Framework. Note that ML.NET currently must run in a 64 bit process.
diff --git a/build.proj b/build.proj
index 77b82dea2c..c9be14e930 100644
--- a/build.proj
+++ b/build.proj
@@ -34,24 +34,26 @@
BuildNative;
$(TraversalBuildDependsOn);
DownloadExternalTestFiles;
- RunTests;
+ Targets="Restore"
+ Properties="MSBuildWarningsAsMessages=NU1503" />
-
+
+ DependsOnTargets="CreateOrUpdateCurrentVersionFile;RestoreProjects">
@@ -77,7 +79,7 @@
TreatErrorsAsWarnings="true"/>
-
+
diff --git a/build/AfterCommonTargets.targets b/build/AfterCommonTargets.targets
new file mode 100644
index 0000000000..cba4c80b5c
--- /dev/null
+++ b/build/AfterCommonTargets.targets
@@ -0,0 +1,13 @@
+
+
+ $(MSBuildAllProjects);$(MSBuildThisFileFullPath)
+
+
+
+
+
\ No newline at end of file
diff --git a/build/BranchInfo.props b/build/BranchInfo.props
index 193aff7d35..b6d49773ec 100644
--- a/build/BranchInfo.props
+++ b/build/BranchInfo.props
@@ -1,7 +1,7 @@
0
- 4
+ 5
0
preview
diff --git a/build/Dependencies.props b/build/Dependencies.props
index 5325011f05..79ae31c598 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -8,5 +8,7 @@
4.3.0
1.0.0-beta-62824-02
2.1.2.2
+ 0.0.0.5
+ 4.5.0
diff --git a/build/Empty.targets b/build/Empty.targets
new file mode 100644
index 0000000000..72abf9cd60
--- /dev/null
+++ b/build/Empty.targets
@@ -0,0 +1,29 @@
+
+
+ $(MSBuildAllProjects);$(MSBuildThisFileFullPath)
+
+ ignore.targets
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/build/ci/phase-template.yml b/build/ci/phase-template.yml
index bd326afb69..037c207dbb 100644
--- a/build/ci/phase-template.yml
+++ b/build/ci/phase-template.yml
@@ -13,16 +13,18 @@ phases:
${{ if ne(parameters.dockerImage, '') }}:
_PREVIEW_VSTS_DOCKER_IMAGE: ${{ parameters.dockerImage }}
queue:
- parallel: 2
+ parallel: 99
matrix:
Build_Debug:
_configuration: Debug
Build_Release:
- _configuration: Release
+ _configuration: Release
${{ insert }}: ${{ parameters.queue }}
steps:
+ - script: $(_buildScript) -$(_configuration)
+ displayName: Build
- script: $(_buildScript) -$(_configuration) -runtests
- displayName: Build and Test
+ displayName: Run Tests
- task: PublishTestResults@2
displayName: Publish Test Results
condition: succeededOrFailed()
diff --git a/build/vsts-ci.yml b/build/vsts-ci.yml
index 5998c341db..1e3d601704 100644
--- a/build/vsts-ci.yml
+++ b/build/vsts-ci.yml
@@ -10,7 +10,7 @@ phases:
DOTNET_SKIP_FIRST_TIME_EXPERIENCE: 1
DOTNET_MULTILEVEL_LOOKUP: 0
queue:
- name: DotNetCore-Test
+ name: DotNet-Build
demands:
- agent.os -equals linux
steps:
diff --git a/config.json b/config.json
index 04ff73c8b7..8436586e61 100644
--- a/config.json
+++ b/config.json
@@ -67,8 +67,8 @@
"defaultValue": ""
},
"RunTests": {
- "description": "Run tests after building.",
- "valueType": "property",
+ "description": "MsBuild target that run the tests. Call this after building.",
+ "valueType": "target",
"values": [],
"defaultValue": ""
},
@@ -113,9 +113,9 @@
}
},
"runtests": {
- "description": "Runs the tests after building.",
+ "description": "Runs the tests. Call this after building.",
"settings": {
- "RunTests": "true"
+ "RunTests": "default"
}
},
"verbose": {
diff --git a/docs/building/windows-instructions.md b/docs/building/windows-instructions.md
index 4b58d1c09b..9ffc9c82ae 100644
--- a/docs/building/windows-instructions.md
+++ b/docs/building/windows-instructions.md
@@ -46,14 +46,14 @@ You can use the Developer Command Prompt, Powershell or work in any regular cmd.
From a (non-admin) Command Prompt window:
- `build.cmd` - builds the assemblies
-- `build.cmd -runTests` - builds the assemblies and runs tests
+- `build.cmd -runTests` - called after a normal "build.cmd" will run all tests
- `build.cmd -buildPackages` called after a normal “build.cmd” will create the NuGet packages with the assemblies in “bin"
**Note**: Before working on individual projects or test projects you **must** run `build.cmd` from the root once before beginning that work. It is also a good idea to run `build.cmd` whenever you pull a large set of unknown changes into your branch.
### Running tests from the command line
-From the root, use `build.cmd -runTests`.
+From the root, run `build.cmd` and then `build.cmd -runTests`.
For more details, or to test an individual project, you can navigate to the test project directory and then use `dotnet test`
### Running tests from Visual Studio
diff --git a/docs/project-docs/developer-guide.md b/docs/project-docs/developer-guide.md
index 60788f8e96..074ea03dab 100644
--- a/docs/project-docs/developer-guide.md
+++ b/docs/project-docs/developer-guide.md
@@ -32,7 +32,8 @@ build.cmd -Release -TargetArchitecture:x64
- Building the src and then building and running the tests
```
-build.cmd -RunTests
+build.cmd
+build.cmd -runTests
```
### Building individual projects
diff --git a/docs/release-notes/0.4/release-0.4.md b/docs/release-notes/0.4/release-0.4.md
new file mode 100644
index 0000000000..41c436f14e
--- /dev/null
+++ b/docs/release-notes/0.4/release-0.4.md
@@ -0,0 +1,88 @@
+# ML.NET 0.4 Release Notes
+
+Today we are releasing ML.NET 0.4. During this release we have started
+exploring new APIs for ML.NET that enable functionality that is missing from
+the current APIs. We welcome feedback and contributions to the
+conversation (relevant issues can be found [here](https://github.com/dotnet/machinelearning/projects/4)). While the
+focus has been on designing the new APIs, we have also moved several
+components from the internal codebase to ML.NET.
+
+### Installation
+
+ML.NET supports Windows, MacOS, and Linux. See [supported OS versions of .NET
+Core
+2.0](https://github.com/dotnet/core/blob/master/release-notes/2.0/2.0-supported-os.md)
+for more details.
+
+You can install ML.NET NuGet from the CLI using:
+```
+dotnet add package Microsoft.ML
+```
+
+From package manager:
+```
+Install-Package Microsoft.ML
+```
+
+### Release Notes
+
+Below are some of the highlights from this release.
+
+* Added SymSGD learner for binary classification
+ ([#624](https://github.com/dotnet/machinelearning/pull/624))
+
+ * [SymSGD](https://arxiv.org/abs/1705.08030) is a technique for
+ parallelizing
+ [SGD](https://en.wikipedia.org/wiki/Stochastic_gradient_descent)
+ (Stochastic Gradient Descent). This enables it to sometimes perform
+ faster than existing SGD implementations (e.g. [Hogwild
+ SGD](https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.trainers.stochasticgradientdescentbinaryclassifier?view=ml-dotnet)).
+ * SymSGD is available for binary classification, but can be used in
+ multiclass classification with
+ [One-Versus-All](https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.models.oneversusall?view=ml-dotnet)
+ * SymSGD requires adding the Microsoft.ML.HalLearners NuGet package to your project
+ * The current implementation in ML.NET does not yet have multi-threading
+ enabled due to build system limitations (tracked by
+ [#655](https://github.com/dotnet/machinelearning/issues/655)), but
+ SymSGD can still be helpful in scenarios where you want to try many
+ different learners and limit each of them to a single thread.
+ * Documentation can be found
+ [here](https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.trainers.symsgdbinaryclassifier?view=ml-dotnet)
+
+* Added Word Embeddings Transform for text scenarios
+ ([#545](https://github.com/dotnet/machinelearning/pull/545))
+
+ * [Word embeddings](https://en.wikipedia.org/wiki/Word_embedding) is a
+ technique for mapping words or phrases to numeric vectors of relatively low
+ dimension (in comparison with the high dimensional n-gram extraction).
+ These numeric vectors are intended to capture some of the meaning of the
+ words so they can be used for training a better model. As an example,
+ SSWE (Sentiment-Specific Word Embedding) can be useful for sentiment
+ related tasks.
+ * This transform enables using pretrained models to get the embeddings
+ (i.e. the embeddings are already trained and available for use).
+ * Several options for pretrained embeddings are available:
+ [GloVe](https://nlp.stanford.edu/projects/glove/),
+ [fastText](https://en.wikipedia.org/wiki/FastText), and
+ [SSWE](http://anthology.aclweb.org/P/P14/P14-1146.pdf). The pretrained model is downloaded automatically on first use.
+ * Documentation can be found
+ [here](https://docs.microsoft.com/en-us/dotnet/api/microsoft.ml.transforms.wordembeddings?view=ml-dotnet).
+
+* Improved support for F# by allowing use of property-based row classes ([#616](https://github.com/dotnet/machinelearning/pull/616))
+
+ * ML.NET now supports F# record types.
+ * The ML.NET samples repository is being updated to include F# samples as part of [#36](https://github.com/dotnet/machinelearning-samples/pull/36).
+
+Additional issues closed in this milestone can be found
+[here](https://github.com/dotnet/machinelearning/milestone/3?closed=1).
+
+### Acknowledgements
+
+Shoutout to [dsyme](https://github.com/dsyme),
+[SolyarA](https://github.com/SolyarA),
+[dan-drews](https://github.com/dan-drews),
+[bojanmisic](https://github.com/bojanmisic),
+[jwood803](https://github.com/jwood803),
+[sharwell](https://github.com/sharwell),
+[JoshuaLight](https://github.com/JoshuaLight), and the ML.NET team for their
+contributions as part of this release!
\ No newline at end of file
diff --git a/netci.groovy b/netci.groovy
index b955bf669f..7c1126aff9 100644
--- a/netci.groovy
+++ b/netci.groovy
@@ -16,6 +16,7 @@ def branch = GithubBranchName
def newJob = job(Utilities.getFullJobName(project, jobName, isPR)) {
steps {
+ shell("./build.sh -$config")
shell("./build.sh -$config -runtests")
shell("./build.sh -buildPackages")
}
diff --git a/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.nupkgproj b/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.nupkgproj
index 918729d99d..18cec880f9 100644
--- a/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.nupkgproj
+++ b/pkg/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.nupkgproj
@@ -11,4 +11,8 @@
+
+
+
+
diff --git a/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.nupkgproj b/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.nupkgproj
new file mode 100644
index 0000000000..132b995a2f
--- /dev/null
+++ b/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.nupkgproj
@@ -0,0 +1,17 @@
+
+
+
+ netstandard2.0
+ ML.NET additional learners making use of hardware acceleration. They use Intel Mkl.
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.symbols.nupkgproj b/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.symbols.nupkgproj
new file mode 100644
index 0000000000..248ae82414
--- /dev/null
+++ b/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.nupkgproj b/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.nupkgproj
new file mode 100644
index 0000000000..8bdef45d07
--- /dev/null
+++ b/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.nupkgproj
@@ -0,0 +1,13 @@
+
+
+
+ netstandard2.0
+ ML.NET component for Image support
+
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.symbols.nupkgproj b/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.symbols.nupkgproj
new file mode 100644
index 0000000000..b36800ea0b
--- /dev/null
+++ b/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
index fc409ae21f..7757e264b6 100644
--- a/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
+++ b/pkg/Microsoft.ML/Microsoft.ML.nupkgproj
@@ -15,6 +15,7 @@
+
diff --git a/pkg/Microsoft.ML/build/netstandard2.0/Microsoft.ML.props b/pkg/common/CommonPackage.props
similarity index 100%
rename from pkg/Microsoft.ML/build/netstandard2.0/Microsoft.ML.props
rename to pkg/common/CommonPackage.props
diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs
index 8b8cb5871b..96e821f16e 100644
--- a/src/Microsoft.ML.Api/ApiUtils.cs
+++ b/src/Microsoft.ML.Api/ApiUtils.cs
@@ -51,14 +51,31 @@ private static OpCode GetAssignmentOpCode(Type t)
///
internal static Delegate GeneratePeek(InternalSchemaDefinition.Column column)
{
- var fieldInfo = column.FieldInfo;
- Type fieldType = fieldInfo.FieldType;
-
- var assignmentOpCode = GetAssignmentOpCode(fieldType);
- Func func = GeneratePeek;
- var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
- .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
- return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
+ switch (column.MemberInfo)
+ {
+ case FieldInfo fieldInfo:
+ Type fieldType = fieldInfo.FieldType;
+
+ var assignmentOpCode = GetAssignmentOpCode(fieldType);
+ Func func = GeneratePeek;
+ var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
+ .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
+ return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
+
+ case PropertyInfo propertyInfo:
+ Type propertyType = propertyInfo.PropertyType;
+
+ var assignmentOpCodeProp = GetAssignmentOpCode(propertyType);
+ Func funcProp = GeneratePeek;
+ var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition()
+ .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType);
+ return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo, assignmentOpCodeProp });
+
+ default:
+ Contracts.Assert(false);
+ throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
+
+ }
}
private static Delegate GeneratePeek(FieldInfo fieldInfo, OpCode assignmentOpCode)
@@ -81,6 +98,28 @@ private static Delegate GeneratePeek(FieldInfo fieldInfo, Op
return mb.CreateDelegate(typeof(Peek));
}
+ private static Delegate GeneratePeek(PropertyInfo propertyInfo, OpCode assignmentOpCode)
+ {
+ // REVIEW: It seems like we really should cache these, instead of generating them per cursor.
+ Type[] args = { typeof(TOwn), typeof(TRow), typeof(long), typeof(TValue).MakeByRefType() };
+ var mb = new DynamicMethod("Peek", null, args, typeof(TOwn), true);
+ var il = mb.GetILGenerator();
+ var minfo = propertyInfo.GetGetMethod();
+ var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call;
+
+ il.Emit(OpCodes.Ldarg_3); // push arg3
+ il.Emit(OpCodes.Ldarg_1); // push arg1
+ il.Emit(opcode, minfo); // call [stack top].get_[propertyInfo]()
+ // Stobj needs to coupled with a type.
+ if (assignmentOpCode == OpCodes.Stobj) // [stack top-1] = [stack top]
+ il.Emit(assignmentOpCode, propertyInfo.PropertyType);
+ else
+ il.Emit(assignmentOpCode);
+ il.Emit(OpCodes.Ret); // ret
+
+ return mb.CreateDelegate(typeof(Peek));
+ }
+
///
/// Each of the specialized 'poke' methods sets the appropriate field value of an instance of T
/// to the provided value. So, the call is 'peek(userObject, providedValue)' and the logic is
@@ -88,14 +127,30 @@ private static Delegate GeneratePeek(FieldInfo fieldInfo, Op
///
internal static Delegate GeneratePoke(InternalSchemaDefinition.Column column)
{
- var fieldInfo = column.FieldInfo;
- Type fieldType = fieldInfo.FieldType;
-
- var assignmentOpCode = GetAssignmentOpCode(fieldType);
- Func func = GeneratePoke;
- var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
- .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
- return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
+ switch (column.MemberInfo)
+ {
+ case FieldInfo fieldInfo:
+ Type fieldType = fieldInfo.FieldType;
+
+ var assignmentOpCode = GetAssignmentOpCode(fieldType);
+ Func func = GeneratePoke;
+ var methInfo = func.GetMethodInfo().GetGenericMethodDefinition()
+ .MakeGenericMethod(typeof(TOwn), typeof(TRow), fieldType);
+ return (Delegate)methInfo.Invoke(null, new object[] { fieldInfo, assignmentOpCode });
+
+ case PropertyInfo propertyInfo:
+ Type propertyType = propertyInfo.PropertyType;
+
+ var assignmentOpCodeProp = GetAssignmentOpCode(propertyType);
+ Func funcProp = GeneratePoke;
+ var methInfoProp = funcProp.GetMethodInfo().GetGenericMethodDefinition()
+ .MakeGenericMethod(typeof(TOwn), typeof(TRow), propertyType);
+ return (Delegate)methInfoProp.Invoke(null, new object[] { propertyInfo });
+
+ default:
+ Contracts.Assert(false);
+ throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
+ }
}
private static Delegate GeneratePoke(FieldInfo fieldInfo, OpCode assignmentOpCode)
@@ -115,5 +170,20 @@ private static Delegate GeneratePoke(FieldInfo fieldInfo, Op
il.Emit(OpCodes.Ret); // ret
return mb.CreateDelegate(typeof(Poke), null);
}
+
+ private static Delegate GeneratePoke(PropertyInfo propertyInfo)
+ {
+ Type[] args = { typeof(TOwn), typeof(TRow), typeof(TValue) };
+ var mb = new DynamicMethod("Poke", null, args, typeof(TOwn), true);
+ var il = mb.GetILGenerator();
+ var minfo = propertyInfo.GetSetMethod();
+ var opcode = (minfo.IsVirtual || minfo.IsAbstract) ? OpCodes.Callvirt : OpCodes.Call;
+
+ il.Emit(OpCodes.Ldarg_1); // push arg1
+ il.Emit(OpCodes.Ldarg_2); // push arg2
+ il.Emit(opcode, minfo); // call [stack top-1].set_[propertyInfo]([stack top])
+ il.Emit(OpCodes.Ret); // ret
+ return mb.CreateDelegate(typeof(Poke), null);
+ }
}
}
diff --git a/src/Microsoft.ML.Api/ComponentCreation.cs b/src/Microsoft.ML.Api/ComponentCreation.cs
index 3080a8197c..0a1e1cd605 100644
--- a/src/Microsoft.ML.Api/ComponentCreation.cs
+++ b/src/Microsoft.ML.Api/ComponentCreation.cs
@@ -6,6 +6,7 @@
using System.IO;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
namespace Microsoft.ML.Runtime.Api
@@ -304,12 +305,20 @@ public static IDataScorerTransform CreateScorer(this IHostEnvironment env, strin
env.CheckValue(predictor, nameof(predictor));
env.CheckValueOrNull(trainSchema);
- var subComponent = SubComponent.Parse(settings);
- var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, subComponent);
+ ICommandLineComponentFactory scorerFactorySettings = ParseScorerSettings(settings);
+ var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings);
var mapper = bindable.Bind(env, data.Schema);
return CreateCore(env, settings, data.Data, mapper, trainSchema);
}
+ private static ICommandLineComponentFactory ParseScorerSettings(string settings)
+ {
+ return CmdParser.CreateComponentFactory(
+ typeof(IComponentFactory),
+ typeof(SignatureDataScorer),
+ settings);
+ }
+
///
/// Creates a default data scorer appropriate to the predictor's prediction kind.
///
diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
index e940ea9d4d..c50e48e16f 100644
--- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
+++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
@@ -118,7 +118,7 @@ private Delegate CreateGetter(int index)
var colType = DataView.Schema.GetColumnType(index);
var column = DataView._schema.SchemaDefn.Columns[index];
- var outputType = column.IsComputed ? column.ReturnType : column.FieldInfo.FieldType;
+ var outputType = column.OutputType;
var genericType = outputType;
Func del;
diff --git a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs
index 3edf7599a4..4c20f25d62 100644
--- a/src/Microsoft.ML.Api/InternalSchemaDefinition.cs
+++ b/src/Microsoft.ML.Api/InternalSchemaDefinition.cs
@@ -23,21 +23,23 @@ internal sealed class InternalSchemaDefinition
public class Column
{
public readonly string ColumnName;
- public readonly FieldInfo FieldInfo;
+ public readonly MemberInfo MemberInfo;
public readonly ParameterInfo ReturnParameterInfo;
public readonly ColumnType ColumnType;
public readonly bool IsComputed;
public readonly Delegate Generator;
private readonly Dictionary _metadata;
public Dictionary Metadata { get { return _metadata; } }
- public Type ReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }}
+ public Type ComputedReturnType {get { return ReturnParameterInfo.ParameterType.GetElementType(); }}
+ public Type FieldOrPropertyType => (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType;
+ public Type OutputType => IsComputed ? ComputedReturnType : FieldOrPropertyType;
- public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo) :
- this(columnName, columnType, fieldInfo, null, null) { }
+ public Column(string columnName, ColumnType columnType, MemberInfo memberInfo) :
+ this(columnName, columnType, memberInfo, null, null) { }
- public Column(string columnName, ColumnType columnType, FieldInfo fieldInfo,
+ public Column(string columnName, ColumnType columnType, MemberInfo memberInfo,
Dictionary metadataInfos) :
- this(columnName, columnType, fieldInfo, null, metadataInfos) { }
+ this(columnName, columnType, memberInfo, null, metadataInfos) { }
public Column(string columnName, ColumnType columnType, Delegate generator) :
this(columnName, columnType, null, generator, null) { }
@@ -46,7 +48,7 @@ public Column(string columnName, ColumnType columnType, Delegate generator,
Dictionary metadataInfos) :
this(columnName, columnType, null, generator, metadataInfos) { }
- private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = null,
+ private Column(string columnName, ColumnType columnType, MemberInfo memberInfo = null,
Delegate generator = null, Dictionary metadataInfos = null)
{
Contracts.AssertNonEmpty(columnName);
@@ -55,8 +57,8 @@ private Column(string columnName, ColumnType columnType, FieldInfo fieldInfo = n
if (generator == null)
{
- Contracts.AssertValue(fieldInfo);
- FieldInfo = fieldInfo;
+ Contracts.AssertValue(memberInfo);
+ MemberInfo = memberInfo;
}
else
{
@@ -95,8 +97,8 @@ public void AssertRep()
// If Column is computed type, it must have a generator.
Contracts.Assert(IsComputed == (Generator != null));
- // Column must have either a generator or a fieldInfo value.
- Contracts.Assert((Generator == null) != (FieldInfo == null));
+ // Column must have either a generator or a memberInfo value.
+ Contracts.Assert((Generator == null) != (MemberInfo == null));
// Additional Checks if there is a generator.
if (Generator == null)
@@ -115,9 +117,7 @@ public void AssertRep()
Contracts.Assert(Generator.GetMethodInfo().ReturnType == typeof(void));
// Checks that the return type of the generator is compatible with ColumnType.
- bool isVector;
- DataKind datakind;
- GetVectorAndKind(ReturnType, "return type", out isVector, out datakind);
+ GetVectorAndKind(ComputedReturnType, "return type", out bool isVector, out DataKind datakind);
Contracts.Assert(isVector == ColumnType.IsVector);
Contracts.Assert(datakind == ColumnType.ItemType.RawKind);
}
@@ -131,19 +131,30 @@ private InternalSchemaDefinition(Column[] columns)
}
///
- /// Given a field info on a type, returns whether this appears to be a vector type,
+ /// Given a field or property info on a type, returns whether this appears to be a vector type,
/// and also the associated data kind for this type. If a data kind could not
/// be determined, this will throw.
///
- /// The field info to inspect.
+ /// The field or property info to inspect.
/// Whether this appears to be a vector type.
/// The data kind of the type, or items of this type if vector.
- public static void GetVectorAndKind(FieldInfo fieldInfo, out bool isVector, out DataKind kind)
+ public static void GetVectorAndKind(MemberInfo memberInfo, out bool isVector, out DataKind kind)
{
- Contracts.AssertValue(fieldInfo);
- Type rawFieldType = fieldInfo.FieldType;
- var name = fieldInfo.Name;
- GetVectorAndKind(rawFieldType, name, out isVector, out kind);
+ Contracts.AssertValue(memberInfo);
+ switch (memberInfo)
+ {
+ case FieldInfo fieldInfo:
+ GetVectorAndKind(fieldInfo.FieldType, fieldInfo.Name, out isVector, out kind);
+ break;
+
+ case PropertyInfo propertyInfo:
+ GetVectorAndKind(propertyInfo.PropertyType, propertyInfo.Name, out isVector, out kind);
+ break;
+
+ default:
+ Contracts.Assert(false);
+ throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
+ }
}
///
@@ -211,23 +222,27 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us
bool isVector;
DataKind kind;
- FieldInfo fieldInfo = null;
+ MemberInfo memberInfo = null;
if (!col.IsComputed)
{
- fieldInfo = userType.GetField(col.MemberName);
+ memberInfo = userType.GetField(col.MemberName);
+
+ if (memberInfo == null)
+ memberInfo = userType.GetProperty(col.MemberName);
- if (fieldInfo == null)
- throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field with name '{0}' found in type '{1}'",
+ if (memberInfo == null)
+ throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field or property with name '{0}' found in type '{1}'",
col.MemberName,
userType.FullName);
//Clause to handle the field that may be used to expose the cursor channel.
//This field does not need a column.
- if (fieldInfo.FieldType == typeof(IChannel))
+ if ( (memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) ||
+ (memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel)))
continue;
- GetVectorAndKind(fieldInfo, out isVector, out kind);
+ GetVectorAndKind(memberInfo, out isVector, out kind);
}
else
{
@@ -268,7 +283,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us
dstCols[i] = col.IsComputed ?
new Column(colName, colType, col.Generator, col.Metadata)
- : new Column(colName, colType, fieldInfo, col.Metadata);
+ : new Column(colName, colType, memberInfo, col.Metadata);
}
return new InternalSchemaDefinition(dstCols);
diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs
index e08845a87e..3258df4ffd 100644
--- a/src/Microsoft.ML.Api/SchemaDefinition.cs
+++ b/src/Microsoft.ML.Api/SchemaDefinition.cs
@@ -14,7 +14,7 @@ namespace Microsoft.ML.Runtime.Api
///
/// Attach to a member of a class to indicate that the item type should be of class key.
///
- [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class KeyTypeAttribute : Attribute
{
// REVIEW: Property based, but should I just have a constructor?
@@ -46,7 +46,7 @@ public KeyTypeAttribute()
/// Allows a member to be marked as a vector valued field, primarily allowing one to set
/// the dimensionality of the resulting array.
///
- [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class VectorTypeAttribute : Attribute
{
private readonly int[] _dims;
@@ -66,7 +66,7 @@ public VectorTypeAttribute(params int[] dims)
/// Describes column information such as name and the source columns indicies that this
/// column encapsulates.
///
- [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class ColumnAttribute : Attribute
{
public ColumnAttribute(string ordinal, string name = null)
@@ -97,7 +97,7 @@ public ColumnAttribute(string ordinal, string name = null)
/// Allows a member to specify its column name directly, as opposed to the default
/// behavior of using the member name as the column name.
///
- [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class ColumnNameAttribute : Attribute
{
private readonly string _name;
@@ -119,7 +119,7 @@ public ColumnNameAttribute(string name)
///
/// Mark this member as not being exposed as a column in the schema.
///
- [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class NoColumnAttribute : Attribute
{
}
@@ -128,7 +128,7 @@ public sealed class NoColumnAttribute : Attribute
/// Mark a member that implements exactly IChannel as being permitted to receive
/// channel information from an external channel.
///
- [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
+ [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = true)]
public sealed class CursorChannelAttribute : Attribute
{
///
@@ -158,19 +158,40 @@ public static bool TrySetCursorChannel(IExceptionContext ectx, T obj, IChanne
.Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any())
.ToArray();
+ var cursorChannelAttrProperties = typeof(T)
+ .GetProperties(BindingFlags.Public | BindingFlags.Instance)
+ .Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0)
+ .Where(x => x.GetCustomAttributes(typeof(CursorChannelAttribute), false).Any());
+
+ var cursorChannelAttrMembers = (cursorChannelAttrFields as IEnumerable).Concat(cursorChannelAttrProperties).ToArray();
+
//Check that there is at most one such field.
- if (cursorChannelAttrFields.Length == 0)
+ if (cursorChannelAttrMembers.Length == 0)
return false;
- ectx.Check(cursorChannelAttrFields.Length == 1,
- "Only one field with CursorChannel attribute is allowed.");
+ ectx.Check(cursorChannelAttrMembers.Length == 1,
+ "Only one public field or property with CursorChannel attribute is allowed.");
//Check that the marked field has type IChannel.
- var cursorChannelFieldInfo = cursorChannelAttrFields[0];
- ectx.Check(cursorChannelFieldInfo.FieldType == typeof(IChannel),
- "Field marked as CursorChannel must have type IChannel.");
-
- cursorChannelFieldInfo.SetValue(obj, channel);
+ var cursorChannelAttrMemberInfo = cursorChannelAttrMembers[0];
+ switch (cursorChannelAttrMemberInfo)
+ {
+ case FieldInfo cursorChannelAttrFieldInfo:
+ ectx.Check(cursorChannelAttrFieldInfo.FieldType == typeof(IChannel),
+ "Field marked as CursorChannel must have type IChannel.");
+ cursorChannelAttrFieldInfo.SetValue(obj, channel);
+ break;
+
+ case PropertyInfo cursorChannelAttrPropertyInfo:
+ ectx.Check(cursorChannelAttrPropertyInfo.PropertyType == typeof(IChannel),
+ "Property marked as CursorChannel must have type IChannel.");
+ cursorChannelAttrPropertyInfo.SetValue(obj, channel);
+ break;
+
+ default:
+ Contracts.Assert(false);
+ throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
+ }
return true;
}
}
@@ -319,37 +340,63 @@ public static SchemaDefinition Create(Type userType)
SchemaDefinition cols = new SchemaDefinition();
HashSet colNames = new HashSet();
- foreach (var fieldInfo in userType.GetFields())
+
+ var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
+ var propertyInfos =
+ userType
+ .GetProperties(BindingFlags.Public | BindingFlags.Instance)
+ .Where(x => x.CanRead && x.CanWrite && x.GetGetMethod() != null && x.GetSetMethod() != null && x.GetIndexParameters().Length == 0);
+
+ var memberInfos = (fieldInfos as IEnumerable).Concat(propertyInfos).ToArray();
+
+ foreach (var memberInfo in memberInfos)
{
// Clause to handle the field that may be used to expose the cursor channel.
// This field does not need a column.
// REVIEW: maybe validate the channel attribute now, instead
// of later at cursor creation.
- if (fieldInfo.FieldType == typeof(IChannel))
- continue;
- // Const fields do not need to be mapped.
- if (fieldInfo.IsLiteral)
- continue;
+ switch (memberInfo)
+ {
+ case FieldInfo fieldInfo:
+ if (fieldInfo.FieldType == typeof(IChannel))
+ continue;
+
+ // Const fields do not need to be mapped.
+ if (fieldInfo.IsLiteral)
+ continue;
+
+ break;
- if (fieldInfo.GetCustomAttribute() != null)
+ case PropertyInfo propertyInfo:
+ if (propertyInfo.PropertyType == typeof(IChannel))
+ continue;
+ break;
+
+ default:
+ Contracts.Assert(false);
+ throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
+ }
+
+ if (memberInfo.GetCustomAttribute() != null)
continue;
- var mappingAttr = fieldInfo.GetCustomAttribute();
- var mappingNameAttr = fieldInfo.GetCustomAttribute();
- string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? fieldInfo.Name;
+
+ var mappingAttr = memberInfo.GetCustomAttribute();
+ var mappingNameAttr = memberInfo.GetCustomAttribute();
+ string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? memberInfo.Name;
// Disallow duplicate names, because the field enumeration order is not actually
// well defined, so we are not gauranteed to have consistent "hiding" from run to
// run, across different .NET versions.
if (!colNames.Add(name))
throw Contracts.ExceptParam(nameof(userType), "Duplicate column name '{0}' detected, this is disallowed", name);
- InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out bool isVector, out DataKind kind);
+ InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind);
PrimitiveType itemType;
- var keyAttr = fieldInfo.GetCustomAttribute();
+ var keyAttr = memberInfo.GetCustomAttribute();
if (keyAttr != null)
{
if (!KeyType.IsValidDataKind(kind))
- throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", fieldInfo.Name);
+ throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name);
itemType = new KeyType(kind, keyAttr.Min, keyAttr.Count, keyAttr.Contiguous);
}
else
@@ -357,9 +404,9 @@ public static SchemaDefinition Create(Type userType)
// Get the column type.
ColumnType columnType;
- var vectorAttr = fieldInfo.GetCustomAttribute();
+ var vectorAttr = memberInfo.GetCustomAttribute();
if (vectorAttr != null && !isVector)
- throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", fieldInfo.Name);
+ throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with VectorType attribute, but does not appear to be a vector type", memberInfo.Name);
if (isVector)
{
int[] dims = vectorAttr?.Dims;
@@ -373,7 +420,7 @@ public static SchemaDefinition Create(Type userType)
else
columnType = itemType;
- cols.Add(new Column() { MemberName = fieldInfo.Name, ColumnName = name, ColumnType = columnType });
+ cols.Add(new Column() { MemberName = memberInfo.Name, ColumnName = name, ColumnType = columnType });
}
return cols;
}
diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs
index cd8198e14d..19f9a7cf72 100644
--- a/src/Microsoft.ML.Api/TypedCursor.cs
+++ b/src/Microsoft.ML.Api/TypedCursor.cs
@@ -103,11 +103,11 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing
throw _host.Except("Column '{0}' not found in the data view", col.ColumnName);
}
var realColType = _data.Schema.GetColumnType(colIndex);
- if (!IsCompatibleType(realColType, col.FieldInfo))
+ if (!IsCompatibleType(realColType, col.MemberInfo))
{
throw _host.Except(
- "Can't bind the IDataView column '{0}' of type '{1}' to field '{2}' of type '{3}'.",
- col.ColumnName, realColType, col.FieldInfo.Name, col.FieldInfo.FieldType.FullName);
+ "Can't bind the IDataView column '{0}' of type '{1}' to field or property '{2}' of type '{3}'.",
+ col.ColumnName, realColType, col.MemberInfo.Name, col.FieldOrPropertyType.FullName);
}
acceptedCols.Add(col);
@@ -130,14 +130,12 @@ private TypedCursorable(IHostEnvironment env, IDataView data, bool ignoreMissing
}
///
- /// Returns whether the column type can be bound to field .
+ /// Returns whether the column type can be bound to field .
/// They must both be vectors or scalars, and the raw data kind should match.
///
- private static bool IsCompatibleType(ColumnType colType, FieldInfo fieldInfo)
+ private static bool IsCompatibleType(ColumnType colType, MemberInfo memberInfo)
{
- bool isVector;
- DataKind kind;
- InternalSchemaDefinition.GetVectorAndKind(fieldInfo, out isVector, out kind);
+ InternalSchemaDefinition.GetVectorAndKind(memberInfo, out bool isVector, out DataKind kind);
if (isVector)
return colType.IsVector && colType.ItemType.RawKind == kind;
else
@@ -269,8 +267,7 @@ public ValueGetter GetIdGetter()
private Action GenerateSetter(IRow input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek)
{
var colType = input.Schema.GetColumnType(index);
- var fieldInfo = column.FieldInfo;
- var fieldType = fieldInfo.FieldType;
+ var fieldType = column.OutputType;
var genericType = fieldType;
Func> del;
if (fieldType.IsArray)
@@ -431,7 +428,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit
else
{
// REVIEW: Is this even possible?
- throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", fieldInfo.FieldType.FullName);
+ throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", column.OutputType.FullName);
}
MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(genericType);
return (Action)meth.Invoke(this, new object[] { input, index, poke, peek });
diff --git a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
index ae327a26c2..f9a1b5ef27 100644
--- a/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
+++ b/src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
@@ -15,6 +15,7 @@
+
diff --git a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
index 405e207773..64fe3b5b80 100644
--- a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
+++ b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
@@ -34,6 +34,7 @@ public enum VisibilityType
private string _specialPurpose;
private VisibilityType _visibility;
private string _name;
+ private Type _signatureType;
///
/// Allows control of command line parsing.
@@ -139,5 +140,11 @@ public bool IsRequired
{
get { return ArgumentType.Required == (_type & ArgumentType.Required); }
}
+
+ public Type SignatureType
+ {
+ get { return _signatureType; }
+ set { _signatureType = value; }
+ }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
index eb85fcce12..bd37a96f7b 100644
--- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
+++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
@@ -249,6 +249,18 @@ public enum SettingsFlags
Default = ShortNames | NoSlashes
}
+ ///
+ /// An IComponentFactory that is used in the command line.
+ ///
+ /// This allows components to be created by name, signature type, and a settings string.
+ ///
+ public interface ICommandLineComponentFactory : IComponentFactory
+ {
+ Type SignatureType { get; }
+ string Name { get; }
+ string GetSettingsString();
+ }
+
///
/// Parser for command line arguments.
///
@@ -797,7 +809,8 @@ private bool ParseArgumentList(ArgumentInfo info, string[] strs, object destinat
ModuleCatalog.ComponentInfo component;
if (IsCurlyGroup(value) && value.Length == 2)
arg.Field.SetValue(destination, null);
- else if (_catalog.Value.TryFindComponentCaseInsensitive(arg.Field.FieldType, value, out component))
+ else if (!arg.IsCollection &&
+ _catalog.Value.TryFindComponentCaseInsensitive(arg.Field.FieldType, value, out component))
{
var activator = Activator.CreateInstance(component.ArgumentType);
if (!IsCurlyGroup(value) && i + 1 < strs.Length && IsCurlyGroup(strs[i + 1]))
@@ -810,8 +823,9 @@ private bool ParseArgumentList(ArgumentInfo info, string[] strs, object destinat
}
else
{
- Report("Error: Failed to find component with name '{0}' for option '{1}'", value, arg.LongName);
- hadError |= true;
+ hadError |= !arg.SetValue(this, ref values[arg.Index], value, tag, destination);
+ if (!IsCurlyGroup(value) && i + 1 < strs.Length && IsCurlyGroup(strs[i + 1]))
+ hadError |= !arg.SetValue(this, ref values[arg.Index], strs[++i], "", destination);
}
continue;
}
@@ -1283,6 +1297,41 @@ private static bool IsValidItemType(Type type)
typeBase.IsEnum;
}
+ ///
+ /// Creates an ICommandLineComponentFactory given the factory type, signature type,
+ /// and a command line string.
+ ///
+ public static ICommandLineComponentFactory CreateComponentFactory(
+ Type factoryType,
+ Type signatureType,
+ string settings)
+ {
+ ParseComponentStrings(settings, out string name, out string args);
+
+ string[] argsArray = string.IsNullOrEmpty(args) ? Array.Empty() : new string[] { args };
+
+ return ComponentFactoryFactory.CreateComponentFactory(factoryType, signatureType, name, argsArray);
+ }
+
+ private static void ParseComponentStrings(string str, out string kind, out string args)
+ {
+ kind = args = null;
+ if (string.IsNullOrWhiteSpace(str))
+ return;
+ str = str.Trim();
+ int ich = str.IndexOf('{');
+ if (ich < 0)
+ {
+ kind = str;
+ return;
+ }
+ if (ich == 0 || str[str.Length - 1] != '}')
+ throw Contracts.Except("Invalid Component string: mismatched braces, or empty component name.");
+
+ kind = str.Substring(0, ich);
+ args = CmdLexer.UnquoteValue(str.Substring(ich));
+ }
+
private sealed class ArgValue
{
public readonly string FirstValue;
@@ -1532,6 +1581,8 @@ private sealed class Argument
// Used for help and composing settings strings.
public readonly object DefaultValue;
+ private readonly Type _signatureType;
+
// For custom types.
private readonly ArgumentInfo _infoCustom;
private readonly ConstructorInfo _ctorCustom;
@@ -1559,6 +1610,7 @@ public Argument(int index, string name, string[] nicks, object defaults, Argumen
IsDefault = attr is DefaultArgumentAttribute;
Contracts.Assert(!IsDefault || Utils.Size(ShortNames) == 0);
IsHidden = attr.Hide;
+ _signatureType = attr.SignatureType;
if (field.FieldType.IsArray)
{
@@ -1664,6 +1716,48 @@ public bool Finish(CmdParser owner, ArgValue val, object destination)
Field.SetValue(destination, com);
}
+ else if (IsSingleComponentFactory)
+ {
+ bool haveName = false;
+ string name = null;
+ string[] settings = null;
+ for (int i = 0; i < Utils.Size(values);)
+ {
+ string str = (string)values[i].Value;
+ if (str.StartsWith("{"))
+ {
+ i++;
+ continue;
+ }
+ if (haveName)
+ {
+ owner.Report("Duplicate component kind for argument {0}", LongName);
+ error = true;
+ }
+ name = str;
+ haveName = true;
+ values.RemoveAt(i);
+ }
+
+ if (Utils.Size(values) > 0)
+ settings = values.Select(x => (string)x.Value).ToArray();
+
+ Contracts.Check(_signatureType != null, "ComponentFactory Arguments need a SignatureType set.");
+ if (ComponentFactoryFactory.TryCreateComponentFactory(
+ ItemType,
+ _signatureType,
+ name,
+ settings,
+ out ICommandLineComponentFactory factory))
+ {
+ Field.SetValue(destination, factory);
+ }
+ else
+ {
+ owner.Report("There was an error creating the ComponentFactory. Ensure '{0}' is configured correctly.", LongName);
+ error = true;
+ }
+ }
else if (IsMultiSubComponent)
{
// REVIEW: the kind should not be separated from settings: everything related
@@ -1711,6 +1805,79 @@ public bool Finish(CmdParser owner, ArgValue val, object destination)
Field.SetValue(destination, arr);
}
}
+ else if (IsMultiComponentFactory)
+ {
+ // REVIEW: the kind should not be separated from settings: everything related
+ // to one item should go into one value, not multiple values
+ if (IsTaggedCollection)
+ {
+ // Tagged collection of IComponentFactory
+ var comList = new List>();
+
+ for (int i = 0; i < Utils.Size(values);)
+ {
+ string tag = values[i].Key;
+ string name = (string)values[i++].Value;
+ string[] settings = null;
+ if (i < values.Count && IsCurlyGroup((string)values[i].Value) && string.IsNullOrEmpty(values[i].Key))
+ settings = new string[] { (string)values[i++].Value };
+ if (ComponentFactoryFactory.TryCreateComponentFactory(
+ ItemValueType,
+ _signatureType,
+ name,
+ settings,
+ out ICommandLineComponentFactory factory))
+ {
+ comList.Add(new KeyValuePair(tag, factory));
+ }
+ else
+ {
+ owner.Report("There was an error creating the ComponentFactory. Ensure '{0}' is configured correctly.", LongName);
+ error = true;
+ }
+ }
+
+ var arr = Array.CreateInstance(ItemType, comList.Count);
+ for (int i = 0; i < arr.Length; i++)
+ {
+ var kvp = Activator.CreateInstance(ItemType, comList[i].Key, comList[i].Value);
+ arr.SetValue(kvp, i);
+ }
+
+ Field.SetValue(destination, arr);
+ }
+ else
+ {
+ // Collection of IComponentFactory
+ var comList = new List();
+ for (int i = 0; i < Utils.Size(values);)
+ {
+ string name = (string)values[i++].Value;
+ string[] settings = null;
+ if (i < values.Count && IsCurlyGroup((string)values[i].Value))
+ settings = new string[] { (string)values[i++].Value };
+ if (ComponentFactoryFactory.TryCreateComponentFactory(
+ ItemValueType,
+ _signatureType,
+ name,
+ settings,
+ out ICommandLineComponentFactory factory))
+ {
+ comList.Add(factory);
+ }
+ else
+ {
+ owner.Report("There was an error creating the ComponentFactory. Ensure '{0}' is configured correctly.", LongName);
+ error = true;
+ }
+ }
+
+ var arr = Array.CreateInstance(ItemValueType, comList.Count);
+ for (int i = 0; i < arr.Length; i++)
+ arr.SetValue(comList[i], i);
+ Field.SetValue(destination, arr);
+ }
+ }
else if (IsTaggedCollection)
{
var res = Array.CreateInstance(ItemType, Utils.Size(values));
@@ -1784,7 +1951,7 @@ public bool SetValue(CmdParser owner, ref ArgValue val, string value, string tag
}
val.Values.Add(new KeyValuePair(tag, newValue));
}
- else if (IsSingleSubComponent)
+ else if (IsSingleSubComponent || IsComponentFactory)
{
Contracts.Assert(newValue is string || newValue == null);
Contracts.Assert((string)newValue != "");
@@ -1834,7 +2001,7 @@ private bool ParseValue(CmdParser owner, string data, out object value)
return false;
}
- if (IsSubComponentItemType)
+ if (IsSubComponentItemType || IsComponentFactory)
{
value = data;
return true;
@@ -2186,19 +2353,28 @@ private string GetString(IExceptionContext ectx, object value, StringBuilder buf
string name;
var catalog = ModuleCatalog.CreateInstance(ectx);
var type = value.GetType();
- bool success = catalog.TryGetComponentShortName(type, out name);
- Contracts.Assert(success);
-
- var settings = GetSettings(ectx, value, Activator.CreateInstance(type));
- buffer.Clear();
- buffer.Append(name);
- if (!string.IsNullOrWhiteSpace(settings))
+ bool isModuleComponent = catalog.TryGetComponentShortName(type, out name);
+ if (isModuleComponent)
{
- StringBuilder sb = new StringBuilder();
- CmdQuoter.QuoteValue(settings, sb, true);
- buffer.Append(sb);
+ var settings = GetSettings(ectx, value, Activator.CreateInstance(type));
+ buffer.Clear();
+ buffer.Append(name);
+ if (!string.IsNullOrWhiteSpace(settings))
+ {
+ StringBuilder sb = new StringBuilder();
+ CmdQuoter.QuoteValue(settings, sb, true);
+ buffer.Append(sb);
+ }
+ return buffer.ToString();
+ }
+ else if (value is ICommandLineComponentFactory)
+ {
+ return value.ToString();
+ }
+ else
+ {
+ throw ectx.Except($"IComponentFactory instances either need to be EntryPointComponents or implement {nameof(ICommandLineComponentFactory)}.");
}
- return buffer.ToString();
}
return value.ToString();
@@ -2344,9 +2520,191 @@ public bool IsMultiSubComponent {
get { return IsSubComponentItemType && Field.FieldType.IsArray; }
}
+ public bool IsSingleComponentFactory
+ {
+ get { return IsComponentFactory && !Field.FieldType.IsArray; }
+ }
+
+ public bool IsMultiComponentFactory
+ {
+ get { return IsComponentFactory && Field.FieldType.IsArray; }
+ }
+
public bool IsCustomItemType {
get { return _infoCustom != null; }
}
}
+
+ ///
+ /// A factory class for creating IComponentFactory instances.
+ ///
+ private static class ComponentFactoryFactory
+ {
+ public static ICommandLineComponentFactory CreateComponentFactory(
+ Type factoryType,
+ Type signatureType,
+ string name,
+ string[] settings)
+ {
+ if (!TryCreateComponentFactory(factoryType, signatureType, name, settings, out ICommandLineComponentFactory factory))
+ {
+ throw Contracts.ExceptNotImpl("ComponentFactoryFactory can only create IComponentFactory<> types with 4 or less type args.");
+ }
+
+ return factory;
+ }
+
+ public static bool TryCreateComponentFactory(
+ Type factoryType,
+ Type signatureType,
+ string name,
+ string[] settings,
+ out ICommandLineComponentFactory factory)
+ {
+
+ if (factoryType == null ||
+ !typeof(IComponentFactory).IsAssignableFrom(factoryType) ||
+ !factoryType.IsGenericType)
+ {
+ factory = null;
+ return false;
+ }
+
+ Type componentFactoryType;
+ switch (factoryType.GenericTypeArguments.Length)
+ {
+ case 1: componentFactoryType = typeof(ComponentFactory<>); break;
+ case 2: componentFactoryType = typeof(ComponentFactory<,>); break;
+ case 3: componentFactoryType = typeof(ComponentFactory<,,>); break;
+ case 4: componentFactoryType = typeof(ComponentFactory<,,,>); break;
+ default:
+ factory = null;
+ return false;
+ }
+
+ factory = (ICommandLineComponentFactory)Activator.CreateInstance(
+ componentFactoryType.MakeGenericType(factoryType.GenericTypeArguments),
+ signatureType,
+ name,
+ settings);
+ return true;
+ }
+
+ private abstract class ComponentFactory : ICommandLineComponentFactory
+ {
+ public Type SignatureType { get; }
+ public string Name { get; }
+ private string[] Settings { get; }
+
+ protected ComponentFactory(Type signatureType, string name, string[] settings)
+ {
+ SignatureType = signatureType;
+ Name = name;
+
+ if (settings == null || (settings.Length == 1 && string.IsNullOrEmpty(settings[0])))
+ {
+ settings = Array.Empty();
+ }
+ Settings = settings;
+ }
+
+ public string GetSettingsString()
+ {
+ return CombineSettings(Settings);
+ }
+
+ public override string ToString()
+ {
+ if (string.IsNullOrEmpty(Name) && Settings.Length == 0)
+ return "{}";
+
+ if (Settings.Length == 0)
+ return Name;
+
+ string str = CombineSettings(Settings);
+ StringBuilder sb = new StringBuilder();
+ CmdQuoter.QuoteValue(str, sb, true);
+ return Name + sb.ToString();
+ }
+ }
+
+ private class ComponentFactory : ComponentFactory, IComponentFactory
+ where TComponent : class
+ {
+ public ComponentFactory(Type signatureType, string name, string[] settings)
+ : base(signatureType, name, settings)
+ {
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env)
+ {
+ return ComponentCatalog.CreateInstance(
+ env,
+ SignatureType,
+ Name,
+ GetSettingsString());
+ }
+ }
+
+ private class ComponentFactory : ComponentFactory, IComponentFactory
+ where TComponent : class
+ {
+ public ComponentFactory(Type signatureType, string name, string[] settings)
+ : base(signatureType, name, settings)
+ {
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
+ {
+ return ComponentCatalog.CreateInstance(
+ env,
+ SignatureType,
+ Name,
+ GetSettingsString(),
+ argument1);
+ }
+ }
+
+ private class ComponentFactory : ComponentFactory, IComponentFactory
+ where TComponent : class
+ {
+ public ComponentFactory(Type signatureType, string name, string[] settings)
+ : base(signatureType, name, settings)
+ {
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2)
+ {
+ return ComponentCatalog.CreateInstance(
+ env,
+ SignatureType,
+ Name,
+ GetSettingsString(),
+ argument1,
+ argument2);
+ }
+ }
+
+ private class ComponentFactory : ComponentFactory, IComponentFactory
+ where TComponent : class
+ {
+ public ComponentFactory(Type signatureType, string name, string[] settings)
+ : base(signatureType, name, settings)
+ {
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3)
+ {
+ return ComponentCatalog.CreateInstance(
+ env,
+ SignatureType,
+ Name,
+ GetSettingsString(),
+ argument1,
+ argument2,
+ argument3);
+ }
+ }
+ }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
index 3b56e8bb36..ddbbd2a500 100644
--- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
@@ -343,7 +343,7 @@ private static bool ShouldSkipPath(string path)
case "libvw.dll":
case "matrixinterf.dll":
case "Microsoft.ML.neuralnetworks.gpucuda.dll":
- case "Microsoft.ML.mklimports.dll":
+ case "MklImports.dll":
case "microsoft.research.controls.decisiontrees.dll":
case "Microsoft.ML.neuralnetworks.sse.dll":
case "neuraltreeevaluator.dll":
@@ -832,10 +832,15 @@ public static LoadableClassInfo[] FindLoadableClasses()
public static LoadableClassInfo GetLoadableClassInfo(string loadName)
{
- Contracts.CheckParam(typeof(TSig).BaseType == typeof(MulticastDelegate), nameof(TSig), "TSig must be a delegate type");
+ return GetLoadableClassInfo(loadName, typeof(TSig));
+ }
+
+ public static LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType)
+ {
+ Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type");
Contracts.CheckValueOrNull(loadName);
loadName = (loadName ?? "").ToLowerInvariant().Trim();
- return FindClassCore(new LoadableClassInfo.Key(loadName, typeof(TSig)));
+ return FindClassCore(new LoadableClassInfo.Key(loadName, signatureType));
}
public static LoadableClassInfo GetLoadableClassInfo(SubComponent sub)
@@ -886,6 +891,18 @@ public static TRes CreateInstance(this SubComponent comp
throw Contracts.Except("Unknown loadable class: {0}", comp.Kind).MarkSensitive(MessageSensitivity.None);
}
+ ///
+ /// Create an instance of the indicated component with the given extra parameters.
+ ///
+ public static TRes CreateInstance(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra)
+ where TRes : class
+ {
+ TRes result;
+ if (TryCreateInstance(env, signatureType, out result, name, options, extra))
+ return result;
+ throw Contracts.Except("Unknown loadable class: {0}", name).MarkSensitive(MessageSensitivity.None);
+ }
+
///
/// Try to create an instance of the indicated component with the given extra parameters. If there is no
/// such component in the catalog, returns false. Any other error results in an exception.
@@ -913,13 +930,19 @@ public static bool TryCreateInstance(IHostEnvironment env, out TRes
///
public static bool TryCreateInstance(IHostEnvironment env, out TRes result, string name, string options, params object[] extra)
where TRes : class
+ {
+ return TryCreateInstance(env, typeof(TSig), out result, name, options, extra);
+ }
+
+ private static bool TryCreateInstance(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra)
+ where TRes : class
{
Contracts.CheckValue(env, nameof(env));
- env.Check(typeof(TSig).BaseType == typeof(MulticastDelegate));
+ env.Check(signatureType.BaseType == typeof(MulticastDelegate));
env.CheckValueOrNull(name);
string nameLower = (name ?? "").ToLowerInvariant().Trim();
- LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, typeof(TSig)));
+ LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType));
if (info == null)
{
result = null;
diff --git a/src/Microsoft.ML.Core/Data/ColumnType.cs b/src/Microsoft.ML.Core/Data/ColumnType.cs
index 0cff911e77..96764d68f1 100644
--- a/src/Microsoft.ML.Core/Data/ColumnType.cs
+++ b/src/Microsoft.ML.Core/Data/ColumnType.cs
@@ -242,6 +242,10 @@ public bool IsStandardScalar
///
internal virtual int ValueCountCore { get { return 1; } }
+ // IEquatable interface recommends also to override base class implementations of
+ // Object.Equals(Object) and GetHashCode. In classes below where Equals(ColumnType other)
+ // is effectively a referencial comparison, there is no need to override base class implementations
+ // of Object.Equals(Object) (and GetHashCode) since its also a referencial comparison.
public abstract bool Equals(ColumnType other);
///
@@ -789,6 +793,16 @@ public override bool Equals(ColumnType other)
return true;
}
+ public override bool Equals(object other)
+ {
+ return other is ColumnType tmp && Equals(tmp);
+ }
+
+ public override int GetHashCode()
+ {
+ return Hashing.CombinedHash(RawKind.GetHashCode(), _contiguous, _min, _count);
+ }
+
public override string ToString()
{
if (_count > 0)
@@ -940,6 +954,21 @@ public override bool Equals(ColumnType other)
return true;
}
+ public override bool Equals(object other)
+ {
+ return other is ColumnType tmp && Equals(tmp);
+ }
+
+ public override int GetHashCode()
+ {
+ int hash = Hashing.CombinedHash(_itemType.GetHashCode(), _size);
+ int count = Utils.Size(_dims);
+ hash = Hashing.CombineHash(hash, count.GetHashCode());
+ for (int i = 0; i < count; i++)
+ hash = Hashing.CombineHash(hash, _dims[i].GetHashCode());
+ return hash;
+ }
+
///
/// Returns true if current has the same item type of other, and the size
/// of other is unknown or the current size is equal to the size of other.
diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
index 9334f0f225..1f8d5b0e8f 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
@@ -24,7 +24,7 @@ public interface IArgsComponent : IComponentFactory
///
/// An interface for creating a component with no extra parameters (other than an ).
///
- public interface IComponentFactory: IComponentFactory
+ public interface IComponentFactory : IComponentFactory
{
TComponent CreateComponent(IHostEnvironment env);
}
@@ -44,4 +44,104 @@ public interface IComponentFactory : ICompon
{
TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2);
}
+
+ ///
+ /// An interface for creating a component when we take three extra parameters (and an ).
+ ///
+ public interface IComponentFactory : IComponentFactory
+ {
+ TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3);
+ }
+
+ ///
+ /// A utility class for creating instances.
+ ///
+ public static class ComponentFactoryUtils
+ {
+ ///
+ /// Creates a component factory with no extra parameters (other than an )
+ /// that simply wraps a delegate which creates the component.
+ ///
+ public static IComponentFactory CreateFromFunction(Func factory)
+ {
+ return new SimpleComponentFactory(factory);
+ }
+
+ ///
+ /// Creates a component factory when we take one extra parameter (and an
+ /// ) that simply wraps a delegate which creates the component.
+ ///
+ public static IComponentFactory CreateFromFunction(Func factory)
+ {
+ return new SimpleComponentFactory(factory);
+ }
+
+ ///
+ /// Creates a component factory when we take three extra parameters (and an
+ /// ) that simply wraps a delegate which creates the component.
+ ///
+ public static IComponentFactory CreateFromFunction(Func factory)
+ {
+ return new SimpleComponentFactory(factory);
+ }
+
+ ///
+ /// A class for creating a component with no extra parameters (other than an )
+ /// that simply wraps a delegate which creates the component.
+ ///
+ private sealed class SimpleComponentFactory : IComponentFactory
+ {
+ private readonly Func _factory;
+
+ public SimpleComponentFactory(Func factory)
+ {
+ _factory = factory;
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env)
+ {
+ return _factory(env);
+ }
+ }
+
+ ///
+ /// A class for creating a component when we take one extra parameter
+ /// (and an ) that simply wraps a delegate which
+ /// creates the component.
+ ///
+ private sealed class SimpleComponentFactory : IComponentFactory
+ {
+ private readonly Func _factory;
+
+ public SimpleComponentFactory(Func factory)
+ {
+ _factory = factory;
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
+ {
+ return _factory(env, argument1);
+ }
+ }
+
+ ///
+ /// A class for creating a component when we take three extra parameters
+ /// (and an ) that simply wraps a delegate which
+ /// creates the component.
+ ///
+ private sealed class SimpleComponentFactory : IComponentFactory
+ {
+ private readonly Func _factory;
+
+ public SimpleComponentFactory(Func factory)
+ {
+ _factory = factory;
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3)
+ {
+ return _factory(env, argument1, argument2, argument3);
+ }
+ }
+ }
}
diff --git a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs b/src/Microsoft.ML.CpuMath/AlignedMatrix.cs
index 67f05ee7cf..f76ec7815d 100644
--- a/src/Microsoft.ML.CpuMath/AlignedMatrix.cs
+++ b/src/Microsoft.ML.CpuMath/AlignedMatrix.cs
@@ -550,7 +550,7 @@ public void ZeroItems(int[] indices)
// REVIEW: Ideally, we'd adjust the indices once so we wouldn't need to
// repeatedly deal with padding adjustments.
- SseUtils.ZeroMatrixItems(Items, ColCount, ColCountPhy, indices);
+ CpuMathUtils.ZeroMatrixItems(Items, ColCount, ColCountPhy, indices);
}
}
diff --git a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs
index ad53810ff3..9c7fa5ae1f 100644
--- a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs
+++ b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs
@@ -16,7 +16,7 @@ public static void AssertCompatible(ICpuFullMatrix values)
#if DEBUG
var mat = values as TMatrix;
Contracts.AssertValue(mat);
- Contracts.Assert(mat.Items.CbAlign == SseUtils.CbAlign);
+ Contracts.Assert(mat.Items.CbAlign == CpuMathUtils.Vector128Alignment);
#endif
}
@@ -29,7 +29,7 @@ public static void AssertCompatible(ICpuVector values)
#if DEBUG
CpuAlignedVector vec = values as CpuAlignedVector;
Contracts.AssertValue(vec);
- Contracts.Assert(vec.Items.CbAlign == SseUtils.CbAlign);
+ Contracts.Assert(vec.Items.CbAlign == CpuMathUtils.Vector128Alignment);
#endif
}
@@ -89,7 +89,7 @@ public static void MatTimesSrc(bool add, ICpuFullMatrix mat, ICpuVector src, ICp
bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol);
AssertCompatible(mat, src, dst);
var m = A(mat);
- SseUtils.MatTimesSrc(colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt);
+ CpuMathUtils.MatTimesSrc(colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt);
}
///
@@ -108,7 +108,7 @@ public static void MatTranTimesSrc(bool add, ICpuFullMatrix mat, ICpuVector src,
bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol);
AssertCompatible(mat, dst, src);
var m = A(mat);
- SseUtils.MatTimesSrc(!colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt);
+ CpuMathUtils.MatTimesSrc(!colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt);
}
}
diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs
new file mode 100644
index 0000000000..81d7acf25a
--- /dev/null
+++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs
@@ -0,0 +1,962 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.Intrinsics.X86;
+using System;
+
+namespace Microsoft.ML.Runtime.Internal.CpuMath
+{
+ public static partial class CpuMathUtils
+ {
+ // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray
+ public const int Vector128Alignment = 16;
+
+ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
+ {
+ Contracts.Assert(mat.Size == dst.Size * src.Size);
+ Contracts.Assert(crun >= 0);
+
+ if (Sse.IsSupported)
+ {
+ if (!tran)
+ {
+ Contracts.Assert(crun <= dst.Size);
+ SseIntrinsics.MatMulA(add, mat, src, dst, crun, src.Size);
+ }
+ else
+ {
+ Contracts.Assert(crun <= src.Size);
+ SseIntrinsics.MatMulTranA(add, mat, src, dst, dst.Size, crun);
+ }
+ }
+ else
+ {
+ if (!tran)
+ {
+ Contracts.Assert(crun <= dst.Size);
+ for (int i = 0; i < crun; i++)
+ {
+ float dotProduct = 0;
+ for (int j = 0; j < src.Size; j++)
+ {
+ dotProduct += mat[i * src.Size + j] * src[j];
+ }
+
+ if (add)
+ {
+ dst[i] += dotProduct;
+ }
+ else
+ {
+ dst[i] = dotProduct;
+ }
+ }
+ }
+ else
+ {
+ Contracts.Assert(crun <= src.Size);
+ for (int i = 0; i < dst.Size; i++)
+ {
+ float dotProduct = 0;
+ for (int j = 0; j < crun; j++)
+ {
+ dotProduct += mat[j * src.Size + i] * src[j];
+ }
+
+ if (add)
+ {
+ dst[i] += dotProduct;
+ }
+ else
+ {
+ dst[i] = dotProduct;
+ }
+ }
+ }
+ }
+ }
+
+ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
+ int posMin, int iposMin, int iposLim, AlignedArray dst, int crun)
+ {
+ Contracts.AssertValue(rgposSrc);
+ Contracts.Assert(iposMin >= 0);
+ Contracts.Assert(iposMin <= iposLim);
+ Contracts.Assert(iposLim <= rgposSrc.Length);
+ Contracts.Assert(mat.Size == dst.Size * srcValues.Size);
+
+ if (iposMin >= iposLim)
+ {
+ if (!add)
+ dst.ZeroItems();
+ return;
+ }
+
+ Contracts.AssertNonEmpty(rgposSrc);
+ Contracts.Assert(crun >= 0);
+
+ if (Sse.IsSupported)
+ {
+ if (!tran)
+ {
+ Contracts.Assert(crun <= dst.Size);
+ SseIntrinsics.MatMulPA(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
+ }
+ else
+ {
+ Contracts.Assert(crun <= srcValues.Size);
+ SseIntrinsics.MatMulTranPA(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size);
+ }
+ }
+ else
+ {
+ if (!tran)
+ {
+ Contracts.Assert(crun <= dst.Size);
+ for (int i = 0; i < crun; i++)
+ {
+ float dotProduct = 0;
+ for (int j = iposMin; j < iposLim; j++)
+ {
+ int col = rgposSrc[j] - posMin;
+ dotProduct += mat[i * srcValues.Size + col] * srcValues[col];
+ }
+
+ if (add)
+ {
+ dst[i] += dotProduct;
+ }
+ else
+ {
+ dst[i] = dotProduct;
+ }
+ }
+ }
+ else
+ {
+ Contracts.Assert(crun <= srcValues.Size);
+ for (int i = 0; i < dst.Size; i++)
+ {
+ float dotProduct = 0;
+ for (int j = iposMin; j < iposLim; j++)
+ {
+ int col = rgposSrc[j] - posMin;
+ dotProduct += mat[col * dst.Size + i] * srcValues[col];
+ }
+
+ if (add)
+ {
+ dst[i] += dotProduct;
+ }
+ else
+ {
+ dst[i] = dotProduct;
+ }
+ }
+
+ }
+ }
+ }
+
+ public static void Add(float a, float[] dst, int count)
+ {
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= dst.Length);
+
+ Add(a, new Span(dst, 0, count));
+ }
+
+ private static void Add(float a, Span dst)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.AddScalarU(a, dst);
+ }
+ else
+ {
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] += a;
+ }
+ }
+ }
+
+ public static void Scale(float a, float[] dst, int count)
+ {
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= dst.Length);
+
+ Scale(a, new Span(dst, 0, count));
+ }
+
+ public static void Scale(float a, float[] dst, int offset, int count)
+ {
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(offset >= 0);
+ Contracts.Assert(offset < (dst.Length - count));
+
+ Scale(a, new Span(dst, offset, count));
+ }
+
+ private static void Scale(float a, Span dst)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.ScaleU(a, dst);
+ }
+ else
+ {
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] *= a;
+ }
+ }
+ }
+
+ // dst = a * src
+ public static void Scale(float a, float[] src, float[] dst, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= dst.Length);
+
+ Scale(a, new Span(src, 0, count), new Span(dst, 0, count));
+ }
+
+ private static void Scale(float a, Span src, Span dst)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.ScaleSrcU(a, src, dst);
+ }
+ else
+ {
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] = a * src[i];
+ }
+ }
+ }
+
+ // dst[i] = a * (dst[i] + b)
+ public static void ScaleAdd(float a, float b, float[] dst, int count)
+ {
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= dst.Length);
+
+ ScaleAdd(a, b, new Span(dst, 0, count));
+ }
+
+ private static void ScaleAdd(float a, float b, Span dst)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.ScaleAddU(a, b, dst);
+ }
+ else
+ {
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] = a * (dst[i] + b);
+ }
+ }
+ }
+
+ public static void AddScale(float a, float[] src, float[] dst, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= dst.Length);
+
+ AddScale(a, new Span(src, 0, count), new Span(dst, 0, count));
+ }
+
+ public static void AddScale(float a, float[] src, float[] dst, int dstOffset, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(dstOffset >= 0);
+ Contracts.Assert(dstOffset < dst.Length);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= (dst.Length - dstOffset));
+
+ AddScale(a, new Span(src, 0, count), new Span(dst, dstOffset, count));
+ }
+
+ private static void AddScale(float a, Span src, Span dst)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.AddScaleU(a, src, dst);
+ }
+ else
+ {
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] += a * src[i];
+ }
+ }
+ }
+
+ public static void AddScale(float a, float[] src, int[] indices, float[] dst, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(indices);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= indices.Length);
+ Contracts.Assert(count < dst.Length);
+
+ AddScale(a, new Span(src), new Span(indices, 0, count), new Span(dst));
+ }
+
+ public static void AddScale(float a, float[] src, int[] indices, float[] dst, int dstOffset, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(indices);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(dstOffset >= 0);
+ Contracts.Assert(dstOffset < dst.Length);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= indices.Length);
+ Contracts.Assert(count < (dst.Length - dstOffset));
+
+ AddScale(a, new Span(src), new Span(indices, 0, count),
+ new Span(dst, dstOffset, dst.Length - dstOffset));
+ }
+
+ private static void AddScale(float a, Span src, Span indices, Span dst)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.AddScaleSU(a, src, indices, dst);
+ }
+ else
+ {
+ for (int i = 0; i < indices.Length; i++)
+ {
+ int index = indices[i];
+ dst[index] += a * src[i];
+ }
+ }
+ }
+
+ public static void AddScaleCopy(float a, float[] src, float[] dst, float[] res, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.AssertNonEmpty(res);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= dst.Length);
+ Contracts.Assert(count <= res.Length);
+
+ AddScaleCopy(a, new Span(src, 0, count), new Span(dst, 0, count), new Span(res, 0, count));
+ }
+
+ private static void AddScaleCopy(float a, Span src, Span dst, Span res)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.AddScaleCopyU(a, src, dst, res);
+ }
+ else
+ {
+ for (int i = 0; i < res.Length; i++)
+ {
+ res[i] = a * src[i] + dst[i];
+ }
+ }
+ }
+
+ public static void Add(float[] src, float[] dst, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= dst.Length);
+
+ Add(new Span(src, 0, count), new Span(dst, 0, count));
+ }
+
+ private static void Add(Span src, Span dst)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.AddU(src, dst);
+ }
+ else
+ {
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] += src[i];
+ }
+ }
+ }
+
+ public static void Add(float[] src, int[] indices, float[] dst, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(indices);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= indices.Length);
+ Contracts.Assert(count < dst.Length);
+
+ Add(new Span(src), new Span(indices, 0, count), new Span(dst));
+ }
+
+ public static void Add(float[] src, int[] indices, float[] dst, int dstOffset, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(indices);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(dstOffset >= 0);
+ Contracts.Assert(dstOffset < dst.Length);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= indices.Length);
+ Contracts.Assert(count <= (dst.Length - dstOffset));
+
+ Add(new Span(src), new Span(indices, 0, count),
+ new Span(dst, dstOffset, dst.Length - dstOffset));
+ }
+
+ private static void Add(Span src, Span indices, Span dst)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.AddSU(src, indices, dst);
+ }
+ else
+ {
+ for (int i = 0; i < indices.Length; i++)
+ {
+ int index = indices[i];
+ dst[index] += src[i];
+ }
+ }
+ }
+
+ public static void MulElementWise(float[] src1, float[] src2, float[] dst, int count)
+ {
+ Contracts.AssertNonEmpty(src1);
+ Contracts.AssertNonEmpty(src2);
+ Contracts.AssertNonEmpty(dst);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src1.Length);
+ Contracts.Assert(count <= src2.Length);
+
+ MulElementWise(new Span(src1, 0, count), new Span(src2, 0, count),
+ new Span(dst, 0, count));
+ }
+
+ private static void MulElementWise(Span src1, Span src2, Span dst)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.MulElementWiseU(src1, src2, dst);
+ }
+ else
+ {
+ for (int i = 0; i < dst.Length; i++)
+ {
+ dst[i] = src1[i] * src2[i];
+ }
+ }
+ }
+
+ public static float Sum(float[] src, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+
+ return Sum(new Span(src, 0, count));
+ }
+
+ public static float Sum(float[] src, int offset, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(offset >= 0);
+ Contracts.Assert(offset <= (src.Length - count));
+
+ return Sum(new Span(src, offset, count));
+ }
+
+ private static float Sum(Span src)
+ {
+ if (Sse.IsSupported)
+ {
+ return SseIntrinsics.SumU(src);
+ }
+ else
+ {
+ float sum = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ sum += src[i];
+ }
+ return sum;
+ }
+ }
+
+ public static float SumSq(float[] src, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+
+ return SumSq(new Span(src, 0, count));
+ }
+
+ public static float SumSq(float[] src, int offset, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(offset >= 0);
+ Contracts.Assert(offset <= (src.Length - count));
+
+ return SumSq(new Span(src, offset, count));
+ }
+
+ private static float SumSq(Span src)
+ {
+ if (Sse.IsSupported)
+ {
+ return SseIntrinsics.SumSqU(src);
+ }
+ else
+ {
+ float result = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ result += src[i] * src[i];
+ }
+ return result;
+ }
+ }
+
+ public static float SumSq(float mean, float[] src, int offset, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(offset >= 0);
+ Contracts.Assert(offset <= (src.Length - count));
+
+ return SumSq(mean, new Span(src, offset, count));
+ }
+
+ private static float SumSq(float mean, Span src)
+ {
+ if (Sse.IsSupported)
+ {
+ return (mean == 0) ? SseIntrinsics.SumSqU(src) : SseIntrinsics.SumSqDiffU(mean, src);
+ }
+ else
+ {
+ float result = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ result += (src[i] - mean) * (src[i] - mean);
+ }
+ return result;
+ }
+ }
+
+ public static float SumAbs(float[] src, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+
+ return SumAbs(new Span(src, 0, count));
+ }
+
+ public static float SumAbs(float[] src, int offset, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(offset >= 0);
+ Contracts.Assert(offset <= (src.Length - count));
+
+ return SumAbs(new Span(src, offset, count));
+ }
+
+ private static float SumAbs(Span src)
+ {
+ if (Sse.IsSupported)
+ {
+ return SseIntrinsics.SumAbsU(src);
+ }
+ else
+ {
+ float sum = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ sum += Math.Abs(src[i]);
+ }
+ return sum;
+ }
+ }
+
+ public static float SumAbs(float mean, float[] src, int offset, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(offset >= 0);
+ Contracts.Assert(offset <= (src.Length - count));
+
+ return SumAbs(mean, new Span(src, offset, count));
+ }
+
+ private static float SumAbs(float mean, Span src)
+ {
+ if (Sse.IsSupported)
+ {
+ return (mean == 0) ? SseIntrinsics.SumAbsU(src) : SseIntrinsics.SumAbsDiffU(mean, src);
+ }
+ else
+ {
+ float sum = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ sum += Math.Abs(src[i] - mean);
+ }
+ return sum;
+ }
+ }
+
+ public static float MaxAbs(float[] src, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+
+ return MaxAbs(new Span(src, 0, count));
+ }
+
+ public static float MaxAbs(float[] src, int offset, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(offset >= 0);
+ Contracts.Assert(offset <= (src.Length - count));
+
+ return MaxAbs(new Span(src, offset, count));
+ }
+
+ private static float MaxAbs(Span src)
+ {
+ if (Sse.IsSupported)
+ {
+ return SseIntrinsics.MaxAbsU(src);
+ }
+ else
+ {
+ float max = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ float abs = Math.Abs(src[i]);
+ if (abs > max)
+ {
+ max = abs;
+ }
+ }
+ return max;
+ }
+ }
+
+ public static float MaxAbsDiff(float mean, float[] src, int count)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+
+ return MaxAbsDiff(mean, new Span(src, 0, count));
+ }
+
+ private static float MaxAbsDiff(float mean, Span src)
+ {
+ if (Sse.IsSupported)
+ {
+ return SseIntrinsics.MaxAbsDiffU(mean, src);
+ }
+ else
+ {
+ float max = 0;
+ for (int i = 0; i < src.Length; i++)
+ {
+ float abs = Math.Abs(src[i] - mean);
+ if (abs > max)
+ {
+ max = abs;
+ }
+ }
+ return max;
+ }
+ }
+
+ public static float DotProductDense(float[] a, float[] b, int count)
+ {
+ Contracts.AssertNonEmpty(a);
+ Contracts.AssertNonEmpty(b);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(a.Length >= count);
+ Contracts.Assert(b.Length >= count);
+
+ return DotProductDense(new Span(a, 0, count), new Span(b, 0, count));
+ }
+
+ public static float DotProductDense(float[] a, int offset, float[] b, int count)
+ {
+ Contracts.AssertNonEmpty(a);
+ Contracts.AssertNonEmpty(b);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= b.Length);
+ Contracts.Assert(offset >= 0);
+ Contracts.Assert(offset <= (a.Length - count));
+
+ return DotProductDense(new Span(a, offset, count), new Span(b, 0, count));
+ }
+
+ private static float DotProductDense(Span a, Span b)
+ {
+ if (Sse.IsSupported)
+ {
+ return SseIntrinsics.DotU(a, b);
+ }
+ else
+ {
+ float result = 0;
+ for (int i = 0; i < b.Length; i++)
+ {
+ result += a[i] * b[i];
+ }
+ return result;
+ }
+ }
+
+ public static float DotProductSparse(float[] a, float[] b, int[] indices, int count)
+ {
+ Contracts.AssertNonEmpty(a);
+ Contracts.AssertNonEmpty(b);
+ Contracts.AssertNonEmpty(indices);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count < a.Length);
+ Contracts.Assert(count <= b.Length);
+ Contracts.Assert(count <= indices.Length);
+
+ return DotProductSparse(new Span(a), new Span(b),
+ new Span(indices, 0, count));
+ }
+
+ public static float DotProductSparse(float[] a, int offset, float[] b, int[] indices, int count)
+ {
+ Contracts.AssertNonEmpty(a);
+ Contracts.AssertNonEmpty(b);
+ Contracts.AssertNonEmpty(indices);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count < (a.Length - offset));
+ Contracts.Assert(count <= b.Length);
+ Contracts.Assert(count <= indices.Length);
+ Contracts.Assert(offset >= 0);
+ Contracts.Assert(offset < a.Length);
+
+ return DotProductSparse(new Span(a, offset, a.Length - offset),
+ new Span(b), new Span(indices, 0, count));
+ }
+
+ private static float DotProductSparse(Span a, Span b, Span indices)
+ {
+ if (Sse.IsSupported)
+ {
+ return SseIntrinsics.DotSU(a, b, indices);
+ }
+ else
+ {
+ float result = 0;
+ for (int i = 0; i < indices.Length; i++)
+ {
+ int index = indices[i];
+ result += a[index] * b[i];
+ }
+ return result;
+ }
+ }
+
+ public static float L2DistSquared(float[] a, float[] b, int count)
+ {
+ Contracts.AssertNonEmpty(a);
+ Contracts.AssertNonEmpty(b);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= a.Length);
+ Contracts.Assert(count <= b.Length);
+
+ return L2DistSquared(new Span(a, 0, count), new Span(b, 0, count));
+ }
+
+ private static float L2DistSquared(Span a, Span b)
+ {
+ if (Sse.IsSupported)
+ {
+ return SseIntrinsics.Dist2(a, b);
+ }
+ else
+ {
+ float norm = 0;
+ for (int i = 0; i < b.Length; i++)
+ {
+ float distance = a[i] - b[i];
+ norm += distance * distance;
+ }
+ return norm;
+ }
+ }
+
+ public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices)
+ {
+ Contracts.Assert(ccol > 0);
+ Contracts.Assert(ccol <= cfltRow);
+
+ if (ccol == cfltRow)
+ {
+ ZeroItemsU(dst, dst.Size, indices, indices.Length);
+ }
+ else
+ {
+ ZeroMatrixItemsCore(dst, dst.Size, ccol, cfltRow, indices, indices.Length);
+ }
+ }
+
+ private static unsafe void ZeroItemsU(AlignedArray dst, int c, int[] indices, int cindices)
+ {
+ fixed (float* pdst = &dst.Items[0])
+ fixed (int* pidx = &indices[0])
+ {
+ for (int i = 0; i < cindices; ++i)
+ {
+ int index = pidx[i];
+ Contracts.Assert(index >= 0);
+ Contracts.Assert(index < c);
+ pdst[index] = 0;
+ }
+ }
+ }
+
+ private static unsafe void ZeroMatrixItemsCore(AlignedArray dst, int c, int ccol, int cfltRow, int[] indices, int cindices)
+ {
+ fixed (float* pdst = &dst.Items[0])
+ fixed (int* pidx = &indices[0])
+ {
+ int ivLogMin = 0;
+ int ivLogLim = ccol;
+ int ivPhyMin = 0;
+
+ for (int i = 0; i < cindices; ++i)
+ {
+ int index = pidx[i];
+ Contracts.Assert(index >= 0);
+ Contracts.Assert(index < c);
+
+ int col = index - ivLogMin;
+ if ((uint)col >= (uint)ccol)
+ {
+ Contracts.Assert(ivLogMin > index || index >= ivLogLim);
+
+ int row = index / ccol;
+ ivLogMin = row * ccol;
+ ivLogLim = ivLogMin + ccol;
+ ivPhyMin = row * cfltRow;
+
+ Contracts.Assert(index >= ivLogMin);
+ Contracts.Assert(index < ivLogLim);
+ col = index - ivLogMin;
+ }
+
+ pdst[ivPhyMin + col] = 0;
+ }
+ }
+ }
+
+ public static void SdcaL1UpdateDense(float primalUpdate, int length, float[] src, float threshold, float[] v, float[] w)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(v);
+ Contracts.AssertNonEmpty(w);
+ Contracts.Assert(length > 0);
+ Contracts.Assert(length <= src.Length);
+ Contracts.Assert(length <= v.Length);
+ Contracts.Assert(length <= w.Length);
+
+ SdcaL1UpdateDense(primalUpdate, new Span(src, 0, length), threshold, new Span(v, 0, length), new Span(w, 0, length));
+ }
+
+ private static void SdcaL1UpdateDense(float primalUpdate, Span src, float threshold, Span v, Span w)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.SdcaL1UpdateU(primalUpdate, src, threshold, v, w);
+ }
+ else
+ {
+ for (int i = 0; i < src.Length; i++)
+ {
+ v[i] += src[i] * primalUpdate;
+ float value = v[i];
+ w[i] = Math.Abs(value) > threshold ? (value > 0 ? value - threshold : value + threshold) : 0;
+ }
+ }
+ }
+
+ // REVIEW NEEDED: The second argument "length" is unused even in the existing code.
+ public static void SdcaL1UpdateSparse(float primalUpdate, int length, float[] src, int[] indices, int count, float threshold, float[] v, float[] w)
+ {
+ Contracts.AssertNonEmpty(src);
+ Contracts.AssertNonEmpty(indices);
+ Contracts.AssertNonEmpty(v);
+ Contracts.AssertNonEmpty(w);
+ Contracts.Assert(count > 0);
+ Contracts.Assert(count <= src.Length);
+ Contracts.Assert(count <= indices.Length);
+ Contracts.Assert(count < length);
+ Contracts.Assert(length <= v.Length);
+ Contracts.Assert(length <= w.Length);
+
+ SdcaL1UpdateSparse(primalUpdate, new Span(src, 0, count), new Span(indices, 0, count), threshold, new Span(v), new Span(w));
+ }
+
+ private static void SdcaL1UpdateSparse(float primalUpdate, Span src, Span indices, float threshold, Span v, Span w)
+ {
+ if (Sse.IsSupported)
+ {
+ SseIntrinsics.SdcaL1UpdateSU(primalUpdate, src, indices, threshold, v, w);
+ }
+ else
+ {
+ for (int i = 0; i < indices.Length; i++)
+ {
+ int index = indices[i];
+ v[index] += src[i] * primalUpdate;
+ float value = v[index];
+ w[index] = Math.Abs(value) > threshold ? (value > 0 ? value - threshold : value + threshold) : 0;
+ }
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs
new file mode 100644
index 0000000000..6f480b0f25
--- /dev/null
+++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs
@@ -0,0 +1,85 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Runtime.Internal.CpuMath
+{
+ public static partial class CpuMathUtils
+ {
+ // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray
+ public const int Vector128Alignment = 16;
+
+ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, add, mat, src, dst, crun);
+
+ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
+ int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun);
+
+ public static void Add(float a, float[] dst, int count) => SseUtils.Add(a, dst, count);
+
+ public static void Scale(float a, float[] dst, int count) => SseUtils.Scale(a, dst, count);
+
+ public static void Scale(float a, float[] dst, int offset, int count) => SseUtils.Scale(a, dst, offset, count);
+
+ public static void Scale(float a, float[] src, float[] dst, int count) => SseUtils.Scale(a, src, dst, count);
+
+ public static void ScaleAdd(float a, float b, float[] dst, int count) => SseUtils.ScaleAdd(a, b, dst, count);
+
+ public static void AddScale(float a, float[] src, float[] dst, int count) => SseUtils.AddScale(a, src, dst, count);
+
+ public static void AddScale(float a, float[] src, float[] dst, int dstOffset, int count) => SseUtils.AddScale(a, src, dst, dstOffset, count);
+
+ public static void AddScale(float a, float[] src, int[] indices, float[] dst, int count) => SseUtils.AddScale(a, src, indices, dst, count);
+
+ public static void AddScale(float a, float[] src, int[] indices, float[] dst, int dstOffset, int count) => SseUtils.AddScale(a, src, indices, dst, dstOffset, count);
+
+ public static void AddScaleCopy(float a, float[] src, float[] dst, float[] res, int count) => SseUtils.AddScaleCopy(a, src, dst, res, count);
+
+ public static void Add(float[] src, float[] dst, int count) => SseUtils.Add(src, dst, count);
+
+ public static void Add(float[] src, int[] indices, float[] dst, int count) => SseUtils.Add(src, indices, dst, count);
+
+ public static void Add(float[] src, int[] indices, float[] dst, int dstOffset, int count) => SseUtils.Add(src, indices, dst, dstOffset, count);
+
+ public static void MulElementWise(float[] src1, float[] src2, float[] dst, int count) => SseUtils.MulElementWise(src1, src2, dst, count);
+
+ public static float Sum(float[] src, int count) => SseUtils.Sum(src, count);
+
+ public static float Sum(float[] src, int offset, int count) => SseUtils.Sum(src, offset, count);
+
+ public static float SumSq(float[] src, int count) => SseUtils.SumSq(src, count);
+
+ public static float SumSq(float[] src, int offset, int count) => SseUtils.SumSq(src, offset, count);
+
+ public static float SumSq(float mean, float[] src, int offset, int count) => SseUtils.SumSq(mean, src, offset, count);
+
+ public static float SumAbs(float[] src, int count) => SseUtils.SumAbs(src, count);
+
+ public static float SumAbs(float[] src, int offset, int count) => SseUtils.SumAbs(src, offset, count);
+
+ public static float SumAbs(float mean, float[] src, int offset, int count) => SseUtils.SumAbs(mean, src, offset, count);
+
+ public static float MaxAbs(float[] src, int count) => SseUtils.MaxAbs(src, count);
+
+ public static float MaxAbs(float[] src, int offset, int count) => SseUtils.MaxAbs(src, offset, count);
+
+ public static float MaxAbsDiff(float mean, float[] src, int count) => SseUtils.MaxAbsDiff(mean, src, count);
+
+ public static float DotProductDense(float[] a, float[] b, int count) => SseUtils.DotProductDense(a, b, count);
+
+ public static float DotProductDense(float[] a, int offset, float[] b, int count) => SseUtils.DotProductDense(a, offset, b, count);
+
+ public static float DotProductSparse(float[] a, float[] b, int[] indices, int count) => SseUtils.DotProductSparse(a, b, indices, count);
+
+ public static float DotProductSparse(float[] a, int offset, float[] b, int[] indices, int count) => SseUtils.DotProductSparse(a, offset, b, indices, count);
+
+ public static float L2DistSquared(float[] a, float[] b, int count) => SseUtils.L2DistSquared(a, b, count);
+
+ public static void ZeroMatrixItems(AlignedArray dst, int ccol, int cfltRow, int[] indices) => SseUtils.ZeroMatrixItems(dst, ccol, cfltRow, indices);
+
+ public static void SdcaL1UpdateDense(float primalUpdate, int length, float[] src, float threshold, float[] v, float[] w)
+ => SseUtils.SdcaL1UpdateDense(primalUpdate, length, src, threshold, v, w);
+
+ public static void SdcaL1UpdateSparse(float primalUpdate, int length, float[] src, int[] indices, int count, float threshold, float[] v, float[] w)
+ => SseUtils.SdcaL1UpdateSparse(primalUpdate, length, src, indices, count, threshold, v, w);
+ }
+}
diff --git a/src/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.csproj b/src/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.csproj
index bde7ae89f5..b6c95b93f4 100644
--- a/src/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.csproj
+++ b/src/Microsoft.ML.CpuMath/Microsoft.ML.CpuMath.csproj
@@ -2,21 +2,29 @@
Debug;Release;Debug-Intrinsics;Release-Intrinsics
- $(Configuration.EndsWith('-Intrinsics'))
-
netstandard2.0
netstandard2.0;netcoreapp3.0
Microsoft.ML.CpuMath
true
$(DefineConstants);CORECLR;PRIVATE_CONTRACTS
+ 7.3
+
+
+
+
-
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs
index 68e6ee906b..13de22dd5b 100644
--- a/src/Microsoft.ML.CpuMath/Sse.cs
+++ b/src/Microsoft.ML.CpuMath/Sse.cs
@@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-
namespace Microsoft.ML.Runtime.Internal.CpuMath
{
///
diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs
new file mode 100644
index 0000000000..bf7ad03e34
--- /dev/null
+++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs
@@ -0,0 +1,1189 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+// The exported function names need to be unique (can't be disambiguated based on signature), hence
+// we introduce suffix letters to indicate the general patterns used.
+// * A suffix means aligned and padded for SSE operations.
+// * U suffix means unaligned and unpadded.
+// * S suffix means sparse (unaligned) vector.
+// * P suffix means sparse (unaligned) partial vector - the vector is only part of a larger sparse vector.
+// * R suffix means sparse matrix.
+// * C suffix means convolution matrix.
+// * D suffix means convolution matrix, with implicit source padding.
+// * Tran means the matrix is transposed.
+
+using System;
+using System.Runtime.CompilerServices;
+using System.Runtime.Intrinsics;
+using System.Runtime.Intrinsics.X86;
+
+namespace Microsoft.ML.Runtime.Internal.CpuMath
+{
+ internal static class SseIntrinsics
+ {
+ // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray
+ private const int Vector128Alignment = 16;
+
+ private static bool Compat(AlignedArray a)
+ {
+ Contracts.AssertValue(a);
+ Contracts.Assert(a.Size > 0);
+ return a.CbAlign == Vector128Alignment;
+ }
+
+ private static unsafe float* Ptr(AlignedArray a, float* p)
+ {
+ Contracts.AssertValue(a);
+ float* q = p + a.GetBase((long)p);
+ Contracts.Assert(((long)q & (Vector128Alignment - 1)) == 0);
+ return q;
+ }
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector128 Load1(float* src, int* idx)
+ {
+ return Sse.SetScalarVector128(src[idx[0]]);
+ }
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static unsafe Vector128 Load4(float* src, int* idx)
+ {
+ return Sse.SetVector128(src[idx[3]], src[idx[2]], src[idx[1]], src[idx[0]]);
+ }
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static Vector128 Rotate(in Vector128 x)
+ {
+ // The control byte shuffles the four 32-bit floats of x: ABCD -> BCDA.
+ return Sse.Shuffle(x, x, 0x39);
+ }
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static unsafe void Store4(in Vector128 x, float* dst, int* idx)
+ {
+ Sse.StoreScalar(dst + idx[0], x);
+ Vector128 rotated = Rotate(in x);
+ Sse.StoreScalar(dst + idx[1], rotated);
+ rotated = Rotate(in rotated);
+ Sse.StoreScalar(dst + idx[2], rotated);
+ rotated = Rotate(in rotated);
+ Sse.StoreScalar(dst + idx[3], rotated);
+ }
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static Vector128 VectorSum(in Vector128 vector)
+ {
+ if (Sse3.IsSupported)
+ {
+ Vector128 partialSum = Sse3.HorizontalAdd(vector, vector);
+ return Sse3.HorizontalAdd(partialSum, partialSum);
+ }
+ else
+ {
+ Vector128 partialSum = Sse.Add(vector, Sse.MoveHighToLow(vector, vector));
+ // The control byte shuffles the four 32-bit floats of partialSum: ABCD -> BADC.
+ return Sse.Add(partialSum, Sse.Shuffle(partialSum, partialSum, 0xB1));
+ }
+ }
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static Vector128 VectorMax(in Vector128 vector)
+ {
+ Vector128 x1 = Sse.Shuffle(vector, vector, 0xB1);
+ Vector128 partialMax = Sse.Max(vector, x1);
+ x1 = Sse.Shuffle(partialMax, partialMax, 0x02);
+ return Sse.MaxScalar(partialMax, x1);
+ }
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static Vector128 GetAbsMask()
+ {
+ return Sse2.IsSupported ?
+ Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) :
+ Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF));
+ }
+
+ [MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
+ private static Vector128 GetNewDst(in Vector128 xDst1, in Vector128 signMask, in Vector128 xThreshold)
+ {
+ Vector128 xSign = Sse.And(xDst1, signMask); // result = 0x8000 0000 if xDst1 is negative or 0x0000 0000 otherwise
+ Vector128 xDst1Abs = Sse.Xor(xDst1, xSign);
+ Vector128 xCond = Sse.CompareGreaterThan(xDst1Abs, xThreshold); // result = 0xFFFF FFFF if true
+ Vector128 x2 = Sse.Xor(xSign, xThreshold); // -xThreshold if xDst1 is negative and +xThreshold otherwise
+ return Sse.And(Sse.Subtract(xDst1, x2), xCond);
+ }
+
+ // Multiply matrix times vector into vector.
+ internal static unsafe void MatMulA(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
+ {
+ Contracts.Assert(Compat(mat));
+ Contracts.Assert(Compat(src));
+ Contracts.Assert(Compat(dst));
+
+ fixed (float* pSrcStart = &src.Items[0])
+ fixed (float* pDstStart = &dst.Items[0])
+ fixed (float* pMatStart = &mat.Items[0])
+ {
+ float* psrc = Ptr(src, pSrcStart);
+ float* pdst = Ptr(dst, pDstStart);
+ float* pmat = Ptr(mat, pMatStart);
+
+ float* pSrcEnd = psrc + ccol;
+ float* pDstEnd = pdst + crow;
+ float* pDstCurrent = pdst;
+ float* pMatCurrent = pmat;
+
+ while (pDstCurrent < pDstEnd)
+ {
+ Vector128 res0 = Sse.SetZeroVector128();
+ Vector128 res1 = res0;
+ Vector128 res2 = res0;
+ Vector128 res3 = res0;
+
+ float* pSrcCurrent = psrc;
+
+ while (pSrcCurrent < pSrcEnd)
+ {
+ float* pMatTemp = pMatCurrent;
+
+ Vector128 x01 = Sse.LoadAlignedVector128(pMatTemp);
+ Vector128 x11 = Sse.LoadAlignedVector128(pMatTemp += ccol);
+ Vector128 x21 = Sse.LoadAlignedVector128(pMatTemp += ccol);
+ Vector128 x31 = Sse.LoadAlignedVector128(pMatTemp += ccol);
+ Vector128 x02 = Sse.LoadAlignedVector128(pSrcCurrent);
+
+ res0 = Sse.Add(res0, Sse.Multiply(x01, x02));
+ res1 = Sse.Add(res1, Sse.Multiply(x11, x02));
+ res2 = Sse.Add(res2, Sse.Multiply(x21, x02));
+ res3 = Sse.Add(res3, Sse.Multiply(x31, x02));
+
+ pSrcCurrent += 4;
+ pMatCurrent += 4;
+ }
+
+ // Add up the entries of each, with the 4 results in res0
+ res0 = Sse3.HorizontalAdd(res0, res1);
+ res2 = Sse3.HorizontalAdd(res2, res3);
+ res0 = Sse3.HorizontalAdd(res0, res2);
+
+ if (add)
+ {
+ res0 = Sse.Add(res0, Sse.LoadAlignedVector128(pDstCurrent));
+ }
+ Sse.StoreAligned(pDstCurrent, res0);
+
+ pDstCurrent += 4;
+ pMatCurrent += 3 * ccol;
+ }
+ }
+ }
+
+ // Partial sparse source vector.
+ internal static unsafe void MatMulPA(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src,
+ int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
+ {
+ Contracts.Assert(Compat(mat));
+ Contracts.Assert(Compat(src));
+ Contracts.Assert(Compat(dst));
+
+ // REVIEW: For extremely sparse inputs, interchanging the loops would
+ // likely be more efficient.
+ fixed (float* pSrcStart = &src.Items[0])
+ fixed (float* pDstStart = &dst.Items[0])
+ fixed (float* pMatStart = &mat.Items[0])
+ fixed (int* pposSrc = &rgposSrc[0])
+ {
+ float* psrc = Ptr(src, pSrcStart);
+ float* pdst = Ptr(dst, pDstStart);
+ float* pmat = Ptr(mat, pMatStart);
+
+ int* pposMin = pposSrc + iposMin;
+ int* pposEnd = pposSrc + iposEnd;
+ float* pDstEnd = pdst + crow;
+ float* pm0 = pmat - posMin;
+ float* pSrcCurrent = psrc - posMin;
+ float* pDstCurrent = pdst;
+
+ while (pDstCurrent < pDstEnd)
+ {
+ float* pm1 = pm0 + ccol;
+ float* pm2 = pm1 + ccol;
+ float* pm3 = pm2 + ccol;
+ Vector128 result = Sse.SetZeroVector128();
+
+ int* ppos = pposMin;
+
+ while (ppos < pposEnd)
+ {
+ int col = *ppos;
+ Vector128 x1 = Sse.SetVector128(pm3[col], pm2[col], pm1[col], pm0[col]);
+ Vector128 x2 = Sse.SetAllVector128(pSrcCurrent[col]);
+ x2 = Sse.Multiply(x2, x1);
+ result = Sse.Add(result, x2);
+
+ ppos++;
+ }
+
+ if (add)
+ {
+ result = Sse.Add(result, Sse.LoadAlignedVector128(pDstCurrent));
+ }
+ Sse.StoreAligned(pDstCurrent, result);
+
+ pDstCurrent += 4;
+ pm0 += 4 * ccol;
+ }
+ }
+ }
+
+ internal static unsafe void MatMulTranA(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
+ {
+ Contracts.Assert(Compat(mat));
+ Contracts.Assert(Compat(src));
+ Contracts.Assert(Compat(dst));
+
+ fixed (float* pSrcStart = &src.Items[0])
+ fixed (float* pDstStart = &dst.Items[0])
+ fixed (float* pMatStart = &mat.Items[0])
+ {
+ float* psrc = Ptr(src, pSrcStart);
+ float* pdst = Ptr(dst, pDstStart);
+ float* pmat = Ptr(mat, pMatStart);
+
+ float* pSrcEnd = psrc + ccol;
+ float* pDstEnd = pdst + crow;
+ float* pSrcCurrent = psrc;
+ float* pMatCurrent = pmat;
+
+ if (!add)
+ {
+ Vector128 x01 = Sse.LoadAlignedVector128(pSrcCurrent);
+ // Replicate each 32-bit slot of x01 (ABCD) into its own register.
+ Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B
+ Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C
+ Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D
+ x01 = Sse.Shuffle(x01, x01, 0x00); // A
+
+ pSrcCurrent += 4;
+
+ float* pDstCurrent = pdst;
+
+ while (pDstCurrent < pDstEnd)
+ {
+ float* pMatTemp = pMatCurrent;
+ Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp);
+ Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow);
+ Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow);
+ Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow);
+
+ x02 = Sse.Multiply(x01, x02);
+ x12 = Sse.Multiply(x11, x12);
+ x22 = Sse.Multiply(x21, x22);
+ x32 = Sse.Multiply(x31, x32);
+
+ x02 = Sse.Add(x02, x12);
+ x22 = Sse.Add(x22, x32);
+ x02 = Sse.Add(x02, x22);
+
+ Sse.StoreAligned(pDstCurrent, x02);
+
+ pDstCurrent += 4;
+ pMatCurrent += 4;
+ }
+
+ pMatCurrent += 3 * crow;
+ }
+
+ while (pSrcCurrent < pSrcEnd)
+ {
+ Vector128 x01 = Sse.LoadAlignedVector128(pSrcCurrent);
+ // Replicate each 32-bit slot of x01 (ABCD) into its own register.
+ Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B
+ Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C
+ Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D
+ x01 = Sse.Shuffle(x01, x01, 0x00); // A
+
+ float* pDstCurrent = pdst;
+
+ while (pDstCurrent < pDstEnd)
+ {
+ float* pMatTemp = pMatCurrent;
+
+ Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp);
+ Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow);
+ Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow);
+ Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow);
+ Vector128 x3 = Sse.LoadAlignedVector128(pDstCurrent);
+
+ x02 = Sse.Multiply(x01, x02);
+ x12 = Sse.Multiply(x11, x12);
+ x22 = Sse.Multiply(x21, x22);
+ x32 = Sse.Multiply(x31, x32);
+
+ x02 = Sse.Add(x02, x12);
+ x22 = Sse.Add(x22, x32);
+ x02 = Sse.Add(x02, x22);
+ x3 = Sse.Add(x02, x3);
+
+ Sse.StoreAligned(pDstCurrent, x3);
+
+ pDstCurrent += 4;
+ pMatCurrent += 4;
+ }
+
+ pMatCurrent += 3 * crow;
+ pSrcCurrent += 4;
+ }
+ }
+ }
+
+ // Partial sparse source vector.
+ internal static unsafe void MatMulTranPA(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src,
+ int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow)
+ {
+ Contracts.Assert(Compat(mat));
+ Contracts.Assert(Compat(src));
+ Contracts.Assert(Compat(dst));
+
+ fixed (float* pSrcStart = &src.Items[0])
+ fixed (float* pDstStart = &dst.Items[0])
+ fixed (float* pMatStart = &mat.Items[0])
+ fixed (int* pposSrc = &rgposSrc[0])
+ {
+ float* psrc = Ptr(src, pSrcStart);
+ float* pdst = Ptr(dst, pDstStart);
+ float* pmat = Ptr(mat, pMatStart);
+
+ int* ppos = pposSrc + iposMin;
+ int* pposEnd = pposSrc + iposEnd;
+ float* pDstEnd = pdst + crow;
+
+ if (!add)
+ {
+ int col = *ppos - posMin;
+ ppos++;
+
+ Vector128 x0 = Sse.SetAllVector128(psrc[col]);
+ float* pDstCurrent = pdst;
+ float* pMatCurrent = pmat + col * crow;
+
+ while (pDstCurrent < pDstEnd)
+ {
+ Vector128 x1 = Sse.LoadAlignedVector128(pMatCurrent);
+ x1 = Sse.Multiply(x1, x0);
+ Sse.StoreAligned(pDstCurrent, x1);
+
+ pDstCurrent += 4;
+ pMatCurrent += 4;
+ }
+ }
+
+ // REVIEW: Should we explore unrolling the outer loop?
+ while (ppos < pposEnd)
+ {
+ int col = *ppos - posMin;
+
+ Vector128 x0 = Sse.SetAllVector128(psrc[col]);
+ float* pDstCurrent = pdst;
+ float* pMatCurrent = pmat + col * crow;
+
+ while (pDstCurrent < pDstEnd)
+ {
+ Vector128 x1 = Sse.LoadAlignedVector128(pMatCurrent);
+ Vector128 x2 = Sse.LoadAlignedVector128(pDstCurrent);
+ x1 = Sse.Multiply(x1, x0);
+ x2 = Sse.Add(x2, x1);
+ Sse.StoreAligned(pDstCurrent, x2);
+
+ pDstCurrent += 4;
+ pMatCurrent += 4;
+ }
+
+ ppos++;
+ }
+ }
+ }
+
+ // dst[i] += scale
+ internal static unsafe void AddScalarU(float scale, Span dst)
+ {
+ fixed (float* pdst = dst)
+ {
+ float* pDstEnd = pdst + dst.Length;
+ float* pDstCurrent = pdst;
+
+ Vector128