diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d0ff80e2f9a..3f5a232e9085 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -132,6 +132,8 @@ quickly understand each button's function. - [File associations are created on Windows and macOS][6077]. This allows opening Enso files by double-clicking them in the file explorer. +- [AI-powered code completions][5910]. It is now possible to get AI-powered + completions when using node searcher with Tables. - [Added capability to create node widgets with complex UI][6347]. Node widgets such as dropdown can now be placed in the node and affect the code text flow. - [The IDE UI element for selecting the execution mode of the project is now @@ -184,6 +186,7 @@ - [Performance and readability of documentation panel was improved][6893]. The documentation is now split into separate pages, which are much smaller. +[5910]: https://github.com/enso-org/enso/pull/5910 [6279]: https://github.com/enso-org/enso/pull/6279 [6421]: https://github.com/enso-org/enso/pull/6421 [6530]: https://github.com/enso-org/enso/pull/6530 diff --git a/app/gui/controller/engine-protocol/src/language_server.rs b/app/gui/controller/engine-protocol/src/language_server.rs index eb9c1f3b6821..e250d712cc42 100644 --- a/app/gui/controller/engine-protocol/src/language_server.rs +++ b/app/gui/controller/engine-protocol/src/language_server.rs @@ -206,6 +206,10 @@ trait API { #[MethodInput=VcsRestoreInput, rpc_name="vcs/restore"] fn restore_vcs(&self, root: Path, commit_id: Option) -> response::RestoreVcs; + /// An OpenAI-powered completion to the given prompt, with the given stop sequence. + #[MethodInput=AiCompletionInput, rpc_name="ai/completion"] + fn ai_completion(&self, prompt: String, stop_sequence: String) -> response::AiCompletion; + /// Set the execution environment of the context for future evaluations. #[MethodInput=SetModeInput, rpc_name="executionContext/setExecutionEnvironment"] fn set_execution_environment(&self, context_id: ContextId, execution_environment: ExecutionEnvironment) -> (); diff --git a/app/gui/controller/engine-protocol/src/language_server/response.rs b/app/gui/controller/engine-protocol/src/language_server/response.rs index 6c3387795f30..4673e0b3d16f 100644 --- a/app/gui/controller/engine-protocol/src/language_server/response.rs +++ b/app/gui/controller/engine-protocol/src/language_server/response.rs @@ -99,6 +99,14 @@ pub struct Completion { pub current_version: SuggestionsDatabaseVersion, } +/// Response of `ai/completion` method. +#[derive(Hash, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[allow(missing_docs)] +pub struct AiCompletion { + pub code: String, +} + /// Response of `get_component_groups` method. #[derive(Hash, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/app/gui/src/controller/graph/executed.rs b/app/gui/src/controller/graph/executed.rs index 2c8f4f676d12..ba49d4be70ac 100644 --- a/app/gui/src/controller/graph/executed.rs +++ b/app/gui/src/controller/graph/executed.rs @@ -158,6 +158,11 @@ impl Handle { self.execution_ctx.attach_visualization(visualization).await } + /// See [`model::ExecutionContext::get_ai_completion`]. + pub async fn get_ai_completion(&self, code: &str, stop: &str) -> FallibleResult { + self.execution_ctx.get_ai_completion(code, stop).await + } + /// See [`model::ExecutionContext::modify_visualization`]. pub fn modify_visualization( &self, diff --git a/app/gui/src/controller/searcher.rs b/app/gui/src/controller/searcher.rs index bcf19b89ad15..210cac59ca52 100644 --- a/app/gui/src/controller/searcher.rs +++ b/app/gui/src/controller/searcher.rs @@ -36,10 +36,12 @@ pub mod breadcrumbs; pub mod component; pub mod input; +use crate::controller::graph::executed::Handle; +use crate::model::execution_context::QualifiedMethodPointer; +use crate::model::execution_context::Visualization; pub use action::Action; - // ================= // === Constants === // ================= @@ -82,6 +84,16 @@ pub struct NotSupported { #[fail(display = "An action cannot be executed when searcher is in \"edit node\" mode.")] pub struct CannotExecuteWhenEditingNode; +#[allow(missing_docs)] +#[derive(Copy, Clone, Debug, Fail)] +#[fail(display = "An action cannot be executed when searcher is run without `this` argument.")] +pub struct CannotRunWithoutThisArgument; + +#[allow(missing_docs)] +#[derive(Copy, Clone, Debug, Fail)] +#[fail(display = "No visualization data received for an AI suggestion.")] +pub struct NoAIVisualizationDataReceived; + #[allow(missing_docs)] #[derive(Copy, Clone, Debug, Fail)] #[fail(display = "Cannot commit expression in current mode ({:?}).", mode)] @@ -96,14 +108,15 @@ pub struct CannotCommitExpression { // ===================== /// The notification emitted by Searcher Controller -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum Notification { /// A new Suggestion list is available. NewActionList, + /// Code should be inserted by means of using an AI autocompletion. + AISuggestionUpdated(String, text::Range), } - // =================== // === Suggestions === // =================== @@ -550,12 +563,78 @@ impl Searcher { self.notifier.notify(Notification::NewActionList); } + const AI_QUERY_PREFIX: &'static str = "AI:"; + const AI_QUERY_ACCEPT_TOKEN: &'static str = "#"; + const AI_STOP_SEQUENCE: &'static str = "`"; + const AI_GOAL_PLACEHOLDER: &'static str = "__$$GOAL$$__"; + + /// Accepts the current AI query and exchanges it for actual expression. + /// To accomplish this, it performs the following steps: + /// 1. Attaches a visualization to `this`, calling `AI.build_ai_prompt`, to + /// get a data-specific prompt for Open AI; + /// 2. Sends the prompt to the Open AI backend proxy, along with the user + /// query. + /// 3. Replaces the query with the result of the Open AI call. + async fn accept_ai_query( + query: String, + query_range: text::Range, + this: ThisNode, + graph: Handle, + notifier: notification::Publisher, + ) -> FallibleResult { + let vis_ptr = QualifiedMethodPointer::from_qualified_text( + "Standard.Visualization.AI", + "Standard.Visualization.AI", + "build_ai_prompt", + )?; + let vis = Visualization::new(this.id, vis_ptr, vec![]); + let mut result = graph.attach_visualization(vis.clone()).await?; + let next = result.next().await.ok_or(NoAIVisualizationDataReceived)?; + let prompt = std::str::from_utf8(&next)?; + let prompt_with_goal = prompt.replace(Self::AI_GOAL_PLACEHOLDER, &query); + graph.detach_visualization(vis.id).await?; + let completion = graph.get_ai_completion(&prompt_with_goal, Self::AI_STOP_SEQUENCE).await?; + notifier.publish(Notification::AISuggestionUpdated(completion, query_range)).await; + Ok(()) + } + + /// Handles AI queries (i.e. searcher input starting with `"AI:"`). Doesn't + /// do anything if the query doesn't end with a specified "accept" + /// sequence. Otherwise, calls `Self::accept_ai_query` to perform the final + /// replacement. + fn handle_ai_query(&self, query: String) -> FallibleResult { + let len = query.as_bytes().len(); + let range = text::Range::new(Byte::from(0), Byte::from(len)); + let query = query.trim_start_matches(Self::AI_QUERY_PREFIX); + if !query.ends_with(Self::AI_QUERY_ACCEPT_TOKEN) { + return Ok(()); + } + let query = query.trim_end_matches(Self::AI_QUERY_ACCEPT_TOKEN).trim().to_string(); + let this = self.this_arg.clone(); + if this.is_none() { + return Err(CannotRunWithoutThisArgument.into()); + } + let this = this.as_ref().as_ref().unwrap().clone(); + let graph = self.graph.clone_ref(); + let notifier = self.notifier.clone_ref(); + executor::global::spawn(async move { + if let Err(e) = Self::accept_ai_query(query, range, this, graph, notifier).await { + error!("error when handling AI query: {e}"); + } + }); + + Ok(()) + } + /// Set the Searcher Input. /// /// This function should be called each time user modifies Searcher input in view. It may result /// in a new action list (the appropriate notification will be emitted). #[profile(Debug)] pub fn set_input(&self, new_input: String, cursor_position: Byte) -> FallibleResult { + if new_input.starts_with(Self::AI_QUERY_PREFIX) { + return self.handle_ai_query(new_input); + } debug!("Manually setting input to {new_input} with cursor position {cursor_position}"); let parsed_input = input::Input::parse(self.ide.parser(), new_input, cursor_position); let new_context = parsed_input.context().map(|ctx| ctx.into_ast().repr()); @@ -1309,7 +1388,6 @@ fn component_list_for_literal( } - // ============= // === Tests === // ============= diff --git a/app/gui/src/model/execution_context.rs b/app/gui/src/model/execution_context.rs index aa7cea01b44f..ee2263ed0752 100644 --- a/app/gui/src/model/execution_context.rs +++ b/app/gui/src/model/execution_context.rs @@ -459,7 +459,6 @@ pub trait API: Debug { FallibleResult>, >; - /// Detach the visualization from this execution context. #[allow(clippy::needless_lifetimes)] // Note: Needless lifetimes fn detach_visualization<'a>( @@ -498,6 +497,13 @@ pub trait API: Debug { futures::future::join_all(detach_actions).boxed_local() } + /// Get an AI completion for the given `prompt`, with specified `stop` sequence. + fn get_ai_completion<'a>( + &'a self, + prompt: &str, + stop: &str, + ) -> BoxFuture<'a, FallibleResult>; + /// Interrupt the program execution. #[allow(clippy::needless_lifetimes)] // Note: Needless lifetimes fn interrupt<'a>(&'a self) -> BoxFuture<'a, FallibleResult>; diff --git a/app/gui/src/model/execution_context/plain.rs b/app/gui/src/model/execution_context/plain.rs index b83fc3e9c7a5..db1bb28781bf 100644 --- a/app/gui/src/model/execution_context/plain.rs +++ b/app/gui/src/model/execution_context/plain.rs @@ -282,6 +282,14 @@ impl model::execution_context::API for ExecutionContext { } } + fn get_ai_completion<'a>( + &'a self, + _prompt: &str, + _stop: &str, + ) -> LocalBoxFuture<'a, FallibleResult> { + futures::future::ready(Ok("".to_string())).boxed_local() + } + fn interrupt(&self) -> BoxFuture { futures::future::ready(Ok(())).boxed_local() } diff --git a/app/gui/src/model/execution_context/synchronized.rs b/app/gui/src/model/execution_context/synchronized.rs index 6b2d2c4a3e23..46ad45f1d853 100644 --- a/app/gui/src/model/execution_context/synchronized.rs +++ b/app/gui/src/model/execution_context/synchronized.rs @@ -292,6 +292,19 @@ impl model::execution_context::API for ExecutionContext { self.model.dispatch_visualization_update(visualization_id, data) } + fn get_ai_completion<'a>( + &'a self, + prompt: &str, + stop: &str, + ) -> BoxFuture<'a, FallibleResult> { + self.language_server + .client + .ai_completion(&prompt.to_string(), &stop.to_string()) + .map(|result| result.map(|completion| completion.code).map_err(Into::into)) + .boxed_local() + } + + fn interrupt(&self) -> BoxFuture { async move { self.language_server.client.interrupt(&self.id).await?; diff --git a/app/gui/src/presenter/searcher.rs b/app/gui/src/presenter/searcher.rs index 01a5d9099cc6..c7656914c7b9 100644 --- a/app/gui/src/presenter/searcher.rs +++ b/app/gui/src/presenter/searcher.rs @@ -363,9 +363,12 @@ impl Searcher { let weak_model = Rc::downgrade(&model); let notifications = model.controller.subscribe(); + let graph = model.view.graph().clone(); spawn_stream_handler(weak_model, notifications, move |notification, _| { match notification { Notification::NewActionList => action_list_changed.emit(()), + Notification::AISuggestionUpdated(expr, range) => + graph.edit_node_expression((input_view, range, ImString::new(expr))), }; std::future::ready(()) }); diff --git a/app/gui/view/src/project.rs b/app/gui/view/src/project.rs index cb80d2497e58..a38e418a7191 100644 --- a/app/gui/view/src/project.rs +++ b/app/gui/view/src/project.rs @@ -358,7 +358,7 @@ impl View { .init_style_toggle_frp() .init_fullscreen_visualization_frp() .init_debug_mode_frp() - .init_shortcut_observer() + .init_shortcut_observer(app) } fn init_top_bar_frp(self, scene: &Scene) -> Self { @@ -659,10 +659,10 @@ impl View { self } - fn init_shortcut_observer(self) -> Self { + fn init_shortcut_observer(self, app: &Application) -> Self { let frp = &self.frp; frp::extend! { network - frp.source.current_shortcut <+ self.model.app.shortcuts.currently_handled; + frp.source.current_shortcut <+ app.shortcuts.currently_handled; } self diff --git a/distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso b/distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso new file mode 100644 index 000000000000..6b7187ea4c6d --- /dev/null +++ b/distribution/lib/Standard/Visualization/0.0.0-dev/src/AI.enso @@ -0,0 +1,31 @@ +from Standard.Base import all +import Standard.Table.Data.Table.Table + +goal_placeholder = "__$$GOAL$$__" + +Table.build_ai_prompt self = + ops = ["aggregate","filter_by_expression","order_by","row_count","set","select_columns","transpose","join"] + aggs = ["Count","Average","Sum","Median","First","Last","Maximum","Minimum"] + joins = ["Inner","Left_Outer","Right_Outer","Full","Left_Exclusive","Right_Exclusive"] + examples = """ + Table["id","category","Unit Price","Stock"];goal=get product count by category==>>`aggregate [Aggregate_Column.Group_By "category", Aggregate_Column.Count Nothing]` + Table["ID","Unit Price","Stock"];goal=order by how many items are available==>>`order_by ["Stock"]` + Table["Name","Enrolled Year"];goal=select people who enrolled between 2015 and 2018==>>`filter_by_expression "[Enrolled Year] >= 2015 && [Enrolled Year] <= 2018` + Table["Number of items","client name","city","unit price"];goal=compute the total value of each order==>>`set "[Number of items] * [unit price]" "total value"` + Table["Number of items","client name","CITY","unit price","total value"];goal=compute the average order value by city==>>`aggregate [Aggregate_Column.Group_By "CITY", Aggregate_Column.Average "total value"]` + Table["Area Code", "number"];goal=get full phone numbers==>>`set "'+1 (' + [Area Code] + ') ' + [number]" "full phone number"` + Table["Name","Grade","Subject"];goal=rank students by their average grade==>>`aggregate [Aggregate_Column.Group_By "Name", Aggregate_Column.Average "Grade" "Average Grade"] . order_by [Sort_Column.Name "Average Grade" Sort_Direction.Descending]` + Table["Country","Prime minister name","2018","2019","2020","2021"];goal=pivot yearly GDP values to rows==>>`transpose ["Country", "Prime minister name"] "Year" "GDP"` + Table["Size","Weight","Width","stuff","thing"];goal=only select size and thing of each record==>>`select_columns ["Size", "thing"]` + Table["ID","Name","Count"];goal=join it with var_17==>>`join var_17 Join_Kind.Inner` + ops_prompt = "Operations available on Table are: " + (ops . join ",") + aggs_prompt = "Available ways to aggregate a column are: " + (aggs . join ",") + joins_prompt = "Available join kinds are: " + (joins . join ",") + base_prompt = ops_prompt + '\n' + aggs_prompt + '\n' + joins_prompt + '\n' + examples + columns = self.column_names . map .to_text . join "," "Table[" "];" + goal_line = "goal=" + goal_placeholder + "==>>`" + base_prompt + '\n' + columns + goal_line + +Any.build_ai_prompt self = "````" + +build_ai_prompt subject = subject.build_ai_prompt diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/ai/AICompletion.scala b/engine/language-server/src/main/scala/org/enso/languageserver/ai/AICompletion.scala new file mode 100644 index 000000000000..953f3d058f03 --- /dev/null +++ b/engine/language-server/src/main/scala/org/enso/languageserver/ai/AICompletion.scala @@ -0,0 +1,18 @@ +package org.enso.languageserver.ai + +import org.enso.jsonrpc.{HasParams, HasResult, Method} + +case object AICompletion extends Method("ai/completion") { + case class Params(prompt: String, stopSequence: String) + case class Result(code: String) + + implicit val hasParams: HasParams.Aux[this.type, AICompletion.Params] = + new HasParams[this.type] { + type Params = AICompletion.Params + } + + implicit val hasResult: HasResult.Aux[this.type, AICompletion.Result] = + new HasResult[this.type] { + type Result = AICompletion.Result + } +} diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/boot/MainModule.scala b/engine/language-server/src/main/scala/org/enso/languageserver/boot/MainModule.scala index 07c9d7e45bb0..f81219b81cf3 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/boot/MainModule.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/boot/MainModule.scala @@ -80,6 +80,10 @@ class MainModule(serverConfig: LanguageServerConfig, logLevel: LogLevel) { ContentRoot.Project(serverConfig.contentRootUuid), new File(serverConfig.contentRootPath) ) + + private val openAiKey = sys.env.get("OPENAI_API_KEY") + private val openAiCfg = openAiKey.map(AICompletionConfig) + val languageServerConfig = Config( contentRoot, FileManagerConfig(timeout = 3.seconds), @@ -92,7 +96,8 @@ class MainModule(serverConfig: LanguageServerConfig, logLevel: LogLevel) { ExecutionContextConfig(), directoriesConfig, serverConfig.profilingConfig, - serverConfig.startupConfig + serverConfig.startupConfig, + openAiCfg ) log.trace("Created Language Server config [{}].", languageServerConfig) diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/data/Config.scala b/engine/language-server/src/main/scala/org/enso/languageserver/data/Config.scala index 46f201189a1d..c319337dcdbf 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/data/Config.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/data/Config.scala @@ -69,6 +69,8 @@ case class VcsManagerConfig( Path.of(ProjectDirectoriesConfig.DataDirectory) } +case class AICompletionConfig(apiKey: String) + object VcsManagerConfig { def apply(asyncInit: Boolean = true): VcsManagerConfig = VcsManagerConfig(initTimeout = 5.seconds, 5.seconds, asyncInit) @@ -153,7 +155,8 @@ case class Config( executionContext: ExecutionContextConfig, directories: ProjectDirectoriesConfig, profiling: ProfilingConfig, - startup: StartupConfig + startup: StartupConfig, + aiCompletionConfig: Option[AICompletionConfig] ) extends ToLogString { /** @inheritdoc */ diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala index 3f8878c45bce..8a4d44df63d3 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonConnectionController.scala @@ -7,6 +7,7 @@ import com.typesafe.scalalogging.LazyLogging import org.enso.cli.task.ProgressUnit import org.enso.cli.task.notifications.TaskNotificationApi import org.enso.jsonrpc._ +import org.enso.languageserver.ai.AICompletion import org.enso.languageserver.boot.resource.InitializationComponent import org.enso.languageserver.capability.CapabilityApi.{ AcquireCapability, @@ -500,6 +501,9 @@ class JsonConnectionController( .props(requestTimeout, suggestionsHandler), InvalidateSuggestionsDatabase -> search.InvalidateSuggestionsDatabaseHandler .props(requestTimeout, suggestionsHandler), + AICompletion -> ai.AICompletionHandler.props( + languageServerConfig.aiCompletionConfig + ), Completion -> search.CompletionHandler .props(requestTimeout, suggestionsHandler), ExecuteExpression -> ExecuteExpressionHandler diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonRpc.scala b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonRpc.scala index 2ee1c019df1d..69ccfc8c85f0 100644 --- a/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonRpc.scala +++ b/engine/language-server/src/main/scala/org/enso/languageserver/protocol/json/JsonRpc.scala @@ -7,6 +7,7 @@ import org.enso.cli.task.notifications.TaskNotificationApi.{ TaskStarted } import org.enso.jsonrpc.Protocol +import org.enso.languageserver.ai.AICompletion import org.enso.languageserver.capability.CapabilityApi.{ AcquireCapability, ForceReleaseCapability, @@ -79,6 +80,7 @@ object JsonRpc { .registerRequest(GetSuggestionsDatabaseVersion) .registerRequest(InvalidateSuggestionsDatabase) .registerRequest(Completion) + .registerRequest(AICompletion) .registerRequest(RenameProject) .registerRequest(ProjectInfo) .registerRequest(EditionsListAvailable) diff --git a/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletionHandler.scala b/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletionHandler.scala new file mode 100644 index 000000000000..2266b34214c7 --- /dev/null +++ b/engine/language-server/src/main/scala/org/enso/languageserver/requesthandler/ai/AICompletionHandler.scala @@ -0,0 +1,113 @@ +package org.enso.languageserver.requesthandler.ai + +import akka.actor.{Actor, ActorRef, Props} +import com.typesafe.scalalogging.LazyLogging +import org.enso.jsonrpc.{Errors, Id, Request, ResponseError, ResponseResult} +import org.enso.languageserver.ai.AICompletion +import org.enso.languageserver.util.UnhandledLogging +import akka.http.scaladsl.model._ +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.OAuth2BearerToken +import akka.pattern.PipeToSupport +import akka.stream.Materializer +import akka.util.ByteString +import io.circe.Json +import org.enso.languageserver.data.AICompletionConfig + +import scala.concurrent.ExecutionContext +import scala.concurrent.duration.FiniteDuration + +class AICompletionHandler(cfg: AICompletionConfig) + extends Actor + with LazyLogging + with UnhandledLogging + with PipeToSupport { + override def receive: Receive = requestStage + + case class AIResponse(status: StatusCode, data: ByteString) + + val http = Http(context.system) + implicit val ec: ExecutionContext = context.dispatcher + implicit val materializer: Materializer = Materializer(context) + + private def requestStage: Receive = { + case Request(AICompletion, id, AICompletion.Params(prompt, stop)) => + val body = Json.fromFields( + Seq( + ("model", Json.fromString("text-davinci-003")), + ("prompt", Json.fromString(prompt)), + ("stop", Json.fromString(stop)), + ("temperature", Json.fromDoubleOrNull(0)), + ("max_tokens", Json.fromInt(64)) + ) + ) + val req = + HttpRequest( + uri = "https://api.openai.com/v1/completions", + method = HttpMethods.POST, + headers = Seq( + headers.Authorization(OAuth2BearerToken(cfg.apiKey)) + ), + entity = HttpEntity(ContentTypes.`application/json`, body.noSpaces) + ) + + http + .singleRequest(req) + .flatMap(response => { + response.entity + .toStrict(FiniteDuration(10, "s")) + .map(e => { + AIResponse(response.status, e.data) + }) + }) + .pipeTo(self) + context.become(awaitingCompletionResponse(id, sender())) + } + + private def awaitingCompletionResponse(id: Id, replyTo: ActorRef): Receive = { + case AIResponse(StatusCodes.OK, data) => + val response = + for { + parsed <- io.circe.parser.parse(data.utf8String).toOption + obj <- parsed.asObject + choices <- obj("choices") + choicesVec <- choices.asArray + firstChoice <- choicesVec.headOption + firstChoiceObj <- firstChoice.asObject + firstChoiceText <- firstChoiceObj("text") + firstChoiceTextStr <- firstChoiceText.asString + } yield ResponseResult( + AICompletion, + id, + AICompletion.Result(firstChoiceTextStr) + ) + val handledErrors = + response.getOrElse(ResponseError(Some(id), Errors.ServiceError)) + replyTo ! handledErrors + case AIResponse(status, data) => + replyTo ! ResponseError( + Some(id), + Errors.UnknownError(status.intValue(), data.utf8String, None) + ) + } +} + +class UnsupportedHandler extends Actor with LazyLogging with UnhandledLogging { + override def receive: Receive = { case Request(AICompletion, id, _) => + sender() ! ResponseError( + Some(id), + Errors.MethodNotFound + ) + + } +} + +object AICompletionHandler { + def props(cfg: Option[AICompletionConfig]): Props = cfg + .map(conf => + Props( + new AICompletionHandler(conf) + ) + ) + .getOrElse(Props(new UnsupportedHandler())) +} diff --git a/engine/language-server/src/test/scala/org/enso/languageserver/boot/resource/RepoInitializationSpec.scala b/engine/language-server/src/test/scala/org/enso/languageserver/boot/resource/RepoInitializationSpec.scala index e800c2b622d3..b1420f8b74ef 100644 --- a/engine/language-server/src/test/scala/org/enso/languageserver/boot/resource/RepoInitializationSpec.scala +++ b/engine/language-server/src/test/scala/org/enso/languageserver/boot/resource/RepoInitializationSpec.scala @@ -197,7 +197,8 @@ class RepoInitializationSpec ExecutionContextConfig(requestTimeout = 3.seconds.dilated), ProjectDirectoriesConfig.initialize(root.file), ProfilingConfig(), - StartupConfig() + StartupConfig(), + None ) } diff --git a/engine/language-server/src/test/scala/org/enso/languageserver/filemanager/ContentRootManagerSpec.scala b/engine/language-server/src/test/scala/org/enso/languageserver/filemanager/ContentRootManagerSpec.scala index 1df1c5e48e78..3e5b2704a4d3 100644 --- a/engine/language-server/src/test/scala/org/enso/languageserver/filemanager/ContentRootManagerSpec.scala +++ b/engine/language-server/src/test/scala/org/enso/languageserver/filemanager/ContentRootManagerSpec.scala @@ -50,7 +50,8 @@ class ContentRootManagerSpec ExecutionContextConfig(requestTimeout = 3.seconds.dilated), ProjectDirectoriesConfig.initialize(root.file), ProfilingConfig(), - StartupConfig() + StartupConfig(), + None ) rootActor = system.actorOf(ContentRootManagerActor.props(config)) rootManager = new ContentRootManagerWrapper(config, rootActor) diff --git a/engine/language-server/src/test/scala/org/enso/languageserver/runtime/ContextEventsListenerSpec.scala b/engine/language-server/src/test/scala/org/enso/languageserver/runtime/ContextEventsListenerSpec.scala index 98659a6ac00f..3f4ce71825bf 100644 --- a/engine/language-server/src/test/scala/org/enso/languageserver/runtime/ContextEventsListenerSpec.scala +++ b/engine/language-server/src/test/scala/org/enso/languageserver/runtime/ContextEventsListenerSpec.scala @@ -430,7 +430,8 @@ class ContextEventsListenerSpec ExecutionContextConfig(requestTimeout = 3.seconds), ProjectDirectoriesConfig.initialize(root.file), ProfilingConfig(), - StartupConfig() + StartupConfig(), + None ) } diff --git a/engine/language-server/src/test/scala/org/enso/languageserver/search/SuggestionsHandlerSpec.scala b/engine/language-server/src/test/scala/org/enso/languageserver/search/SuggestionsHandlerSpec.scala index a9f9dfa2abbe..6ebccf51717c 100644 --- a/engine/language-server/src/test/scala/org/enso/languageserver/search/SuggestionsHandlerSpec.scala +++ b/engine/language-server/src/test/scala/org/enso/languageserver/search/SuggestionsHandlerSpec.scala @@ -994,7 +994,8 @@ class SuggestionsHandlerSpec ExecutionContextConfig(requestTimeout = 3.seconds), ProjectDirectoriesConfig.initialize(root.file), ProfilingConfig(), - StartupConfig() + StartupConfig(), + None ) } diff --git a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/binary/BaseBinaryServerTest.scala b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/binary/BaseBinaryServerTest.scala index 65e1922adc20..1a141650f7fd 100644 --- a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/binary/BaseBinaryServerTest.scala +++ b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/binary/BaseBinaryServerTest.scala @@ -50,7 +50,8 @@ class BaseBinaryServerTest extends BinaryServerTestKit { ExecutionContextConfig(requestTimeout = 3.seconds), ProjectDirectoriesConfig.initialize(testContentRoot.file), ProfilingConfig(), - StartupConfig() + StartupConfig(), + None ) sys.addShutdownHook(FileUtils.deleteQuietly(testContentRoot.file)) diff --git a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/BaseServerTest.scala b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/BaseServerTest.scala index a7fef1577b8c..c05917bb83ca 100644 --- a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/BaseServerTest.scala +++ b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/BaseServerTest.scala @@ -101,7 +101,8 @@ class BaseServerTest ExecutionContextConfig(requestTimeout = 3.seconds), ProjectDirectoriesConfig(testContentRoot.file), ProfilingConfig(), - StartupConfig() + StartupConfig(), + None ) override def protocolFactory: ProtocolFactory = diff --git a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/FileManagerTest.scala b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/FileManagerTest.scala index 87d304059710..07d0a81c2a24 100644 --- a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/FileManagerTest.scala +++ b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/FileManagerTest.scala @@ -29,7 +29,8 @@ class FileManagerTest extends BaseServerTest with RetrySpec { ExecutionContextConfig(requestTimeout = 3.seconds), ProjectDirectoriesConfig.initialize(testContentRoot.file), ProfilingConfig(), - StartupConfig() + StartupConfig(), + None ) } diff --git a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/VcsManagerTest.scala b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/VcsManagerTest.scala index 164c90164f0b..3305b97fffdb 100644 --- a/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/VcsManagerTest.scala +++ b/engine/language-server/src/test/scala/org/enso/languageserver/websocket/json/VcsManagerTest.scala @@ -32,7 +32,8 @@ class VcsManagerTest extends BaseServerTest with RetrySpec with FlakySpec { ExecutionContextConfig(requestTimeout = 3.seconds), ProjectDirectoriesConfig.initialize(testContentRoot.file), ProfilingConfig(), - StartupConfig() + StartupConfig(), + None ) }