Skip to content

Commit

Permalink
Merge pull request #323 from truefoundry/carbondataloader
Browse files Browse the repository at this point in the history
Carbon dataloader and Internet search feature
  • Loading branch information
S1LV3RJ1NX authored Sep 6, 2024
2 parents 9f623b5 + ba1a0d8 commit 783634c
Show file tree
Hide file tree
Showing 23 changed files with 1,438 additions and 505 deletions.
7 changes: 7 additions & 0 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,19 @@ ENV MODELS_CONFIG_PATH=${MODELS_CONFIG_PATH}
ARG INFINITY_API_KEY
ENV INFINITY_API_KEY=${INFINITY_API_KEY}

ARG BRAVE_API_KEY
ENV BRAVE_API_KEY=${BRAVE_API_KEY}

ARG CARBON_AI_API_KEY
ENV CARBON_AI_API_KEY=${CARBON_AI_API_KEY}

ARG UNSTRUCTURED_IO_URL
ENV UNSTRUCTURED_IO_URL=${UNSTRUCTURED_IO_URL}

ARG UNSTRUCTURED_IO_API_KEY
ENV UNSTRUCTURED_IO_API_KEY=${UNSTRUCTURED_IO_API_KEY}


# Copy the project files
COPY . /app

Expand Down
7 changes: 6 additions & 1 deletion backend/modules/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from backend.modules.dataloaders.githubloader import GithubLoader
from backend.modules.dataloaders.loader import register_dataloader
from backend.modules.dataloaders.localdirloader import LocalDirLoader
from backend.modules.dataloaders.truefoundryloader import TrueFoundryLoader
from backend.modules.dataloaders.webloader import WebLoader
from backend.settings import settings

register_dataloader("localdir", LocalDirLoader)
register_dataloader("web", WebLoader)
register_dataloader("github", GithubLoader)
if settings.TFY_API_KEY:
from backend.modules.dataloaders.truefoundryloader import TrueFoundryLoader

register_dataloader("truefoundry", TrueFoundryLoader)
if settings.CARBON_AI_API_KEY:
from backend.modules.dataloaders.carbondataloader import CarbonDataLoader

register_dataloader("carbon", CarbonDataLoader)
187 changes: 187 additions & 0 deletions backend/modules/dataloaders/carbondataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import os
from typing import Any, Dict, Iterator, List, Optional

import requests
from pydantic import BaseModel, Field

from backend.logger import logger
from backend.modules.dataloaders.loader import BaseDataLoader
from backend.settings import settings
from backend.types import DataIngestionMode, DataPoint, DataSource, LoadedDataPoint


class _FileStatistics(BaseModel):
file_size: Optional[int] = None
mime_type: Optional[str] = None
file_format: Optional[str] = None


class _File(BaseModel):
id: int
name: str
presigned_url: Optional[str] = None
external_file_id: str
source: str
source_created_at: str
file_statistics: _FileStatistics = Field(default_factory=_FileStatistics)


class _UserFilesV2Response(BaseModel):
results: List[_File]
count: int


class _CarbonClient:
def __init__(self, api_key: str, customer_id: str):
self.api_key = api_key
self.customer_id = customer_id

def _request(self, method: str, endpoint: str, **kwargs):
headers = {
"Authorization": f"Bearer {self.api_key}",
"customer-id": self.customer_id,
}
response = requests.request(method, endpoint, headers=headers, **kwargs)
response.raise_for_status()
return response.json()

def query_user_files(
self,
pagination: Optional[Dict[str, int]] = None,
order_by: Optional[str] = None,
order_dir: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
include_raw_file: Optional[Optional[bool]] = None,
include_parsed_text_file: Optional[Optional[bool]] = None,
include_additional_files: Optional[Optional[bool]] = None,
) -> Iterator[_File]:
payload = {}
if pagination is not None:
payload["pagination"] = pagination
if order_by is not None:
payload["order_by"] = order_by
if order_dir is not None:
payload["order_dir"] = order_dir
if filters is not None:
payload["filters"] = filters
if include_raw_file is not None:
payload["include_raw_file"] = include_raw_file
if include_parsed_text_file is not None:
payload["include_parsed_text_file"] = include_parsed_text_file
if include_additional_files is not None:
payload["include_additional_files"] = include_additional_files

total = -1
count = 0

while total == -1 or count < total:
response = self._request(
"POST", "https://api.carbon.ai/user_files_v2", json=payload
)
page = _UserFilesV2Response.parse_obj(response)
if total == -1:
total = page.count
for file in page.results:
# TODO (chiragjn): There can be an edge case here where file.file_metadata.is_folder = True
yield file
count += len(page.results)
payload["pagination"]["offset"] = count


class CarbonDataLoader(BaseDataLoader):
"""
Load data from variety of sources like Google Drive, Confluence, Notion and more
"""

def _download_file(self, url: str, local_filepath: str, chunk_size: int = 8192):
with requests.get(url, stream=True) as response:
response.raise_for_status()
with open(local_filepath, "wb") as local_file:
for chunk in response.iter_content(chunk_size=chunk_size):
local_file.write(chunk)
return local_filepath

def load_filtered_data(
self,
data_source: DataSource,
dest_dir: str,
previous_snapshot: Dict[str, str],
batch_size: int,
data_ingestion_mode: DataIngestionMode,
) -> Iterator[List[LoadedDataPoint]]:
carbon_customer_id, carbon_data_source_id = data_source.uri.split("/")
carbon = _CarbonClient(
api_key=settings.CARBON_AI_API_KEY,
customer_id=carbon_customer_id,
)

user_files = carbon.query_user_files(
pagination={
"limit": 50,
"offset": 0,
},
order_by="created_at",
order_dir="desc",
filters={"organization_user_data_source_id": [int(carbon_data_source_id)]},
include_raw_file=True,
include_parsed_text_file=False,
include_additional_files=False, # TODO (chiragjn): Evaluate later
)

loaded_data_points: List[LoadedDataPoint] = []

for file in user_files:
url = file.presigned_url
filename = file.name
file_id = file.external_file_id
_, file_extension = os.path.splitext(filename)
local_filepath = os.path.join(dest_dir, f"{file_id}-{filename}")
logger.info(
f"Downloading file {filename} from {file.source} data source type to {local_filepath}"
)
self._download_file(url=url, local_filepath=local_filepath)

data_point_uri = f"{file.source}::{file.external_file_id}"
data_point_hash = (
f"{file.source_created_at}::{file.file_statistics.file_size or 0}"
)

data_point = DataPoint(
data_source_fqn=data_source.fqn,
data_point_uri=data_point_uri,
data_point_hash=data_point_hash,
local_filepath=local_filepath,
file_extension=file_extension,
)

# If the data ingestion mode is incremental, check if the data point already exists.
if (
data_ingestion_mode == DataIngestionMode.INCREMENTAL
and previous_snapshot.get(data_point.data_point_fqn)
and previous_snapshot.get(data_point.data_point_fqn)
== data_point.data_point_hash
):
continue

loaded_data_points.append(
LoadedDataPoint(
data_point_hash=data_point.data_point_hash,
data_point_uri=data_point.data_point_uri,
data_source_fqn=data_point.data_source_fqn,
local_filepath=local_filepath,
file_extension=file_extension,
metadata={
"data_source_type": file.source,
"id": file.id,
"external_file_id": file.external_file_id,
"filename": filename,
"file_format": file.file_statistics.file_format,
"mime_type": file.file_statistics.mime_type,
"created_at": file.source_created_at,
},
)
)
if len(loaded_data_points) >= batch_size:
yield loaded_data_points
loaded_data_points.clear()
yield loaded_data_points
2 changes: 2 additions & 0 deletions backend/modules/dataloaders/truefoundryloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def load_filtered_data(
if f.startswith("."):
continue
full_path = os.path.join(root, f)
if ".truefoundry" in full_path:
continue
logger.debug(f"Processing file: {full_path}")
rel_path = os.path.relpath(full_path, dest_dir)
file_ext = os.path.splitext(f)[1]
Expand Down
39 changes: 39 additions & 0 deletions backend/modules/query_controllers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import AsyncIterator

import async_timeout
import requests
from fastapi import HTTPException
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever
Expand All @@ -15,6 +16,7 @@
from backend.modules.model_gateway.model_gateway import model_gateway
from backend.modules.query_controllers.types import *
from backend.modules.vector_db.client import VECTOR_STORE_CLIENT
from backend.settings import settings
from backend.types import Collection, ModelConfig


Expand Down Expand Up @@ -180,6 +182,43 @@ def _cleanup_metadata(self, docs):
)
return formatted_docs

def _intent_summary_search(self, query: str):
url = f"https://api.search.brave.com/res/v1/web/search?q={query}&summary=1"

payload = {}
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": f"{settings.BRAVE_API_KEY}",
}

response = requests.request("GET", url, headers=headers, data=payload)
answer = response.json()

if "summarizer" in answer.keys():
summary_query = answer["summarizer"]["key"]
url = f"https://api.search.brave.com/res/v1/summarizer/search?key={summary_query}"
response = requests.request("GET", url, headers=headers, data=payload)
answer = response.json()["summary"][0]["data"]
return answer
return ""

def _internet_search(self, context):
logger.info("Using Internet search...")
if settings.BRAVE_API_KEY:
data_context, question = context["context"], context["question"]
intent_summary_results = self._intent_summary_search(question)
# insert internet search results into context at the beginning
data_context.insert(
0,
Document(
page_content=intent_summary_results,
metadata={"_data_point_fqn": "internet::Internet"},
),
)
context["context"] = data_context
return context

async def _sse_wrap(self, gen):
async for data in gen:
yield "event: data\n"
Expand Down
23 changes: 21 additions & 2 deletions backend/modules/query_controllers/example/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async def answer(
"""
Sample answer method to answer the question using the context from the collection
"""
logger.info(f"Request: {request.dict()}")
try:
# Get the vector store
vector_store = await self._get_vector_store(request.collection_name)
Expand All @@ -55,7 +56,12 @@ async def answer(
# Using LCEL
rag_chain_from_docs = (
RunnablePassthrough.assign(
context=(lambda x: self._format_docs(x["context"]))
# add internet search results to context
context=(
lambda x: self._format_docs(
x["context"],
)
)
)
| QA_PROMPT
| llm
Expand All @@ -64,7 +70,16 @@ async def answer(

rag_chain_with_source = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
)

if request.internet_search_enabled:
rag_chain_with_source = (
rag_chain_with_source | self._internet_search
).assign(answer=rag_chain_from_docs)
else:
rag_chain_with_source = rag_chain_with_source.assign(
answer=rag_chain_from_docs
)

if request.stream:
return StreamingResponse(
Expand All @@ -83,6 +98,10 @@ async def answer(
# outputs = await setup_and_retrieval.ainvoke(request.query)
# print(outputs)

# Retriever, internet search
# outputs = await (setup_and_retrieval | self.internet_search).ainvoke(request.query)
# print(outputs)

# Retriever and QA
# outputs = await (setup_and_retrieval | QA_PROMPT).ainvoke(request.query)
# print(outputs)
Expand Down
Loading

0 comments on commit 783634c

Please sign in to comment.