Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(VirtualDataFrame): virtual dataframe to load data on demand and enable direct_sql #1434

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .agent import Agent
from .helpers.cache import Cache
from .dataframe.base import DataFrame
from .dataframe.loader import DatasetLoader
from .data_loader.loader import DatasetLoader

# Global variable to store the current agent
_current_agent = None
Expand Down Expand Up @@ -61,7 +61,7 @@ def follow_up(query: str):
_dataset_loader = DatasetLoader()


def load(dataset_path: str) -> DataFrame:
def load(dataset_path: str, virtualized=False) -> DataFrame:
"""
Load data based on the provided dataset path.

Expand All @@ -72,7 +72,7 @@ def load(dataset_path: str) -> DataFrame:
DataFrame: A new PandasAI DataFrame instance with loaded data.
"""
global _dataset_loader
return _dataset_loader.load(dataset_path)
return _dataset_loader.load(dataset_path, virtualized)


__all__ = [
Expand Down
49 changes: 40 additions & 9 deletions pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pandas as pd
from pandasai.agent.base_security import BaseSecurity

from pandasai.data_loader.schema_validator import is_schema_source_same
from pandasai.llm.bamboo_llm import BambooLLM
from pandasai.pipelines.chat.chat_pipeline_input import ChatPipelineInput
from pandasai.pipelines.chat.code_execution_pipeline_input import (
Expand Down Expand Up @@ -62,17 +64,13 @@ def __init__(

self.dfs = dfs if isinstance(dfs, list) else [dfs]

# Validate SQL connectors
sql_connectors = [
df
for df in self.dfs
if hasattr(df, "type") and df.type in ["sql", "postgresql"]
]
if len(sql_connectors) > 1:
raise InvalidConfigError("Cannot use multiple SQL connectors")

# Instantiate the context
self.config = self.get_config(config)

# Validate df input with configurations
self.validate_input()

# Initialize the context
self.context = PipelineContext(
dfs=self.dfs,
config=self.config,
Expand Down Expand Up @@ -106,6 +104,39 @@ def __init__(
self.pipeline = None
self.security = security

def validate_input(self):
from pandasai.dataframe.virtual_dataframe import VirtualDataFrame

# Check if all DataFrames are VirtualDataFrame, and set direct_sql accordingly
all_virtual = all(isinstance(df, VirtualDataFrame) for df in self.dfs)
if all_virtual:
self.config.direct_sql = True

# Validate the configurations based on direct_sql flag all have same source
if self.config.direct_sql and all_virtual:
base_schema_source = self.dfs[0].schema
for df in self.dfs[1:]:
# Ensure all DataFrames have the same source in direct_sql mode

if not is_schema_source_same(base_schema_source, df.schema):
raise InvalidConfigError(
"Direct SQL requires all connectors to be of the same type, "
"belong to the same datasource, and have the same credentials."
)
else:
# If not using direct_sql, ensure all DataFrames have the same source
if any(isinstance(df, VirtualDataFrame) for df in self.dfs):
base_schema_source = self.dfs[0].schema
for df in self.dfs[1:]:
if not is_schema_source_same(base_schema_source, df.schema):
raise InvalidConfigError(
"All DataFrames must belong to the same source."
)
self.config.direct_sql = True
else:
# Means all are none virtual
self.config.direct_sql = False

def configure(self):
# Add project root path if save_charts_path is default
if (
Expand Down
103 changes: 80 additions & 23 deletions pandasai/dataframe/loader.py → pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import copy
import os
import yaml
import pandas as pd
from datetime import datetime, timedelta
import hashlib

from pandasai.dataframe.base import DataFrame
from pandasai.dataframe.virtual_dataframe import VirtualDataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.helpers.path import find_project_root
from .base import DataFrame
import importlib
from typing import Any
from .query_builder import QueryBuilder
Expand All @@ -18,27 +20,35 @@ def __init__(self):
self.schema = None
self.dataset_path = None

def load(self, dataset_path: str, lazy=False) -> DataFrame:
def load(self, dataset_path: str, virtualized=False) -> DataFrame:
self.dataset_path = dataset_path
self._load_schema()
self._validate_source_type()
if not virtualized:
cache_file = self._get_cache_file_path()

cache_file = self._get_cache_file_path()
if self._is_cache_valid(cache_file):
return self._read_cache(cache_file)

if self._is_cache_valid(cache_file):
return self._read_cache(cache_file)
df = self._load_from_source()
df = self._apply_transformations(df)
self._cache_data(df, cache_file)

df = self._load_from_source()
df = self._apply_transformations(df)
self._cache_data(df, cache_file)
table_name = self.schema["source"]["table"]

return DataFrame(df, schema=self.schema)
return DataFrame(df, schema=self.schema, name=table_name)
else:
# Initialize new dataset loader for virtualization
data_loader = self.copy()
table_name = self.schema["source"]["table"]
return VirtualDataFrame(
schema=self.schema, data_loader=data_loader, name=table_name
)

def _load_schema(self):
schema_path = os.path.join(
find_project_root(), "datasets", self.dataset_path, "schema.yaml"
)
print(schema_path)
if not os.path.exists(schema_path):
raise FileNotFoundError(f"Schema file not found: {schema_path}")

Expand Down Expand Up @@ -82,32 +92,67 @@ def _read_cache(self, cache_file: str) -> DataFrame:
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

def _load_from_source(self) -> pd.DataFrame:
source_type = self.schema["source"]["type"]
connection_info = self.schema["source"].get("connection", {})
query_builder = QueryBuilder(self.schema)
query = query_builder.build_query()

def _get_loader_function(self, source_type: str):
"""
Get the loader function for a specified data source type.
"""
try:
module_name = SUPPORTED_SOURCES[source_type]
module = importlib.import_module(module_name)

if source_type in [
if source_type not in {
"mysql",
"postgres",
"cockroach",
"sqlite",
"cockroachdb",
]:
load_function = getattr(module, f"load_from_{source_type}")
return load_function(connection_info, query)
else:
raise InvalidDataSourceType("Invalid data source type")
}:
raise InvalidDataSourceType(
f"Unsupported data source type: {source_type}"
)

return getattr(module, f"load_from_{source_type}")

except KeyError:
raise InvalidDataSourceType(f"Unsupported data source type: {source_type}")

except ImportError as e:
raise ImportError(
f"{source_type.capitalize()} connector not found. "
f"Please install the {module_name} library."
f"Please install the {SUPPORTED_SOURCES[source_type]} library."
) from e

def _load_from_source(self) -> pd.DataFrame:
query_builder = QueryBuilder(self.schema)
query = query_builder.build_query()
return self.execute_query(query)

def load_head(self) -> pd.DataFrame:
query_builder = QueryBuilder(self.schema)
query = query_builder.get_head_query()
return self.execute_query(query)

def get_row_count(self) -> int:
query_builder = QueryBuilder(self.schema)
query = query_builder.get_row_count()
result = self.execute_query(query)
return result.iloc[0, 0]

def execute_query(self, query: str) -> pd.DataFrame:
source = self.schema.get("source", {})
source_type = source.get("type")
connection_info = source.get("connection", {})

if not source_type:
raise ValueError("Source type is missing in the schema.")

load_function = self._get_loader_function(source_type)

try:
return load_function(connection_info, query)
except Exception as e:
raise RuntimeError(
f"Failed to execute query for source type '{source_type}' with query: {query}"
) from e

def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -140,3 +185,15 @@ def _cache_data(self, df: pd.DataFrame, cache_file: str):
df.to_csv(cache_file, index=False)
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

def copy(self) -> "DatasetLoader":
"""
Create a new independent copy of the current DatasetLoader instance.

Returns:
DatasetLoader: A new instance with the same state.
"""
new_loader = DatasetLoader()
new_loader.schema = copy.deepcopy(self.schema)
new_loader.dataset_path = self.dataset_path
return new_loader
55 changes: 55 additions & 0 deletions pandasai/data_loader/query_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Dict, Any, List, Union


class QueryBuilder:
def __init__(self, schema: Dict[str, Any]):
self.schema = schema

def build_query(self) -> str:
columns = self._get_columns()
table_name = self.schema["source"]["table"]
query = f"SELECT {columns} FROM {table_name}"

query += self._add_order_by()
query += self._add_limit()

return query

def _get_columns(self) -> str:
if "columns" in self.schema:
return ", ".join([col["name"] for col in self.schema["columns"]])
else:
return "*"

def _add_order_by(self) -> str:
if "order_by" not in self.schema:
return ""

order_by = self.schema["order_by"]
order_by_clause = self._format_order_by(order_by)
return f" ORDER BY {order_by_clause}"

def _format_order_by(self, order_by: Union[List[str], str]) -> str:
return ", ".join(order_by) if isinstance(order_by, list) else order_by

def _add_limit(self, n=None) -> str:
limit = n if n else (self.schema["limit"] if "limit" in self.schema else "")
return f" LIMIT {self.schema['limit']}" if limit else ""

def get_head_query(self, n=5):
source = self.schema.get("source", {})
source_type = source.get("type")

table_name = self.schema["source"]["table"]

columns = self._get_columns()

order_by = "RAND()"
if source_type in {"sqlite", "postgres"}:
order_by = "RANDOM()"

return f"SELECT {columns} FROM {table_name} ORDER BY {order_by} LIMIT {n}"

def get_row_count(self):
table_name = self.schema["source"]["table"]
return f"SELECT COUNT(*) FROM {table_name}"
9 changes: 9 additions & 0 deletions pandasai/data_loader/schema_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import json


def is_schema_source_same(schema1: dict, schema2: dict) -> bool:
return schema1.get("source").get("type") == schema2.get("source").get(
"type"
) and json.dumps(
schema1.get("source").get("connection"), sort_keys=True
) == json.dumps(schema2.get("source").get("connection"), sort_keys=True)
41 changes: 41 additions & 0 deletions pandasai/dataframe/virtual_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
import pandas as pd
from pandasai.dataframe.base import DataFrame

if TYPE_CHECKING:
from pandasai.data_loader.loader import DatasetLoader


class VirtualDataFrame(DataFrame):
_metadata: ClassVar[list] = [
"_loader",
"head",
"_head",
"name",
"description",
"schema",
"config",
"_agent",
"_column_hash",
]

def __init__(self, *args, **kwargs):
self._loader: DatasetLoader = kwargs.pop("data_loader", None)
if not self._loader:
raise Exception("Data loader is required for virtualization!")
self._head = None
super().__init__(self.get_head(), *args, **kwargs)

def head(self):
if self._head is None:
self._head = self._loader.load_head()

return self._head

@property
def rows_count(self) -> int:
return self._loader.get_row_count()

def execute_sql_query(self, query: str) -> pd.DataFrame:
return self._loader.execute_query(query)
Loading
Loading