Skip to content

Commit

Permalink
merging ML.ColumnOptions and ML.InputOutputColumnPair
Browse files Browse the repository at this point in the history
  • Loading branch information
artidoro committed Mar 22, 2019
1 parent b4c3ea0 commit 802e4de
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 95 deletions.
2 changes: 1 addition & 1 deletion docs/code/MlNetCookBook.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ var pipeline =
// Use the multi-class SDCA model to predict the label using features.
.Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated())
// Apply the inverse conversion from 'PredictedLabel' column back to string value.
.Append(mlContext.Transforms.Conversion.MapKeyToValue(("PredictedLabel", "Data")));
.Append(mlContext.Transforms.Conversion.MapKeyToValue("Data", "PredictedLabel"));

// Train the model.
var model = pipeline.Fit(trainData);
Expand Down
34 changes: 9 additions & 25 deletions src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,6 @@ public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.Co
public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.ConversionTransforms catalog, InputOutputColumnPair[] columns)
=> new KeyToValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns.Select(x => (x.OutputColumnName, x.InputColumnName)).ToArray());

/// <summary>
/// Convert the key types (name of the column specified in the first item of the tuple) back to their original values
/// (named as specified in the second item of the tuple).
/// </summary>
/// <param name="catalog">The conversion transform's catalog</param>
/// <param name="columns">The pairs of input and output columns.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[KeyToValueMappingEstimator](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/ValueMappingStringToKeyType.cs)]
/// ]]></format>
/// </example>
[BestFriend]
internal static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.ConversionTransforms catalog, params ColumnOptions[] columns)
=> new KeyToValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), ColumnOptions.ConvertToValueTuples(columns));

/// <summary>
/// Maps key types or key values into a floating point vector.
/// </summary>
Expand Down Expand Up @@ -218,7 +202,7 @@ public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.Co
}

/// <summary>
/// Converts value types into <see cref="KeyType"/>, optionally loading the keys to use from <paramref name="keyData"/>.
/// Converts value types into <see cref="KeyDataViewType"/>, optionally loading the keys to use from <paramref name="keyData"/>.
/// </summary>
/// <param name="catalog">The conversion transform's catalog.</param>
/// <param name="columns">The data columns to map to keys.</param>
Expand Down Expand Up @@ -292,11 +276,11 @@ public static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputType
internal static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputType, TOutputType>(
this TransformsCatalog.ConversionTransforms catalog,
IEnumerable<KeyValuePair<TInputType, TOutputType>> keyValuePairs,
params ColumnOptions[] columns)
params InputOutputColumnPair[] columns)
{
var keys = keyValuePairs.Select(pair => pair.Key);
var values = keyValuePairs.Select(pair => pair.Value);
return new ValueMappingEstimator<TInputType, TOutputType>(CatalogUtils.GetEnvironment(catalog), keys, values, ColumnOptions.ConvertToValueTuples(columns));
return new ValueMappingEstimator<TInputType, TOutputType>(CatalogUtils.GetEnvironment(catalog), keys, values, InputOutputColumnPair.ConvertToValueTuples(columns));
}

/// <summary>
Expand All @@ -320,12 +304,12 @@ internal static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputTy
this TransformsCatalog.ConversionTransforms catalog,
IEnumerable<KeyValuePair<TInputType, TOutputType>> keyValuePairs,
bool treatValuesAsKeyType,
params ColumnOptions[] columns)
params InputOutputColumnPair[] columns)
{
var keys = keyValuePairs.Select(pair => pair.Key);
var values = keyValuePairs.Select(pair => pair.Value);
return new ValueMappingEstimator<TInputType, TOutputType>(CatalogUtils.GetEnvironment(catalog), keys, values, treatValuesAsKeyType,
ColumnOptions.ConvertToValueTuples(columns));
InputOutputColumnPair.ConvertToValueTuples(columns));
}

/// <summary>
Expand Down Expand Up @@ -381,12 +365,12 @@ public static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputType
internal static ValueMappingEstimator<TInputType, TOutputType> MapValue<TInputType, TOutputType>(
this TransformsCatalog.ConversionTransforms catalog,
IEnumerable<KeyValuePair<TInputType, TOutputType[]>> keyValuePairs,
params ColumnOptions[] columns)
params InputOutputColumnPair[] columns)
{
var keys = keyValuePairs.Select(pair => pair.Key);
var values = keyValuePairs.Select(pair => pair.Value);
return new ValueMappingEstimator<TInputType, TOutputType>(CatalogUtils.GetEnvironment(catalog), keys, values,
ColumnOptions.ConvertToValueTuples(columns));
InputOutputColumnPair.ConvertToValueTuples(columns));
}

/// <summary>
Expand Down Expand Up @@ -437,8 +421,8 @@ public static ValueMappingEstimator MapValue(
[BestFriend]
internal static ValueMappingEstimator MapValue(
this TransformsCatalog.ConversionTransforms catalog,
IDataView lookupMap, DataViewSchema.Column keyColumn, DataViewSchema.Column valueColumn, params ColumnOptions[] columns)
IDataView lookupMap, DataViewSchema.Column keyColumn, DataViewSchema.Column valueColumn, params InputOutputColumnPair[] columns)
=> new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), lookupMap, keyColumn.Name, valueColumn.Name,
ColumnOptions.ConvertToValueTuples(columns));
InputOutputColumnPair.ConvertToValueTuples(columns));
}
}
37 changes: 4 additions & 33 deletions src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,40 +32,11 @@ public InputOutputColumnPair(string outputColumnName, string inputColumnName = n
InputColumnName = inputColumnName ?? outputColumnName;
OutputColumnName = outputColumnName;
}
}

/// <summary>
/// Specifies input and output column names for a transformation.
/// </summary>
[BestFriend]
internal sealed class ColumnOptions
{
private readonly string _outputColumnName;
private readonly string _inputColumnName;

/// <summary>
/// Specifies input and output column names for a transformation.
/// </summary>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
public ColumnOptions(string outputColumnName, string inputColumnName = null)
{
_outputColumnName = outputColumnName;
_inputColumnName = inputColumnName ?? outputColumnName;
}

/// <summary>
/// Instantiates a <see cref="ColumnOptions"/> from a tuple of input and output column names.
/// </summary>
public static implicit operator ColumnOptions((string outputColumnName, string inputColumnName) value)
{
return new ColumnOptions(value.outputColumnName, value.inputColumnName);
}

[BestFriend]
internal static (string outputColumnName, string inputColumnName)[] ConvertToValueTuples(ColumnOptions[] infos)
internal static (string outputColumnName, string inputColumnName)[] ConvertToValueTuples(InputOutputColumnPair[] infos)
{
return infos.Select(info => (info._outputColumnName, info._inputColumnName)).ToArray();
return infos.Select(info => (info.OutputColumnName, info.InputColumnName)).ToArray();
}
}

Expand Down Expand Up @@ -104,8 +75,8 @@ public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog,
/// </format>
/// </example>
[BestFriend]
internal static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, params ColumnOptions[] columns)
=> new ColumnCopyingEstimator(CatalogUtils.GetEnvironment(catalog), ColumnOptions.ConvertToValueTuples(columns));
internal static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, params InputOutputColumnPair[] columns)
=> new ColumnCopyingEstimator(CatalogUtils.GetEnvironment(catalog), InputOutputColumnPair.ConvertToValueTuples(columns));

/// <summary>
/// Concatenates columns together.
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ public static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalo
/// ]]></format>
/// </example>
[BestFriend]
internal static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalog catalog, params ColumnOptions[] columns)
=> new ImageGrayscalingEstimator(CatalogUtils.GetEnvironment(catalog), ColumnOptions.ConvertToValueTuples(columns));
internal static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalog catalog, params InputOutputColumnPair[] columns)
=> new ImageGrayscalingEstimator(CatalogUtils.GetEnvironment(catalog), InputOutputColumnPair.ConvertToValueTuples(columns));

/// <summary>
/// Loads the images from the <see cref="ImageLoadingTransformer.ImageFolder" /> into memory.
Expand Down Expand Up @@ -80,8 +80,8 @@ public static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog, s
/// ]]></format>
/// </example>
[BestFriend]
internal static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog, string imageFolder, params ColumnOptions[] columns)
=> new ImageLoadingEstimator(CatalogUtils.GetEnvironment(catalog), imageFolder, ColumnOptions.ConvertToValueTuples(columns));
internal static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog, string imageFolder, params InputOutputColumnPair[] columns)
=> new ImageLoadingEstimator(CatalogUtils.GetEnvironment(catalog), imageFolder, InputOutputColumnPair.ConvertToValueTuples(columns));

/// <include file='doc.xml' path='doc/members/member[@name="ImagePixelExtractingEstimator"]/*' />
/// <param name="catalog">The transform's catalog.</param>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/ConversionsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ public static class ConversionsCatalog
/// <param name="columns">Specifies the output and input columns on which the transformation should be applied.</param>
[BestFriend]
internal static KeyToBinaryVectorMappingEstimator MapKeyToBinaryVector(this TransformsCatalog.ConversionTransforms catalog,
params ColumnOptions[] columns)
=> new KeyToBinaryVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), ColumnOptions.ConvertToValueTuples(columns));
params InputOutputColumnPair[] columns)
=> new KeyToBinaryVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), InputOutputColumnPair.ConvertToValueTuples(columns));

/// <summary>
/// Convert the key types back to binary vector.
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/NormalizerCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public static NormalizingEstimator Normalize(this TransformsCatalog catalog,
[BestFriend]
internal static NormalizingEstimator Normalize(this TransformsCatalog catalog,
NormalizingEstimator.NormalizationMode mode,
params ColumnOptions[] columns)
=> new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), mode, ColumnOptions.ConvertToValueTuples(columns));
params InputOutputColumnPair[] columns)
=> new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), mode, InputOutputColumnPair.ConvertToValueTuples(columns));

/// <summary>
/// Normalize (rescale) columns according to specified custom parameters.
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Text/TextCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ public static TokenizingByCharactersEstimator TokenizeIntoCharactersAsKeys(this
[BestFriend]
internal static TokenizingByCharactersEstimator TokenizeIntoCharactersAsKeys(this TransformsCatalog.TextTransforms catalog,
bool useMarkerCharacters = CharTokenizingDefaults.UseMarkerCharacters,
params ColumnOptions[] columns)
=> new TokenizingByCharactersEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), useMarkerCharacters, ColumnOptions.ConvertToValueTuples(columns));
params InputOutputColumnPair[] columns)
=> new TokenizingByCharactersEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), useMarkerCharacters, InputOutputColumnPair.ConvertToValueTuples(columns));

/// <summary>
/// Normalizes incoming text in <paramref name="inputColumnName"/> by changing case, removing diacritical marks, punctuation marks and/or numbers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ private ITransformer TrainOnIris(string irisDataPath)
// [2] -9.709775 float

// Apply the inverse conversion from 'PredictedLabel' column back to string value.
var finalPipeline = pipeline.Append(mlContext.Transforms.Conversion.MapKeyToValue(("Data", "PredictedLabel")));
var finalPipeline = pipeline.Append(mlContext.Transforms.Conversion.MapKeyToValue("Data", "PredictedLabel"));
dataPreview = finalPipeline.Preview(trainData);

return finalPipeline.Fit(trainData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest()
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.MulticlassClassification.Trainers.SdcaCalibrated(
new SdcaCalibratedMulticlassTrainer.Options { NumberOfThreads = 1 }))
.Append(mlContext.Transforms.Conversion.MapKeyToValue(("Plant", "PredictedLabel")));
.Append(mlContext.Transforms.Conversion.MapKeyToValue("Plant", "PredictedLabel"));

// Train the pipeline
var trainedModel = pipe.Fit(trainData);
Expand Down
Loading

0 comments on commit 802e4de

Please sign in to comment.