From 43544b3eb4654b24f6cc7060c1487d62c40d4b0b Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 25 Oct 2024 17:57:48 +0800 Subject: [PATCH 1/2] add document lock for multi-thread --- .../service_api/dataset/document.py | 1 + api/services/dataset_service.py | 308 +++++++++--------- 2 files changed, 156 insertions(+), 153 deletions(-) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index fb48a6c76c05e..47f3fde7f718f 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -17,6 +17,7 @@ from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from core.errors.error import ProviderTokenNotInitError from extensions.ext_database import db +from extensions.ext_redis import redis_client from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ca084bde5656d..414ef0224a233 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -760,166 +760,168 @@ def save_document_with_dataset_id( ) db.session.add(dataset_process_rule) db.session.commit() - position = DocumentService.get_documents_position(dataset.id) - document_ids = [] - duplicate_document_ids = [] - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] - for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) - .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() - ) + lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + with redis_client.lock(lock_name, timeout=600): + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + if document_data["data_source"]["type"] == "upload_file": + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) - # raise error if file not found - if not file: - raise FileNotExistsError() + # raise error if file not found + if not file: + raise FileNotExistsError() - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } - # check duplicate - if document_data.get("duplicate", False): - document = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ).first() - if document: - document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = datetime.datetime.utcnow() - document.created_from = created_from - document.doc_form = document_data["doc_form"] - document.doc_language = document_data["doc_language"] - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = "waiting" - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, - created_from, - position, - account, - file_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] - exist_page_ids = [] - exist_document = {} - documents = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, - ).all() - if documents: - for document in documents: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info["notion_page_id"]) - exist_document[data_source_info["notion_page_id"]] = document.id - for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + # check duplicate + if document_data.get("duplicate", False): + document = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ).first() + if document: + document.dataset_process_rule_id = dataset_process_rule.id + document.updated_at = datetime.datetime.utcnow() + document.created_from = created_from + document.doc_form = document_data["doc_form"] + document.doc_language = document_data["doc_language"] + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, + created_from, + position, + account, + file_name, + batch, ) - ).first() - if not data_source_binding: - raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: - if page["page_id"] not in exist_page_ids: - data_source_info = { - "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], - } - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, - created_from, - position, - account, - page["page_name"], - batch, + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + elif document_data["data_source"]["type"] == "notion_import": + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + exist_page_ids = [] + exist_document = {} + documents = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id + for notion_info in notion_info_list: + workspace_id = notion_info["workspace_id"] + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 + ).first() + if not data_source_binding: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: + if page["page_id"] not in exist_page_ids: + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], + } + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, + created_from, + position, + account, + page["page_name"], + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + else: + exist_document.pop(page["page_id"]) + # delete not selected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + elif document_data["data_source"]["type"] == "website_crawl": + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] + for url in urls: + data_source_info = { + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", + } + if len(url) > 255: + document_name = url[:200] + "..." else: - exist_document.pop(page["page_id"]) - # delete not selected documents - if len(exist_document) > 0: - clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] - for url in urls: - data_source_info = { - "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), - "mode": "crawl", - } - if len(url) > 255: - document_name = url[:200] + "..." - else: - document_name = url - document = DocumentService.build_document( - dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, - created_from, - position, - account, - document_name, - batch, - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - db.session.commit() + document_name = url + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, + created_from, + position, + account, + document_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + db.session.commit() - # trigger async task - if document_ids: - document_indexing_task.delay(dataset.id, document_ids) - if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + # trigger async task + if document_ids: + document_indexing_task.delay(dataset.id, document_ids) + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) - return documents, batch + return documents, batch @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): From 1d1a25111fcd16814ec928d93b44fa5e1ecd2453 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 25 Oct 2024 18:01:51 +0800 Subject: [PATCH 2/2] add document lock for multi-thread --- api/controllers/service_api/dataset/document.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 47f3fde7f718f..fb48a6c76c05e 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -17,7 +17,6 @@ from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from core.errors.error import ProviderTokenNotInitError from extensions.ext_database import db -from extensions.ext_redis import redis_client from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment