From 5a5c9f30bd232502ddc7d238f17c1382281eda9e Mon Sep 17 00:00:00 2001 From: hwzhuhao <923196325@qq.com> Date: Tue, 15 Oct 2024 10:29:25 +0800 Subject: [PATCH] refactor: Add an enumeration type and use the factory pattern to obtain the corresponding class --- .../rag/datasource/keyword/keyword_factory.py | 25 ++++++++++-------- .../rag/datasource/keyword/keyword_type.py | 5 ++++ api/services/auth/api_key_auth_factory.py | 26 +++++++++++++------ api/services/auth/auth_type.py | 6 +++++ api/services/auth/firecrawl/__init__.py | 0 .../auth/{ => firecrawl}/firecrawl.py | 0 api/services/auth/jina/__init__.py | 0 api/services/auth/{ => jina}/jina.py | 0 8 files changed, 43 insertions(+), 19 deletions(-) create mode 100644 api/core/rag/datasource/keyword/keyword_type.py create mode 100644 api/services/auth/auth_type.py create mode 100644 api/services/auth/firecrawl/__init__.py rename api/services/auth/{ => firecrawl}/firecrawl.py (100%) create mode 100644 api/services/auth/jina/__init__.py rename api/services/auth/{ => jina}/jina.py (100%) diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index 3c99f33be61e3..f1a6ade91f9bd 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -1,8 +1,8 @@ from typing import Any from configs import dify_config -from core.rag.datasource.keyword.jieba.jieba import Jieba from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.datasource.keyword.keyword_type import KeyWordType from core.rag.models.document import Document from models.dataset import Dataset @@ -13,16 +13,19 @@ def __init__(self, dataset: Dataset): self._keyword_processor = self._init_keyword() def _init_keyword(self) -> BaseKeyword: - config = dify_config - keyword_type = config.KEYWORD_STORE - - if not keyword_type: - raise ValueError("Keyword store must be specified.") - - if keyword_type == "jieba": - return Jieba(dataset=self._dataset) - else: - raise ValueError(f"Keyword store {keyword_type} is not supported.") + keyword_type = dify_config.KEYWORD_STORE + keyword_factory = self.get_keyword_factory(keyword_type) + return keyword_factory(self._dataset) + + @staticmethod + def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]: + match keyword_type: + case KeyWordType.JIEBA: + from core.rag.datasource.keyword.jieba.jieba import Jieba + + return Jieba + case _: + raise ValueError(f"Keyword store {keyword_type} is not supported.") def create(self, texts: list[Document], **kwargs): self._keyword_processor.create(texts, **kwargs) diff --git a/api/core/rag/datasource/keyword/keyword_type.py b/api/core/rag/datasource/keyword/keyword_type.py new file mode 100644 index 0000000000000..d6deba3fb09fd --- /dev/null +++ b/api/core/rag/datasource/keyword/keyword_type.py @@ -0,0 +1,5 @@ +from enum import Enum + + +class KeyWordType(str, Enum): + JIEBA = "jieba" diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 36387e9c2efdb..f91c448fb94a2 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,15 +1,25 @@ -from services.auth.firecrawl import FirecrawlAuth -from services.auth.jina import JinaAuth +from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.auth_type import AuthType class ApiKeyAuthFactory: def __init__(self, provider: str, credentials: dict): - if provider == "firecrawl": - self.auth = FirecrawlAuth(credentials) - elif provider == "jinareader": - self.auth = JinaAuth(credentials) - else: - raise ValueError("Invalid provider") + auth_factory = self.get_apikey_auth_factory(provider) + self.auth = auth_factory(credentials) def validate_credentials(self): return self.auth.validate_credentials() + + @staticmethod + def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]: + match provider: + case AuthType.FIRECRAWL: + from services.auth.firecrawl.firecrawl import FirecrawlAuth + + return FirecrawlAuth + case AuthType.JINA: + from services.auth.jina.jina import JinaAuth + + return JinaAuth + case _: + raise ValueError("Invalid provider") diff --git a/api/services/auth/auth_type.py b/api/services/auth/auth_type.py new file mode 100644 index 0000000000000..2d6e901447c36 --- /dev/null +++ b/api/services/auth/auth_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class AuthType(str, Enum): + FIRECRAWL = "firecrawl" + JINA = "jinareader" diff --git a/api/services/auth/firecrawl/__init__.py b/api/services/auth/firecrawl/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py similarity index 100% rename from api/services/auth/firecrawl.py rename to api/services/auth/firecrawl/firecrawl.py diff --git a/api/services/auth/jina/__init__.py b/api/services/auth/jina/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/services/auth/jina.py b/api/services/auth/jina/jina.py similarity index 100% rename from api/services/auth/jina.py rename to api/services/auth/jina/jina.py