Skip to content

Commit

Permalink
added catalog extensions and moved tensorflow arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
artidoro committed Feb 12, 2019
1 parent 8a3df28 commit 35e93ef
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ private void MapDist(in VBuffer<float> src, ref float score, ref float prob)
}

/// <summary>
/// Learns the prior distribution for 0/1 class labels and just outputs that.
/// Learns the prior distribution for 0/1 class labels and outputs that.
/// </summary>
public sealed class PriorTrainer : TrainerBase<PriorModelParameters>,
ITrainerEstimator<BinaryPredictionTransformer<PriorModelParameters>, PriorModelParameters>
Expand Down Expand Up @@ -263,8 +263,8 @@ internal PriorTrainer(IHostEnvironment env, Options options)
/// <summary>
/// Initializes PriorTrainer object.
/// </summary>
internal PriorTrainer(IHost host, String labelColumn, String weightColunn = null)
: base(host, LoadNameValue)
internal PriorTrainer(IHostEnvironment env, String labelColumn, String weightColunn = null)
: base(env, LoadNameValue)
{
Contracts.CheckValue(labelColumn, nameof(labelColumn));
Contracts.CheckValueOrNull(weightColunn);
Expand Down
33 changes: 16 additions & 17 deletions src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -558,28 +558,27 @@ public static LinearSvmTrainer LinearSupportVectorMachines(this BinaryClassifica
}

/// <summary>
/// Predict a target using a linear binary classification model trained with the <see cref="LinearSvmTrainer"/> trainer.
/// Predict a target using a random binary classification model <see cref="RandomTrainer"/>.
/// </summary>
/// <remarks>
/// <para>
/// The idea behind support vector machines, is to map instances into a high dimensional space
/// in which the two classes are linearly separable, i.e., there exists a hyperplane such that all the positive examples are on one side of it,
/// and all the negative examples are on the other.
/// </para>
/// <para>
/// After this mapping, quadratic programming is used to find the separating hyperplane that maximizes the
/// margin, i.e., the minimal distance between it and the instances.
/// </para>
/// </remarks>
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
/// <param name="options">Advanced arguments to the algorithm.</param>
public static LinearSvmTrainer LinearSupportVectorMachines(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
LinearSvmTrainer.Options options)
public static RandomTrainer Random(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog)
{
Contracts.CheckValue(catalog, nameof(catalog));
Contracts.CheckValue(options, nameof(options));
return new RandomTrainer(CatalogUtils.GetEnvironment(catalog), new RandomTrainer.Options());
}

return new LinearSvmTrainer(CatalogUtils.GetEnvironment(catalog), options);
/// <summary>
/// Predict a target using a binary classification model trained with <see cref="PriorTrainer"/> trainer.
/// </summary>
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
/// <param name="labelColumn">The name of the label column. </param>
/// <param name="weightsColumn">The optional name of the weights column.</param>
public static PriorTrainer Prior(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
string labelColumn = DefaultColumnNames.Label,
string weightsColumn = null)
{
Contracts.CheckValue(catalog, nameof(catalog));
return new PriorTrainer(CatalogUtils.GetEnvironment(catalog), labelColumn, weightsColumn);
}
}
}
14 changes: 7 additions & 7 deletions src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,23 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames, tensorFlowModel);

/// <summary>
/// Score or Retrain a tensorflow model (based on setting of the <see cref="TensorFlowTransformer.Options.ReTrain"/>) setting.
/// The model is specified in the <see cref="TensorFlowTransformer.Options.ModelLocation"/>.
/// Score or Retrain a tensorflow model (based on setting of the <see cref="TensorFlowEstimator.Options.ReTrain"/>) setting.
/// The model is specified in the <see cref="TensorFlowEstimator.Options.ModelLocation"/>.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
TensorFlowTransformer.Options options)
TensorFlowEstimator.Options options)
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options);

/// <summary>
/// Scores or retrains (based on setting of the <see cref="TensorFlowTransformer.Options.ReTrain"/>) a pre-traiend TensorFlow model specified via <paramref name="tensorFlowModel"/>.
/// Scores or retrains (based on setting of the <see cref="TensorFlowEstimator.Options.ReTrain"/>) a pre-traiend TensorFlow model specified via <paramref name="tensorFlowModel"/>.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The <see cref="TensorFlowTransformer.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model.</param>
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
TensorFlowTransformer.Options options,
TensorFlowEstimator.Options options,
TensorFlowModelInfo tensorFlowModel)
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel);
}
Expand Down
Loading

0 comments on commit 35e93ef

Please sign in to comment.