Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sql markdown response #16103

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from llama_index.core.response_synthesizers import (
get_response_synthesizer,
)
from llama_index.core.schema import QueryBundle
from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.settings import Settings
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from sqlalchemy import Table
Expand Down Expand Up @@ -328,6 +328,7 @@ def __init__(
self,
llm: Optional[LLM] = None,
synthesize_response: bool = True,
markdown_response: bool = False,
response_synthesis_prompt: Optional[BasePromptTemplate] = None,
callback_manager: Optional[CallbackManager] = None,
refine_synthesis_prompt: Optional[BasePromptTemplate] = None,
Expand All @@ -352,6 +353,7 @@ def __init__(
_validate_prompt(self._refine_synthesis_prompt, DEFAULT_REFINE_PROMPT)

self._synthesize_response = synthesize_response
self._markdown_response = markdown_response
self._verbose = verbose
self._streaming = streaming
super().__init__(callback_manager=callback_manager or Settings.callback_manager)
Expand All @@ -374,6 +376,27 @@ def _get_prompt_modules(self) -> PromptMixinType:
def sql_retriever(self) -> NLSQLRetriever:
"""Get SQL retriever."""

def _format_result_markdown(self, retrieved_nodes: List[NodeWithScore]) -> str:
"""Format the result in markdown."""
tables = []
for node_with_score in retrieved_nodes:
node = node_with_score.node
metadata = node.metadata

col_keys = metadata.get("col_keys", [])
results = metadata.get("result", [])
table_header = "| " + " | ".join(col_keys) + " |\n"
table_separator = "|" + "|".join(["---"] * len(col_keys)) + "|\n"

table_rows = ""
for row in results:
table_rows += "| " + " | ".join(str(item) for item in row) + " |\n"

markdown_table = table_header + table_separator + table_rows
tables.append(markdown_table)

return "\n\n".join(tables).strip()

def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
retrieved_nodes, metadata = self.sql_retriever.retrieve_with_metadata(
Expand Down Expand Up @@ -402,7 +425,10 @@ def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
return cast(StreamingResponse, response)
return cast(Response, response)
else:
response_str = "\n".join([node.text for node in retrieved_nodes])
if self._markdown_response:
response_str = self._format_result_markdown(retrieved_nodes)
else:
response_str = "\n".join([node.text for node in retrieved_nodes])
return Response(response=response_str, metadata=metadata)

async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
Expand Down Expand Up @@ -457,6 +483,7 @@ def __init__(
text_to_sql_prompt: Optional[BasePromptTemplate] = None,
context_query_kwargs: Optional[dict] = None,
synthesize_response: bool = True,
markdown_response: bool = False,
response_synthesis_prompt: Optional[BasePromptTemplate] = None,
refine_synthesis_prompt: Optional[BasePromptTemplate] = None,
tables: Optional[Union[List[str], List[Table]]] = None,
Expand All @@ -483,6 +510,7 @@ def __init__(
)
super().__init__(
synthesize_response=synthesize_response,
markdown_response=markdown_response,
response_synthesis_prompt=response_synthesis_prompt,
refine_synthesis_prompt=refine_synthesis_prompt,
llm=llm,
Expand Down
13 changes: 13 additions & 0 deletions llama-index-core/tests/indices/struct_store/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ def test_sql_index_query(
response = nl_table_engine.query("test_table:user_id,foo")
assert str(response) == sql_to_test

# query with markdown return
nl_table_engine = NLSQLTableQueryEngine(
index.sql_database, synthesize_response=False, markdown_response=True
)
response = nl_table_engine.query("test_table:user_id,foo")
assert (
str(response)
== """| user_id | foo |
|---|---|
| 2 | bar |
| 8 | hello |"""
)


def test_sql_index_async_query(
allow_networking: Any,
Expand Down
Loading