-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AI Completions #5910
Merged
Merged
AI Completions #5910
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
f9e859f
checkpoint
kustosz e34c3d5
un-cleanups
kustosz b78afd0
final touches
kustosz e1cd873
Merge branch 'develop' into wip/mk/ai-stuff
kustosz 00fdce4
fmt
kustosz a922cda
lint
kustosz bbdcca2
fix tests
kustosz 78f72e5
ehhhh
kustosz a381513
Merge branch 'develop' into wip/mk/ai-stuff
kustosz bf58dbb
undo fmt
kustosz 3a4aa47
more fixes
kustosz eb5763a
fix compilation
kustosz 6cd55d8
Merge branch 'develop' into wip/mk/ai-stuff
kustosz 47b59e4
use edit_node_expression
kustosz 92404d5
Merge branch 'develop' into wip/mk/ai-stuff
kustosz 4fdccf6
fix lints
kustosz 6549b8f
fix more
kustosz bce1281
Fix compilation
vitvakatu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,10 +36,12 @@ pub mod breadcrumbs; | |
pub mod component; | ||
pub mod input; | ||
|
||
use crate::controller::graph::executed::Handle; | ||
use crate::model::execution_context::QualifiedMethodPointer; | ||
use crate::model::execution_context::Visualization; | ||
pub use action::Action; | ||
|
||
|
||
|
||
// ================= | ||
// === Constants === | ||
// ================= | ||
|
@@ -82,6 +84,16 @@ pub struct NotSupported { | |
#[fail(display = "An action cannot be executed when searcher is in \"edit node\" mode.")] | ||
pub struct CannotExecuteWhenEditingNode; | ||
|
||
#[allow(missing_docs)] | ||
#[derive(Copy, Clone, Debug, Fail)] | ||
#[fail(display = "An action cannot be executed when searcher is run without `this` argument.")] | ||
pub struct CannotRunWithoutThisArgument; | ||
|
||
#[allow(missing_docs)] | ||
#[derive(Copy, Clone, Debug, Fail)] | ||
#[fail(display = "No visualization data received for an AI suggestion.")] | ||
pub struct NoAIVisualizationDataReceived; | ||
|
||
#[allow(missing_docs)] | ||
#[derive(Copy, Clone, Debug, Fail)] | ||
#[fail(display = "Cannot commit expression in current mode ({:?}).", mode)] | ||
|
@@ -96,14 +108,15 @@ pub struct CannotCommitExpression { | |
// ===================== | ||
|
||
/// The notification emitted by Searcher Controller | ||
#[derive(Copy, Clone, Debug, Eq, PartialEq)] | ||
#[derive(Clone, Debug, Eq, PartialEq)] | ||
pub enum Notification { | ||
/// A new Suggestion list is available. | ||
NewActionList, | ||
/// Code should be inserted by means of using an AI autocompletion. | ||
AISuggestionUpdated(String, text::Range<Byte>), | ||
} | ||
|
||
|
||
|
||
// =================== | ||
// === Suggestions === | ||
// =================== | ||
|
@@ -550,12 +563,78 @@ impl Searcher { | |
self.notifier.notify(Notification::NewActionList); | ||
} | ||
|
||
const AI_QUERY_PREFIX: &'static str = "AI:"; | ||
const AI_QUERY_ACCEPT_TOKEN: &'static str = "#"; | ||
const AI_STOP_SEQUENCE: &'static str = "`"; | ||
const AI_GOAL_PLACEHOLDER: &'static str = "__$$GOAL$$__"; | ||
|
||
/// Accepts the current AI query and exchanges it for actual expression. | ||
/// To accomplish this, it performs the following steps: | ||
/// 1. Attaches a visualization to `this`, calling `AI.build_ai_prompt`, to | ||
/// get a data-specific prompt for Open AI; | ||
/// 2. Sends the prompt to the Open AI backend proxy, along with the user | ||
/// query. | ||
/// 3. Replaces the query with the result of the Open AI call. | ||
async fn accept_ai_query( | ||
query: String, | ||
query_range: text::Range<Byte>, | ||
this: ThisNode, | ||
graph: Handle, | ||
notifier: notification::Publisher<Notification>, | ||
) -> FallibleResult { | ||
let vis_ptr = QualifiedMethodPointer::from_qualified_text( | ||
"Standard.Visualization.AI", | ||
"Standard.Visualization.AI", | ||
"build_ai_prompt", | ||
)?; | ||
let vis = Visualization::new(this.id, vis_ptr, vec![]); | ||
let mut result = graph.attach_visualization(vis.clone()).await?; | ||
let next = result.next().await.ok_or(NoAIVisualizationDataReceived)?; | ||
let prompt = std::str::from_utf8(&next)?; | ||
let prompt_with_goal = prompt.replace(Self::AI_GOAL_PLACEHOLDER, &query); | ||
graph.detach_visualization(vis.id).await?; | ||
let completion = graph.get_ai_completion(&prompt_with_goal, Self::AI_STOP_SEQUENCE).await?; | ||
notifier.publish(Notification::AISuggestionUpdated(completion, query_range)).await; | ||
Ok(()) | ||
} | ||
|
||
/// Handles AI queries (i.e. searcher input starting with `"AI:"`). Doesn't | ||
/// do anything if the query doesn't end with a specified "accept" | ||
/// sequence. Otherwise, calls `Self::accept_ai_query` to perform the final | ||
/// replacement. | ||
fn handle_ai_query(&self, query: String) -> FallibleResult { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is PoC, but docs are needed here for future refactoring/development. Please describe shortly what this function does and why. |
||
let len = query.as_bytes().len(); | ||
let range = text::Range::new(Byte::from(0), Byte::from(len)); | ||
let query = query.trim_start_matches(Self::AI_QUERY_PREFIX); | ||
if !query.ends_with(Self::AI_QUERY_ACCEPT_TOKEN) { | ||
return Ok(()); | ||
} | ||
let query = query.trim_end_matches(Self::AI_QUERY_ACCEPT_TOKEN).trim().to_string(); | ||
let this = self.this_arg.clone(); | ||
if this.is_none() { | ||
return Err(CannotRunWithoutThisArgument.into()); | ||
} | ||
let this = this.as_ref().as_ref().unwrap().clone(); | ||
let graph = self.graph.clone_ref(); | ||
let notifier = self.notifier.clone_ref(); | ||
executor::global::spawn(async move { | ||
if let Err(e) = Self::accept_ai_query(query, range, this, graph, notifier).await { | ||
error!("error when handling AI query: {e}"); | ||
} | ||
}); | ||
|
||
Ok(()) | ||
} | ||
|
||
/// Set the Searcher Input. | ||
/// | ||
/// This function should be called each time user modifies Searcher input in view. It may result | ||
/// in a new action list (the appropriate notification will be emitted). | ||
#[profile(Debug)] | ||
pub fn set_input(&self, new_input: String, cursor_position: Byte) -> FallibleResult { | ||
if new_input.starts_with(Self::AI_QUERY_PREFIX) { | ||
return self.handle_ai_query(new_input); | ||
} | ||
debug!("Manually setting input to {new_input} with cursor position {cursor_position}"); | ||
let parsed_input = input::Input::parse(self.ide.parser(), new_input, cursor_position); | ||
let new_context = parsed_input.context().map(|ctx| ctx.into_ast().repr()); | ||
|
@@ -1309,7 +1388,6 @@ fn component_list_for_literal( | |
} | ||
|
||
|
||
|
||
// ============= | ||
// === Tests === | ||
// ============= | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 31 additions & 0 deletions
31
distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from Standard.Base import all | ||
import Standard.Table.Data.Table.Table | ||
|
||
goal_placeholder = "__$$GOAL$$__" | ||
|
||
Table.build_ai_prompt self = | ||
ops = ["aggregate","filter_by_expression","order_by","row_count","set","select_columns","transpose","join"] | ||
aggs = ["Count","Average","Sum","Median","First","Last","Maximum","Minimum"] | ||
joins = ["Inner","Left_Outer","Right_Outer","Full","Left_Exclusive","Right_Exclusive"] | ||
examples = """ | ||
Table["id","category","Unit Price","Stock"];goal=get product count by category==>>`aggregate [Aggregate_Column.Group_By "category", Aggregate_Column.Count Nothing]` | ||
Table["ID","Unit Price","Stock"];goal=order by how many items are available==>>`order_by ["Stock"]` | ||
Table["Name","Enrolled Year"];goal=select people who enrolled between 2015 and 2018==>>`filter_by_expression "[Enrolled Year] >= 2015 && [Enrolled Year] <= 2018` | ||
Table["Number of items","client name","city","unit price"];goal=compute the total value of each order==>>`set "[Number of items] * [unit price]" "total value"` | ||
Table["Number of items","client name","CITY","unit price","total value"];goal=compute the average order value by city==>>`aggregate [Aggregate_Column.Group_By "CITY", Aggregate_Column.Average "total value"]` | ||
Table["Area Code", "number"];goal=get full phone numbers==>>`set "'+1 (' + [Area Code] + ') ' + [number]" "full phone number"` | ||
Table["Name","Grade","Subject"];goal=rank students by their average grade==>>`aggregate [Aggregate_Column.Group_By "Name", Aggregate_Column.Average "Grade" "Average Grade"] . order_by [Sort_Column.Name "Average Grade" Sort_Direction.Descending]` | ||
Table["Country","Prime minister name","2018","2019","2020","2021"];goal=pivot yearly GDP values to rows==>>`transpose ["Country", "Prime minister name"] "Year" "GDP"` | ||
Table["Size","Weight","Width","stuff","thing"];goal=only select size and thing of each record==>>`select_columns ["Size", "thing"]` | ||
Table["ID","Name","Count"];goal=join it with var_17==>>`join var_17 Join_Kind.Inner` | ||
ops_prompt = "Operations available on Table are: " + (ops . join ",") | ||
aggs_prompt = "Available ways to aggregate a column are: " + (aggs . join ",") | ||
joins_prompt = "Available join kinds are: " + (joins . join ",") | ||
base_prompt = ops_prompt + '\n' + aggs_prompt + '\n' + joins_prompt + '\n' + examples | ||
columns = self.column_names . map .to_text . join "," "Table[" "];" | ||
goal_line = "goal=" + goal_placeholder + "==>>`" | ||
base_prompt + '\n' + columns + goal_line | ||
|
||
Any.build_ai_prompt self = "````" | ||
|
||
build_ai_prompt subject = subject.build_ai_prompt |
18 changes: 18 additions & 0 deletions
18
engine/language-server/src/main/scala/org/enso/languageserver/ai/AICompletion.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package org.enso.languageserver.ai | ||
|
||
import org.enso.jsonrpc.{HasParams, HasResult, Method} | ||
|
||
case object AICompletion extends Method("ai/completion") { | ||
case class Params(prompt: String, stopSequence: String) | ||
case class Result(code: String) | ||
|
||
implicit val hasParams: HasParams.Aux[this.type, AICompletion.Params] = | ||
new HasParams[this.type] { | ||
type Params = AICompletion.Params | ||
} | ||
|
||
implicit val hasResult: HasResult.Aux[this.type, AICompletion.Result] = | ||
new HasResult[this.type] { | ||
type Result = AICompletion.Result | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is PoC, but docs are needed here for future refactoring/development. Please describe shortly what this function does and why.