Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google drive crawler #101

Merged
merged 17 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
# install python packages
WORKDIR ${HOME}
COPY requirements.txt requirements-extra.txt $HOME/
COPY crawlers/credentials.json $HOME/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a good idea. I'll explain offline why.

RUN pip install --no-cache-dir torch==2.1.2 --index-url https://download.pytorch.org/whl/cpu
RUN pip install --no-cache-dir -r requirements.txt \
&& find /usr/local -type d \( -name test -o -name tests \) -exec rm -rf '{}' + \
Expand Down
12 changes: 12 additions & 0 deletions config/vectara-gdrive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
vectara:
corpus_id: 277
customer_id: 1526022105
reindex: true

crawling:
crawler_type: gdrive

gdrive_crawler:
delegated_users:
- <add email id>
- <add email id>
178 changes: 178 additions & 0 deletions crawlers/gdrive_crawler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
from core.crawler import Crawler
from omegaconf import OmegaConf
import logging
import io
from datetime import datetime, timedelta
from google.oauth2 import service_account
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload
import pandas as pd
from typing import List
from slugify import slugify

SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
SERVICE_ACCOUNT_FILE = 'credentials.json'

def get_credentials(delegated_user):
credentials = service_account.Credentials.from_service_account_file(
SERVICE_ACCOUNT_FILE, scopes=SCOPES)
delegated_credentials = credentials.with_subject(delegated_user)
return delegated_credentials

def download_or_export_file(service, file_id, mime_type=None):
try:
if mime_type:
request = service.files().export_media(fileId=file_id, mimeType=mime_type)
else:
request = service.files().get_media(fileId=file_id)

byte_stream = io.BytesIO() # an in-memory bytestream
downloader = MediaIoBaseDownload(byte_stream, request)
done = False
while not done:
status, done = downloader.next_chunk()
logging.info(f"Download {int(status.progress() * 100)}.")
byte_stream.seek(0) # Reset the file pointer to the beginning
return byte_stream
except HttpError as error:
logging.info(f"An error occurred: {error}")
return None
# Note: Handling of large files that may exceed memory limits should be implemented if necessary.

def save_local_file(service, file_id, name, mime_type=None):
sanitized_name = slugify(name)
file_path = os.path.join("/tmp", sanitized_name)
try:
byte_stream = download_or_export_file(service, file_id, mime_type)
if byte_stream:
with open(file_path, 'wb') as f:
f.write(byte_stream.read())
return file_path
except Exception as e:
logging.info(f"Error saving local file: {e}")
return None

class GdriveCrawler(Crawler):

def __init__(self, cfg: OmegaConf, endpoint: str, customer_id: str, corpus_id: int, api_key: str, delegated_users: List[str]) -> None:
super().__init__(cfg, endpoint, customer_id, corpus_id, api_key)
logging.info("Google Drive Crawler initialized")

self.delegated_users = delegated_users
self.creds = None
self.service = None
self.api_key = api_key
self.customer_id = customer_id
self.corpus_id = corpus_id

def list_files(self, service, parent_id=None, date_threshold=None):
results = []
page_token = None
query = f"('{parent_id}' in parents or sharedWithMe) and trashed=false and modifiedTime > '{date_threshold}'" if parent_id else f"('root' in parents or sharedWithMe) and trashed=false and modifiedTime > '{date_threshold}'"

while True:
try:
params = {
'fields': 'nextPageToken, files(id, name, mimeType, permissions, modifiedTime, createdTime, owners, size)',
'q': query,
'corpora': 'allDrives',
'includeItemsFromAllDrives': True,
'supportsAllDrives': True
}
if page_token:
params['pageToken'] = page_token
response = service.files().list(**params).execute()
files = response.get('files', [])
for file in files:
permissions = file.get('permissions', [])
if any(p.get('displayName') == 'Vectara' or p.get('displayName') == 'all' for p in permissions):
results.append(file)
page_token = response.get('nextPageToken', None)
if not page_token:
break
except HttpError as error:
logging.info(f"An error occurred: {error}")
break
return results

def handle_file(self, file):
AbhilashaLodha marked this conversation as resolved.
Show resolved Hide resolved
file_id = file['id']
mime_type = file['mimeType']
name = file['name']
permissions = file.get('permissions', [])

logging.info(f"\nHandling file: {name} with MIME type: {mime_type}")

if not any(p.get('displayName') == 'Vectara' or p.get('displayName') == 'all' for p in permissions):
logging.info(f"Skipping restricted file: {name}")
return None

if mime_type == 'application/vnd.google-apps.document':
local_file_path = save_local_file(self.service, file_id, name + '.docx', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document')
url = f'https://docs.google.com/document/d/{file_id}/edit'
elif mime_type == 'application/vnd.google-apps.spreadsheet':
local_file_path = save_local_file(self.service, file_id, name + '.csv', 'text/csv')
url = f'https://docs.google.com/spreadsheets/d/{file_id}/edit'
elif mime_type == 'application/vnd.google-apps.presentation':
local_file_path = save_local_file(self.service, file_id, name + '.pptx', 'application/vnd.openxmlformats-officedocument.presentationml.presentation')
url = f'https://docs.google.com/presentation/d/{file_id}/edit'
elif mime_type.startswith('application/'):
local_file_path = save_local_file(self.service, file_id, name)
if local_file_path and name.endswith('.xlsx'):
df = pd.read_excel(local_file_path)
csv_file_path = local_file_path.replace('.xlsx', '.csv')
df.to_csv(csv_file_path, index=False)
local_file_path = csv_file_path
url = f'https://drive.google.com/file/d/{file_id}/view'
else:
logging.info(f"Unsupported file type: {mime_type}")
return None, None

if local_file_path:
logging.info(f"local_file_path :: {local_file_path}")
return local_file_path, url
else:
logging.info(f"local_file_path :: None")
return None, None

def crawl_file(self, file):
local_file_path, url = self.handle_file(file)
if local_file_path:
file_id = file['id']
name = file['name']
created_time = file.get('createdTime', 'N/A')
modified_time = file.get('modifiedTime', 'N/A')
owners = ', '.join([owner['displayName'] for owner in file.get('owners', [])])
size = file.get('size', 'N/A')

logging.info(f'\nCrawling file {name}')

file_metadata = {
'id': file_id,
'name': name,
'created_at': created_time,
AbhilashaLodha marked this conversation as resolved.
Show resolved Hide resolved
'modified_at': modified_time,
'owners': owners,
'size': size,
'source': 'gdrive'
}

try:
self.indexer.index_file(filename=local_file_path, uri=url, metadata=file_metadata)
except Exception as e:
logging.info(f"Error {e} indexing document for file {name}, file_id {file_id}")

def crawl(self) -> None:
N = 7 # Number of days to look back
date_threshold = datetime.utcnow() - timedelta(days=N)

for user in self.delegated_users:
logging.info(f"Processing files for user: {user}")
self.creds = get_credentials(user)
self.service = build("drive", "v3", credentials=self.creds)

list_files = self.list_files(self.service, date_threshold=date_threshold.isoformat() + 'Z')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you using the ISO format + "z": is this because Google API demands it to be in UTC format and in this way?

for file in list_files:
self.crawl_file(file)
23 changes: 21 additions & 2 deletions ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from authlib.integrations.requests_client import OAuth2Session

def instantiate_crawler(base_class, folder_name: str, class_name: str, *args, **kwargs) -> Any: # type: ignore
logging.info(f'inside instantiate crawler')
sys.path.insert(0, os.path.abspath(folder_name))

crawler_name = class_name.split('Crawler')[0]
Expand All @@ -27,6 +28,7 @@ def instantiate_crawler(base_class, folder_name: str, class_name: str, *args, **
raise TypeError(f"{class_name} is not a subclass of {base_class.__name__}")

# Instantiate the class and return the instance
logging.info(f'end of instantiate crawler')
return class_(*args, **kwargs)

def get_jwt_token(auth_url: str, auth_id: str, auth_secret: str, customer_id: str) -> Any:
Expand Down Expand Up @@ -76,6 +78,8 @@ def main() -> None:
if len(sys.argv) != 3:
logging.info("Usage: python ingest.py <config_file> <secrets-profile>")
return

logging.info("Starting the Crawler...")
config_name = sys.argv[1]
profile_name = sys.argv[2]

Expand All @@ -89,7 +93,7 @@ def main() -> None:
if profile_name not in env_dict:
logging.info(f'Profile "{profile_name}" not found in secrets.toml')
return

logging.info(f'Using profile "{profile_name}" from secrets.toml')
# Add all keys from "general" section to the vectara config
general_dict = env_dict.get('general', {})
for k,v in general_dict.items():
Expand Down Expand Up @@ -129,15 +133,30 @@ def main() -> None:
# default (otherwise) - add to vectara config
OmegaConf.update(cfg['vectara'], k, v)

logging.info(f"Configuration loaded...")
endpoint = cfg.vectara.get("endpoint", "api.vectara.io")
customer_id = cfg.vectara.customer_id
corpus_id = cfg.vectara.corpus_id
api_key = cfg.vectara.api_key
crawler_type = cfg.crawling.crawler_type

# instantiate the crawler
crawler = instantiate_crawler(Crawler, 'crawlers', f'{crawler_type.capitalize()}Crawler', cfg, endpoint, customer_id, corpus_id, api_key)
# crawler = instantiate_crawler(Crawler, 'crawlers', f'{crawler_type.capitalize()}Crawler', cfg, endpoint, customer_id, corpus_id, api_key)

# Conditionally extract delegated_users if the crawler type is gdrive
if crawler_type == "gdrive":
delegated_users = cfg.gdrive_crawler.delegated_users
crawler = instantiate_crawler(
Crawler, 'crawlers', f'{crawler_type.capitalize()}Crawler',
cfg, endpoint, customer_id, corpus_id, api_key, delegated_users
)
else:
crawler = instantiate_crawler(
Crawler, 'crawlers', f'{crawler_type.capitalize()}Crawler',
cfg, endpoint, customer_id, corpus_id, api_key
)

logging.info(f"Crawling instantiated...")
# When debugging a crawler, it is sometimes useful to reset the corpus (remove all documents)
# To do that you would have to set this to True and also include <auth_url> and <auth_id> in the secrets.toml file
# NOTE: use with caution; this will delete all documents in the corpus and is irreversible
Expand Down