Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AI Completions #5910

Merged
merged 18 commits into from
Jun 18, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@
quickly understand each button's function.
- [File associations are created on Windows and macOS][6077]. This allows
opening Enso files by double-clicking them in the file explorer.
- [AI-powered code completions][5910]. It is now possible to get AI-powered
completions when using node searcher with Tables.
- [Added capability to create node widgets with complex UI][6347]. Node widgets
such as dropdown can now be placed in the node and affect the code text flow.
- [The IDE UI element for selecting the execution mode of the project is now
Expand Down Expand Up @@ -184,6 +186,7 @@
- [Performance and readability of documentation panel was improved][6893]. The
documentation is now split into separate pages, which are much smaller.

[5910]: https://github.com/enso-org/enso/pull/5910
[6279]: https://github.com/enso-org/enso/pull/6279
[6421]: https://github.com/enso-org/enso/pull/6421
[6530]: https://github.com/enso-org/enso/pull/6530
Expand Down
4 changes: 4 additions & 0 deletions app/gui/controller/engine-protocol/src/language_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ trait API {
#[MethodInput=VcsRestoreInput, rpc_name="vcs/restore"]
fn restore_vcs(&self, root: Path, commit_id: Option<String>) -> response::RestoreVcs;

/// An OpenAI-powered completion to the given prompt, with the given stop sequence.
#[MethodInput=AiCompletionInput, rpc_name="ai/completion"]
fn ai_completion(&self, prompt: String, stop_sequence: String) -> response::AiCompletion;

/// Set the execution environment of the context for future evaluations.
#[MethodInput=SetModeInput, rpc_name="executionContext/setExecutionEnvironment"]
fn set_execution_environment(&self, context_id: ContextId, execution_environment: ExecutionEnvironment) -> ();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ pub struct Completion {
pub current_version: SuggestionsDatabaseVersion,
}

/// Response of `ai/completion` method.
#[derive(Hash, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[allow(missing_docs)]
pub struct AiCompletion {
pub code: String,
}

/// Response of `get_component_groups` method.
#[derive(Hash, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down
5 changes: 5 additions & 0 deletions app/gui/src/controller/graph/executed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ impl Handle {
self.execution_ctx.attach_visualization(visualization).await
}

/// See [`model::ExecutionContext::get_ai_completion`].
pub async fn get_ai_completion(&self, code: &str, stop: &str) -> FallibleResult<String> {
self.execution_ctx.get_ai_completion(code, stop).await
}

/// See [`model::ExecutionContext::modify_visualization`].
pub fn modify_visualization(
&self,
Expand Down
86 changes: 82 additions & 4 deletions app/gui/src/controller/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ pub mod breadcrumbs;
pub mod component;
pub mod input;

use crate::controller::graph::executed::Handle;
use crate::model::execution_context::QualifiedMethodPointer;
use crate::model::execution_context::Visualization;
pub use action::Action;



// =================
// === Constants ===
// =================
Expand Down Expand Up @@ -82,6 +84,16 @@ pub struct NotSupported {
#[fail(display = "An action cannot be executed when searcher is in \"edit node\" mode.")]
pub struct CannotExecuteWhenEditingNode;

#[allow(missing_docs)]
#[derive(Copy, Clone, Debug, Fail)]
#[fail(display = "An action cannot be executed when searcher is run without `this` argument.")]
pub struct CannotRunWithoutThisArgument;

#[allow(missing_docs)]
#[derive(Copy, Clone, Debug, Fail)]
#[fail(display = "No visualization data received for an AI suggestion.")]
pub struct NoAIVisualizationDataReceived;

#[allow(missing_docs)]
#[derive(Copy, Clone, Debug, Fail)]
#[fail(display = "Cannot commit expression in current mode ({:?}).", mode)]
Expand All @@ -96,14 +108,15 @@ pub struct CannotCommitExpression {
// =====================

/// The notification emitted by Searcher Controller
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Notification {
/// A new Suggestion list is available.
NewActionList,
/// Code should be inserted by means of using an AI autocompletion.
AISuggestionUpdated(String, text::Range<Byte>),
}



// ===================
// === Suggestions ===
// ===================
Expand Down Expand Up @@ -550,12 +563,78 @@ impl Searcher {
self.notifier.notify(Notification::NewActionList);
}

const AI_QUERY_PREFIX: &'static str = "AI:";
const AI_QUERY_ACCEPT_TOKEN: &'static str = "#";
const AI_STOP_SEQUENCE: &'static str = "`";
const AI_GOAL_PLACEHOLDER: &'static str = "__$$GOAL$$__";

/// Accepts the current AI query and exchanges it for actual expression.
/// To accomplish this, it performs the following steps:
/// 1. Attaches a visualization to `this`, calling `AI.build_ai_prompt`, to
/// get a data-specific prompt for Open AI;
/// 2. Sends the prompt to the Open AI backend proxy, along with the user
/// query.
/// 3. Replaces the query with the result of the Open AI call.
async fn accept_ai_query(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is PoC, but docs are needed here for future refactoring/development. Please describe shortly what this function does and why.

query: String,
query_range: text::Range<Byte>,
this: ThisNode,
graph: Handle,
notifier: notification::Publisher<Notification>,
) -> FallibleResult {
let vis_ptr = QualifiedMethodPointer::from_qualified_text(
"Standard.Visualization.AI",
"Standard.Visualization.AI",
"build_ai_prompt",
)?;
let vis = Visualization::new(this.id, vis_ptr, vec![]);
let mut result = graph.attach_visualization(vis.clone()).await?;
let next = result.next().await.ok_or(NoAIVisualizationDataReceived)?;
let prompt = std::str::from_utf8(&next)?;
let prompt_with_goal = prompt.replace(Self::AI_GOAL_PLACEHOLDER, &query);
graph.detach_visualization(vis.id).await?;
let completion = graph.get_ai_completion(&prompt_with_goal, Self::AI_STOP_SEQUENCE).await?;
notifier.publish(Notification::AISuggestionUpdated(completion, query_range)).await;
Ok(())
}

/// Handles AI queries (i.e. searcher input starting with `"AI:"`). Doesn't
/// do anything if the query doesn't end with a specified "accept"
/// sequence. Otherwise, calls `Self::accept_ai_query` to perform the final
/// replacement.
fn handle_ai_query(&self, query: String) -> FallibleResult {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is PoC, but docs are needed here for future refactoring/development. Please describe shortly what this function does and why.

let len = query.as_bytes().len();
let range = text::Range::new(Byte::from(0), Byte::from(len));
let query = query.trim_start_matches(Self::AI_QUERY_PREFIX);
if !query.ends_with(Self::AI_QUERY_ACCEPT_TOKEN) {
return Ok(());
}
let query = query.trim_end_matches(Self::AI_QUERY_ACCEPT_TOKEN).trim().to_string();
let this = self.this_arg.clone();
if this.is_none() {
return Err(CannotRunWithoutThisArgument.into());
}
let this = this.as_ref().as_ref().unwrap().clone();
let graph = self.graph.clone_ref();
let notifier = self.notifier.clone_ref();
executor::global::spawn(async move {
if let Err(e) = Self::accept_ai_query(query, range, this, graph, notifier).await {
error!("error when handling AI query: {e}");
}
});

Ok(())
}

/// Set the Searcher Input.
///
/// This function should be called each time user modifies Searcher input in view. It may result
/// in a new action list (the appropriate notification will be emitted).
#[profile(Debug)]
pub fn set_input(&self, new_input: String, cursor_position: Byte) -> FallibleResult {
if new_input.starts_with(Self::AI_QUERY_PREFIX) {
return self.handle_ai_query(new_input);
}
debug!("Manually setting input to {new_input} with cursor position {cursor_position}");
let parsed_input = input::Input::parse(self.ide.parser(), new_input, cursor_position);
let new_context = parsed_input.context().map(|ctx| ctx.into_ast().repr());
Expand Down Expand Up @@ -1309,7 +1388,6 @@ fn component_list_for_literal(
}



// =============
// === Tests ===
// =============
Expand Down
8 changes: 7 additions & 1 deletion app/gui/src/model/execution_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ pub trait API: Debug {
FallibleResult<futures::channel::mpsc::UnboundedReceiver<VisualizationUpdateData>>,
>;


/// Detach the visualization from this execution context.
#[allow(clippy::needless_lifetimes)] // Note: Needless lifetimes
fn detach_visualization<'a>(
Expand Down Expand Up @@ -498,6 +497,13 @@ pub trait API: Debug {
futures::future::join_all(detach_actions).boxed_local()
}

/// Get an AI completion for the given `prompt`, with specified `stop` sequence.
fn get_ai_completion<'a>(
&'a self,
prompt: &str,
stop: &str,
) -> BoxFuture<'a, FallibleResult<String>>;

/// Interrupt the program execution.
#[allow(clippy::needless_lifetimes)] // Note: Needless lifetimes
fn interrupt<'a>(&'a self) -> BoxFuture<'a, FallibleResult>;
Expand Down
8 changes: 8 additions & 0 deletions app/gui/src/model/execution_context/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,14 @@ impl model::execution_context::API for ExecutionContext {
}
}

fn get_ai_completion<'a>(
&'a self,
_prompt: &str,
_stop: &str,
) -> LocalBoxFuture<'a, FallibleResult<String>> {
futures::future::ready(Ok("".to_string())).boxed_local()
}

fn interrupt(&self) -> BoxFuture<FallibleResult> {
futures::future::ready(Ok(())).boxed_local()
}
Expand Down
13 changes: 13 additions & 0 deletions app/gui/src/model/execution_context/synchronized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,19 @@ impl model::execution_context::API for ExecutionContext {
self.model.dispatch_visualization_update(visualization_id, data)
}

fn get_ai_completion<'a>(
&'a self,
prompt: &str,
stop: &str,
) -> BoxFuture<'a, FallibleResult<String>> {
self.language_server
.client
.ai_completion(&prompt.to_string(), &stop.to_string())
.map(|result| result.map(|completion| completion.code).map_err(Into::into))
.boxed_local()
}


fn interrupt(&self) -> BoxFuture<FallibleResult> {
async move {
self.language_server.client.interrupt(&self.id).await?;
Expand Down
3 changes: 3 additions & 0 deletions app/gui/src/presenter/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,12 @@ impl Searcher {

let weak_model = Rc::downgrade(&model);
let notifications = model.controller.subscribe();
let graph = model.view.graph().clone();
spawn_stream_handler(weak_model, notifications, move |notification, _| {
match notification {
Notification::NewActionList => action_list_changed.emit(()),
Notification::AISuggestionUpdated(expr, range) =>
graph.edit_node_expression((input_view, range, ImString::new(expr))),
};
std::future::ready(())
});
Expand Down
6 changes: 3 additions & 3 deletions app/gui/view/src/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ impl View {
.init_style_toggle_frp()
.init_fullscreen_visualization_frp()
.init_debug_mode_frp()
.init_shortcut_observer()
.init_shortcut_observer(app)
}

fn init_top_bar_frp(self, scene: &Scene) -> Self {
Expand Down Expand Up @@ -659,10 +659,10 @@ impl View {
self
}

fn init_shortcut_observer(self) -> Self {
fn init_shortcut_observer(self, app: &Application) -> Self {
let frp = &self.frp;
frp::extend! { network
frp.source.current_shortcut <+ self.model.app.shortcuts.currently_handled;
frp.source.current_shortcut <+ app.shortcuts.currently_handled;
}

self
Expand Down
31 changes: 31 additions & 0 deletions distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from Standard.Base import all
import Standard.Table.Data.Table.Table

goal_placeholder = "__$$GOAL$$__"

Table.build_ai_prompt self =
ops = ["aggregate","filter_by_expression","order_by","row_count","set","select_columns","transpose","join"]
aggs = ["Count","Average","Sum","Median","First","Last","Maximum","Minimum"]
joins = ["Inner","Left_Outer","Right_Outer","Full","Left_Exclusive","Right_Exclusive"]
examples = """
Table["id","category","Unit Price","Stock"];goal=get product count by category==>>`aggregate [Aggregate_Column.Group_By "category", Aggregate_Column.Count Nothing]`
Table["ID","Unit Price","Stock"];goal=order by how many items are available==>>`order_by ["Stock"]`
Table["Name","Enrolled Year"];goal=select people who enrolled between 2015 and 2018==>>`filter_by_expression "[Enrolled Year] >= 2015 && [Enrolled Year] <= 2018`
Table["Number of items","client name","city","unit price"];goal=compute the total value of each order==>>`set "[Number of items] * [unit price]" "total value"`
Table["Number of items","client name","CITY","unit price","total value"];goal=compute the average order value by city==>>`aggregate [Aggregate_Column.Group_By "CITY", Aggregate_Column.Average "total value"]`
Table["Area Code", "number"];goal=get full phone numbers==>>`set "'+1 (' + [Area Code] + ') ' + [number]" "full phone number"`
Table["Name","Grade","Subject"];goal=rank students by their average grade==>>`aggregate [Aggregate_Column.Group_By "Name", Aggregate_Column.Average "Grade" "Average Grade"] . order_by [Sort_Column.Name "Average Grade" Sort_Direction.Descending]`
Table["Country","Prime minister name","2018","2019","2020","2021"];goal=pivot yearly GDP values to rows==>>`transpose ["Country", "Prime minister name"] "Year" "GDP"`
Table["Size","Weight","Width","stuff","thing"];goal=only select size and thing of each record==>>`select_columns ["Size", "thing"]`
Table["ID","Name","Count"];goal=join it with var_17==>>`join var_17 Join_Kind.Inner`
ops_prompt = "Operations available on Table are: " + (ops . join ",")
aggs_prompt = "Available ways to aggregate a column are: " + (aggs . join ",")
joins_prompt = "Available join kinds are: " + (joins . join ",")
base_prompt = ops_prompt + '\n' + aggs_prompt + '\n' + joins_prompt + '\n' + examples
columns = self.column_names . map .to_text . join "," "Table[" "];"
goal_line = "goal=" + goal_placeholder + "==>>`"
base_prompt + '\n' + columns + goal_line

Any.build_ai_prompt self = "````"

build_ai_prompt subject = subject.build_ai_prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.enso.languageserver.ai

import org.enso.jsonrpc.{HasParams, HasResult, Method}

case object AICompletion extends Method("ai/completion") {
case class Params(prompt: String, stopSequence: String)
case class Result(code: String)

implicit val hasParams: HasParams.Aux[this.type, AICompletion.Params] =
new HasParams[this.type] {
type Params = AICompletion.Params
}

implicit val hasResult: HasResult.Aux[this.type, AICompletion.Result] =
new HasResult[this.type] {
type Result = AICompletion.Result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class MainModule(serverConfig: LanguageServerConfig, logLevel: LogLevel) {
ContentRoot.Project(serverConfig.contentRootUuid),
new File(serverConfig.contentRootPath)
)

private val openAiKey = sys.env.get("OPENAI_API_KEY")
private val openAiCfg = openAiKey.map(AICompletionConfig)

val languageServerConfig = Config(
contentRoot,
FileManagerConfig(timeout = 3.seconds),
Expand All @@ -92,7 +96,8 @@ class MainModule(serverConfig: LanguageServerConfig, logLevel: LogLevel) {
ExecutionContextConfig(),
directoriesConfig,
serverConfig.profilingConfig,
serverConfig.startupConfig
serverConfig.startupConfig,
openAiCfg
)
log.trace("Created Language Server config [{}].", languageServerConfig)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ case class VcsManagerConfig(
Path.of(ProjectDirectoriesConfig.DataDirectory)
}

case class AICompletionConfig(apiKey: String)

object VcsManagerConfig {
def apply(asyncInit: Boolean = true): VcsManagerConfig =
VcsManagerConfig(initTimeout = 5.seconds, 5.seconds, asyncInit)
Expand Down Expand Up @@ -153,7 +155,8 @@ case class Config(
executionContext: ExecutionContextConfig,
directories: ProjectDirectoriesConfig,
profiling: ProfilingConfig,
startup: StartupConfig
startup: StartupConfig,
aiCompletionConfig: Option[AICompletionConfig]
) extends ToLogString {

/** @inheritdoc */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.typesafe.scalalogging.LazyLogging
import org.enso.cli.task.ProgressUnit
import org.enso.cli.task.notifications.TaskNotificationApi
import org.enso.jsonrpc._
import org.enso.languageserver.ai.AICompletion
import org.enso.languageserver.boot.resource.InitializationComponent
import org.enso.languageserver.capability.CapabilityApi.{
AcquireCapability,
Expand Down Expand Up @@ -500,6 +501,9 @@ class JsonConnectionController(
.props(requestTimeout, suggestionsHandler),
InvalidateSuggestionsDatabase -> search.InvalidateSuggestionsDatabaseHandler
.props(requestTimeout, suggestionsHandler),
AICompletion -> ai.AICompletionHandler.props(
languageServerConfig.aiCompletionConfig
),
Completion -> search.CompletionHandler
.props(requestTimeout, suggestionsHandler),
ExecuteExpression -> ExecuteExpressionHandler
Expand Down
Loading