Skip to content

Commit

Permalink
Predict parameter values in the suggestion (#12984)
Browse files Browse the repository at this point in the history
* Get the parameter value from the history.

* Add a mock ps console for testing purpose.

- The mock ps console will echo back most of the commands. So that we
  don't need to really execute the Az command on Azure to test the
  prediction.
  • Loading branch information
kceiw authored Sep 18, 2020
1 parent 7893ac5 commit 57cab6d
Show file tree
Hide file tree
Showing 14 changed files with 465 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ public AzPredictorServiceTests(ModelFixture fixture)
{
this._fixture = fixture;
var startHistory = $"{AzPredictorConstants.CommandHistoryPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandHistoryPlaceholder}";
this._suggestionsPredictor = new Predictor(this._fixture.PredictionCollection[startHistory]);
this._commandsPredictor = new Predictor(this._fixture.CommandCollection);
this._suggestionsPredictor = new Predictor(this._fixture.PredictionCollection[startHistory], null);
this._commandsPredictor = new Predictor(this._fixture.CommandCollection, null);

this._service = new MockAzPredictorService(startHistory, this._fixture.PredictionCollection[startHistory], this._fixture.CommandCollection);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public MockAzPredictorService(string history, IList<string> suggestions, IList<s
}

/// <inheritdoc/>
public override void RequestPredictions(string history)
public override void RequestPredictions(IEnumerable<string> history)
{
this.IsPredictionRequested = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public PredictorTests(ModelFixture fixture)
{
this._fixture = fixture;
var startHistory = $"{AzPredictorConstants.CommandHistoryPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandHistoryPlaceholder}";
this._predictor = new Predictor(this._fixture.PredictionCollection[startHistory]);
this._predictor = new Predictor(this._fixture.PredictionCollection[startHistory], null);
}

/// <summary>
Expand Down
8 changes: 7 additions & 1 deletion tools/Az.Tools.Predictor/Az.Tools.Predictor.sln
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ VisualStudioVersion = 16.0.30426.262
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor", "Az.Tools.Predictor\Az.Tools.Predictor.csproj", "{E4A5F697-086C-4908-B90E-A31EE47ECF13}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Az.Tools.Predictor.Test", "Az.Tools.Predictor.Test\Az.Tools.Predictor.Test.csproj", "{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor.Test", "Az.Tools.Predictor.Test\Az.Tools.Predictor.Test.csproj", "{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MockPSConsole", "MockPSConsole\MockPSConsole.csproj", "{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand All @@ -21,6 +23,10 @@ Global
{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}.Debug|Any CPU.Build.0 = Debug|Any CPU
{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}.Release|Any CPU.ActiveCfg = Release|Any CPU
{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}.Release|Any CPU.Build.0 = Release|Any CPU
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Debug|Any CPU.Build.0 = Debug|Any CPU
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Release|Any CPU.ActiveCfg = Release|Any CPU
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
70 changes: 37 additions & 33 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ public sealed class AzPredictor : ICommandPredictor
private const int SuggestionCountForTelemetry = 5;
private const string ParameterValueMask = "***";
private const char ParameterValueSeperator = ':';
private const char ParameterIndicator = '-';

private static readonly string[] CommonParameters = new string[] { "location" };

Expand All @@ -74,45 +73,45 @@ public void StartEarlyProcessing(IReadOnlyList<string> history)
{
if (history.Count > 0)
{
var historyLines = history.TakeLast(AzPredictorConstants.CommandHistoryCountToProcess).ToList();
var historyLines = history.TakeLast(AzPredictorConstants.CommandHistoryCountToProcess);

while (historyLines.Count < AzPredictorConstants.CommandHistoryCountToProcess)
while (historyLines.Count() < AzPredictorConstants.CommandHistoryCountToProcess)
{
historyLines.Insert(0, AzPredictorConstants.CommandHistoryPlaceholder);
historyLines = historyLines.Prepend(AzPredictorConstants.CommandHistoryPlaceholder);
}

for (int i = historyLines.Count - 1; i >= 0; --i)
{
var ast = Parser.ParseInput(historyLines[i], out Token[] tokens, out _);
var commandAsts = ast.FindAll((ast) => ast is CommandAst, true);
var commandAsts = historyLines.Select((h) =>
{
var ast = Parser.ParseInput(h, out Token[] tokens, out _);
var allAsts = ast?.FindAll((ast) => ast is CommandAst, true);
return allAsts?.LastOrDefault() as CommandAst;
}).ToArray();

if (!commandAsts.Any())
{
historyLines[i] = AzPredictorConstants.CommandHistoryPlaceholder;
continue;
}
var maskedHistoryLines = commandAsts.Select((c) =>
{
var commandName = c?.CommandElements?.FirstOrDefault().ToString();

var lastCommandAst = commandAsts.Last() as CommandAst;
var lastCommand = lastCommandAst?.CommandElements?.FirstOrDefault()?.ToString();
if (!_service.IsSupportedCommand(commandName))
{
return AzPredictorConstants.CommandHistoryPlaceholder;
}

if (string.IsNullOrWhiteSpace(lastCommand) || !_service.IsSupportedCommand(lastCommand))
{
historyLines[i] = AzPredictorConstants.CommandHistoryPlaceholder;
continue;
}
return AzPredictor.MaskCommandLine(c);
});

historyLines[i] = MaskCommandLine(lastCommandAst);
var lastMaskedHistoryLines = maskedHistoryLines.Last();

if (i == historyLines.Count - 1)
{
var suggestionIndex = _service.GetRankOfSuggestion(lastCommandAst, ast);
var fallbackIndex = _service.GetRankOfFallback(lastCommandAst, ast);
var topFiveSuggestion = _service.GetTopNSuggestions(AzPredictor.SuggestionCountForTelemetry);
_telemetryClient.OnSuggestionForHistory(historyLines[i], suggestionIndex, fallbackIndex, topFiveSuggestion);
}
if (lastMaskedHistoryLines != AzPredictorConstants.CommandHistoryPlaceholder)
{
var commandName = (commandAsts.LastOrDefault()?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value;
var suggestionIndex = _service.GetRankOfSuggestion(commandName);
var fallbackIndex = _service.GetRankOfFallback(commandName);
var topFiveSuggestion = _service.GetTopNSuggestions(AzPredictor.SuggestionCountForTelemetry);
_telemetryClient.OnSuggestionForHistory(maskedHistoryLines.LastOrDefault(), suggestionIndex, fallbackIndex, topFiveSuggestion);
}

_service.RequestPredictions(String.Join(AzPredictorConstants.CommandConcatenator, historyLines));
_service.RecordHistory(commandAsts);
_service.RequestPredictions(maskedHistoryLines);
}
}

Expand Down Expand Up @@ -175,7 +174,12 @@ private static string MergeStrings(string a, string b)
/// <param name="cmdAst">The last user input command</param>
private static string MaskCommandLine(CommandAst cmdAst)
{
var commandElements = cmdAst.CommandElements;
var commandElements = cmdAst?.CommandElements;

if (commandElements == null)
{
return null;
}

if (commandElements.Count == 1)
{
Expand All @@ -196,15 +200,15 @@ private static string MaskCommandLine(CommandAst cmdAst)
if (param.Argument != null)
{
// Parameter is in the form of `-Name:name`
_ = sb.Append(AzPredictor.ParameterIndicator)
_ = sb.Append(AzPredictorConstants.ParameterIndicator)
.Append(param.ParameterName)
.Append(AzPredictor.ParameterValueSeperator)
.Append(AzPredictor.ParameterValueMask);
}
else
{
// Parameter is in the form of `-Name`
_ = sb.Append(AzPredictor.ParameterIndicator)
_ = sb.Append(AzPredictorConstants.ParameterIndicator)
.Append(param.ParameterName)
.Append(AzPredictorConstants.CommandParameterSeperator)
.Append(AzPredictor.ParameterValueMask);
Expand All @@ -223,8 +227,8 @@ public class PredictorInitializer : IModuleAssemblyInitializer
public void OnImport()
{
var settings = Settings.GetSettings();
var azPredictorService = new AzPredictorService(settings.ServiceUri);
var telemetryClient = new AzPredictorTelemetryClient();
var azPredictorService = new AzPredictorService(settings.ServiceUri);
var predictor = new AzPredictor(azPredictorService, telemetryClient);
SubsystemManager.RegisterSubsystem<ICommandPredictor, AzPredictor>(predictor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal static class AzPredictorConstants
/// <summary>
/// The value to check to determine if it's an Az command.
/// </summary>
public const string AzCommandMoniktor = "az";
public const string AzCommandMoniker = "-Az";

/// <summary>
/// The character to use when we join the commands together.
Expand All @@ -54,6 +54,11 @@ internal static class AzPredictorConstants
/// </summary>
public const char CommandParameterSeperator = ' ';

/// <summary>
/// The character that begins a parameter.
/// </summary>
public const char ParameterIndicator = '-';

/// <summary>
/// The setting file name.
/// </summary>
Expand Down
51 changes: 30 additions & 21 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
// limitations under the License.
// ----------------------------------------------------------------------------------

using Microsoft.WindowsAzure.Commands.Utilities.Common;
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation.Language;
Expand Down Expand Up @@ -48,16 +50,15 @@ public sealed class RequestContext
public PredictionRequestBody(string history) => this.History = history;
};

private const int PredictionRequestInProgress = 1;
private const int PredictionRequestNotInProgress = 0;
private static readonly HttpClient _client = new HttpClient();
private readonly string _commandsEndpoint;
private readonly string _predictionsEndpoint;
private volatile Tuple<string, Predictor> _historySuggestions; // The history and the prediction for that.
private volatile Predictor _commands;
private volatile string _history;
private HashSet<string> _commandSet = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
private HashSet<string> _commandSet;
private CancellationTokenSource _predictionRequestCancellationSource;
private ParameterValuePredictor _parameterValuePredictor = new ParameterValuePredictor();

/// <summary>
/// The AzPredictor service interacts with the Aladdin service specified in serviceUri.
Expand Down Expand Up @@ -144,53 +145,58 @@ public Tuple<string, PredictionSource> GetSuggestion(Ast input, CancellationToke
}

/// <inheritdoc/>
public virtual void RequestPredictions(string history)
public virtual void RequestPredictions(IEnumerable<string> history)
{
// Even if it's called multiple times, we only need to keep the one for the latest history.

this._predictionRequestCancellationSource?.Cancel();
this._predictionRequestCancellationSource = new CancellationTokenSource();
var cancellationToken = this._predictionRequestCancellationSource.Token;
this._history = history;
var localHistory = string.Join(AzPredictorConstants.CommandConcatenator, history);
this._history = localHistory;

// We don't need to block on the task. We send the HTTP request and update prediction list at the background.
Task.Run(async () => {
var requestBody = JsonConvert.SerializeObject(new PredictionRequestBody(history));
var requestBody = JsonConvert.SerializeObject(new PredictionRequestBody(localHistory));
var httpResponseMessage = await _client.PostAsync(this._predictionsEndpoint, new StringContent(requestBody, Encoding.UTF8, "application/json"), cancellationToken);

var reply = await httpResponseMessage.Content.ReadAsStringAsync(cancellationToken);
var suggestionsList = JsonConvert.DeserializeObject<List<string>>(reply);

this.SetSuggestionPredictor(history, suggestionsList);
this.SetSuggestionPredictor(localHistory, suggestionsList);
},
cancellationToken);
}

/// <summary>
/// For logging purposes, get the rank of the user input in the model suggestions list.
/// </summary>
public int? GetRankOfSuggestion(CommandAst command, Ast input)
/// <inheritdoc/>
public virtual void RecordHistory(IEnumerable<CommandAst> history)
{
history.ForEach((h) => this._parameterValuePredictor.ProcessHistoryCommand(h));
}

/// <inhericdoc/>
public int? GetRankOfSuggestion(string commandName)
{
var historySuggestions = this._historySuggestions;
return historySuggestions?.Item2?.GetCommandPrediction(command, input, CancellationToken.None).Item2;
return historySuggestions?.Item2?.GetCommandPrediction(commandName, isCommandNameComplete: true, cancellationToken:CancellationToken.None).Item2;
}

/// <inheritdoc/>
public int? GetRankOfFallback(CommandAst command, Ast input)
/// <inhericdoc/>
public int? GetRankOfFallback(string commandName)
{
var commands = this._commands;
return commands?.GetCommandPrediction(command, input, CancellationToken.None).Item2;
return commands?.GetCommandPrediction(commandName, isCommandNameComplete:true, cancellationToken:CancellationToken.None).Item2;
}

/// <inheritdoc/>
/// <inhericdoc/>
public IEnumerable<string> GetTopNSuggestions(int n)
{
var historySuggestions = this._historySuggestions;
return historySuggestions?.Item2?.GetTopNPrediction(n);
}

/// <inheritdoc/>
public bool IsSupportedCommand(string cmd) => !string.IsNullOrWhiteSpace(cmd) && _commandSet.Contains(cmd);
public bool IsSupportedCommand(string cmd) => !string.IsNullOrWhiteSpace(cmd) && (_commandSet?.Contains(cmd) == true);

/// <summary>
/// Requests a list of popular commands from service. These commands are used as fallback suggestion
Expand All @@ -209,7 +215,10 @@ protected virtual void RequestCommands()

// Initialize predictions
var startHistory = $"{AzPredictorConstants.CommandHistoryPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandHistoryPlaceholder}";
RequestPredictions(startHistory);
RequestPredictions(new string[] {
AzPredictorConstants.CommandHistoryPlaceholder,
AzPredictorConstants.CommandHistoryPlaceholder});

});
}

Expand All @@ -219,8 +228,8 @@ protected virtual void RequestCommands()
/// <param name="commands">The command collection to set the predictor</param>
protected void SetCommandsPredictor(IList<string> commands)
{
this._commands = new Predictor(commands);
this._commandSet = new HashSet<string>(commands.Select(x => AzPredictorService.GetCommandName(x))); // this could be slow
this._commands = new Predictor(commands, this._parameterValuePredictor);
this._commandSet = commands.Select(x => AzPredictorService.GetCommandName(x)).ToHashSet<string>(StringComparer.OrdinalIgnoreCase); // this could be slow

}

Expand All @@ -231,7 +240,7 @@ protected void SetCommandsPredictor(IList<string> commands)
/// <param name="suggestions">The suggestion collection to set the predictor</param>
protected void SetSuggestionPredictor(string history, IList<string> suggestions)
{
this._historySuggestions = Tuple.Create(history, new Predictor(suggestions));
this._historySuggestions = Tuple.Create(history, new Predictor(suggestions, this._parameterValuePredictor));
}

/// <summary>
Expand Down
16 changes: 11 additions & 5 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/IAzPredictorService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,29 @@ public interface IAzPredictorService
/// <summary>
/// Requests predictions, given a history string.
/// </summary>
/// <param name="history">A history string could look like: "Get-AzContext -Name NAME\nSet-AzContext"</param>
public void RequestPredictions(string history);
/// <param name="history">A list of history commands</param>
public void RequestPredictions(IEnumerable<string> history);

/// <summary>
/// For logging purposes, get the rank of the user input in the model suggestions list.
/// Record the history from PSReadLine.
/// </summary>
public int? GetRankOfSuggestion(CommandAst command, Ast input);
/// <param name="history">A list of history commands</param>
public void RecordHistory(IEnumerable<CommandAst> history);

/// <summary>
/// Return true if command is part of known set of Az cmdlets, false otherwise.
/// </summary>
public bool IsSupportedCommand(string cmd);

/// <summary>
/// For logging purposes, get the rank of the user input in the model suggestions list.
/// </summary>
public int? GetRankOfSuggestion(string commandName);

/// <summary>
/// For logging purposes, get the rank of the user input in the fallback commands cache.
/// </summary>
public int? GetRankOfFallback(CommandAst command, Ast input);
public int? GetRankOfFallback(string commandName);

/// <summary>
/// For logging purposes, get the top N suggestions from the model suggestions list.
Expand Down
Loading

0 comments on commit 57cab6d

Please sign in to comment.