From 89e1f7cdd156b90f37a66942ab0a743a5df6abe2 Mon Sep 17 00:00:00 2001 From: tharun571 Date: Thu, 12 Sep 2024 23:23:23 +0530 Subject: [PATCH] Add LLM Implementation via xassist --- .gitignore | 4 + CMakeLists.txt | 41 ++++- docs/source/gemini.png | Bin 0 -> 6594 bytes docs/source/index.rst | 1 + docs/source/magics.rst | 41 +++++ environment-wasm-build.yml | 2 +- environment-wasm-host.yml | 2 +- src/xinterpreter.cpp | 6 + src/xmagics/xassist.cpp | 328 +++++++++++++++++++++++++++++++++++++ src/xmagics/xassist.hpp | 26 +++ test/test_interpreter.cpp | 76 +++++++++ 11 files changed, 524 insertions(+), 3 deletions(-) create mode 100644 docs/source/gemini.png create mode 100644 docs/source/magics.rst create mode 100644 src/xmagics/xassist.cpp create mode 100644 src/xmagics/xassist.hpp diff --git a/.gitignore b/.gitignore index 7d268e51..9ed5b1a7 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,7 @@ __pycache__/ build/ bld + +# LLM Implementation +*_api_key.txt +*_chat_history.txt \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index d75628ff..ef82619d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,6 +205,13 @@ set(XEUS_CPP_SRC src/xutils.cpp ) +if(NOT EMSCRIPTEN) + list(APPEND XEUS_CPP_SRC + src/xmagics/xassist.hpp + src/xmagics/xassist.cpp + ) +endif() + if(EMSCRIPTEN) list(APPEND XEUS_CPP_SRC src/xinterpreter_wasm.cpp) endif() @@ -309,9 +316,41 @@ macro(xeus_cpp_create_target target_name linkage output_name) else () set(XEUS_CPP_XEUS_TARGET xeus-static) endif () + + #This is a workaround for the issue with the libcurl target on Windows specifically for xassist + if (WIN32) + # Set the MSVC runtime library + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>DLL") + + # Find libcurl + find_package(CURL REQUIRED) + + # Add CURL_STATICLIB definition if linking statically + if (CURL_STATICLIB) + target_compile_definitions(${target_name} PUBLIC CURL_STATICLIB) + endif() - target_link_libraries(${target_name} PUBLIC ${XEUS_CPP_XEUS_TARGET} clangCppInterOp pugixml argparse::argparse) + # Link against the correct libcurl target + if (CURL_FOUND) + target_include_directories(${target_name} PRIVATE ${CURL_INCLUDE_DIRS}) + target_link_libraries(${target_name} PRIVATE ${CURL_LIBRARIES}) + endif() + # Existing target_link_libraries call, adjusted for clarity + target_link_libraries(${target_name} PUBLIC ${XEUS_CPP_XEUS_TARGET} clangCppInterOp pugixml argparse::argparse) + + # Ensure all linked libraries use the same runtime library + if (MSVC) + target_compile_options(${target_name} PRIVATE "/MD$<$:d>") + endif() + elseif (NOT EMSCRIPTEN) + # Curl initialised specifically for xassist + target_link_libraries(${target_name} PUBLIC ${XEUS_CPP_XEUS_TARGET} clangCppInterOp pugixml argparse::argparse curl) + else () + # TODO : Add curl support for emscripten + target_link_libraries(${target_name} PUBLIC ${XEUS_CPP_XEUS_TARGET} clangCppInterOp pugixml argparse::argparse) + endif() + if (WIN32 OR CYGWIN) # elseif (APPLE) diff --git a/docs/source/gemini.png b/docs/source/gemini.png new file mode 100644 index 0000000000000000000000000000000000000000..f99f86d964fed1272dc540ff436e884060ca1332 GIT binary patch literal 6594 zcma)>cQjl<-^YnwgXo<^^d2Qvi47r9Ls-#8R$b98R*4W@BzmtAEtb{miXJ68t9Mq3 zzOlMjp65O1dCq&@_Q#!j=ge>Bo|!ZEetzHiM(Ap*QIIl|;^5#=JXcrI!@aU#OBn|T5lePyMQ~duaZ!Kmj)Ozl@%O^*aV~m` zgTn-RuA=_rc!Byu5&Jd8O%ktog;z`NeN&%g);}0w@6$kob8t(cDc46_Dmby9O~A$}reB|6YWJIODSTpnPFVk)wBaM(3&7-HEi!m}-!gVAHq_n$s;ce-6&kl{8Uty*B6 ztSjBU(%%)uRz4APeU;E0o2d9fvQutUHr_*;O5&_9vzEhp_7QeSw$As)Xc2iD z9VO?u$_287ur1X!Vo9{M2(FepvQlWspK4sZ{2d)%tX;q~jCX)xieiTvpp=Xrx8{|?Y|N!(+gX)%uX`XKr9a{6#T*pGO82Q#Mh~p z&|!NV&wdpzAvayWDiR;AHgBr^7_V|*CEGTCu@`!rs?&58Pf3LaUnuC5(F!j_}*=>pdp(udtHT?PCT zEnOr}1Z$JmwAFKW6q1wEKns%nw&H^|Q(q$QD-6m?Lq-f! zlMHIA1yCb{G>x}Avbqlz)UQ&w;bNX@s`4jagkQ1u%UEB)&EMW3JM|DXu6yb69L%TV zoml}-m<wWZzgd zLx{>n$#s!IZRsLXaxM}Kg^U^R;SKA=Eie4l_zHQT@@m zU2VUgtfNW)3-DF)8+r7w)ahh{pQ(!MeIEwn{g11pyA-U&iLuJ#{$M^`c2A13%`26i zPGX5smg1o?N$#(&`l!r#zwa-$=kGxnzvi(Te_?;7-j%0q)~J{E9zM@cp{-)+awmwQ;FSUD4v(Va8pm!C z#Ya`?R>#&n!eZox&nj=mD|Qt8CtCVr8r?%9#oDm0c8so|`vjE)N@Y{T(@KNSV=(e$ zTccN<=`J)jclxGY+koA%E}UO1#F~iGTIhESm4yEbTJzmeVH!c~r^B7Jw(3|0s~_aw zdXXMVqLGPBdRLA8&Q|Lzcb8VGqJH7bN$yMpP72Z0cNo)-Qk{c}!FFYz8$ks@p?I%H zg@2zwa;Gd~&UFDmLA~sB3KORgg-sg5Ju-wm5$>wtXGh4>s7C4P^oaW}TZ^2$gZ>z2 z2tFONGUZ=(2a^dYUKzAgBl0c}^1lBDZt+E*1OLqYe6k$fXh?ed?ITTOLL$aapLs~7 zrUSQ`$0ZUg!##I)j6?5(b4`jl#IC5O6qm>-y;z!lq zf-wk}Syc$0hi>@f-W<>jFj*GAeG)jRrVh^e2T_J6p8PM20TY*B36`84*_MY;>Dk*I z{)<81gqzzQ2pU`c#41skmH0X0HUTClK=4r&mam%sv?> zQlv_YnCm#q-nhBInqC%iyLw!2uj4Wf8VLkhu;Y2zJ0NP{hPY|ahrgV41>tdp3_$-U z4DXjS6W_LzKmJSQp99K&qyH1p|D^2yO7xV++|9O!Bl^O!t`ls7WPR$#4Y3E~`Z^7} z_kG%1v`OrneeVb>&v<42@Pzsry1F`?HDjhSu69iM?@Qd%eCk~@qWKQL7U+}C5O8&z zWS0V2e`Op;#vd!pP;rB)J3TR(UugGbaI>gb0kaN!u)$E7oNxE^kOx_SU@`?Y1c$|v z=D-{K?cVzC>@NrNMK`5aaM6n^TzAE>G5lu#@!uS$3ZTX!Vi2zjuA^>MJ6n8ek3V@| zamUV$2S)(w^N|GMX4}^0_sIcu?4seJ0f$3=hC31@$z&qhuc{pxjV3rU-Fffq0BlrQ z$N+D?SoB1gbs0~e%q*pjytBgiPm946Sw!1B)LLLyBJP(nKiw@l6rr`8ZO6=B*1CJ% zzvi>GtcBPD0ybOR{3drIXyawlZaDn+76dsap087DSYuK@jxlbDclqjr@42>32}kQ^ z9St{*aoay~{d2@uKNM?QO>TX*x_5VMxwTwx>%RNl%m;-YNaQSb(iVDC)}$2Rh5cn| zaoULOZF0^c{&ek+t*X1U?r+s;_QCcpl+V8u|9O$+cf4n-G=o;#pMXw@MFnQG4CT1` zSLzJXw@oASR*SYsMj3TXFh;AZt&2*NSG@ZtnNmbRp*{UCl1}l*N^yZ3?4r#-FU5z3 zcbYFc++@&~-WOW}8PAkB5>FAz{avNb`ZQ5Gh;=KtBz9VMM_~j;m)<&!F@jvAQA$c7{zK#83WDFgaYei4D1x2-3^Et>FleQcHryjbG%fS98mG3^1s=)S zfrgf%;-!VW(eEB#U#n4=#-d=A^k;tp18_0T=BC)aKz6mUu=3txH;VTvWk1hrmB|CJ zEhMi`F4(-W9g=%lJF=3*zi)zJs|}KGmceCuj)!JHLN+o(on~TDsA&zs4(#dmTGQ0> z`pr4RV$7c>%`Rh}BJhh@+6`;W+36vS9NB^=8YXggi>`&jH084ly=MMK;+nM)H1Ev# zGyfn1sy6j*4|w{&UV(r_Xs0h zwovNYv-Q*XZnTye*PZW7WLQ1*WijFd5q+!rdEuMXv+p;-orfA%Mr)5!N@Gk zvhJhbZ-?D=qZ-vYoR!&oR?|W`64%Nj7pcR!szYVf@irU@rj)f?QXyTeNhtr;k&;xbUJHQra?z z9U|CGtc4h1<#g8GX5;m+Crzje#knx3h0ua;Hs1 z5oG9jTXNpK!#!w3f6{OF4qSSY3yr%OV8ht`o0PU|tqr(+1-U48+WdomHbz?|#0Hc6 z!2B!)D*d}V&3Z60gF9yDW-Wl=-dYSKmyZzdt%*uW0X)?)k}1O&D|e#t8h@xNy_Dl) zUp{A0>2ny2cwh9zkA5wphVm2+L0voW;vK~3*>R>%Skl!ZyV^q!(IPuTx8FgZNBJ+% zZ{3MdEh1jJF>@f(?E0mXJ03r%jr5<6@Kp8qI~{KC^IKf#dU@c08-pLc{q#oI)RP2% zb4j$!Od{O{p>eC9Qhb&iI72?e$o@@7{Xc%v|1P`!J-qiZVEXj5e)j%fTpp3`SAsLj z6R3bTcN%Yz@IMN-e;NI>?seixbJ9Syf?x98%>IM>kkLeC*+G>aVe)wLvI$ zpz$Z2?3%Eu_8Q#kGV_?)c$v9~$1kS0>VtxyhSq-oc$I3zsZK9w7rJDjo4E4hJ}g7f zJ4%;V%J;vc<=*#<-gM>nFsz^%_Lt5uY&q7s#0b-yYv2>Ov(Y}gcv!t_B}M7o==%G& z(vf`3LrTb{{akQT7+NTQ?e0e%@Pq9;Y`qpBsUfr&oi*|j11WLi-QWWk=98u?F?MYd zrGsOsBxV!%n(GkndYtPg5R4YZi5Hj;;T0nO>y+$QZd}M0#f*zxS}hw_CPofGH5c8+ zUL?U|w;%i0jR!6OKHJI^$BqhVa#@+M(MZZ zHI!ktE4JG}MihyF61tAUD6>4QXkwR`0=zHTqc&!h$%{Pq+V3A?jzcE_Tv+_Elv3rZ z1~SA0f62lcn9t(<)4i=(b=B`?bP+a7Kl!|X_{;Ik|D(Jq0Un|bMV zzxGlKJ7kdhZj||0c!bl@Ld{-6&YL9r@T}tz*+-w8d+ana0g_T1S(sp{4*%SU z%W0cgT$9S4U7_4x`Vuhpsk^?wu?YC5P+oBA%Ct+yvH*S5lIK- zMxb47mCx@EL%AS+yv02mWA9L30R!51)ny=XT(Hk1msko1Aj|U{l+5y<)+!21EFuji2xUB6v zCmuGfwR#c1R5sx>cMC&@)e)RnF&TKYk0f5)1HRYZH5wB3hC)}8(?2Bi1fPdIMC(a0 z(;I!D7I%}3cG@LqZ-C^ZRKEY5s(hkPz4h!Pb{haz(#e3xU8N3H?vHQrc06}ULRnku zGv5RKc_Or|<2%wQ3;-BEEUWqej*NDObU#|08JDMArKzjD^4?zCe#yNh+V>%|QJx#j zZ!)59);(%fd(aRd*Zimv)@KsQe?Abmtt}LDc5GS)0fCNZTczDtbpo5C*S$MScdp(0 z-r@Q537`i&yS%2!3i5Hz2W#W;%olvN#D)g^SsRA>^C8#cdFKv_{s=Ht=$Y$8^B2df zgS?%iN4InJoFD`+zzyY-pUY84-*hDa$9X)XRg1P|FghLWa9(w$ zASV|cV~>h=iSkSHUDbNPk!q&k2v9Pr!(nGp$^8P*B1^j`Z`w82g6Vf>&ZNCX1_>L(bU? zkj5kWXkl4EieTm_!HF~5{1+z=S-CdlI~v$7se>0TQ!iw<9!SUT8(is` zQ1{kl%|6I@k3}DqZ-qIbyvmjBOy}c(jTxV_7pcE@Wzv74JQAt=s+XE_M48M_!)+KC zfgeAVj+y1jfS7vjElXrGam_^WDP9pIB?zRT^>gzwV|3iJ>mDVvnVj^7rLwq3Wg6ACTv;)1-gr>r!4pxKz=X7xPF1hIEglV85Jh8&6~POg7;pXqB}08zshN zl>?KBJdC2T!`1m=LfPhB5Id&g^q@0pSE?Xsb}jJSOo&1`x(hbZ2xq_uaWzfPyhvjDq8 ziapU6-kB2{!iZNiK7Ab+-wz@cZr3SBg=a^DPgmJoFTCfC>vHV5WYiyPh||@5eP#7f znBPEbNH|xO$IQQmP}mE(Q318<8`;xVZE(u@T`8b`ww03hLzwrh5tVITgRklL;gNO~ z#Q;?SS-C`!gB@qqR zFlP{!;InFq+h{(_eP%{Dz4tu-jP5CC5-Ny!fBQQ1oh&qawz~rhyll zJ$+}f!Yqh6;QT(AC!v8>e}SiflxFvZT2eyf`k8Kt;L2)D^Km>gCd{xl;jnfR(lR3! z6r*aq1g;Bm>o4J4FU;rsh9q#3VC3LWx^(g}cCrjj={Y`h=+IMw?6u*uB}54=Km8VK z#O*YyUI)X=usNvOj2&xvvD~Plo^d2(4(w|a3SPUnQ%NJClXDQa_W7N2brU2|vB#Vs z)!Ai>)5`Y% literal 0 HcmV?d00001 diff --git a/docs/source/index.rst b/docs/source/index.rst index 3a7ee4b1..6d59becc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,6 +34,7 @@ The Xeus-Cpp is a Jupyter kernel for the C++ programming language InstallationAndUsage UsingXeus-Cpp tutorials + magics dev-build-options debug FAQ diff --git a/docs/source/magics.rst b/docs/source/magics.rst new file mode 100644 index 00000000..d8e64274 --- /dev/null +++ b/docs/source/magics.rst @@ -0,0 +1,41 @@ +Magics commands +-------------------- + +Magics are special commands for the kernel that are not part of the C++ +programming language. + +There are defined with the symbol ``%`` for a line magic and ``%%`` for a cell +magic. + +Here are the magics available in xeus-cpp. + +%%xassist +======================== + +Leverage the large language models to assist in your development process. Currently supported models are Gemini - gemini-1.5-flash, OpenAI - gpt-3.5-turbo-16k. + +- Save the api key + +.. code:: + + %%xassist model --save-key + key + +- Use the model + +.. code:: + + %%xassist model + prompt + +- Reset model and clear chat history + +.. code:: + + %%xassist model --refresh + +- Example + +.. image:: gemini.png + +A new prompt is sent to the model everytime and the functionality to use previous context will be added soon. \ No newline at end of file diff --git a/environment-wasm-build.yml b/environment-wasm-build.yml index 27d414af..130ba900 100644 --- a/environment-wasm-build.yml +++ b/environment-wasm-build.yml @@ -4,4 +4,4 @@ channels: dependencies: - cmake - emsdk >=3.1.11 - - empack >=2.0.1 + - empack >=2.0.1 \ No newline at end of file diff --git a/environment-wasm-host.yml b/environment-wasm-host.yml index 99db9689..7d310366 100644 --- a/environment-wasm-host.yml +++ b/environment-wasm-host.yml @@ -8,4 +8,4 @@ dependencies: - xeus - CppInterOp>=1.3.0 - cpp-argparse - - pugixml + - pugixml \ No newline at end of file diff --git a/src/xinterpreter.cpp b/src/xinterpreter.cpp index e715d457..0bbc7329 100644 --- a/src/xinterpreter.cpp +++ b/src/xinterpreter.cpp @@ -28,6 +28,9 @@ #include "xinput.hpp" #include "xinspect.hpp" #include "xmagics/os.hpp" +#ifndef EMSCRIPTEN +#include "xmagics/xassist.hpp" +#endif #include "xparser.hpp" #include "xsystem.hpp" @@ -404,5 +407,8 @@ __get_cxx_version () // preamble_manager["magics"].get_cast().register_magic("file", writefile()); // preamble_manager["magics"].get_cast().register_magic("timeit", timeit(&m_interpreter)); // preamble_manager["magics"].get_cast().register_magic("python", pythonexec()); +#ifndef EMSCRIPTEN + preamble_manager["magics"].get_cast().register_magic("xassist", xassist()); +#endif } } diff --git a/src/xmagics/xassist.cpp b/src/xmagics/xassist.cpp new file mode 100644 index 00000000..a2985e7d --- /dev/null +++ b/src/xmagics/xassist.cpp @@ -0,0 +1,328 @@ +/************************************************************************************ + * Copyright (c) 2023, xeus-cpp contributors * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ************************************************************************************/ +#include "xassist.hpp" + +#define CURL_STATICLIB +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::json; + +// TODO: Implement xplugin to separate the magics from the main code. +// TODO: Add support for open-source models. +namespace xcpp +{ + class api_key_manager + { + public: + + static void save_api_key(const std::string& model, const std::string& api_key) + { + std::string api_key_file_path = model + "_api_key.txt"; + std::ofstream out(api_key_file_path); + if (out) + { + out << api_key; + out.close(); + std::cout << "API key saved for model " << model << std::endl; + } + else + { + std::cerr << "Failed to open file for writing API key for model " << model << std::endl; + } + } + + // Method to load the API key for a specific model + static std::string load_api_key(const std::string& model) + { + std::string api_key_file_path = model + "_api_key.txt"; + std::ifstream in(api_key_file_path); + std::string api_key; + if (in) + { + std::getline(in, api_key); + in.close(); + return api_key; + } + + std::cerr << "Failed to open file for reading API key for model " << model << std::endl; + return ""; + } + }; + + class chat_history + { + public: + + static std::string chat(const std::string& model, const std::string& user, const std::string& cell) + { + return append_and_read_back(model, user, "\"" + cell + "\""); + } + + static std::string chat(const std::string& model, const std::string& user, const nlohmann::json& cell) + { + return append_and_read_back(model, user, cell.dump()); + } + + static void refresh(const std::string& model) + { + std::string chat_history_file_path = model + "_chat_history.txt"; + std::ofstream out(chat_history_file_path, std::ios::out); + } + + private: + + static std::string + append_and_read_back(const std::string& model, const std::string& user, const std::string& serialized_cell) + { + std::string chat_history_file_path = model + "_chat_history.txt"; + std::ofstream out; + bool is_empty = is_file_empty(chat_history_file_path); + + out.open(chat_history_file_path, std::ios::app); + if (!out) + { + std::cerr << "Failed to open file for writing chat history for model " << model << std::endl; + return ""; + } + + if (!is_empty) + { + out << ", "; + } + + if (model == "gemini") + { + out << R"({ "role": ")" << user << R"(", "parts": [ { "text": )" << serialized_cell << "}]}\n"; + } + else + { + out << R"({ "role": ")" << user << R"(", "content": )" << serialized_cell << "}\n"; + } + + out.close(); + + return read_file_content(chat_history_file_path); + } + + static bool is_file_empty(const std::string& file_path) + { + std::ifstream file(file_path, std::ios::ate); // Open the file at the end + if (!file) // If the file cannot be opened, it might not exist + { + return true; // Consider non-existent files as empty + } + return file.tellg() == 0; + } + + static std::string read_file_content(const std::string& file_path) + { + std::ifstream in(file_path); + std::stringstream buffer_stream; + buffer_stream << in.rdbuf(); + return buffer_stream.str(); + } + }; + + class curl_helper + { + private: + + CURL* m_curl; + curl_slist* m_headers; + + public: + + curl_helper() + : m_curl(curl_easy_init()) + , m_headers(curl_slist_append(nullptr, "Content-Type: application/json")) + { + } + + ~curl_helper() + { + if (m_curl) + { + curl_easy_cleanup(m_curl); + } + if (m_headers) + { + curl_slist_free_all(m_headers); + } + } + + // Delete copy constructor and copy assignment operator + curl_helper(const curl_helper&) = delete; + curl_helper& operator=(const curl_helper&) = delete; + + // Delete move constructor and move assignment operator + curl_helper(curl_helper&&) = delete; + curl_helper& operator=(curl_helper&&) = delete; + + std::string + perform_request(const std::string& url, const std::string& post_data, const std::string& auth_header = "") + { + if (!auth_header.empty()) + { + m_headers = curl_slist_append(m_headers, auth_header.c_str()); + } + + curl_easy_setopt(m_curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(m_curl, CURLOPT_HTTPHEADER, m_headers); + curl_easy_setopt(m_curl, CURLOPT_POSTFIELDS, post_data.c_str()); + + std::string response; + curl_easy_setopt( + m_curl, + CURLOPT_WRITEFUNCTION, + +[](const char* in, size_t size, size_t num, std::string* out) + { + const size_t total_bytes(size * num); + out->append(in, total_bytes); + return total_bytes; + } + ); + curl_easy_setopt(m_curl, CURLOPT_WRITEDATA, &response); + + CURLcode res = curl_easy_perform(m_curl); + if (res != CURLE_OK) + { + std::cerr << "CURL request failed: " << curl_easy_strerror(res) << std::endl; + return ""; + } + + return response; + } + }; + + std::string gemini(const std::string& cell, const std::string& key) + { + curl_helper curl_helper; + const std::string chat_message = xcpp::chat_history::chat("gemini", "user", cell); + const std::string url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=" + + key; + const std::string post_data = R"({"contents": [ )" + chat_message + R"(]})"; + + std::string response = curl_helper.perform_request(url, post_data); + + json j = json::parse(response); + if (j.find("error") != j.end()) + { + std::cerr << "Error: " << j["error"]["message"] << std::endl; + return ""; + } + + const std::string chat = xcpp::chat_history::chat( + "gemini", + "model", + j["candidates"][0]["content"]["parts"][0]["text"] + ); + + return j["candidates"][0]["content"]["parts"][0]["text"]; + } + + std::string openai(const std::string& cell, const std::string& key) + { + curl_helper curl_helper; + const std::string url = "https://api.openai.com/v1/chat/completions"; + const std::string chat_message = xcpp::chat_history::chat("openai", "user", cell); + const std::string post_data = R"({ + "model": "gpt-3.5-turbo-16k", + "messages": [)" + chat_message + + R"(], + "temperature": 0.7 + })"; + std::string auth_header = "Authorization: Bearer " + key; + + std::string response = curl_helper.perform_request(url, post_data, auth_header); + + json j = json::parse(response); + + if (j.find("error") != j.end()) + { + std::cerr << "Error: " << j["error"]["message"] << std::endl; + return ""; + } + + const std::string chat = xcpp::chat_history::chat( + "openai", + "assistant", + j["choices"][0]["message"]["content"] + ); + + return j["choices"][0]["message"]["content"]; + } + + void xassist::operator()(const std::string& line, const std::string& cell) + { + try + { + std::istringstream iss(line); + std::vector tokens( + std::istream_iterator{iss}, + std::istream_iterator() + ); + + std::vector models = {"gemini", "openai"}; + std::string model = tokens[1]; + + if (std::find(models.begin(), models.end(), model) == models.end()) + { + std::cerr << "Model not found." << std::endl; + return; + } + + if (tokens.size() > 2) + { + if (tokens[2] == "--save-key") + { + xcpp::api_key_manager::save_api_key(model, cell); + return; + } + + if (tokens[2] == "--refresh") + { + xcpp::chat_history::refresh(model); + return; + } + } + + std::string key = xcpp::api_key_manager::load_api_key(model); + if (key.empty()) + { + std::cerr << "API key for model " << model << " is not available." << std::endl; + return; + } + + std::string response; + if (model == "gemini") + { + response = gemini(cell, key); + } + else if (model == "openai") + { + response = openai(cell, key); + } + + std::cout << response; + } + catch (const std::runtime_error& e) + { + std::cerr << "Caught an exception: " << e.what() << std::endl; + } + catch (...) + { + std::cerr << "Caught an unknown exception" << std::endl; + } + } +} // namespace xcpp \ No newline at end of file diff --git a/src/xmagics/xassist.hpp b/src/xmagics/xassist.hpp new file mode 100644 index 00000000..363dcbd0 --- /dev/null +++ b/src/xmagics/xassist.hpp @@ -0,0 +1,26 @@ +/************************************************************************************ + * Copyright (c) 2023, xeus-cpp contributors * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ************************************************************************************/ + +#ifndef XEUS_CPP_XASSIST_MAGIC_HPP +#define XEUS_CPP_XASSIST_MAGIC_HPP + +#include + +#include "xeus-cpp/xmagics.hpp" + +namespace xcpp +{ + class xassist : public xmagic_cell + { + public: + + XEUS_CPP_API + void operator()(const std::string& line, const std::string& cell) override; + }; +} // namespace xcpp +#endif \ No newline at end of file diff --git a/test/test_interpreter.cpp b/test/test_interpreter.cpp index 364e1fd7..eea3c28a 100644 --- a/test/test_interpreter.cpp +++ b/test/test_interpreter.cpp @@ -19,6 +19,7 @@ #include "../src/xparser.hpp" #include "../src/xsystem.hpp" #include "../src/xmagics/os.hpp" +#include "../src/xmagics/xassist.hpp" #include "../src/xinspect.hpp" @@ -886,4 +887,79 @@ TEST_SUITE("xinspect"){ cmp.child_value = "nonexistentMethod"; REQUIRE(cmp(node) == false); } +} + +TEST_SUITE("xassist"){ + + TEST_CASE("model_not_found"){ + xcpp::xassist assist; + std::string line = "%%xassist testModel"; + std::string cell = "test input"; + + StreamRedirectRAII redirect(std::cerr); + + assist(line, cell); + + REQUIRE(redirect.getCaptured() == "Model not found.\n"); + + } + + TEST_CASE("gemini_save"){ + xcpp::xassist assist; + std::string line = "%%xassist gemini --save-key"; + std::string cell = "1234"; + + assist(line, cell); + + std::ifstream infile("gemini_api_key.txt"); + std::string content; + std::getline(infile, content); + + REQUIRE(content == "1234"); + infile.close(); + + StreamRedirectRAII redirect(std::cerr); + + assist("%%xassist gemini", "hello"); + + REQUIRE(!redirect.getCaptured().empty()); + + std::remove("gemini_api_key.txt"); + } + + TEST_CASE("gemini"){ + xcpp::xassist assist; + std::string line = "%%xassist gemini"; + std::string cell = "hello"; + + StreamRedirectRAII redirect(std::cerr); + + assist(line, cell); + + REQUIRE(!redirect.getCaptured().empty()); + } + + TEST_CASE("openai"){ + xcpp::xassist assist; + std::string line = "%%xassist openai --save-key"; + std::string cell = "1234"; + + assist(line, cell); + + std::ifstream infile("openai_api_key.txt"); + std::string content; + std::getline(infile, content); + + REQUIRE(content == "1234"); + infile.close(); + + StreamRedirectRAII redirect(std::cerr); + + assist("%%xassist openai", "hello"); + + REQUIRE(!redirect.getCaptured().empty()); + + std::remove("openai_api_key.txt"); + } + } \ No newline at end of file