Skip to content

Commit

Permalink
Sql markdown response (#16103)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoseLuckmann authored Sep 20, 2024
1 parent 81ecb2a commit 762a45e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from llama_index.core.response_synthesizers import (
get_response_synthesizer,
)
from llama_index.core.schema import QueryBundle
from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.settings import Settings
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from sqlalchemy import Table
Expand Down Expand Up @@ -328,6 +328,7 @@ def __init__(
self,
llm: Optional[LLM] = None,
synthesize_response: bool = True,
markdown_response: bool = False,
response_synthesis_prompt: Optional[BasePromptTemplate] = None,
callback_manager: Optional[CallbackManager] = None,
refine_synthesis_prompt: Optional[BasePromptTemplate] = None,
Expand All @@ -352,6 +353,7 @@ def __init__(
_validate_prompt(self._refine_synthesis_prompt, DEFAULT_REFINE_PROMPT)

self._synthesize_response = synthesize_response
self._markdown_response = markdown_response
self._verbose = verbose
self._streaming = streaming
super().__init__(callback_manager=callback_manager or Settings.callback_manager)
Expand All @@ -374,6 +376,27 @@ def _get_prompt_modules(self) -> PromptMixinType:
def sql_retriever(self) -> NLSQLRetriever:
"""Get SQL retriever."""

def _format_result_markdown(self, retrieved_nodes: List[NodeWithScore]) -> str:
"""Format the result in markdown."""
tables = []
for node_with_score in retrieved_nodes:
node = node_with_score.node
metadata = node.metadata

col_keys = metadata.get("col_keys", [])
results = metadata.get("result", [])
table_header = "| " + " | ".join(col_keys) + " |\n"
table_separator = "|" + "|".join(["---"] * len(col_keys)) + "|\n"

table_rows = ""
for row in results:
table_rows += "| " + " | ".join(str(item) for item in row) + " |\n"

markdown_table = table_header + table_separator + table_rows
tables.append(markdown_table)

return "\n\n".join(tables).strip()

def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
retrieved_nodes, metadata = self.sql_retriever.retrieve_with_metadata(
Expand Down Expand Up @@ -402,7 +425,10 @@ def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
return cast(StreamingResponse, response)
return cast(Response, response)
else:
response_str = "\n".join([node.text for node in retrieved_nodes])
if self._markdown_response:
response_str = self._format_result_markdown(retrieved_nodes)
else:
response_str = "\n".join([node.text for node in retrieved_nodes])
return Response(response=response_str, metadata=metadata)

async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
Expand Down Expand Up @@ -457,6 +483,7 @@ def __init__(
text_to_sql_prompt: Optional[BasePromptTemplate] = None,
context_query_kwargs: Optional[dict] = None,
synthesize_response: bool = True,
markdown_response: bool = False,
response_synthesis_prompt: Optional[BasePromptTemplate] = None,
refine_synthesis_prompt: Optional[BasePromptTemplate] = None,
tables: Optional[Union[List[str], List[Table]]] = None,
Expand All @@ -483,6 +510,7 @@ def __init__(
)
super().__init__(
synthesize_response=synthesize_response,
markdown_response=markdown_response,
response_synthesis_prompt=response_synthesis_prompt,
refine_synthesis_prompt=refine_synthesis_prompt,
llm=llm,
Expand Down
13 changes: 13 additions & 0 deletions llama-index-core/tests/indices/struct_store/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ def test_sql_index_query(
response = nl_table_engine.query("test_table:user_id,foo")
assert str(response) == sql_to_test

# query with markdown return
nl_table_engine = NLSQLTableQueryEngine(
index.sql_database, synthesize_response=False, markdown_response=True
)
response = nl_table_engine.query("test_table:user_id,foo")
assert (
str(response)
== """| user_id | foo |
|---|---|
| 2 | bar |
| 8 | hello |"""
)


def test_sql_index_async_query(
allow_networking: Any,
Expand Down

0 comments on commit 762a45e

Please sign in to comment.