diff --git a/src/xmagics/xassist.cpp b/src/xmagics/xassist.cpp index ddcbbe99..60971b0c 100644 --- a/src/xmagics/xassist.cpp +++ b/src/xmagics/xassist.cpp @@ -19,7 +19,6 @@ using json = nlohmann::json; // TODO: Implement xplugin to separate the magics from the main code. -// TODO: Add support for open-source models. namespace xcpp { class api_key_manager @@ -279,6 +278,53 @@ namespace xcpp } }; + std::string escape_special_cases(const std::string& input) + { + std::string escaped; + for (char c : input) + { + switch (c) + { + case '\\': + escaped += "\\\\"; + break; + case '\"': + escaped += "\\\""; + break; + case '\n': + escaped += "\\n"; + break; + case '\t': + escaped += "\\t"; + break; + case '\r': + escaped += "\\r"; + break; + case '\b': + escaped += "\\b"; + break; + case '\f': + escaped += "\\f"; + break; + default: + if (c < 0x20 || c > 0x7E) + { + // Escape non-printable ASCII characters and non-ASCII characters + std::array buffer{}; + std::stringstream ss; + ss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (c & 0xFFFF); + escaped += ss.str(); + } + else + { + escaped += c; + } + break; + } + } + return escaped; + } + std::string gemini(const std::string& cell, const std::string& key) { curl_helper curl_helper; @@ -369,8 +415,8 @@ namespace xcpp } const std::string post_data = R"({ - "model": [)" + model - + R"(], + "model": ")" + model + + R"(", "messages": [)" + chat_message + R"(], "temperature": 0.7 @@ -453,18 +499,21 @@ namespace xcpp } } + + const std::string prompt = escape_special_cases(cell); + std::string response; if (model == "gemini") { - response = gemini(cell, key); + response = gemini(prompt, key); } else if (model == "openai") { - response = openai(cell, key); + response = openai(prompt, key); } else if (model == "ollama") { - response = ollama(cell); + response = ollama(prompt); } std::cout << response; diff --git a/test/test_xcpp_kernel.py b/test/test_xcpp_kernel.py index f02b68d9..b21375a3 100644 --- a/test/test_xcpp_kernel.py +++ b/test/test_xcpp_kernel.py @@ -167,8 +167,6 @@ def test_notebooks(self): with open(out) as f: output_nb = nbformat.read(f, as_version=4) - check = True - # Iterate over the cells in the input and output notebooks for i, (input_cell, output_cell) in enumerate(zip(input_nb.cells, output_nb.cells)): if input_cell.cell_type == 'code' and output_cell.cell_type == 'code':