Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add document lock for multi-thread #9873

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 155 additions & 153 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,166 +760,168 @@ def save_document_with_dataset_id(
)
db.session.add(dataset_process_rule)
db.session.commit()
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)

# raise error if file not found
if not file:
raise FileNotExistsError()
# raise error if file not found
if not file:
raise FileNotExistsError()

file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if document_data.get("duplicate", False):
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = datetime.datetime.utcnow()
document.created_from = created_from
document.doc_form = document_data["doc_form"]
document.doc_language = document_data["doc_language"]
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if document_data.get("duplicate", False):
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = datetime.datetime.utcnow()
document.created_from = created_from
document.doc_form = document_data["doc_form"]
document.doc_language = document_data["doc_language"]
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
).first()
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info["pages"]:
if page["page_id"] not in exist_page_ids:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page["page_id"],
"notion_page_icon": page["page_icon"],
"type": page["type"],
}
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info,
created_from,
position,
account,
page["page_name"],
batch,
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif document_data["data_source"]["type"] == "notion_import":
notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
).first()
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info["pages"]:
if page["page_id"] not in exist_page_ids:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page["page_id"],
"notion_page_icon": page["page_icon"],
"type": page["type"],
}
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info,
created_from,
position,
account,
page["page_name"],
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
else:
exist_document.pop(page["page_id"])
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]["info_list"]["website_info_list"]
urls = website_info["urls"]
for url in urls:
data_source_info = {
"url": url,
"provider": website_info["provider"],
"job_id": website_info["job_id"],
"only_main_content": website_info.get("only_main_content", False),
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
else:
exist_document.pop(page["page_id"])
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif document_data["data_source"]["type"] == "website_crawl":
website_info = document_data["data_source"]["info_list"]["website_info_list"]
urls = website_info["urls"]
for url in urls:
data_source_info = {
"url": url,
"provider": website_info["provider"],
"job_id": website_info["job_id"],
"only_main_content": website_info.get("only_main_content", False),
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
else:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info,
created_from,
position,
account,
document_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
document_data["doc_language"],
data_source_info,
created_from,
position,
account,
document_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()

# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)

return documents, batch
return documents, batch

@staticmethod
def check_documents_upload_quota(count: int, features: FeatureModel):
Expand Down