Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google drive crawler #101

Merged
merged 17 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*

# install python packages
# Install python packages
WORKDIR ${HOME}
COPY requirements.txt requirements-extra.txt $HOME/
RUN pip install --no-cache-dir torch==2.1.2 --index-url https://download.pytorch.org/whl/cpu
Expand Down
12 changes: 12 additions & 0 deletions config/vectara-gdrive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
vectara:
corpus_id: 277
customer_id: 1526022105
reindex: true

crawling:
crawler_type: gdrive

gdrive_crawler:
delegated_users:
- <add email id>
- <add email id>
175 changes: 175 additions & 0 deletions crawlers/gdrive_crawler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import os
from core.crawler import Crawler
from omegaconf import OmegaConf
import logging
import io
from datetime import datetime, timedelta
from google.oauth2 import service_account
from googleapiclient.discovery import build, Resource
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload
import pandas as pd
from slugify import slugify
from typing import List, Tuple, Optional

SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
SERVICE_ACCOUNT_FILE = '/home/vectara/env/credentials.json'

def get_credentials(delegated_user: str) -> service_account.Credentials:
credentials = service_account.Credentials.from_service_account_file(
SERVICE_ACCOUNT_FILE, scopes=SCOPES)
delegated_credentials = credentials.with_subject(delegated_user)
return delegated_credentials

def download_or_export_file(service: Resource, file_id: str, mime_type: Optional[str] = None) -> Optional[io.BytesIO]:
try:
if mime_type:
request = service.files().export_media(fileId=file_id, mimeType=mime_type)
else:
request = service.files().get_media(fileId=file_id)

byte_stream = io.BytesIO() # an in-memory bytestream
downloader = MediaIoBaseDownload(byte_stream, request)
done = False
while not done:
status, done = downloader.next_chunk()
logging.info(f"Download {int(status.progress() * 100)}.")
byte_stream.seek(0) # Reset the file pointer to the beginning
return byte_stream
except HttpError as error:
logging.info(f"An error occurred: {error}")
return None
# Note: Handling of large files that may exceed memory limits should be implemented if necessary.

def save_local_file(service: Resource, file_id: str, name: str, mime_type: Optional[str] = None) -> Optional[str]:
sanitized_name = slugify(name)
file_path = os.path.join("/tmp", sanitized_name)
try:
byte_stream = download_or_export_file(service, file_id, mime_type)
if byte_stream:
with open(file_path, 'wb') as f:
f.write(byte_stream.read())
return file_path
except Exception as e:
logging.info(f"Error saving local file: {e}")
return None

class GdriveCrawler(Crawler):

def __init__(self, cfg: OmegaConf, endpoint: str, customer_id: str, corpus_id: int, api_key: str) -> None:
super().__init__(cfg, endpoint, customer_id, corpus_id, api_key)
logging.info("Google Drive Crawler initialized")

self.delegated_users = cfg.gdrive_crawler.delegated_users
self.creds = None
self.service = None

def list_files(self, service: Resource, parent_id: Optional[str] = None, date_threshold: Optional[str] = None) -> List[dict]:
results = []
page_token = None
query = f"('{parent_id}' in parents or sharedWithMe) and trashed=false and modifiedTime > '{date_threshold}'" if parent_id else f"('root' in parents or sharedWithMe) and trashed=false and modifiedTime > '{date_threshold}'"

while True:
try:
params = {
'fields': 'nextPageToken, files(id, name, mimeType, permissions, modifiedTime, createdTime, owners, size)',
'q': query,
'corpora': 'allDrives',
'includeItemsFromAllDrives': True,
'supportsAllDrives': True
}
if page_token:
params['pageToken'] = page_token
response = service.files().list(**params).execute()
files = response.get('files', [])
for file in files:
permissions = file.get('permissions', [])
if any(p.get('displayName') == 'Vectara' or p.get('displayName') == 'all' for p in permissions):
results.append(file)
page_token = response.get('nextPageToken', None)
if not page_token:
break
except HttpError as error:
logging.info(f"An error occurred: {error}")
break
return results

def handle_file(self, file: dict) -> Tuple[Optional[str], Optional[str]]:
file_id = file['id']
mime_type = file['mimeType']
name = file['name']
permissions = file.get('permissions', [])

logging.info(f"\nHandling file: {name} with MIME type: {mime_type}")

if not any(p.get('displayName') == 'Vectara' or p.get('displayName') == 'all' for p in permissions):
logging.info(f"Skipping restricted file: {name}")
return None

if mime_type == 'application/vnd.google-apps.document':
local_file_path = save_local_file(self.service, file_id, name + '.docx', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document')
url = f'https://docs.google.com/document/d/{file_id}/edit'
elif mime_type == 'application/vnd.google-apps.spreadsheet':
local_file_path = save_local_file(self.service, file_id, name + '.csv', 'text/csv')
url = f'https://docs.google.com/spreadsheets/d/{file_id}/edit'
elif mime_type == 'application/vnd.google-apps.presentation':
local_file_path = save_local_file(self.service, file_id, name + '.pptx', 'application/vnd.openxmlformats-officedocument.presentationml.presentation')
url = f'https://docs.google.com/presentation/d/{file_id}/edit'
elif mime_type.startswith('application/'):
local_file_path = save_local_file(self.service, file_id, name)
if local_file_path and name.endswith('.xlsx'):
df = pd.read_excel(local_file_path)
csv_file_path = local_file_path.replace('.xlsx', '.csv')
df.to_csv(csv_file_path, index=False)
local_file_path = csv_file_path
url = f'https://drive.google.com/file/d/{file_id}/view'
else:
logging.info(f"Unsupported file type: {mime_type}")
return None, None

if local_file_path:
logging.info(f"local_file_path :: {local_file_path}")
return local_file_path, url
else:
logging.info("local_file_path :: None")
return None, None

def crawl_file(self, file: dict) -> None:
local_file_path, url = self.handle_file(file)
if local_file_path:
file_id = file['id']
name = file['name']
created_time = file.get('createdTime', 'N/A')
modified_time = file.get('modifiedTime', 'N/A')
owners = ', '.join([owner['displayName'] for owner in file.get('owners', [])])
size = file.get('size', 'N/A')

logging.info(f'\nCrawling file {name}')

file_metadata = {
'id': file_id,
'name': name,
'created_at': created_time,
AbhilashaLodha marked this conversation as resolved.
Show resolved Hide resolved
'modified_at': modified_time,
'owners': owners,
'size': size,
'source': 'gdrive'
}

try:
self.indexer.index_file(filename=local_file_path, uri=url, metadata=file_metadata)
except Exception as e:
logging.info(f"Error {e} indexing document for file {name}, file_id {file_id}")

def crawl(self) -> None:
N = 7 # Number of days to look back
date_threshold = datetime.utcnow() - timedelta(days=N)

for user in self.delegated_users:
logging.info(f"Processing files for user: {user}")
self.creds = get_credentials(user)
self.service = build("drive", "v3", credentials=self.creds)

list_files = self.list_files(self.service, date_threshold=date_threshold.isoformat() + 'Z')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you using the ISO format + "z": is this because Google API demands it to be in UTC format and in this way?

for file in list_files:
self.crawl_file(file)
13 changes: 11 additions & 2 deletions ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from authlib.integrations.requests_client import OAuth2Session

def instantiate_crawler(base_class, folder_name: str, class_name: str, *args, **kwargs) -> Any: # type: ignore
logging.info(f'inside instantiate crawler')
sys.path.insert(0, os.path.abspath(folder_name))

crawler_name = class_name.split('Crawler')[0]
Expand All @@ -27,6 +28,7 @@ def instantiate_crawler(base_class, folder_name: str, class_name: str, *args, **
raise TypeError(f"{class_name} is not a subclass of {base_class.__name__}")

# Instantiate the class and return the instance
logging.info(f'end of instantiate crawler')
return class_(*args, **kwargs)

def get_jwt_token(auth_url: str, auth_id: str, auth_secret: str, customer_id: str) -> Any:
Expand Down Expand Up @@ -76,6 +78,8 @@ def main() -> None:
if len(sys.argv) != 3:
logging.info("Usage: python ingest.py <config_file> <secrets-profile>")
return

logging.info("Starting the Crawler...")
config_name = sys.argv[1]
profile_name = sys.argv[2]

Expand All @@ -89,7 +93,7 @@ def main() -> None:
if profile_name not in env_dict:
logging.info(f'Profile "{profile_name}" not found in secrets.toml')
return

logging.info(f'Using profile "{profile_name}" from secrets.toml')
# Add all keys from "general" section to the vectara config
general_dict = env_dict.get('general', {})
for k,v in general_dict.items():
Expand Down Expand Up @@ -129,15 +133,20 @@ def main() -> None:
# default (otherwise) - add to vectara config
OmegaConf.update(cfg['vectara'], k, v)

logging.info("Configuration loaded...")
endpoint = cfg.vectara.get("endpoint", "api.vectara.io")
customer_id = cfg.vectara.customer_id
corpus_id = cfg.vectara.corpus_id
api_key = cfg.vectara.api_key
crawler_type = cfg.crawling.crawler_type

# instantiate the crawler
crawler = instantiate_crawler(Crawler, 'crawlers', f'{crawler_type.capitalize()}Crawler', cfg, endpoint, customer_id, corpus_id, api_key)
crawler = instantiate_crawler(
Crawler, 'crawlers', f'{crawler_type.capitalize()}Crawler',
cfg, endpoint, customer_id, corpus_id, api_key
)

logging.info("Crawling instantiated...")
# When debugging a crawler, it is sometimes useful to reset the corpus (remove all documents)
# To do that you would have to set this to True and also include <auth_url> and <auth_id> in the secrets.toml file
# NOTE: use with caution; this will delete all documents in the corpus and is irreversible
Expand Down
1 change: 1 addition & 0 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ fi
# Mount secrets file into docker container
mkdir -p ~/tmp/mount
cp secrets.toml ~/tmp/mount
cp credentials.json ~/tmp/mount
cp $1 ~/tmp/mount/

# Build docker container
Expand Down