From 1a541b6c365f57d04ed71f52df687606494f3e91 Mon Sep 17 00:00:00 2001 From: NeonKirill <74428618+NeonKirill@users.noreply.github.com> Date: Sat, 13 Apr 2024 19:50:33 +0200 Subject: [PATCH] Dev -> master (#68) * Increment Version * Fixed issues while making structure for neon api (#67) * Increment Version * Fixed issue when conversation id specified but unset (#69) * Increment Version * Project dependencies update (#70) * Migrated Python version to 3.10 * Cleaned up and restructured dependencies * Added pre commit hooks * Ran black formatter * Made setting up MySQLConnector optional * Updated project requirements * updated python version in references * Incremented subversion * Increment Version * Update for Helm deployment (#71) * Update containers to fully support ovos-config * WIP troubleshooting container init failures * Troubleshooting unit test failures * Troubleshooting unit test failures * Fix legacy config file checks * Replace 'update' with 'update_many' per PyMongo docs https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.update_many Update command validation to reference known commands enum directly --------- Co-authored-by: Daniel McKnight * Increment Version * Support for admin api (#72) * Added K8S and Rabbit MQ management utilities * added callback on missing k8s config * Fixed backward compatibility issue with pymongo * Added submind state processor * Moved out mq validation utility * incremented subversion * added retry-wait for klatchat observer * Fixed issue with caching * Simplified socket io logic * Added license notice --------- Co-authored-by: NeonKirill * Increment Version * Fix configuration handling (#75) Co-authored-by: Daniel McKnight * Increment Version * Added support for banning/unbanning subminds and admins endpoint for fetching chats data (#76) * Increment Version * Update PyPI action spec (#73) * Increment Version * Replaces `pull_master` with `propose_release` action (#78) Updates deployment automation to address #77 Co-authored-by: Daniel McKnight * Increment Version * Reorganized Mongo DB API (#79) * Refactored Mongo DB API to be more granular, optimized and generified queries in place * Fixed unittests * Addressed comments * Fixed issue with prompts * removed redundant property from request_tts * Increment Version * Fixed issue with unauthorized user * Increment Version * Troubleshooting Proctored Conversations (#80) * Increment Version * Upgraded Socket IO to use any transport * Increment Version * Support for personas (#81) * Initial implementation of Personas API in klatchat * Added personas endpoints and optimized middleware * Support for generic configs management --------- Co-authored-by: NeonKirill * Increment Version * Alpha -> Dev (#85) * added new deployment instructions to target alpha * Disabled pull request hook for CI flow * Fixed import issue * fixing "/ is not a connected namespace." * incremented controller expiration time and decreased ttl length * Fixing issues with updating user account (#83) * Reorganized Socket IO handlers * Fixing Chat Messages Overflow Issue (#84) * Removed dependency on chat_flow property to track messages in conversation * Increment Version --------- Co-authored-by: NeonDaniel Co-authored-by: NeonKirill --- .github/workflows/deploy_containers.yml | 17 +- .github/workflows/propose_release.yml | 27 + .github/workflows/publish_release.yml | 6 +- .github/workflows/publish_test_build.yml | 4 +- .github/workflows/pull_master.yml | 21 - .github/workflows/unit_tests.yml | 5 +- .pre-commit-config.yaml | 16 + chat_client/__main__.py | 11 +- chat_client/app.py | 65 +- chat_client/blueprints/auth.py | 61 +- chat_client/blueprints/base.py | 6 +- chat_client/blueprints/chat.py | 57 +- chat_client/blueprints/components.py | 79 +- chat_client/blueprints/users.py | 92 +- chat_client/client_config.py | 23 +- chat_client/client_utils/api_utils.py | 23 +- chat_client/static/js/chat_utils.js | 19 +- chat_client/static/js/db.js | 52 +- chat_client/static/js/klatchatNano.js | 2 +- chat_client/static/js/message_utils.js | 33 +- chat_client/static/js/sio.js | 6 +- chat_server/__main__.py | 12 +- chat_server/app.py | 108 +-- chat_server/blueprints/__init__.py | 40 + chat_server/blueprints/admin.py | 100 ++ chat_server/blueprints/auth.py | 84 +- chat_server/blueprints/chat.py | 146 +-- chat_server/blueprints/configs.py | 68 ++ chat_server/blueprints/files_api.py | 98 +- chat_server/blueprints/languages.py | 8 +- chat_server/blueprints/personas.py | 161 ++++ chat_server/blueprints/preferences.py | 23 +- chat_server/blueprints/users.py | 197 ++-- chat_server/constants/conversations.py | 8 +- chat_server/constants/users.py | 31 +- chat_server/server_config.py | 70 +- chat_server/server_utils/admin_utils.py | 52 ++ chat_server/server_utils/auth.py | 300 +++--- chat_server/server_utils/cache_utils.py | 13 +- .../server_utils/conversation_utils.py | 42 +- chat_server/server_utils/db_utils.py | 499 ---------- chat_server/server_utils/dependencies.py | 37 + chat_server/server_utils/enums.py | 9 +- chat_server/server_utils/exceptions.py | 55 ++ chat_server/server_utils/factory_utils.py | 6 +- chat_server/server_utils/http_utils.py | 95 +- chat_server/server_utils/k8s_utils.py | 80 ++ chat_server/server_utils/languages.py | 81 +- chat_server/server_utils/middleware.py | 80 ++ chat_server/server_utils/models/chats.py | 43 + chat_server/server_utils/models/configs.py | 13 + chat_server/server_utils/models/personas.py | 79 ++ chat_server/server_utils/models/users.py | 44 + chat_server/server_utils/os_utils.py | 2 +- chat_server/server_utils/prompt_utils.py | 103 --- chat_server/server_utils/rmq_utils.py | 174 ++++ chat_server/server_utils/sftp_utils.py | 18 +- chat_server/server_utils/user_utils.py | 84 +- chat_server/services/popularity_counter.py | 118 ++- chat_server/sio.py | 616 ------------- chat_server/sio/__init__.py | 37 + chat_server/sio/handlers/prompt.py | 137 +++ chat_server/sio/handlers/session.py | 62 ++ chat_server/sio/handlers/stt.py | 121 +++ chat_server/sio/handlers/translation.py | 125 +++ chat_server/sio/handlers/tts.py | 164 ++++ chat_server/sio/handlers/user_message.py | 192 ++++ chat_server/sio/server.py | 31 + chat_server/sio/utils.py | 120 +++ chat_server/tests/test_sio.py | 184 ++-- chat_server/tests/utils/app_utils.py | 7 +- chat_server/wsgi.py | 3 +- config.py | 57 +- dockerfiles/Dockerfile.base | 6 +- dockerfiles/Dockerfile.server | 1 + migration_scripts/__main__.py | 112 ++- .../constants/migration_constants.py | 6 +- migration_scripts/conversations.py | 105 ++- migration_scripts/shouts.py | 174 ++-- migration_scripts/users.py | 164 ++-- migration_scripts/utils/__init__.py | 25 +- migration_scripts/utils/conversation_utils.py | 12 +- migration_scripts/utils/shout_utils.py | 4 +- migration_scripts/utils/sql_utils.py | 4 +- .../legacy_migration_requirements.txt | 3 + requirements/requirements.txt | 35 +- scripts/file_merger.py | 115 ++- scripts/files_manipulator.py | 32 +- scripts/minifier.py | 78 +- services/klatchat_observer/__main__.py | 22 +- .../constants/neon_api_constants.py | 14 +- services/klatchat_observer/controller.py | 866 ++++++++++++------ .../klatchat_observer/utils/neon_api_utils.py | 14 +- setup.py | 45 +- tests/mock.py | 3 +- tests/test_db_utils.py | 119 ++- utils/__init__.py | 27 + utils/common.py | 32 +- utils/connection_utils.py | 28 +- utils/database_utils/base_connector.py | 8 +- utils/database_utils/db_controller.py | 39 +- utils/database_utils/mongo_utils/__init__.py | 8 +- .../mongo_utils/queries/__init__.py | 27 + .../mongo_utils/queries/constants.py | 54 ++ .../mongo_utils/queries/dao/__init__.py | 27 + .../mongo_utils/queries/dao/abc.py | 224 +++++ .../mongo_utils/queries/dao/chats.py | 124 +++ .../mongo_utils/queries/dao/configs.py | 57 ++ .../mongo_utils/queries/dao/personas.py | 37 + .../mongo_utils/queries/dao/prompts.py | 192 ++++ .../mongo_utils/queries/dao/shouts.py | 198 ++++ .../mongo_utils/queries/dao/users.py | 206 +++++ .../mongo_utils/queries/mongo_queries.py | 240 +++++ .../mongo_utils/queries/wrapper.py | 72 ++ .../database_utils/mongo_utils/structures.py | 129 +-- .../database_utils/mongo_utils/user_utils.py | 17 +- utils/database_utils/mongodb_connector.py | 61 +- utils/database_utils/mysql_connector.py | 12 +- utils/exceptions.py | 35 + utils/http_utils.py | 13 +- utils/logging_utils.py | 20 +- utils/template_utils.py | 20 +- version.py | 2 +- version_bump.py | 8 +- 124 files changed, 6386 insertions(+), 3152 deletions(-) create mode 100644 .github/workflows/propose_release.yml delete mode 100644 .github/workflows/pull_master.yml create mode 100644 .pre-commit-config.yaml create mode 100644 chat_server/blueprints/__init__.py create mode 100644 chat_server/blueprints/admin.py create mode 100644 chat_server/blueprints/configs.py create mode 100644 chat_server/blueprints/personas.py create mode 100644 chat_server/server_utils/admin_utils.py delete mode 100644 chat_server/server_utils/db_utils.py create mode 100644 chat_server/server_utils/dependencies.py create mode 100644 chat_server/server_utils/exceptions.py create mode 100644 chat_server/server_utils/k8s_utils.py create mode 100644 chat_server/server_utils/middleware.py create mode 100644 chat_server/server_utils/models/chats.py create mode 100644 chat_server/server_utils/models/configs.py create mode 100644 chat_server/server_utils/models/personas.py create mode 100644 chat_server/server_utils/models/users.py delete mode 100644 chat_server/server_utils/prompt_utils.py create mode 100644 chat_server/server_utils/rmq_utils.py delete mode 100644 chat_server/sio.py create mode 100644 chat_server/sio/__init__.py create mode 100644 chat_server/sio/handlers/prompt.py create mode 100644 chat_server/sio/handlers/session.py create mode 100644 chat_server/sio/handlers/stt.py create mode 100644 chat_server/sio/handlers/translation.py create mode 100644 chat_server/sio/handlers/tts.py create mode 100644 chat_server/sio/handlers/user_message.py create mode 100644 chat_server/sio/server.py create mode 100644 chat_server/sio/utils.py create mode 100644 requirements/legacy_migration_requirements.txt create mode 100644 utils/__init__.py create mode 100644 utils/database_utils/mongo_utils/queries/__init__.py create mode 100644 utils/database_utils/mongo_utils/queries/constants.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/__init__.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/abc.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/chats.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/configs.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/personas.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/prompts.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/shouts.py create mode 100644 utils/database_utils/mongo_utils/queries/dao/users.py create mode 100644 utils/database_utils/mongo_utils/queries/mongo_queries.py create mode 100644 utils/database_utils/mongo_utils/queries/wrapper.py create mode 100644 utils/exceptions.py diff --git a/.github/workflows/deploy_containers.yml b/.github/workflows/deploy_containers.yml index 4ebbd1fd..227aa60d 100644 --- a/.github/workflows/deploy_containers.yml +++ b/.github/workflows/deploy_containers.yml @@ -5,17 +5,13 @@ on: branches: - dev - main - pull_request: + - alpha workflow_dispatch: inputs: version: description: 'Images tag' required: true - default: 'dev' -# images: -# description: 'Include Images (klatchat_observer|chat_client|chat_server) space-separated' -# required: false -# default: klatchat_observer chat_client chat_server + default: 'alpha' permissions: contents: write @@ -57,17 +53,10 @@ jobs: if: github.event_name == 'workflow_dispatch' run: | echo "VERSION=${{ github.event.inputs.version }}" >> $GITHUB_ENV -# echo ${{ github.event.inputs.images }} >> $IMAGES - - name: Apply Pull Request Environments - if: github.event_name == 'pull_request' - run: | - echo "VERSION=${{ github.event.pull_request.base.ref }}" >> $GITHUB_ENV -# echo klatchat_observer chat_client chat_server >> $IMAGES - name: Apply Push Environments if: github.event_name == 'push' run: | - echo "VERSION=$(echo ${GITHUB_HEAD_REF} | tr / -)" >> $GITHUB_ENV -# echo klatchat_observer chat_client chat_server >> $IMAGES + echo "VERSION=$(echo ${{ github.ref_name }} | tr / -)" >> $GITHUB_ENV - name: Clean Up Version run: | echo "VERSION=${VERSION//['/']/-}" >> $GITHUB_ENV diff --git a/.github/workflows/propose_release.yml b/.github/workflows/propose_release.yml new file mode 100644 index 00000000..81dfe43b --- /dev/null +++ b/.github/workflows/propose_release.yml @@ -0,0 +1,27 @@ +name: Propose Stable Release +on: + workflow_dispatch: + inputs: + release_type: + type: choice + description: Release Type + options: + - patch + - minor + - major +jobs: + update_version: + uses: neongeckocom/.github/.github/workflows/propose_semver_release.yml@master + with: + branch: dev + release_type: ${{ inputs.release_type }} + update_changelog: True + pull_changes: + uses: neongeckocom/.github/.github/workflows/pull_master.yml@master + needs: update_version + with: + pr_reviewer: neonreviewers + pr_assignee: ${{ github.actor }} + pr_draft: false + pr_title: ${{ needs.update_version.outputs.version }} + pr_body: ${{ needs.update_version.outputs.changelog }} \ No newline at end of file diff --git a/.github/workflows/publish_release.yml b/.github/workflows/publish_release.yml index 65ee102b..314e5363 100644 --- a/.github/workflows/publish_release.yml +++ b/.github/workflows/publish_release.yml @@ -26,7 +26,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: '3.10' - name: Install Build Tools run: | python -m pip install build wheel @@ -35,6 +35,6 @@ jobs: run: | python setup.py bdist_wheel - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@master + uses: pypa/gh-action-pypi-publish@release/v1 with: - password: ${{secrets.PYPI_TOKEN}} \ No newline at end of file + password: ${{secrets.PYPI_TOKEN}} diff --git a/.github/workflows/publish_test_build.yml b/.github/workflows/publish_test_build.yml index 4c31b14a..bc5bbba5 100644 --- a/.github/workflows/publish_test_build.yml +++ b/.github/workflows/publish_test_build.yml @@ -18,7 +18,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: '3.10' - name: Install Build Tools run: | python -m pip install build wheel @@ -34,6 +34,6 @@ jobs: run: | python setup.py bdist_wheel - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@master + uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{secrets.PYPI_TOKEN}} diff --git a/.github/workflows/pull_master.yml b/.github/workflows/pull_master.yml deleted file mode 100644 index 8e9c5a87..00000000 --- a/.github/workflows/pull_master.yml +++ /dev/null @@ -1,21 +0,0 @@ -# This workflow will generate a PR for changes in cert into master - -name: Pull to Master -on: - push: - branches: - - dev - workflow_dispatch: - -jobs: - pull_changes: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: pull-request-action - uses: repo-sync/pull-request@v2 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - pr_reviewer: 'neonreviewers' - pr_assignee: 'neondaniel' - pr_draft: true \ No newline at end of file diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 0a00255f..8c13576e 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -7,7 +7,7 @@ jobs: unit_tests: strategy: matrix: - python-version: [ '3.8', '3.9' ] + python-version: [ '3.10' ] max-parallel: 1 runs-on: ubuntu-latest env: @@ -24,6 +24,7 @@ jobs: python -m pip install --upgrade pip pip install -r requirements/requirements.txt pip install -r requirements/test_requirements.txt + pip install -r requirements/legacy_migration_requirements.txt - name: Get Credential run: | mkdir -p ~/.local/share/neon @@ -51,7 +52,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: '3.10' - name: Install Build Tools run: | python -m pip install build wheel diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..51acee97 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: check-json + - id: check-symlinks + - id: end-of-file-fixer + - id: trailing-whitespace + - id: pretty-format-json + - id: requirements-txt-fixer + - id: sort-simple-yaml +- repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black diff --git a/chat_client/__main__.py b/chat_client/__main__.py index bdf894a2..a67aa0e9 100644 --- a/chat_client/__main__.py +++ b/chat_client/__main__.py @@ -31,7 +31,10 @@ from chat_client.wsgi import app -if __name__ == '__main__': - uvicorn.run(app=app, host=os.environ.get('HOST', '127.0.0.1'), - port=int(os.environ.get('PORT', 8001)), - log_level=os.environ.get('LOG_LEVEL', 'INFO').lower()) +if __name__ == "__main__": + uvicorn.run( + app=app, + host=os.environ.get("HOST", "127.0.0.1"), + port=int(os.environ.get("PORT", 8001)), + log_level=os.environ.get("LOG_LEVEL", "INFO").lower(), + ) diff --git a/chat_client/app.py b/chat_client/app.py index b5528f0c..96ad43dd 100644 --- a/chat_client/app.py +++ b/chat_client/app.py @@ -47,44 +47,55 @@ sys.path.append(os.path.pardir) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from .blueprints import base as base_blueprint, \ - chat as chat_blueprint, \ - components as components_blueprint +from .blueprints import ( + base as base_blueprint, + chat as chat_blueprint, + components as components_blueprint, +) def create_app() -> FastAPI: """ - Application factory for the Klatchat Client + Application factory for the Klatchat Client """ - app_version = get_version('chat_client/version.py') - LOG.name = os.environ.get('LOG_NAME', 'client_err') - LOG.base_path = os.environ.get('LOG_BASE_PATH', '.') - LOG.init(config={'level': os.environ.get('LOG_LEVEL', 'INFO'), 'path': os.environ.get('LOG_PATH', os.getcwd())}) - logger = LOG.create_logger('chat_client') + app_version = get_version("chat_client/version.py") + LOG.name = os.environ.get("LOG_NAME", "client_err") + LOG.base_path = os.environ.get("LOG_BASE_PATH", ".") + LOG.init( + config={ + "level": os.environ.get("LOG_LEVEL", "INFO"), + "path": os.environ.get("LOG_PATH", os.getcwd()), + } + ) + logger = LOG.create_logger("chat_client") logger.addHandler(logging.StreamHandler()) - LOG.info(f'Starting Klatchat Client v{app_version}') - chat_app = FastAPI(title="Klatchat Client", - version=app_version) + LOG.info(f"Starting Klatchat Client v{app_version}") + chat_app = FastAPI(title="Klatchat Client", version=app_version) @chat_app.middleware("http") async def log_requests(request: Request, call_next): """Logs requests and gracefully handles Internal Server Errors""" - idem = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) + idem = "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) LOG.info(f"rid={idem} start request path={request.url.path}") start_time = time.time() try: response = await call_next(request) process_time = (time.time() - start_time) * 1000 - formatted_process_time = '{0:.2f}'.format(process_time) - LOG.info(f"rid={idem} completed_in={formatted_process_time}ms status_code={response.status_code}") + formatted_process_time = "{0:.2f}".format(process_time) + LOG.info( + f"rid={idem} completed_in={formatted_process_time}ms status_code={response.status_code}" + ) return response except ConnectionError as ex: LOG.error(ex) from .client_config import app_config - return Response(f'Connection error : {app_config["SERVER_URL"]}', status_code=404) + + return Response( + f'Connection error : {app_config["SERVER_URL"]}', status_code=404 + ) except Exception as ex: LOG.error(f"rid={idem} received an exception {ex}") - return Response(f'Chat server error occurred', status_code=500) + return Response(f"Chat server error occurred", status_code=500) # Redirects any not found pages to chats page @chat_app.exception_handler(StarletteHTTPException) @@ -92,20 +103,28 @@ async def custom_http_exception_handler(request, exc): if exc.status_code == status.HTTP_404_NOT_FOUND: return RedirectResponse("/chats") - __cors_allowed_origins = os.environ.get('COST_ALLOWED_ORIGINS', '') or '*' + __cors_allowed_origins = os.environ.get("CORS_ALLOWED_ORIGINS", "") or "*" - LOG.info(f'CORS_ALLOWED_ORIGINS={__cors_allowed_origins}') + LOG.info(f"CORS_ALLOWED_ORIGINS={__cors_allowed_origins}") chat_app.add_middleware( CORSMiddleware, - allow_origins=__cors_allowed_origins.split(','), + allow_origins=__cors_allowed_origins.split(","), allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) - static_suffix = '/build' if os.environ.get('KLAT_ENV', 'dev').upper() == 'PROD' else '' - chat_app.mount("/css", StaticFiles(directory=f"chat_client/static/css{static_suffix}"), name="css") - chat_app.mount("/js", StaticFiles(directory=f"chat_client/static/js{static_suffix}"), name="js") + static_suffix = ( + "/build" if os.environ.get("KLAT_ENV", "dev").upper() == "PROD" else "" + ) + chat_app.mount( + "/css", + StaticFiles(directory=f"chat_client/static/css{static_suffix}"), + name="css", + ) + chat_app.mount( + "/js", StaticFiles(directory=f"chat_client/static/js{static_suffix}"), name="js" + ) chat_app.mount("/img", StaticFiles(directory=f"chat_client/static/img"), name="img") chat_app.include_router(base_blueprint.router) diff --git a/chat_client/blueprints/auth.py b/chat_client/blueprints/auth.py index 9346219c..9344c7dc 100644 --- a/chat_client/blueprints/auth.py +++ b/chat_client/blueprints/auth.py @@ -35,24 +35,22 @@ router = APIRouter( prefix="/auth", - responses={'404': {"description": "Unknown endpoint"}}, + responses={"404": {"description": "Unknown endpoint"}}, ) @router.post("/login", response_class=JSONResponse) -async def login(username: str = Form(...), - password: str = Form(...)): +async def login(username: str = Form(...), password: str = Form(...)): """ - Forwards input login data to the Server API endpoint and handles the returned response + Forwards input login data to the Server API endpoint and handles the returned response - :param username: posted Form Data username param - :param password: posted Form Data password param + :param username: posted Form Data username param + :param password: posted Form Data password param - :returns Response object depending on returned status with refreshed session cookies if status_code == 200 + :returns Response object depending on returned status with refreshed session cookies if status_code == 200 """ - data = dict(username=username, - password=password) + data = dict(username=username, password=password) post_response = requests.post(f'{app_config["SERVER_URL"]}/auth/login', data=data) @@ -61,38 +59,37 @@ async def login(username: str = Form(...), response = JSONResponse(content=json_data, status_code=post_response.status_code) if post_response.status_code == 200: - for cookie in post_response.cookies: - - response.delete_cookie('session') + response.delete_cookie("session") response.set_cookie(key=cookie.name, value=cookie.value, httponly=True) - LOG.info(f'Login response for {username}: {json_data}') + LOG.info(f"Login response for {username}: {json_data}") return response @router.post("/signup", response_class=JSONResponse) -async def signup(nickname: str = Form(...), - first_name: str = Form(...), - last_name: str = Form(...), - password: str = Form(...)): +async def signup( + nickname: str = Form(...), + first_name: str = Form(...), + last_name: str = Form(...), + password: str = Form(...), +): """ - Forwards new user signup data to the Server API endpoint and handles the returned response + Forwards new user signup data to the Server API endpoint and handles the returned response - :param nickname: posted Form Data nickname param - :param first_name: posted Form Data first name param - :param last_name: posted Form Data last name param - :param password: posted Form Data password param + :param nickname: posted Form Data nickname param + :param first_name: posted Form Data first name param + :param last_name: posted Form Data last name param + :param password: posted Form Data password param - :returns Response object depending on returned status with refreshed session cookies if status_code == 200 + :returns Response object depending on returned status with refreshed session cookies if status_code == 200 """ - data = dict(nickname=nickname, - first_name=first_name, - last_name=last_name, - password=password) + data = dict( + nickname=nickname, first_name=first_name, last_name=last_name, password=password + ) post_response = requests.post(f'{app_config["SERVER_URL"]}/auth/signup', data=data) @@ -101,13 +98,12 @@ async def signup(nickname: str = Form(...), response = JSONResponse(content=json_data, status_code=post_response.status_code) if post_response.status_code == 200: - - response.delete_cookie('session') + response.delete_cookie("session") for cookie in post_response.cookies: response.set_cookie(key=cookie.name, value=cookie.value, httponly=True) - LOG.info(f'Signup response for {nickname}: {json_data}') + LOG.info(f"Signup response for {nickname}: {json_data}") return response @@ -123,12 +119,11 @@ async def logout(): response = JSONResponse(content=json_data, status_code=logout_response.status_code) if logout_response.status_code == 200: - - response.delete_cookie('session') + response.delete_cookie("session") for cookie in logout_response.cookies: response.set_cookie(key=cookie.name, value=cookie.value, httponly=True) - LOG.info(f'Logout response: {json_data}') + LOG.info(f"Logout response: {json_data}") return response diff --git a/chat_client/blueprints/base.py b/chat_client/blueprints/base.py index 71ad85c1..0eea3450 100644 --- a/chat_client/blueprints/base.py +++ b/chat_client/blueprints/base.py @@ -33,7 +33,7 @@ router = APIRouter( prefix="/base", - responses={'404': {"description": "Unknown endpoint"}}, + responses={"404": {"description": "Unknown endpoint"}}, ) @@ -41,8 +41,8 @@ async def fetch_runtime_config(): """Fetches runtime config from local JSON file in provided location""" try: - runtime_configs = app_config.get('RUNTIME_CONFIG', {}) + runtime_configs = app_config.get("RUNTIME_CONFIG", {}) except Exception as ex: - LOG.error(f'Exception while fetching runtime configs: {ex}') + LOG.error(f"Exception while fetching runtime configs: {ex}") runtime_configs = {} return JSONResponse(content=runtime_configs) diff --git a/chat_client/blueprints/chat.py b/chat_client/blueprints/chat.py index 6e876a19..da466e1c 100644 --- a/chat_client/blueprints/chat.py +++ b/chat_client/blueprints/chat.py @@ -38,46 +38,53 @@ router = APIRouter( prefix="/chats", - responses={'404': {"description": "Unknown endpoint"}}, + responses={"404": {"description": "Unknown endpoint"}}, ) conversation_templates = Jinja2Templates(directory="chat_client/templates") -@router.get('/') +@router.get("/") async def chats(request: Request): """ - Renders chats page HTML as a response related to the input request + Renders chats page HTML as a response related to the input request - :param request: input Request object + :param request: input Request object - :returns chats template response + :returns chats template response """ - return conversation_templates.TemplateResponse("conversation/base.html", - {"request": request, - 'section': 'Followed Conversations', - 'add_sio': True, - 'redirect_to_https': - app_config.get('FORCE_HTTPS', False)}) + return conversation_templates.TemplateResponse( + "conversation/base.html", + { + "request": request, + "section": "Followed Conversations", + "add_sio": True, + "redirect_to_https": app_config.get("FORCE_HTTPS", False), + }, + ) @router.get("/nano_demo") async def nano_demo(request: Request): """ - Minimal working Example of Nano + Minimal working Example of Nano """ client_url = f'"{request.url.scheme}://{request.url.netloc}"' server_url = f'"{app_config["SERVER_URL"]}"' - if app_config.get('FORCE_HTTPS', False): - client_url = client_url.replace('http://', 'https://') - server_url = server_url.replace('http://', 'https://') - client_url_unquoted = client_url.replace('"', '') - return conversation_templates.TemplateResponse("sample_nano.html", - {"request": request, - 'title': 'Nano Demonstration', - 'description': 'Klatchat Nano is injectable JS module, ' - 'allowing to render Klat conversations on any third-party pages, ' - 'supporting essential features.', - 'server_url': server_url, - 'client_url': client_url, - 'client_url_unquoted': client_url_unquoted}) + if app_config.get("FORCE_HTTPS", False): + client_url = client_url.replace("http://", "https://") + server_url = server_url.replace("http://", "https://") + client_url_unquoted = client_url.replace('"', "") + return conversation_templates.TemplateResponse( + "sample_nano.html", + { + "request": request, + "title": "Nano Demonstration", + "description": "Klatchat Nano is injectable JS module, " + "allowing to render Klat conversations on any third-party pages, " + "supporting essential features.", + "server_url": server_url, + "client_url": client_url, + "client_url_unquoted": client_url_unquoted, + }, + ) diff --git a/chat_client/blueprints/components.py b/chat_client/blueprints/components.py index c414c6d5..6a3f0e85 100644 --- a/chat_client/blueprints/components.py +++ b/chat_client/blueprints/components.py @@ -38,69 +38,76 @@ router = APIRouter( prefix="/components", - responses={'404': {"description": "Unknown endpoint"}}, + responses={"404": {"description": "Unknown endpoint"}}, ) -@router.get('/profile') -async def get_profile_modal(request: Request, nickname: str = '', edit: str = '0'): - """ Callbacks template with matching modal populated with user's data """ - auth_header = 'Authorization' - headers = {auth_header: request.headers.get(auth_header, '')} - if edit == '1': +@router.get("/profile") +async def get_profile_modal(request: Request, nickname: str = "", edit: str = "0"): + """Callbacks template with matching modal populated with user's data""" + auth_header = "Authorization" + headers = {auth_header: request.headers.get(auth_header, "")} + if edit == "1": resp = requests.get(f'{app_config["SERVER_URL"]}/users_api/', headers=headers) if resp.ok: - user = resp.json()['data'] + user = resp.json()["data"] # if user.get('is_tmp'): # raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, # detail='Cannot render edit modal for tmp user') else: - return respond('Server was not able to process the request', 422) - template_name = 'edit_profile_modal' + return respond("Server was not able to process the request", 422) + template_name = "edit_profile_modal" else: if not nickname: - return respond('No nickname provided', 422) - resp = requests.get(f'{app_config["SERVER_URL"]}/users_api/get_users?nicknames={nickname}', headers=headers) + return respond("No nickname provided", 422) + resp = requests.get( + f'{app_config["SERVER_URL"]}/users_api/get_users?nicknames={nickname}', + headers=headers, + ) if resp.ok: - user_data = resp.json().get('users', []) + user_data = resp.json().get("users", []) if not user_data: - return respond(f'User with nickname={nickname} not found', 404) + return respond(f"User with nickname={nickname} not found", 404) else: user = user_data[0] else: - return respond('Server was not able to process the request', 422) - template_name = 'profile_modal' - context = {'server_url': app_config["SERVER_URL"], - 'user_id': user['_id'], - 'nickname': user['nickname'], - 'first_name': user.get('first_name', 'Klat'), - 'last_name': user.get('last_name', 'User'), - 'bio': user.get('bio', f'No information about {user["nickname"]}')} - return callback_template(request=request, template_name=template_name, context=context) + return respond("Server was not able to process the request", 422) + template_name = "profile_modal" + context = { + "server_url": app_config["SERVER_URL"], + "user_id": user["_id"], + "nickname": user["nickname"], + "first_name": user.get("first_name", "Klat"), + "last_name": user.get("last_name", "User"), + "bio": user.get("bio", f'No information about {user["nickname"]}'), + } + return callback_template( + request=request, template_name=template_name, context=context + ) -@router.get('/conversation') -async def render_conversation(request: Request, skin: str = 'base'): +@router.get("/conversation") +async def render_conversation(request: Request, skin: str = "base"): """ - Base renderer by the provided HTML template name + Base renderer by the provided HTML template name - :param request: FastAPI request object - :param skin: conversation skin to render (defaults to 'base') + :param request: FastAPI request object + :param skin: conversation skin to render (defaults to 'base') - :returns chats conversation response + :returns chats conversation response """ - folder = 'conversation_skins' - return callback_template(request=request, template_name=f'{folder}/{skin}') + folder = "conversation_skins" + return callback_template(request=request, template_name=f"{folder}/{skin}") -@router.get('/{template_name}') +@router.get("/{template_name}") async def render_template(request: Request, template_name: str): """ - Base renderer by the provided HTML template name + Base renderer by the provided HTML template name - :param request: FastAPI request object - :param template_name: name of template to fetch + :param request: FastAPI request object + :param template_name: name of template to fetch - :returns chats conversation response + :returns chats conversation response """ return callback_template(request=request, template_name=template_name) diff --git a/chat_client/blueprints/users.py b/chat_client/blueprints/users.py index 0fd6165a..3312a28d 100644 --- a/chat_client/blueprints/users.py +++ b/chat_client/blueprints/users.py @@ -40,24 +40,24 @@ router = APIRouter( prefix="/users", - responses={'404': {"description": "Unknown user"}}, + responses={"404": {"description": "Unknown user"}}, ) @router.get("/") async def get_user(response: Response, request: Request, user_id: Optional[str] = None): """ - Forwards getting user by id to the server API and handles the response cookies + Forwards getting user by id to the server API and handles the response cookies - :param request: input request object - :param response: output response object with applied cookies from server response - :param user_id: requested user id + :param request: input request object + :param response: output response object with applied cookies from server response + :param user_id: requested user id - :returns JSON-formatted response from server + :returns JSON-formatted response from server """ - user_id = user_id or '' + user_id = user_id or "" url = f'{app_config["SERVER_URL"]}/users_api?user_id={user_id}' - LOG.info(f'Getting user from url = {url}') + LOG.info(f"Getting user from url = {url}") get_user_response = requests.get(url, cookies=request.cookies) if not get_user_response or get_user_response.status_code != 200: raise HTTPException( @@ -70,45 +70,55 @@ async def get_user(response: Response, request: Request, user_id: Optional[str] @router.post("/update") -async def update_user(request: Request, - user_id: str = Form(...), - first_name: str = Form(""), - last_name: str = Form(""), - bio: str = Form(""), - nickname: str = Form(""), - password: str = Form(""), - repeat_password: str = Form(""), - avatar: UploadFile = File(None), ): +async def update_user( + request: Request, + user_id: str = Form(...), + first_name: str = Form(""), + last_name: str = Form(""), + bio: str = Form(""), + nickname: str = Form(""), + password: str = Form(""), + repeat_password: str = Form(""), + avatar: UploadFile = File(None), +): """ - Forwards getting user by id to the server API and handles the response cookies + Forwards getting user by id to the server API and handles the response cookies - :param request: input request object - :param user_id: requested user id - :param first_name: new first name value - :param last_name: new last name value - :param nickname: new nickname value - :param bio: updated user's bio - :param password: new password - :param repeat_password: repeat new password - :param avatar: new avatar image + :param request: input request object + :param user_id: requested user id + :param first_name: new first name value + :param last_name: new last name value + :param nickname: new nickname value + :param bio: updated user's bio + :param password: new password + :param repeat_password: repeat new password + :param avatar: new avatar image - :returns JSON-formatted response from server + :returns JSON-formatted response from server """ send_kwargs = { - 'data': { - 'user_id': user_id, - 'first_name': first_name, - 'last_name': last_name, - 'bio': bio, - 'nickname': nickname, - 'password': password, - 'repeat_password': repeat_password, + "data": { + "user_id": user_id, + "first_name": first_name, + "last_name": last_name, + "bio": bio, + "nickname": nickname, + "password": password, + "repeat_password": repeat_password, } } if avatar and avatar.filename: - send_kwargs['files'] = {'avatar': (avatar.filename, avatar.file.read(), avatar.content_type, )} + send_kwargs["files"] = { + "avatar": ( + avatar.filename, + avatar.file.read(), + avatar.content_type, + ) + } - return call_server(url_suffix='/users_api/update', - request_method='post', - request=request, - **send_kwargs) + return call_server( + url_suffix="/users_api/update", + request_method="post", + request=request, + **send_kwargs, + ) diff --git a/chat_client/client_config.py b/chat_client/client_config.py index 9f9e935d..acf75365 100644 --- a/chat_client/client_config.py +++ b/chat_client/client_config.py @@ -32,9 +32,20 @@ from config import Configuration from utils.logging_utils import LOG -config_file_path = os.environ.get('CHATCLIENT_CONFIG', '~/.local/share/neon/credentials_client.json') - -config = Configuration(from_files=[config_file_path]) -app_config = config.get('CHAT_CLIENT', {}).get(Configuration.KLAT_ENV) - -LOG.info(f'App config: {app_config}') +config_file_path = os.path.expanduser(os.environ.get( + "CHATCLIENT_CONFIG", "~/.local/share/neon/credentials_client.json" +)) +if os.path.isfile(config_file_path): + LOG.warning(f"Using legacy configuration at {config_file_path}") + config = Configuration(from_files=[config_file_path]) + app_config = config.get("CHAT_CLIENT", {}).get(Configuration.KLAT_ENV) +else: + # ovos-config has built-in mechanisms for loading configuration files based + # on envvars, so the configuration structure is simplified + from ovos_config.config import Configuration + app_config = Configuration().get("CHAT_CLIENT") + env_spec = os.environ.get("KLAT_ENV") + if env_spec and app_config.get(env_spec): + LOG.warning("Legacy configuration handling KLAT_ENV envvar") + app_config = app_config.get(env_spec) +LOG.info(f"App config: {app_config}") diff --git a/chat_client/client_utils/api_utils.py b/chat_client/client_utils/api_utils.py index 5544989c..9fb5df82 100644 --- a/chat_client/client_utils/api_utils.py +++ b/chat_client/client_utils/api_utils.py @@ -35,18 +35,25 @@ from utils.http_utils import respond -def call_server(url_suffix: str, request_method: str = 'get', - return_type: str = 'json', request: Request = None, **kwargs): - """ Convenience wrapper to call application server from client server""" +def call_server( + url_suffix: str, + request_method: str = "get", + return_type: str = "json", + request: Request = None, + **kwargs, +): + """Convenience wrapper to call application server from client server""" url = f'{app_config["SERVER_URL"]}{url_suffix}' if request: - kwargs['cookies'] = request.cookies + kwargs["cookies"] = request.cookies response = getattr(requests, request_method)(url, **kwargs) if response.ok: - if return_type == 'json': + if return_type == "json": return JSONResponse(content=response.json()) - elif return_type == 'text': + elif return_type == "text": return response.text else: - return respond(msg=response.json().get('msg', 'Server Error'), - status_code=response.status_code) + return respond( + msg=response.json().get("msg", "Server Error"), + status_code=response.status_code, + ) diff --git a/chat_client/static/js/chat_utils.js b/chat_client/static/js/chat_utils.js index 6a71c92e..d331796a 100644 --- a/chat_client/static/js/chat_utils.js +++ b/chat_client/static/js/chat_utils.js @@ -262,7 +262,7 @@ async function buildConversation(conversationData={}, skin = CONVERSATION_SKINS. const newConversationHTML = await buildConversationHTML(conversationData, skin); const conversationsBody = document.getElementById(conversationParentID); conversationsBody.insertAdjacentHTML('afterbegin', newConversationHTML); - initMessages(conversationData, skin); + await initMessages(conversationData, skin); const messageListContainer = getMessageListContainer(cid); const currentConversation = document.getElementById(cid); @@ -363,18 +363,18 @@ async function buildConversation(conversationData={}, skin = CONVERSATION_SKINS. /** * Gets conversation data based on input string * @param input: input string text - * @param firstMessageID: id of the the most recent message + * @param oldestMessageTS: creation timestamp of the oldest displayed message * @param skin: resolves by server for which data to return * @param maxResults: max number of messages to fetch * @param alertParent: parent of error alert (optional) * @returns {Promise<{}>} promise resolving conversation data returned */ -async function getConversationDataByInput(input="", skin=CONVERSATION_SKINS.BASE, firstMessageID=null, maxResults=20, alertParent=null){ +async function getConversationDataByInput(input="", skin=CONVERSATION_SKINS.BASE, oldestMessageTS=null, maxResults=20, alertParent=null){ let conversationData = {}; - if(input && typeof input === "string"){ - let query_url = `chat_api/search/${input}?limit_chat_history=${maxResults}&skin=${skin}`; - if(firstMessageID){ - query_url += `&first_message_id=${firstMessageID}`; + if(input){ + let query_url = `chat_api/search/${input.toString()}?limit_chat_history=${maxResults}&skin=${skin}`; + if(oldestMessageTS){ + query_url += `&creation_time_from=${oldestMessageTS}`; } await fetchServer(query_url) .then(response => { @@ -443,7 +443,8 @@ async function addNewCID(cid, skin){ * @param cid: conversation id to remove */ async function removeConversation(cid){ - return await getChatAlignmentTable().where({cid: cid}).delete(); + return await Promise.all([DBGateway.getInstance(DB_TABLES.CHAT_ALIGNMENT).deleteItem(cid), + DBGateway.getInstance(DB_TABLES.CHAT_MESSAGES_PAGINATION).deleteItem(cid)]); } /** @@ -698,7 +699,7 @@ async function createNewConversation(conversationName, isPrivate=false, conversa let formData = new FormData(); formData.append('conversation_name', conversationName); - formData.append('id', conversationID); + formData.append('conversation_id', conversationID); formData.append('is_private', isPrivate? '1': '0') formData.append('bound_service', boundServiceID?boundServiceID: ''); diff --git a/chat_client/static/js/db.js b/chat_client/static/js/db.js index bb654592..e60adbd3 100644 --- a/chat_client/static/js/db.js +++ b/chat_client/static/js/db.js @@ -3,12 +3,14 @@ const DATABASES = { } const DB_TABLES = { CHAT_ALIGNMENT: 'chat_alignment', - MINIFY_SETTINGS: 'minify_settings' + MINIFY_SETTINGS: 'minify_settings', + CHAT_MESSAGES_PAGINATION: 'chat_messages_pagination' } const __db_instances = {} const __db_definitions = { - "chats": { - "chat_alignment": `cid, added_on, skin` + [DATABASES.CHATS]: { + [DB_TABLES.CHAT_ALIGNMENT]: `cid, added_on, skin`, + [DB_TABLES.CHAT_MESSAGES_PAGINATION]: `cid, oldest_created_on` } } @@ -30,4 +32,46 @@ const getDb = (db, table) => { _instance = __db_instances[db]; } return _instance[table]; -} \ No newline at end of file +} + + +class DBGateway { + constructor(db, table) { + this.db = db; + this.table = table; + + this._db_instance = getDb(this.db, this.table); + this._db_columns_definitions = __db_definitions[this.db][this.table] + this._db_key = this._db_columns_definitions.split(',')[0] + } + + async getItem(key = "") { + return await this._db_instance.where( {[this._db_key]: key} ).first(); + } + + async listItems(orderBy="") { + let expression = this._db_instance; + if (orderBy !== ""){ + expression = expression.orderBy(orderBy) + } + return await expression.toArray(); + } + + async putItem(data = {}){ + return await this._db_instance.put(data, [data[this._db_key]]) + } + + updateItem(data = {}) { + const key = data[this._db_key] + delete data[this._db_key] + return this._db_instance.update(key, data); + } + + async deleteItem(key = "") { + return await this._db_instance.where({[this._db_key]: key}).delete(); + } + + static getInstance(table){ + return new DBGateway(DATABASES.CHATS, table); + } +} diff --git a/chat_client/static/js/klatchatNano.js b/chat_client/static/js/klatchatNano.js index 400a1067..4ef31b49 100644 --- a/chat_client/static/js/klatchatNano.js +++ b/chat_client/static/js/klatchatNano.js @@ -1441,7 +1441,7 @@ async function createNewConversation(conversationName, isPrivate = false, conver let formData = new FormData(); formData.append('conversation_name', conversationName); - formData.append('id', conversationID); + formData.append('conversation_id', conversationID); formData.append('is_private', isPrivate ? '1' : '0') formData.append('bound_service', boundServiceID ? boundServiceID : ''); diff --git a/chat_client/static/js/message_utils.js b/chat_client/static/js/message_utils.js index c80774ab..937e6581 100644 --- a/chat_client/static/js/message_utils.js +++ b/chat_client/static/js/message_utils.js @@ -165,10 +165,10 @@ async function addOldMessages(cid, skin=CONVERSATION_SKINS.BASE) { if (messageContainer.children.length > 0) { for (let i = 0; i < messageContainer.children.length; i++) { const firstMessageItem = messageContainer.children[i]; - const firstMessageID = getFirstMessageFromCID( firstMessageItem ); - if (firstMessageID) { + const oldestMessageTS = await DBGateway.getInstance(DB_TABLES.CHAT_MESSAGES_PAGINATION).getItem(cid).then(res=> res?.oldest_created_on || null); + if (oldestMessageTS) { const numMessages = await getCurrentSkin(cid) === CONVERSATION_SKINS.PROMPTS? 50: 20; - await getConversationDataByInput( cid, skin, firstMessageID, numMessages, null ).then( async conversationData => { + await getConversationDataByInput( cid, skin, oldestMessageTS, numMessages, null ).then( async conversationData => { if (messageContainer) { const userMessageList = getUserMessages( conversationData, null ); userMessageList.sort( (a, b) => { @@ -183,7 +183,7 @@ async function addOldMessages(cid, skin=CONVERSATION_SKINS.BASE) { console.debug( `!!message_id=${message["message_id"]} is already displayed` ) } } - initMessages( conversationData, skin ); + await initMessages( conversationData, skin ); } } ).then( _ => { firstMessageItem.scrollIntoView( {behavior: "smooth"} ); @@ -293,7 +293,7 @@ function addProfileDisplay(cid, messageId, messageType='plain'){ /** * Inits addProfileDisplay() on each message of provided conversation - * @param conversationData: target conversation data + * @param conversationData - target conversation data */ function initProfileDisplay(conversationData){ getUserMessages(conversationData, null).forEach(message => { @@ -302,9 +302,25 @@ function initProfileDisplay(conversationData){ } +/** + * Inits pagination based on the oldest message creation timestamp + * @param conversationData - target conversation data + */ +async function initPagination(conversationData) { + const userMessages = getUserMessages(conversationData, null); + if (userMessages.length > 0){ + const oldestMessage = Math.min(...userMessages.map(msg => parseInt(msg.created_on))); + await DBGateway + .getInstance(DB_TABLES.CHAT_MESSAGES_PAGINATION) + .putItem({cid: conversationData['_id'], + oldest_created_on: oldestMessage}) + } +} + + /** * Initializes messages based on provided conversation aata - * @param conversationData: JS Object containing conversation data of type: + * @param conversationData - JS Object containing conversation data of type: * { * '_id': 'id of conversation', * 'conversation_name': 'title of the conversation', @@ -318,14 +334,15 @@ function initProfileDisplay(conversationData){ * 'created_on': 'creation time of the message' * }, ... (num of user messages returned)] * } - * @param skin: target conversation skin to consider + * @param skin - target conversation skin to consider */ -function initMessages(conversationData, skin = CONVERSATION_SKINS.BASE){ +async function initMessages(conversationData, skin = CONVERSATION_SKINS.BASE){ initProfileDisplay(conversationData); attachReplies(conversationData); addAttachments(conversationData); addCommunicationChannelTransformCallback(conversationData); initLoadOldMessages(conversationData, skin); + await initPagination(conversationData); } /** diff --git a/chat_client/static/js/sio.js b/chat_client/static/js/sio.js index 286f9802..508871e3 100644 --- a/chat_client/static/js/sio.js +++ b/chat_client/static/js/sio.js @@ -14,8 +14,8 @@ sioTriggeringEvents.forEach(event=>{ */ function initSIO(){ - const sioServerURL = configData['CHAT_SERVER_URL_BASE']; - const socket = io(sioServerURL, {transports: ['polling'], extraHeaders: { + const sioServerURL = configData['CHAT_SERVER_URL_BASE'].replace("http", 'ws'); + const socket = io(sioServerURL, {extraHeaders: { "session": getSessionToken() }}); @@ -105,4 +105,4 @@ function initSIO(){ // }); return socket; -} \ No newline at end of file +} diff --git a/chat_server/__main__.py b/chat_server/__main__.py index a7172bb9..ffd6eb74 100644 --- a/chat_server/__main__.py +++ b/chat_server/__main__.py @@ -31,7 +31,11 @@ from .wsgi import app -if __name__ == '__main__': - uvicorn.run(app=app, root_path=os.environ.get('URL_PREFIX', ''), host=os.environ.get('HOST', '127.0.0.1'), - port=int(os.environ.get('PORT', 8000)), - log_level=os.environ.get('LOG_LEVEL', 'INFO').lower()) +if __name__ == "__main__": + uvicorn.run( + app=app, + root_path=os.environ.get("URL_PREFIX", ""), + host=os.environ.get("HOST", "127.0.0.1"), + port=int(os.environ.get("PORT", 8000)), + log_level=os.environ.get("LOG_LEVEL", "INFO").lower(), + ) diff --git a/chat_server/app.py b/chat_server/app.py index f28ffc59..8634b5a5 100644 --- a/chat_server/app.py +++ b/chat_server/app.py @@ -25,94 +25,66 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import logging +import importlib import os -import random -import string import sys -import time import socketio from typing import Union from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware from fastapi.testclient import TestClient -from starlette.requests import Request - -from utils.common import get_version -from utils.logging_utils import LOG +from starlette.middleware.cors import CORSMiddleware sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from .sio import sio -from .blueprints import auth as auth_blueprint, \ - chat as chat_blueprint, \ - users as users_blueprint, \ - languages as languages_blueprint, \ - files_api as files_blueprint, \ - preferences as preferences_blueprint +from utils.common import get_version +from utils.logging_utils import LOG +from chat_server.server_utils.middleware import SUPPORTED_MIDDLEWARE -def create_app(testing_mode: bool = False, sio_server: socketio.AsyncServer = sio) -> Union[FastAPI, socketio.ASGIApp]: +def create_app( + testing_mode: bool = False, sio_server: socketio.AsyncServer = None +) -> Union[FastAPI, socketio.ASGIApp]: """ - Application factory for the Klatchat Server + Application factory for the Klatchat Server - :param testing_mode: to run application in testing mode (defaults to False) - :param sio_server: socket io server instance (optional) + :param testing_mode: to run application in testing mode (defaults to False) + :param sio_server: socket io server instance (optional) """ - app_version = get_version('chat_server/version.py') - LOG.name = os.environ.get('LOG_NAME', 'server_err') - LOG.base_path = os.environ.get('LOG_BASE_PATH', '.') - LOG.init(config={'level': os.environ.get('LOG_LEVEL', 'INFO'), 'path': os.environ.get('LOG_PATH', os.getcwd())}) - logger = LOG.create_logger('chat_server') - logger.addHandler(logging.StreamHandler()) - LOG.info(f'Starting Klatchat Server v{app_version}') - chat_app = FastAPI(title="Klatchat Server API", - version=app_version) - - @chat_app.middleware("http") - async def log_requests(request: Request, call_next): - """Logs requests and gracefully handles Internal Server Errors""" - idem = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6)) - LOG.info(f"rid={idem} start request path={request.url.path}") - start_time = time.time() - try: - response = await call_next(request) - process_time = (time.time() - start_time) * 1000 - formatted_process_time = '{0:.2f}'.format(process_time) - log_message = f"rid={idem} completed_in={formatted_process_time}ms status_code={response.status_code}" - LOG.info(log_message) - return response - except Exception as ex: - LOG.error(f"rid={idem} received an exception {ex}") - return None + app_version = get_version("version.py") + chat_app = FastAPI(title="Klatchat Server API", version=app_version) - chat_app.include_router(auth_blueprint.router) - chat_app.include_router(chat_blueprint.router) - chat_app.include_router(users_blueprint.router) - chat_app.include_router(languages_blueprint.router) - chat_app.include_router(files_blueprint.router) - chat_app.include_router(preferences_blueprint.router) - - # __cors_allowed_origins = os.environ.get('COST_ALLOWED_ORIGINS', '').split(',') or ['*'] - # - # LOG.info(f'CORS_ALLOWED_ORIGINS={__cors_allowed_origins}') - # - # chat_app.user_middleware.clear() - chat_app.add_middleware( - CORSMiddleware, - allow_origins=['*'], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - # chat_app.middleware_stack = chat_app.build_middleware_stack() + _init_middleware(app=chat_app) + _init_blueprints(app=chat_app) if testing_mode: chat_app = TestClient(chat_app) if sio_server: - chat_app = socketio.ASGIApp(socketio_server=sio_server, - other_asgi_app=chat_app) + chat_app = socketio.ASGIApp(socketio_server=sio_server, other_asgi_app=chat_app) + + LOG.info(f"Starting Klatchat Server v{app_version}") return chat_app + + +def _init_blueprints(app: FastAPI): + blueprint_module = importlib.import_module("blueprints") + for blueprint_module_name in dir(blueprint_module): + if blueprint_module_name.endswith("blueprint"): + blueprint_obj = importlib.import_module( + f"blueprints.{blueprint_module_name.split('_blueprint')[0]}" + ) + app.include_router(blueprint_obj.router) + + +def _init_middleware(app: FastAPI): + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + for middleware_class in SUPPORTED_MIDDLEWARE: + app.add_middleware(middleware_class=middleware_class) diff --git a/chat_server/blueprints/__init__.py b/chat_server/blueprints/__init__.py new file mode 100644 index 00000000..6b527bb1 --- /dev/null +++ b/chat_server/blueprints/__init__.py @@ -0,0 +1,40 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Import blueprint here to include it to Web App +from . import ( + admin as admin_blueprint, + auth as auth_blueprint, + chat as chat_blueprint, + users as users_blueprint, + languages as languages_blueprint, + files_api as files_api_blueprint, + preferences as preferences_blueprint, + personas as personas_blueprint, + configs as configs_blueprint, +) diff --git a/chat_server/blueprints/admin.py b/chat_server/blueprints/admin.py new file mode 100644 index 00000000..63dcb0a2 --- /dev/null +++ b/chat_server/blueprints/admin.py @@ -0,0 +1,100 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from fastapi import APIRouter +from starlette.requests import Request +from starlette.responses import JSONResponse + +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI +from utils.logging_utils import LOG +from utils.http_utils import respond + +from chat_server.server_config import k8s_config +from chat_server.server_utils.auth import login_required +from chat_server.server_utils.k8s_utils import restart_deployment +from chat_server.server_utils.admin_utils import run_mq_validation + +router = APIRouter( + prefix="/admin", + responses={"404": {"description": "Unknown authorization endpoint"}}, +) + + +@router.post("/refresh/{service_name}") +@login_required(tmp_allowed=False, required_roles=["admin"]) +async def refresh_state( + request: Request, service_name: str, target_items: str | None = "" +): + """ + Refreshes state of the target + + :param request: Starlette Request Object + :param service_name: name of service to refresh + :param target_items: comma-separated list of items to refresh + + :returns JSON-formatted response from server + """ + target_items = [x for x in target_items.split(",") if x] + if service_name == "k8s": + if not k8s_config: + return respond("K8S Service Unavailable", 503) + deployments = target_items + if deployments == "*": + deployments = k8s_config.get("MANAGED_DEPLOYMENTS", []) + LOG.info(f"Restarting {deployments=!r}") + for deployment in deployments: + restart_deployment(deployment_name=deployment) + elif service_name == "mq": + run_mq_validation() + else: + return respond(f"Unknown refresh type: {service_name!r}", 404) + return respond("OK") + + +@router.get("/chats/list") +@login_required(tmp_allowed=False, required_roles=["admin"]) +async def chats_overview(request: Request, search_str: str = ""): + conversations_data = MongoDocumentsAPI.CHATS.get_conversation_data( + search_str=search_str, + limit=100, + allow_regex_search=True, + ) + result_data = [] + + for conversation_data in conversations_data: + + result_data.append( + { + "cid": conversation_data["_id"], + "conversation_name": conversation_data["conversation_name"], + "bound_service": conversation_data.get("bound_service", ""), + } + ) + # TODO: sort it based on PopularityCounter.get_first_n_items + + return JSONResponse(content=dict(data=result_data)) diff --git a/chat_server/blueprints/auth.py b/chat_server/blueprints/auth.py index 48f8456e..f86da171 100644 --- a/chat_server/blueprints/auth.py +++ b/chat_server/blueprints/auth.py @@ -31,51 +31,57 @@ from fastapi import APIRouter, Form, Request from fastapi.responses import JSONResponse -from chat_server.server_config import db_controller from utils.common import get_hash, generate_uuid -from chat_server.server_utils.auth import check_password_strength, get_current_user_data, generate_session_token +from chat_server.server_utils.auth import ( + check_password_strength, + get_current_user_data, + generate_session_token, +) +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond router = APIRouter( prefix="/auth", - responses={'404': {"description": "Unknown authorization endpoint"}}, + responses={"404": {"description": "Unknown authorization endpoint"}}, ) @router.post("/signup") -async def signup(first_name: str = Form(...), - last_name: str = Form(...), - nickname: str = Form(...), - password: str = Form(...)): +async def signup( + first_name: str = Form(...), + last_name: str = Form(...), + nickname: str = Form(...), + password: str = Form(...), +): """ - Creates new user based on received form data + Creates new user based on received form data - :param first_name: new user first name - :param last_name: new user last name - :param nickname: new user nickname (unique) - :param password: new user password + :param first_name: new user first name + :param last_name: new user last name + :param nickname: new user nickname (unique) + :param password: new user password - :returns JSON response with status corresponding to the new user creation status, - sets session cookies if creation is successful + :returns JSON response with status corresponding to the new user creation status, + sets session cookies if creation is successful """ - existing_user = db_controller.exec_query(query={'command': 'find_one', - 'document': 'users', - 'data': {'nickname': nickname}}) + existing_user = MongoDocumentsAPI.USERS.get_user(nickname=nickname) if existing_user: return respond("Nickname is already in use", 400) password_check = check_password_strength(password) - if password_check != 'OK': + if password_check != "OK": return respond(password_check, 400) - new_user_record = dict(_id=generate_uuid(length=20), - first_name=first_name, - last_name=last_name, - password=get_hash(password), - nickname=nickname, - date_created=int(time()), - is_tmp=False) - db_controller.exec_query(query=dict(document='users', command='insert_one', data=new_user_record)) - - token = generate_session_token(user_id=new_user_record['_id']) + new_user_record = dict( + _id=generate_uuid(length=20), + first_name=first_name, + last_name=last_name, + password=get_hash(password), + nickname=nickname, + date_created=int(time()), + is_tmp=False, + ) + MongoDocumentsAPI.USERS.add_item(data=new_user_record) + + token = generate_session_token(user_id=new_user_record["_id"]) return JSONResponse(content=dict(token=token)) @@ -83,22 +89,20 @@ async def signup(first_name: str = Form(...), @router.post("/login") async def login(username: str = Form(...), password: str = Form(...)): """ - Logs In user based on provided credentials + Logs In user based on provided credentials - :param username: provided username (nickname) - :param password: provided password matching username + :param username: provided username (nickname) + :param password: provided password matching username - :returns JSON response with status corresponding to authorization status, sets session cookie with response + :returns JSON response with status corresponding to authorization status, sets session cookie with response """ - user = db_controller.exec_query(query={'command': 'find_one', - 'document': 'users', - 'data': {'nickname': username}}) - if not user or user.get('is_tmp', False): + user = MongoDocumentsAPI.USERS.get_user(nickname=username) + if not user or user.get("is_tmp", False): return respond("Invalid username or password", 400) db_password = user["password"] if get_hash(password) != db_password: return respond("Invalid username or password", 400) - token = generate_session_token(user_id=user['_id']) + token = generate_session_token(user_id=user["_id"]) response = JSONResponse(content=dict(token=token)) return response @@ -107,11 +111,11 @@ async def login(username: str = Form(...), password: str = Form(...)): @router.get("/logout") async def logout(request: Request): """ - Erases current user session cookies and returns temporal credentials + Erases current user session cookies and returns temporal credentials - :param request: logout intended request + :param request: logout intended request - :returns response with temporal cookie + :returns response with temporal cookie """ user_data = get_current_user_data(request=request, force_tmp=True) response = JSONResponse(content=dict(token=user_data.session)) diff --git a/chat_server/blueprints/chat.py b/chat_server/blueprints/chat.py index 714d2c26..6d81a57b 100644 --- a/chat_server/blueprints/chat.py +++ b/chat_server/blueprints/chat.py @@ -25,117 +25,131 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Optional +import warnings from time import time -from fastapi import APIRouter, Request, Form +from fastapi import APIRouter, Request, Form, Depends from fastapi.responses import JSONResponse -from chat_server.constants.conversations import ConversationSkins -from chat_server.server_config import db_controller from chat_server.server_utils.auth import login_required from chat_server.server_utils.conversation_utils import build_message_json -from chat_server.server_utils.db_utils import DbUtils, MongoQuery, MongoCommands, MongoDocuments +from chat_server.server_utils.dependencies import CurrentUserDependency +from chat_server.server_utils.models.chats import GetConversationModel from chat_server.services.popularity_counter import PopularityCounter from utils.common import generate_uuid +from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators +from utils.database_utils.mongo_utils.queries.mongo_queries import fetch_message_data +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond from utils.logging_utils import LOG router = APIRouter( prefix="/chat_api", - responses={'404': {"description": "Unknown authorization endpoint"}}, + responses={"404": {"description": "Unknown authorization endpoint"}}, ) @router.post("/new") @login_required -async def new_conversation(request: Request, - conversation_id: str = Form(None), - conversation_name: str = Form(...), - is_private: str = Form(False), - bound_service: str = Form('')): +async def new_conversation( + request: Request, + conversation_id: str = Form(""), # DEPRECATED + conversation_name: str = Form(...), + is_private: str = Form(False), + bound_service: str = Form(""), +): """ - Creates new conversation from provided conversation data + Creates new conversation from provided conversation data - :param request: Starlette Request object - :param conversation_id: new conversation id (optional) - :param conversation_name: new conversation name (optional) - :param is_private: if new conversation should be private (defaults to False) - :param bound_service: name of the bound service (ignored if empty value) + :param request: Starlette Request object + :param conversation_id: new conversation id (DEPRECATED) + :param conversation_name: new conversation name (optional) + :param is_private: if new conversation should be private (defaults to False) + :param bound_service: name of the bound service (ignored if empty value) - :returns JSON response with new conversation data if added, 401 error message otherwise + :returns JSON response with new conversation data if added, 401 error message otherwise """ - - conversation_data = DbUtils.get_conversation_data(search_str=[conversation_id, conversation_name]) + if conversation_id: + warnings.warn( + "Param conversation id is no longer considered", DeprecationWarning + ) + conversation_data = MongoDocumentsAPI.CHATS.get_conversation_data( + search_str=[conversation_name], + ) if conversation_data: return respond(f'Conversation "{conversation_name}" already exists', 400) - cid = conversation_id or generate_uuid() - request_data_dict = {'_id': cid, - 'conversation_name': conversation_name, - 'is_private': True if is_private == '1' else False, - 'bound_service': bound_service, - 'created_on': int(time())} - db_controller.exec_query(query=MongoQuery(command=MongoCommands.INSERT_ONE, - document=MongoDocuments.CHATS, - data=request_data_dict)) - PopularityCounter.add_new_chat(cid=cid, name=conversation_name) + cid = generate_uuid() + request_data_dict = { + "_id": cid, + "conversation_name": conversation_name, + "is_private": True if is_private == "1" else False, + "bound_service": bound_service, + "created_on": int(time()), + } + MongoDocumentsAPI.CHATS.add_item(data=request_data_dict) + PopularityCounter.add_new_chat(cid=cid) return JSONResponse(content=request_data_dict) @router.get("/search/{search_str}") -# @login_required -async def get_matching_conversation(request: Request, - search_str: str, - chat_history_from: int = 0, - first_message_id: Optional[str] = None, - limit_chat_history: int = 100, - skin: str = ConversationSkins.BASE): +async def get_matching_conversation( + current_user: CurrentUserDependency, model: GetConversationModel = Depends() +): """ - Gets conversation data matching search string + Gets conversation data matching search string - :param request: Starlette Request object - :param search_str: provided search string - :param chat_history_from: upper time bound for messages - :param first_message_id: id of the first message to start from - :param limit_chat_history: lower time bound for messages - :param skin: conversation skin type from ConversationSkins + :param current_user: current user data + :param model: request data model described in GetConversationModel - :returns conversation data if found, 401 error code otherwise + :returns conversation data if found, 401 error code otherwise """ - conversation_data = DbUtils.get_conversation_data(search_str=search_str) + conversation_data = MongoDocumentsAPI.CHATS.get_conversation_data( + search_str=model.search_str, requested_user_id=current_user.user_id + ) if not conversation_data: - return respond(f'No conversation matching = "{search_str}"', 404) - - message_data = DbUtils.fetch_skin_message_data(skin=skin, - conversation_data=conversation_data, - start_idx=chat_history_from, - limit=limit_chat_history, - start_message_id=first_message_id) or [] - conversation_data['chat_flow'] = [] - for i in range(len(message_data)): - message_record = build_message_json(raw_message=message_data[i], skin=skin) - conversation_data['chat_flow'].append(message_record) + return respond(f'No conversation matching = "{model.search_str}"', 404) + + if model.creation_time_from: + query_filter = MongoFilter( + key="created_on", + logical_operator=MongoLogicalOperators.LT, + value=int(model.creation_time_from), + ) + else: + query_filter = None + + message_data = ( + fetch_message_data( + skin=model.skin, + conversation_data=conversation_data, + limit=model.limit_chat_history, + creation_time_filter=query_filter, + ) + or [] + ) + conversation_data["chat_flow"] = [ + build_message_json(raw_message=message_data[i], skin=model.skin) + for i in range(len(message_data)) + ] return conversation_data @router.get("/get_popular_cids") -async def get_popular_cids(search_str: str = "", - exclude_items="", - limit: int = 10): +async def get_popular_cids(search_str: str = "", exclude_items="", limit: int = 10): """ - Returns n-most popular conversations + Returns n-most popular conversations - :param search_str: Searched substring to match - :param exclude_items: list of conversation ids to exclude from search - :param limit: limit returned amount of matched instances + :param search_str: Searched substring to match + :param exclude_items: list of conversation ids to exclude from search + :param limit: limit returned amount of matched instances """ try: if exclude_items: - exclude_items = exclude_items.split(',') + exclude_items = exclude_items.split(",") items = PopularityCounter.get_first_n_items(search_str, exclude_items, limit) except Exception as ex: - LOG.error(f'Failed to extract most popular items - {ex}') + LOG.error(f"Failed to extract most popular items - {ex}") items = [] return JSONResponse(content=items) diff --git a/chat_server/blueprints/configs.py b/chat_server/blueprints/configs.py new file mode 100644 index 00000000..6f500011 --- /dev/null +++ b/chat_server/blueprints/configs.py @@ -0,0 +1,68 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from fastapi import APIRouter, Depends +from starlette.responses import JSONResponse + +from chat_server.server_utils.dependencies import CurrentUserDependency +from chat_server.server_utils.exceptions import ( + ItemNotFoundException, + UserUnauthorizedException, +) +from chat_server.server_utils.http_utils import KlatAPIResponse +from chat_server.server_utils.models.configs import SetConfigModel, ConfigModel +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI + +router = APIRouter( + prefix="/configs", + responses={"404": {"description": "Unknown endpoint"}}, +) + + +@router.get("/{config_property}") +async def get_config_data(model: ConfigModel = Depends()) -> JSONResponse: + """Retrieves configured data by name""" + items = MongoDocumentsAPI.CONFIGS.get_by_name( + config_name=model.config_property, version=model.version + ) + return JSONResponse(content=items) + + +@router.put("/{config_property}") +async def update_config( + current_user: CurrentUserDependency, model: SetConfigModel = Depends() +) -> JSONResponse: + """Updates provided config by name""" + if "admin" not in current_user.roles: + raise UserUnauthorizedException + updated_data = MongoDocumentsAPI.CONFIGS.update_by_name( + config_name=model.config_property, version=model.version, data=model.data + ) + if updated_data.matched_count == 0: + raise ItemNotFoundException + return KlatAPIResponse.OK diff --git a/chat_server/blueprints/files_api.py b/chat_server/blueprints/files_api.py index 6f282795..47be5778 100644 --- a/chat_server/blueprints/files_api.py +++ b/chat_server/blueprints/files_api.py @@ -31,93 +31,103 @@ from starlette.requests import Request from starlette.responses import JSONResponse -from chat_server.server_config import db_controller from chat_server.server_utils.auth import login_required -from chat_server.server_utils.db_utils import DbUtils from chat_server.server_utils.http_utils import get_file_response, save_file +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond from utils.logging_utils import LOG router = APIRouter( prefix="/files", - responses={'404': {"description": "Unknown authorization endpoint"}}, + responses={"404": {"description": "Unknown authorization endpoint"}}, ) @router.get("/audio/{message_id}") -async def get_audio_message(request: Request, message_id: str,): - """ Gets file based on the name """ - matching_shouts = DbUtils.fetch_shouts(shout_ids=[message_id], fetch_senders=False) - if matching_shouts and matching_shouts[0].get('is_audio', '0') == '1': - LOG.info(f'Fetching audio for message_id={message_id}') - return get_file_response(matching_shouts[0]["message_text"], - location_prefix='audio', - media_type='audio/wav') +async def get_audio_message( + request: Request, + message_id: str, +): + """Gets file based on the name""" + matching_shout = MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id) + if matching_shout and matching_shout.get("is_audio", "0") == "1": + LOG.info(f"Fetching audio for message_id={message_id}") + return get_file_response( + matching_shout["message_text"], + location_prefix="audio", + media_type="audio/wav", + ) else: - return respond('Matching shout not found', 404) + return respond("Matching shout not found", 404) @router.get("/avatar/{user_id}") async def get_avatar(user_id: str): """ - Gets file from the server + Gets file from the server - :param user_id: target user id + :param user_id: target user id """ - LOG.debug(f'Getting avatar of user id: {user_id}') - user_data = db_controller.exec_query(query={'document': 'users', - 'command': 'find_one', - 'data': {'_id': user_id}}) or {} - if user_data.get('avatar', None): + LOG.debug(f"Getting avatar of user id: {user_id}") + user_data = MongoDocumentsAPI.USERS.get_user(user_id=user_id) or {} + if user_data.get("avatar", None): num_attempts = 0 try: - return get_file_response(filename=user_data['avatar'], location_prefix='avatars') + return get_file_response( + filename=user_data["avatar"], location_prefix="avatars" + ) except Exception as ex: - LOG.error(f'(attempt={num_attempts}) get_file_response(filename={user_data["avatar"]}, ' - f'location_prefix="avatars") failed with ex - {ex}') - return respond(f'Failed to get avatar of {user_id}', 404) + LOG.error( + f'(attempt={num_attempts}) get_file_response(filename={user_data["avatar"]}, ' + f'location_prefix="avatars") failed with ex - {ex}' + ) + return respond(f"Failed to get avatar of {user_id}", 404) @router.get("/{msg_id}/get_attachment/{filename}") # @login_required async def get_message_attachment(request: Request, msg_id: str, filename: str): """ - Gets file from the server + Gets file from the server - :param request: Starlette Request Object - :param msg_id: parent message id - :param filename: name of the file to get + :param request: Starlette Request Object + :param msg_id: parent message id + :param filename: name of the file to get """ - LOG.debug(f'{msg_id} - {filename}') - message_files = db_controller.exec_query(query={'document': 'shouts', - 'command': 'find_one', - 'data': {'_id': msg_id}}) - if message_files: - attachment_data = [attachment for attachment in message_files['attachments'] if attachment['name'] == filename][0] - media_type = attachment_data['mime'] - file_response = get_file_response(filename=filename, media_type=media_type, location_prefix='attachments') + LOG.debug(f"{msg_id} - {filename}") + shout_data = MongoDocumentsAPI.SHOUTS.get_item(item_id=msg_id) + if shout_data: + attachment_data = [ + attachment + for attachment in shout_data["attachments"] + if attachment["name"] == filename + ][0] + media_type = attachment_data["mime"] + file_response = get_file_response( + filename=filename, media_type=media_type, location_prefix="attachments" + ) if file_response is None: - return JSONResponse({'msg': 'Missing attachments in destination'}, 400) + return JSONResponse({"msg": "Missing attachments in destination"}, 400) return file_response else: - return JSONResponse({'msg': f'invalid message id: {msg_id}'}, 400) + return JSONResponse({"msg": f"invalid message id: {msg_id}"}, 400) @router.post("/attachments") @login_required async def save_attachments(request: Request, files: List[UploadFile] = File(...)): """ - Stores received files in filesystem + Stores received files in filesystem - :param request: Starlette Request Object - :param files: list of files to process + :param request: Starlette Request Object + :param files: list of files to process - :returns JSON-formatted response from server + :returns JSON-formatted response from server """ response = {} for file in files: name = file.filename - stored_location = await save_file(location_prefix='attachments', file=file) - LOG.info(f'Stored location for {file.filename} - {stored_location}') + stored_location = await save_file(location_prefix="attachments", file=file) + LOG.info(f"Stored location for {file.filename} - {stored_location}") response[name] = stored_location - return JSONResponse(content={'location_mapping': response}) + return JSONResponse(content={"location_mapping": response}) diff --git a/chat_server/blueprints/languages.py b/chat_server/blueprints/languages.py index 7bc43264..e560ebc0 100644 --- a/chat_server/blueprints/languages.py +++ b/chat_server/blueprints/languages.py @@ -33,15 +33,15 @@ router = APIRouter( prefix="/language_api", - responses={'404': {"description": "Unknown endpoint"}}, + responses={"404": {"description": "Unknown endpoint"}}, ) @router.get("/settings") async def list_language_settings(): """ - Returns language settings + Returns language settings - :returns JSON-formatted response from server + :returns JSON-formatted response from server """ - return JSONResponse(content={'supported_languages': LanguageSettings.list()}) + return JSONResponse(content={"supported_languages": LanguageSettings.list()}) diff --git a/chat_server/blueprints/personas.py b/chat_server/blueprints/personas.py new file mode 100644 index 00000000..ccc35941 --- /dev/null +++ b/chat_server/blueprints/personas.py @@ -0,0 +1,161 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from typing import Annotated + +from fastapi import APIRouter, Depends +from starlette.responses import JSONResponse + +from chat_server.server_utils.auth import is_authorized_for_user_id +from chat_server.server_utils.dependencies import CurrentUserDependency +from chat_server.server_utils.exceptions import ( + UserUnauthorizedException, + ItemNotFoundException, + DuplicatedItemException, +) +from chat_server.server_utils.http_utils import KlatAPIResponse +from chat_server.server_utils.models.personas import ( + AddPersonaModel, + DeletePersonaModel, + SetPersonaModel, + TogglePersonaStatusModel, + ListPersonasQueryModel, +) +from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI + +router = APIRouter( + prefix="/personas", + responses={"404": {"description": "Unknown endpoint"}}, +) + + +@router.get("/list") +async def list_personas( + current_user: CurrentUserDependency, model: ListPersonasQueryModel = Depends() +): + """Lists personas matching query params""" + filters = [] + if model.llms: + filters.append( + MongoFilter( + key="supported_llms", + value=model.llms, + logical_operator=MongoLogicalOperators.ALL, + ) + ) + if model.user_id: + if ( + model.user_id == "*" and "admin" not in current_user.roles + ) or not is_authorized_for_user_id(current_user, user_id=model.user_id): + raise UserUnauthorizedException + elif model.user_id != "*": + filters.append(MongoFilter(key="user_id", value=model.user_id)) + else: + user_filter = [{"user_id": None}, {"user_id": current_user.user_id}] + filters.append( + MongoFilter(value=user_filter, logical_operator=MongoLogicalOperators.OR) + ) + if model.only_enabled: + filters.append(MongoFilter(key="enabled", value=True)) + items = MongoDocumentsAPI.PERSONAS.list_items( + filters=filters, result_as_cursor=False + ) + for item in items: + item["id"] = item.pop("_id") + item["enabled"] = item.get("enabled", False) + return JSONResponse(content={"items": items}) + + +@router.get("/get/{persona_id}") +async def get_persona( + current_user: CurrentUserDependency, + persona_id: str, +): + """Gets persona details for a given persona_id""" + personas_tokens = persona_id.split("_") + if len(personas_tokens) >= 2: + persona_user_id = personas_tokens[1] + if not is_authorized_for_user_id(current_user, user_id=persona_user_id): + raise ItemNotFoundException + item = MongoDocumentsAPI.PERSONAS.get_item(item_id=persona_id) + if not item: + raise ItemNotFoundException + return JSONResponse(content=item) + + +@router.put("/add") +async def add_persona(current_user: CurrentUserDependency, model: AddPersonaModel): + """Adds new persona""" + if not is_authorized_for_user_id(current_user=current_user, user_id=model.user_id): + raise UserUnauthorizedException + existing_model = MongoDocumentsAPI.PERSONAS.get_item(item_id=model.persona_id) + if existing_model: + raise DuplicatedItemException + MongoDocumentsAPI.PERSONAS.add_item(data=model.model_dump()) + return KlatAPIResponse.OK + + +@router.post("/set") +async def set_persona(current_user: CurrentUserDependency, model: SetPersonaModel): + """Sets persona's data""" + if not is_authorized_for_user_id(current_user=current_user, user_id=model.user_id): + raise UserUnauthorizedException + existing_model = MongoDocumentsAPI.PERSONAS.get_item(item_id=model.persona_id) + if not existing_model: + raise ItemNotFoundException + mongo_filter = MongoFilter(key="_id", value=model.persona_id) + MongoDocumentsAPI.PERSONAS.update_item( + filters=mongo_filter, data=model.model_dump() + ) + return KlatAPIResponse.OK + + +@router.delete("/delete") +async def delete_persona( + current_user: CurrentUserDependency, model: DeletePersonaModel = Depends() +): + """Deletes persona""" + if not is_authorized_for_user_id(current_user=current_user, user_id=model.user_id): + raise UserUnauthorizedException + MongoDocumentsAPI.PERSONAS.delete_item(item_id=model.persona_id) + return KlatAPIResponse.OK + + +@router.post("/toggle") +async def toggle_persona_state( + current_user: CurrentUserDependency, model: TogglePersonaStatusModel +): + if not is_authorized_for_user_id(current_user=current_user, user_id=model.user_id): + raise UserUnauthorizedException + updated_data = MongoDocumentsAPI.PERSONAS.update_item( + filters=MongoFilter(key="_id", value=model.persona_id), + data={"enabled": model.enabled}, + ) + if updated_data.matched_count == 0: + raise ItemNotFoundException + return KlatAPIResponse.OK diff --git a/chat_server/blueprints/preferences.py b/chat_server/blueprints/preferences.py index d78575d3..cc4105de 100644 --- a/chat_server/blueprints/preferences.py +++ b/chat_server/blueprints/preferences.py @@ -27,27 +27,30 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from fastapi import APIRouter, Request, Form -from chat_server.server_config import db_controller from chat_server.server_utils.auth import get_current_user, login_required -from chat_server.server_utils.db_utils import DbUtils +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond from utils.logging_utils import LOG router = APIRouter( prefix="/preferences", - responses={'404': {"description": "Unknown user"}}, + responses={"404": {"description": "Unknown user"}}, ) @router.post("/update_language/{cid}/{input_type}") @login_required -async def update_language(request: Request, cid: str, input_type: str, lang: str = Form(...)): - """ Updates preferred language of user in conversation """ +async def update_language( + request: Request, cid: str, input_type: str, lang: str = Form(...) +): + """Updates preferred language of user in conversation""" try: - current_user_id = get_current_user(request)['_id'] + current_user_id = get_current_user(request)["_id"] except Exception as ex: LOG.error(ex) - return respond(f'Failed to update language of {cid}/{input_type} to {lang}') - DbUtils.set_user_preferences(user_id=current_user_id, - preferences_mapping={f'chat_language_mapping.{cid}.{input_type}': lang}) - return respond(f'Updated cid={cid}, input_type={input_type} to language={lang}') + return respond(f"Failed to update language of {cid}/{input_type} to {lang}") + MongoDocumentsAPI.USERS.set_preferences( + user_id=current_user_id, + preferences_mapping={f"chat_language_mapping.{cid}.{input_type}": lang}, + ) + return respond(f"Updated cid={cid}, input_type={input_type} to language={lang}") diff --git a/chat_server/blueprints/users.py b/chat_server/blueprints/users.py index 52505c62..4ebbe8c9 100644 --- a/chat_server/blueprints/users.py +++ b/chat_server/blueprints/users.py @@ -32,145 +32,164 @@ from fastapi.encoders import jsonable_encoder from chat_server.server_config import db_controller -from chat_server.server_utils.auth import get_current_user, check_password_strength, get_current_user_data, \ - login_required -from chat_server.server_utils.db_utils import DbUtils +from chat_server.server_utils.auth import ( + get_current_user, + check_password_strength, + get_current_user_data, + login_required, +) from chat_server.server_utils.http_utils import save_file from utils.common import get_hash +from utils.database_utils.mongo_utils import MongoFilter +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.http_utils import respond from utils.logging_utils import LOG router = APIRouter( prefix="/users_api", - responses={'404': {"description": "Unknown user"}}, + responses={"404": {"description": "Unknown user"}}, ) @router.get("/") -async def get_user(request: Request, - response: Response, - nano_token: str = None, - user_id: Optional[str] = None): +async def get_user( + request: Request, + nano_token: str = None, + user_id: Optional[str] = None, +): """ - Gets current user data from session cookies + Gets current user data from session cookies - :param request: active client session request - :param response: response object to be returned to user - :param nano_token: token from nano client (optional) - :param user_id: id of external user (optional, if not provided - current user is returned) + :param request: active client session request + :param nano_token: token from nano client (optional) + :param user_id: id of external user (optional, if not provided - current user is returned) - :returns JSON response containing data of current user + :returns JSON response containing data of current user """ - session_token = '' + session_token = "" if user_id: - user = db_controller.exec_query(query={'document': 'users', - 'command': 'find_one', - 'data': {'_id': user_id}}) - user.pop('password', None) - user.pop('date_created', None) - user.pop('tokens', None) - LOG.info(f'Fetched user data (id={user_id}): {user}') + user = MongoDocumentsAPI.USERS.get_user(user_id=user_id) + user.pop("password", None) + user.pop("date_created", None) + user.pop("tokens", None) + LOG.info(f"Fetched user data (id={user_id}): {user}") else: - current_user_data = get_current_user_data(request=request, nano_token=nano_token) + current_user_data = get_current_user_data( + request=request, nano_token=nano_token + ) user = current_user_data.user session_token = current_user_data.session if not user: - return respond('User not found', 404) + return respond("User not found", 404) return dict(data=user, token=session_token) -@router.get('/get_users') +@router.get("/get_users") # @login_required -async def fetch_received_user_ids(request: Request, user_ids: str = None, nicknames: str = None): +async def fetch_received_user_ids( + request: Request, user_ids: str = None, nicknames: str = None +): """ - Gets users data based on provided user ids + Gets users data based on provided user ids - :param request: Starlette Request Object - :param user_ids: list of provided user ids - :param nicknames: list of provided nicknames + :param request: Starlette Request Object + :param user_ids: list of provided user ids + :param nicknames: list of provided nicknames - :returns JSON response containing array of fetched user data + :returns JSON response containing array of fetched user data """ filter_data = {} if not any(x for x in (user_ids, nicknames)): - return respond('Either user_ids or nicknames should be provided', 422) + return respond("Either user_ids or nicknames should be provided", 422) if user_ids: - filter_data['_id'] = {'$in': user_ids.split(',')} + filter_data["_id"] = {"$in": user_ids.split(",")} if nicknames: - filter_data['nickname'] = {'$in': nicknames.split(',')} + filter_data["nickname"] = {"$in": nicknames.split(",")} - users = db_controller.exec_query(query={'document': 'users', - 'command': 'find', - 'data': filter_data}, - as_cursor=False) + users = MongoDocumentsAPI.USERS.list_items( + filters=filter_data, result_as_cursor=False + ) for user in users: - user.pop('password', None) - user.pop('is_tmp', None) - user.pop('tokens', None) - user.pop('date_created', None) + user.pop("password", None) + user.pop("is_tmp", None) + user.pop("tokens", None) + user.pop("date_created", None) - return JSONResponse(content={'users': jsonable_encoder(users)}) + return JSONResponse(content={"users": jsonable_encoder(users)}) @router.post("/update") @login_required -async def update_profile(request: Request, - user_id: str = Form(...), - first_name: str = Form(""), - last_name: str = Form(""), - bio: str = Form(""), - nickname: str = Form(""), - password: str = Form(""), - repeat_password: str = Form(""), - avatar: UploadFile = File(None), ): +async def update_profile( + request: Request, + user_id: str = Form(...), + first_name: str = Form(""), + last_name: str = Form(""), + bio: str = Form(""), + nickname: str = Form(""), + password: str = Form(""), + repeat_password: str = Form(""), + avatar: UploadFile = File(None), +): """ - Gets file from the server - - :param request: FastAPI Request Object - :param user_id: submitted user id - :param first_name: new first name value - :param last_name: new last name value - :param nickname: new nickname value - :param bio: updated user's bio - :param password: new password - :param repeat_password: repeat new password - :param avatar: new avatar image - - :returns status: 200 if data updated successfully, 403 if operation is on tmp user, 401 if something went wrong + Gets file from the server + + :param request: FastAPI Request Object + :param user_id: submitted user id + :param first_name: new first name value + :param last_name: new last name value + :param nickname: new nickname value + :param bio: updated user's bio + :param password: new password + :param repeat_password: repeat new password + :param avatar: new avatar image + + :returns status: 200 if data updated successfully, 403 if operation is on tmp user, 401 if something went wrong """ user = get_current_user(request=request) - if user.get('is_tmp'): - return respond(msg=f"Unable to update data of 'tmp' user", status_code=status.HTTP_403_FORBIDDEN) - update_dict = {'first_name': first_name, - 'last_name': last_name, - 'bio': bio, - 'nickname': nickname} + if user.get("is_tmp"): + return respond( + msg=f"Unable to update data of 'tmp' user", + status_code=status.HTTP_403_FORBIDDEN, + ) + update_dict = { + "first_name": first_name, + "last_name": last_name, + "bio": bio, + "nickname": nickname, + } if password: if password != repeat_password: - return respond(msg='Passwords do not match', status_code=status.HTTP_401_UNAUTHORIZED) + return respond( + msg="Passwords do not match", status_code=status.HTTP_401_UNAUTHORIZED + ) password_check = check_password_strength(password) - if password_check != 'OK': + if password_check != "OK": return respond(msg=password_check, status_code=status.HTTP_401_UNAUTHORIZED) - update_dict['password'] = get_hash(password) + update_dict["password"] = get_hash(password) if avatar: - update_dict['avatar'] = await save_file(location_prefix='avatars', file=avatar) + update_dict["avatar"] = await save_file(location_prefix="avatars", file=avatar) try: - filter_expression = {'_id': user['_id']} - update_expression = {'$set': {k: v for k, v in update_dict.items() if v}} - db_controller.exec_query(query={'document': 'users', - 'command': 'update', - 'data': (filter_expression, - update_expression,)}) + filter_expression = MongoFilter(key="_id", value=user_id) + update_dict = {k: v for k, v in update_dict.items() if v} + MongoDocumentsAPI.USERS.update_item( + filters=(filter_expression,), data=update_dict + ) return respond(msg="OK") except Exception as ex: LOG.error(ex) - return respond(msg='Unable to update user data at the moment', status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) + return respond( + msg="Unable to update user data at the moment", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + ) @router.post("/settings/update") @login_required -async def update_settings(request: Request, - minify_messages: str = Form("0"),): +async def update_settings( + request: Request, + minify_messages: str = Form("0"), +): """ Updates user settings with provided form data :param request: FastAPI Request Object @@ -178,8 +197,8 @@ async def update_settings(request: Request, :return: status 200 if OK, error code otherwise """ user = get_current_user(request=request) - preferences_mapping = { - 'minify_messages': minify_messages - } - DbUtils.set_user_preferences(user_id=user['_id'], preferences_mapping=preferences_mapping) - return respond(msg='OK') + preferences_mapping = {"minify_messages": minify_messages} + MongoDocumentsAPI.USERS.set_preferences( + user_id=user["_id"], preferences_mapping=preferences_mapping + ) + return respond(msg="OK") diff --git a/chat_server/constants/conversations.py b/chat_server/constants/conversations.py index 9847cf86..19de3332 100644 --- a/chat_server/constants/conversations.py +++ b/chat_server/constants/conversations.py @@ -26,7 +26,9 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + class ConversationSkins: - """ List of supported conversation skins """ - BASE = 'base' - PROMPTS = 'prompts' + """List of supported conversation skins""" + + BASE = "base" + PROMPTS = "prompts" diff --git a/chat_server/constants/users.py b/chat_server/constants/users.py index b3f70bd5..e4ee78f6 100644 --- a/chat_server/constants/users.py +++ b/chat_server/constants/users.py @@ -32,33 +32,28 @@ class UserPatterns(Enum): """Collection of user patterns used for commonly in conversations""" + UNRECOGNIZED_USER = { - 'first_name': 'Deleted', - 'last_name': 'User', - 'nickname': 'deleted_user' - } - GUEST = { - 'first_name': 'Klat', - 'last_name': 'Guest' + "first_name": "Deleted", + "last_name": "User", + "nickname": "deleted_user", } + GUEST = {"first_name": "Klat", "last_name": "Guest"} NEON = { - 'first_name': 'Neon', - 'last_name': 'AI', - 'nickname': 'neon', - 'avatar': 'neon.webp' - } - GUEST_NANO = { - 'first_name': 'Nano', - 'last_name': 'Guest', - 'tokens': [] + "first_name": "Neon", + "last_name": "AI", + "nickname": "neon", + "avatar": "neon.webp", } + GUEST_NANO = {"first_name": "Nano", "last_name": "Guest", "tokens": []} class ChatPatterns(Enum): """Collection of chat patterns used for create conversations""" + TEST_CHAT = { - "_id": '-1', + "_id": "-1", "conversation_name": "test", "is_private": False, - "created_on": int(time.time()) + "created_on": int(time.time()), } diff --git a/chat_server/server_config.py b/chat_server/server_config.py index 42e74618..57dd1750 100644 --- a/chat_server/server_config.py +++ b/chat_server/server_config.py @@ -27,22 +27,76 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os +from typing import Optional + from config import Configuration from chat_server.server_utils.sftp_utils import init_sftp_connector +from chat_server.server_utils.rmq_utils import RabbitMQAPI + from utils.logging_utils import LOG +from utils.database_utils import DatabaseController +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI + +server_config_path = os.path.expanduser( + os.environ.get("CHATSERVER_CONFIG", "~/.local/share/neon/credentials.json") +) +database_config_path = os.path.expanduser( + os.environ.get("DATABASE_CONFIG", "~/.local/share/neon/credentials.json") +) + + +def _init_db_controller(db_config: dict) -> Optional[DatabaseController]: + + # Determine configured database dialect + dialect = db_config.pop("dialect", "mongo") + + try: + # Create a database connection + db_controller = DatabaseController(config_data=db_config) + db_controller.attach_connector(dialect=dialect) + db_controller.connect() + return db_controller + except Exception as e: + LOG.exception(f"DatabaseController init failed: {e}") + return None + -server_config_path = os.environ.get('CHATSERVER_CONFIG', '~/.local/share/neon/credentials.json') -database_config_path = os.environ.get('DATABASE_CONFIG', '~/.local/share/neon/credentials.json') +if os.path.isfile(server_config_path) or os.path.isfile(database_config_path): + LOG.warning(f"Using legacy configuration at {server_config_path}") + LOG.warning(f"Using legacy configuration at {database_config_path}") + LOG.info(f"KLAT_ENV : {Configuration.KLAT_ENV}") + config = Configuration(from_files=[server_config_path, database_config_path]) + app_config = config.get("CHAT_SERVER", {}).get(Configuration.KLAT_ENV, {}) + db_controller = config.get_db_controller(name="pyklatchat_3333") +else: + # ovos-config has built-in mechanisms for loading configuration files based + # on envvars, so the configuration structure is simplified + from ovos_config.config import Configuration -LOG.info(f'KLAT_ENV : {Configuration.KLAT_ENV}') + config = Configuration() + app_config = config.get("CHAT_SERVER") or dict() + env_spec = os.environ.get("KLAT_ENV") + if env_spec and app_config.get(env_spec): + LOG.warning("Legacy configuration handling KLAT_ENV envvar") + app_config = app_config.get(env_spec) + db_controller = _init_db_controller( + app_config.get("connection_properties", config.get("DATABASE_CONFIG", {})) + ) -config = Configuration(from_files=[server_config_path, database_config_path]) +LOG.info(f"App config: {app_config}") -app_config = config.get('CHAT_SERVER', {}).get(Configuration.KLAT_ENV, {}) +sftp_connector = init_sftp_connector(config=app_config.get("SFTP", {})) -LOG.info(f'App config: {app_config}') +MongoDocumentsAPI.init(db_controller=db_controller, sftp_connector=sftp_connector) -db_controller = config.get_db_controller(name='pyklatchat_3333') +mq_api = None +mq_management_config = config.get("MQ_MANAGEMENT", {}) +if mq_management_url := mq_management_config.get("MQ_MANAGEMENT_URL"): + mq_api = RabbitMQAPI(url=mq_management_url) + mq_api.login( + username=mq_management_config["MQ_MANAGEMENT_LOGIN"], + password=mq_management_config["MQ_MANAGEMENT_PASSWORD"], + ) -sftp_connector = init_sftp_connector(config=app_config.get('SFTP', {})) +k8s_config = config.get("K8S_CONFIG", {}) diff --git a/chat_server/server_utils/admin_utils.py b/chat_server/server_utils/admin_utils.py new file mode 100644 index 00000000..937d00b9 --- /dev/null +++ b/chat_server/server_utils/admin_utils.py @@ -0,0 +1,52 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from chat_server.server_config import mq_api, mq_management_config, LOG + + +def run_mq_validation(): + if mq_api: + for vhost in mq_management_config.get("VHOSTS", []): + status = mq_api.add_vhost(vhost=vhost["name"]) + if not status.ok: + raise ConnectionError(f'Failed to add {vhost["name"]}, {status=}') + for user_creds in mq_management_config.get("USERS", []): + mq_api.add_user( + user=user_creds["name"], + password=user_creds["password"], + tags=user_creds.get("tags", ""), + ) + for user_vhost_permissions in mq_management_config.get( + "USER_VHOST_PERMISSIONS", [] + ): + mq_api.configure_vhost_user_permissions(**user_vhost_permissions) + else: + LOG.error("MQ API is unavailable") + + +if __name__ == "__main__": + run_mq_validation() diff --git a/chat_server/server_utils/auth.py b/chat_server/server_utils/auth.py index 1196d273..de56d06b 100644 --- a/chat_server/server_utils/auth.py +++ b/chat_server/server_utils/auth.py @@ -33,68 +33,69 @@ from time import time from fastapi import Request + +from chat_server.server_utils.models.users import CurrentUserModel +from utils.database_utils.mongo_utils import MongoFilter, MongoLogicalOperators +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.logging_utils import LOG -from chat_server.constants.users import UserPatterns -from chat_server.server_config import db_controller, app_config -from chat_server.server_utils.db_utils import DbUtils -from utils.common import generate_uuid +from chat_server.server_config import config, app_config from utils.http_utils import respond -cookies_config = app_config.get('COOKIES', {}) - -secret_key = cookies_config.get('SECRET', None) - -session_lifetime = int(cookies_config.get('LIFETIME', 60 * 60)) -session_refresh_rate = int(cookies_config.get('REFRESH_RATE', 5 * 60)) +cookies_config = app_config.get("COOKIES", {}) -jwt_encryption_algo = cookies_config.get('JWT_ALGO', 'HS256') - -AUTHORIZATION_HEADER = 'Authorization' +secret_key = cookies_config.get("SECRET", None) +session_lifetime = int(cookies_config.get("LIFETIME", 60 * 60)) +session_refresh_rate = int(cookies_config.get("REFRESH_RATE", 5 * 60)) +jwt_encryption_algo = cookies_config.get("JWT_ALGO", "HS256") +AUTHORIZATION_HEADER = "Authorization" @dataclass class UserData: - """ Dataclass wrapping user data """ + """Dataclass wrapping user data""" + user: dict session: str def check_password_strength(password: str) -> str: """ - Checks if input string is a strong password + Checks if input string is a strong password - :param password: input string + :param password: input string - :returns: 'OK' if input string is strong enough, unfulfilled condition otherwise + :returns: 'OK' if input string is strong enough, unfulfilled condition otherwise """ if len(password) < 8: - return 'Password should be longer than 8 symbols' + return "Password should be longer than 8 symbols" else: - return 'OK' + return "OK" def get_cookie_from_request(request: Request, cookie_name: str) -> Optional[str]: """ - Gets cookie from response by its name + Gets cookie from response by its name - :param request: Starlet request object - :param cookie_name: name of the desired cookie + :param request: Starlet request object + :param cookie_name: name of the desired cookie - :returns value of cookie if present + :returns value of cookie if present """ return request.cookies.get(cookie_name) -def get_header_from_request(request: Union[Request, str], header_name: str, sio_request: bool = False) -> Optional[str]: +def get_header_from_request( + request: Union[Request, str], header_name: str, sio_request: bool = False +) -> Optional[str]: """ - Gets header value from response by its name + Gets header value from response by its name - :param request: Starlet request object - :param header_name: name of the desired cookie - :param sio_request: is request from Socket IO service endpoint (defaults to False) + :param request: Starlet request object + :param header_name: name of the desired cookie + :param sio_request: is request from Socket IO service endpoint (defaults to False) - :returns value of cookie if present + :returns value of cookie if present """ if sio_request: return request @@ -104,134 +105,174 @@ def get_header_from_request(request: Union[Request, str], header_name: str, sio_ def generate_session_token(user_id) -> str: """ - Generates JWT token based on the user id - :returns generate JWT token string + Generates JWT token based on the user id + :returns generate JWT token string """ - return jwt.encode(payload={"sub": user_id, - 'creation_time': int(time()), - 'last_refresh_time': int(time())}, - key=secret_key, - algorithm=jwt_encryption_algo) - - -def create_unauthorized_user(authorize: bool = True, nano_token: str = None) -> UserData: + return jwt.encode( + payload={ + "sub": user_id, + "creation_time": int(time()), + "last_refresh_time": int(time()), + }, + key=secret_key, + algorithm=jwt_encryption_algo, + ) + + +def create_unauthorized_user( + authorize: bool = True, nano_token: str = None +) -> UserData: """ - Creates unauthorized user and sets its credentials to cookies + Creates unauthorized user and sets its credentials to cookies - :param authorize: to authorize new user - :param nano_token: nano token to append to user on creation + :param authorize: to authorize new user + :param nano_token: nano token to append to user on creation - :returns: generated UserData + :returns: generated UserData """ - from chat_server.server_utils.user_utils import create_from_pattern - - guest_nickname = f'guest_{generate_uuid(length=8)}' - - if nano_token: - new_user = create_from_pattern(source=UserPatterns.GUEST_NANO, - override_defaults=dict(nickname=guest_nickname, - tokens=[nano_token])) - else: - new_user = create_from_pattern(source=UserPatterns.GUEST, - override_defaults=dict(nickname=guest_nickname)) - db_controller.exec_query(query={'document': 'users', - 'command': 'insert_one', - 'data': new_user}) - token = generate_session_token(user_id=new_user['_id']) if authorize else '' + new_user = MongoDocumentsAPI.USERS.create_guest(nano_token=nano_token) + token = "" + if authorize: + token = generate_session_token(user_id=new_user["_id"]) + LOG.debug(f"Created new user with name {new_user['nickname']}") return UserData(user=new_user, session=token) -def get_current_user_data(request: Request, - force_tmp: bool = False, - nano_token: str = None, - sio_request: bool = False) -> UserData: +def get_current_user_data( + request: Request, + force_tmp: bool = False, + nano_token: str = None, + sio_request: bool = False, +) -> UserData: """ - Gets current user according to response cookies + Gets current user according to response cookies - :param request: Starlet request object - :param force_tmp: to force setting temporal credentials - :param nano_token: token from nano client (optional) - :param sio_request: if request is from Socket IO server + :param request: Starlet request object + :param force_tmp: to force setting temporal credentials + :param nano_token: token from nano client (optional) + :param sio_request: if request is from Socket IO server - :returns UserData based on received authorization header or sets temporal user credentials if not found + :returns UserData based on received authorization header or sets temporal user credentials if not found """ - user_id = None - user_data = {} + user_data: UserData = None if not force_tmp: if nano_token: - user = db_controller.exec_query(query={'command': 'find_one', - 'document': 'users', - 'data': {'tokens': {'$all': [nano_token]}}}) + user = MongoDocumentsAPI.USERS.get_item( + filters=MongoFilter( + key="tokens", + value=[nano_token], + logical_operator=MongoLogicalOperators.ALL, + ) + ) if not user: - LOG.info('Creating new user for nano agent') - user_data = create_unauthorized_user(nano_token=nano_token, authorize=False) + LOG.info("Creating new user for nano agent") + user_data = create_unauthorized_user( + nano_token=nano_token, authorize=False + ) else: try: - session = get_header_from_request(request, AUTHORIZATION_HEADER, sio_request) + session = get_header_from_request( + request, AUTHORIZATION_HEADER, sio_request + ) if session: - payload = jwt.decode(jwt=session, key=secret_key, algorithms=jwt_encryption_algo) + payload = jwt.decode( + jwt=session, key=secret_key, algorithms=jwt_encryption_algo + ) current_timestamp = time() - if (int(current_timestamp) - int(payload.get('creation_time', 0))) <= session_lifetime: - user_id = payload['sub'] - user = DbUtils.get_user(user_id=user_id) - LOG.info(f'Fetched user data: {user}') - user['preferences'] = DbUtils.get_user_preferences(user_id=user_id) - LOG.info(f'Fetched user preferences data: {user["preferences"]}') + if ( + int(current_timestamp) - int(payload.get("creation_time", 0)) + ) <= session_lifetime: + user_id = payload["sub"] + user = MongoDocumentsAPI.USERS.get_user(user_id=user_id) + LOG.info(f"Fetched user data for nickname = {user['nickname']}") if not user: - LOG.info(f'{payload["sub"]} is not found among users, setting temporal user credentials') + LOG.info( + f'{payload["sub"]} is not found among users, setting temporal user credentials' + ) else: - if (int(current_timestamp) - int(payload.get('last_refresh_time', 0))) >= session_refresh_rate: + if ( + int(current_timestamp) + - int(payload.get("last_refresh_time", 0)) + ) >= session_refresh_rate: session = refresh_session(payload=payload) - LOG.info('Session was refreshed') - user_data = UserData(user=user, session=session) + LOG.info("Session was refreshed") + user_data = UserData(user=user, session=session) except BaseException as ex: - LOG.warning(f'Problem resolving current user: {ex}, setting tmp user credentials') - if not user_id or force_tmp: + LOG.exception( + f"Problem resolving current user: {ex}\n" + f"setting tmp user credentials" + ) + if not user_data: + LOG.debug("Creating temp user") user_data = create_unauthorized_user() - user_data.user.pop('password', None) - user_data.user.pop('date_created', None) - user_data.user.pop('tokens', None) + LOG.debug(f"Resolved user: {user_data}") + user_data.user.pop("password", None) + user_data.user.pop("date_created", None) + user_data.user.pop("tokens", None) return user_data -def get_current_user(request: Request, force_tmp: bool = False, nano_token: str = None) -> dict: - """ Backward compatibility method to support previous invocations """ - return get_current_user_data(request=request, force_tmp=force_tmp, nano_token=nano_token).user +def get_current_user( + request: Request, force_tmp: bool = False, nano_token: str = None +) -> dict: + """Backward compatibility method to support previous invocations""" + return get_current_user_data( + request=request, force_tmp=force_tmp, nano_token=nano_token + ).user def refresh_session(payload: dict): """ - Refreshes session token + Refreshes session token - :param payload: dictionary with decoded token params + :param payload: dictionary with decoded token params """ - session = jwt.encode({"sub": payload['sub'], - 'creation_time': payload['creation_time'], - 'last_refresh_time': time()}, secret_key) + session = jwt.encode( + { + "sub": payload["sub"], + "creation_time": payload["creation_time"], + "last_refresh_time": time(), + }, + secret_key, + ) return session -def validate_session(request: Union[str, Request], check_tmp: bool = False, sio_request: bool = False) -> Tuple[str, int]: +def validate_session( + request: Union[str, Request], + check_tmp: bool = False, + required_roles: list = None, + sio_request: bool = False, +) -> Tuple[str, int]: """ - Check if session token contained in request is valid - :returns validation output + Check if session token contained in request is valid + :returns validation output """ session = get_header_from_request(request, AUTHORIZATION_HEADER, sio_request) if session: - payload = jwt.decode(jwt=session, key=secret_key, algorithms=jwt_encryption_algo) - if check_tmp: - from chat_server.server_utils.db_utils import DbUtils - user = DbUtils.get_user(user_id=payload['sub']) - if user.get('is_tmp'): - return 'Permission denied', 403 - if (int(time()) - int(payload.get('creation_time', 0))) <= session_lifetime: - return 'OK', 200 - return 'Session Expired', 401 + payload = jwt.decode( + jwt=session, key=secret_key, algorithms=jwt_encryption_algo + ) + should_check_user_data = check_tmp or required_roles + is_authorized = True + if should_check_user_data: + user = MongoDocumentsAPI.USERS.get_user(user_id=payload["sub"]) + if check_tmp and user.get("is_tmp"): + is_authorized = False + elif required_roles and not any( + user_role in required_roles for user_role in user.get("roles", []) + ): + is_authorized = False + if not is_authorized: + return "Permission denied", 403 + if (int(time()) - int(payload.get("creation_time", 0))) <= session_lifetime: + return "OK", 200 + return "Session Expired", 401 def login_required(*outer_args, **outer_kwargs): """ - Decorator that validates current authorization token + Decorator that validates current authorization token """ no_args = False @@ -241,14 +282,19 @@ def login_required(*outer_args, **outer_kwargs): no_args = True func = outer_args[0] - outer_kwargs.setdefault('tmp_allowed', True) + outer_kwargs.setdefault("tmp_allowed", True) def outer(func): - @wraps(func) async def wrapper(request: Request, *args, **kwargs): - session_validation_output = validate_session(request, check_tmp=not outer_kwargs.get('tmp_allowed')) - LOG.debug(f'(url={request.url}) Received session validation output: {session_validation_output}') + session_validation_output = validate_session( + request, + check_tmp=not outer_kwargs.get("tmp_allowed"), + required_roles=outer_kwargs.get("required_roles"), + ) + LOG.debug( + f"(url={request.url}) Received session validation output: {session_validation_output}" + ) if session_validation_output[1] != 200: return respond(*session_validation_output) return await func(request, *args, **kwargs) @@ -260,3 +306,23 @@ async def wrapper(request: Request, *args, **kwargs): else: return outer + +def is_authorized_for_user_id(current_user: CurrentUserModel, user_id: str) -> bool: + """ + Checks if provided to current user model and is authorized to perform actions on behalf of the target user data + :param current_user: current user model created from request + :param user_id: target user id to check authority on + :return: True if authorized, False otherwise + """ + return current_user.user_id == user_id or "admin" in current_user.roles + + +def get_current_user_model(request: Request) -> CurrentUserModel: + """ + Get current user from request objects and returns it as a CurrentUserModel instance + :param request: Starlette request object to process + :return: CurrentUserModel instance + :raises ValidationError: if pydantic validation failed for provided request + """ + current_user = get_current_user(request=request) + return CurrentUserModel.model_validate(current_user, strict=True) diff --git a/chat_server/server_utils/cache_utils.py b/chat_server/server_utils/cache_utils.py index 60d25238..7acff20d 100644 --- a/chat_server/server_utils/cache_utils.py +++ b/chat_server/server_utils/cache_utils.py @@ -30,22 +30,23 @@ class CacheFactory: - """ Cache creation factory """ + """Cache creation factory""" __active_caches = {} @classmethod def get(cls, name: str, cache_type: Type = None, **kwargs): """ - Get cache instance based on name and type + Get cache instance based on name and type - :param name: name of the cache to retrieve - :param cache_type: type of the cache to create if not found - :param kwargs: keyword args to provide along with cache instance creation + :param name: name of the cache to retrieve + :param cache_type: type of the cache to create if not found + :param kwargs: keyword args to provide along with cache instance creation """ if not cls.__active_caches.get(name): if cache_type: + kwargs.setdefault("maxsize", 124) cls.__active_caches[name] = cache_type(**kwargs) else: - raise KeyError(f'Missing cache instance under {name}') + raise KeyError(f"Missing cache instance under {name}") return cls.__active_caches[name] diff --git a/chat_server/server_utils/conversation_utils.py b/chat_server/server_utils/conversation_utils.py index 478f67b1..8a91c7d8 100644 --- a/chat_server/server_utils/conversation_utils.py +++ b/chat_server/server_utils/conversation_utils.py @@ -30,26 +30,30 @@ from utils.logging_utils import LOG -def build_message_json(raw_message: dict, skin: ConversationSkins = ConversationSkins.BASE) -> dict: - """ Builds user message json based on provided conversation skin """ - if raw_message['message_type'] == 'plain': - message = {'user_id': raw_message['user_id'], - 'created_on': int(raw_message['created_on']), - 'message_id': raw_message['message_id'], - 'message_text': raw_message['message_text'], - 'message_type': raw_message['message_type'], - 'is_audio': raw_message.get('is_audio', '0'), - 'is_announcement': raw_message.get('is_announcement', '0'), - 'replied_message': raw_message.get('replied_message', ''), - 'attachments': raw_message.get('attachments', []), - 'user_first_name': raw_message['first_name'], - 'user_last_name': raw_message['last_name'], - 'user_nickname': raw_message['nickname'], - 'user_is_bot': raw_message.get('is_bot', '0'), - 'user_avatar': raw_message.get('avatar', '')} - elif raw_message['message_type'] == 'prompt': +def build_message_json( + raw_message: dict, skin: ConversationSkins = ConversationSkins.BASE +) -> dict: + """Builds user message json based on provided conversation skin""" + if raw_message["message_type"] == "plain": + message = { + "user_id": raw_message["user_id"], + "created_on": int(raw_message["created_on"]), + "message_id": raw_message["message_id"], + "message_text": raw_message["message_text"], + "message_type": raw_message["message_type"], + "is_audio": raw_message.get("is_audio", "0"), + "is_announcement": raw_message.get("is_announcement", "0"), + "replied_message": raw_message.get("replied_message", ""), + "attachments": raw_message.get("attachments", []), + "user_first_name": raw_message["first_name"], + "user_last_name": raw_message["last_name"], + "user_nickname": raw_message["nickname"], + "user_is_bot": raw_message.get("is_bot", "0"), + "user_avatar": raw_message.get("avatar", ""), + } + elif raw_message["message_type"] == "prompt": return raw_message else: - LOG.error(f'Undefined skin = {skin}') + LOG.error(f"Undefined skin = {skin}") message = {} return message diff --git a/chat_server/server_utils/db_utils.py b/chat_server/server_utils/db_utils.py deleted file mode 100644 index 437d9df7..00000000 --- a/chat_server/server_utils/db_utils.py +++ /dev/null @@ -1,499 +0,0 @@ -# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework -# All trademark and other rights reserved by their respective owners -# Copyright 2008-2022 Neongecko.com Inc. -# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, -# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo -# BSD-3 License -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# 1. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, -# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from typing import List, Tuple, Union, Dict - -import pymongo -from bson import ObjectId -from pymongo import UpdateOne - -from chat_server.constants.conversations import ConversationSkins -from chat_server.constants.users import UserPatterns -from chat_server.server_utils.factory_utils import Singleton -from chat_server.server_utils.user_utils import create_from_pattern -from utils.common import buffer_to_base64 -from utils.database_utils.mongo_utils import * -from utils.logging_utils import LOG - - -class DbUtils(metaclass=Singleton): - """ Singleton DB Utils class for convenience""" - - db_controller = None - - @classmethod - def init(cls, db_controller): - """ Inits Singleton with specified database controller """ - cls.db_controller = db_controller - - @classmethod - def get_user(cls, user_id=None, nickname=None) -> Union[dict, None]: - """ - Gets user data based on provided params - :param user_id: target user id - :param nickname: target user nickname - """ - if not any(x for x in (user_id, nickname,)): - LOG.warning('Neither user_id nor nickname was provided') - return - filter_data = {} - if user_id: - filter_data['_id'] = user_id - if nickname: - filter_data['nickname'] = nickname - return cls.db_controller.exec_query(MongoQuery(command=MongoCommands.FIND_ONE, - document=MongoDocuments.USERS, - filters=filter_data)) - - @classmethod - def list_items(cls, document: MongoDocuments, source_set: list, key: str = 'id', value_keys: list = None) -> dict: - """ - Lists items under provided document belonging to source set of provided column values - - :param document: source document to query - :param key: document's key to check - :param source_set: list of :param key values to check - :param value_keys: list of value keys to return - :returns results aggregated by :param column value - """ - if not value_keys: - value_keys = [] - if key == 'id': - key = '_id' - aggregated_data = {} - if source_set: - source_set = list(set(source_set)) - items = cls.db_controller.exec_query(MongoQuery(command=MongoCommands.FIND_ALL, - document=document, - filters=MongoFilter(key=key, - value=source_set, - logical_operator=MongoLogicalOperators.IN))) - for item in items: - items_key = item.pop(key, None) - if items_key: - aggregated_data.setdefault(items_key, []).append({k: v for k, v in item.items() if k in value_keys - or not value_keys}) - return aggregated_data - - @classmethod - def get_conversation_data(cls, search_str: Union[list, str], column_identifiers: List[str] = None) -> Union[None, - dict]: - """ - Gets matching conversation data - :param search_str: search string to lookup - :param column_identifiers: desired column identifiers to lookup - """ - if isinstance(search_str, str): - search_str = [search_str] - if not column_identifiers: - column_identifiers = ['_id', 'conversation_name'] - or_expression = [] - for _keyword in [item for item in search_str if item is not None]: - for identifier in column_identifiers: - if identifier == '_id' and isinstance(_keyword, str): - try: - or_expression.append({identifier: ObjectId(_keyword)}) - except: - pass - or_expression.append({identifier: _keyword}) - - conversation_data = cls.db_controller.exec_query(MongoQuery(command=MongoCommands.FIND_ONE, - document=MongoDocuments.CHATS, - filters=MongoFilter(value=or_expression, - logical_operator=MongoLogicalOperators.OR))) - if not conversation_data: - return - conversation_data['_id'] = str(conversation_data['_id']) - return conversation_data - - @classmethod - def fetch_shout_data(cls, conversation_data: dict, start_idx: int = 0, limit: int = 100, - fetch_senders: bool = True, id_from: str = None, - shout_ids: List[str] = None) -> List[dict]: - """ - Fetches shout data out of conversation data - - :param conversation_data: input conversation data - :param start_idx: message index to start from (sorted by recency) - :param limit: number of shouts to fetch - :param fetch_senders: to fetch shout senders data - :param id_from: message id to start from - :param shout_ids: list of shout ids to fetch - """ - if not shout_ids and conversation_data.get('chat_flow', None): - if id_from: - try: - start_idx = len(conversation_data["chat_flow"]) - \ - conversation_data["chat_flow"].index(id_from) - except ValueError: - LOG.warning('Matching start message id not found') - return [] - if start_idx == 0: - conversation_data['chat_flow'] = conversation_data['chat_flow'][start_idx - limit:] - else: - conversation_data['chat_flow'] = conversation_data['chat_flow'][-start_idx - limit: - -start_idx] - shout_ids = [str(msg_id) for msg_id in conversation_data["chat_flow"]] - shouts_data = cls.fetch_shouts(shout_ids=shout_ids, fetch_senders=fetch_senders) - return sorted(shouts_data, key=lambda user_shout: int(user_shout['created_on'])) - - @classmethod - def fetch_users_from_prompt(cls, prompt: dict): - """ Fetches user ids detected in provided prompt """ - prompt_data = prompt['data'] - user_ids = prompt_data.get('participating_subminds', []) - return cls.list_items(document=MongoDocuments.USERS, source_set=user_ids, value_keys=['first_name', - 'last_name', - 'nickname', - 'is_bot', - 'avatar']) - - @classmethod - def fetch_messages_from_prompt(cls, prompt: dict): - """ Fetches message ids detected in provided prompt """ - prompt_data = prompt['data'] - message_ids = [] - for column in ('proposed_responses', 'submind_opinions', 'votes',): - message_ids.extend(list(prompt_data.get(column, {}).values())) - return cls.list_items(document=MongoDocuments.SHOUTS, source_set=message_ids) - - @classmethod - def fetch_prompt_data(cls, cid: str, limit: int = 100, id_from: str = None, - prompt_ids: List[str] = None, fetch_user_data: bool = False, - created_from: int = None) -> List[dict]: - """ - Fetches prompt data out of conversation data - - :param cid: target conversation id - :param limit: number of prompts to fetch - :param id_from: prompt id to start from - :param prompt_ids: prompt ids to fetch - :param fetch_user_data: to fetch user data in the - :param created_from: timestamp to filter messages from - - :returns list of matching prompt data along with matching messages and users - """ - filters = [MongoFilter('cid', cid)] - if id_from: - checkpoint_prompt = cls.db_controller.exec_query(MongoQuery(document=MongoDocuments.PROMPTS, - command=MongoCommands.FIND_ONE, - filters=MongoFilter('_id', id_from))) - if checkpoint_prompt: - filters.append(MongoFilter('created_on', checkpoint_prompt['created_on'], MongoLogicalOperators.LT)) - if prompt_ids: - if isinstance(prompt_ids, str): - prompt_ids = [prompt_ids] - filters.append(MongoFilter('_id', prompt_ids, MongoLogicalOperators.IN)) - if created_from: - filters.append(MongoFilter('created_on', created_from, MongoLogicalOperators.GT)) - matching_prompts = cls.db_controller.exec_query(query=MongoQuery(document=MongoDocuments.PROMPTS, - command=MongoCommands.FIND_ALL, - filters=filters, - result_filters={'sort': [('created_on', - pymongo.DESCENDING)], - 'limit': limit}), - as_cursor=False) - for prompt in matching_prompts: - prompt['user_mapping'] = cls.fetch_users_from_prompt(prompt) - prompt['message_mapping'] = cls.fetch_messages_from_prompt(prompt) - if fetch_user_data: - for user in prompt.get('data', {}).get('participating_subminds', []): - try: - nick = prompt['user_mapping'][user][0]['nickname'] - except KeyError: - LOG.warning(f'user_id - "{user}" was not detected setting it as nick') - nick = user - for k in ('proposed_responses', 'submind_opinions', 'votes',): - msg_id = prompt['data'][k].pop(user, '') - if msg_id: - prompt['data'][k][nick] = prompt['message_mapping'].get(msg_id, [{}])[0].get('message_text') or msg_id - prompt['data']['participating_subminds'] = [prompt['user_mapping'][x][0]['nickname'] - for x in prompt['data']['participating_subminds']] - return sorted(matching_prompts, key=lambda _prompt: int(_prompt['created_on'])) - - @classmethod - def fetch_skin_message_data(cls, skin: ConversationSkins, conversation_data: dict, start_idx: int = 0, - limit: int = 100, - fetch_senders: bool = True, start_message_id: str = None): - """ Fetches message data based on provided conversation skin """ - message_data = cls.fetch_shout_data(conversation_data=conversation_data, - fetch_senders=fetch_senders, - start_idx=start_idx, - id_from=start_message_id, - limit=limit) - for message in message_data: - message['message_type'] = 'plain' - if skin == ConversationSkins.PROMPTS: - detected_prompts = list(set(item.get('prompt_id') for item in message_data if item.get('prompt_id'))) - prompt_data = cls.fetch_prompt_data(cid=conversation_data['_id'], - prompt_ids=detected_prompts) - if prompt_data: - detected_prompt_ids = [] - for prompt in prompt_data: - prompt['message_type'] = 'prompt' - detected_prompt_ids.append(prompt['_id']) - message_data = [message for message in message_data if message.get('prompt_id') not in detected_prompt_ids] - message_data.extend(prompt_data) - return sorted(message_data, key=lambda shout: int(shout['created_on'])) - - @classmethod - def fetch_shouts(cls, shout_ids: List[str] = None, fetch_senders: bool = True) -> List[dict]: - """ - Fetches shout data from provided shouts list - :param shout_ids: list of shout ids to fetch - :param fetch_senders: to fetch shout senders data - - :returns Data from requested shout ids along with matching user data - """ - if not shout_ids: - return [] - shouts = cls.db_controller.exec_query(query=MongoQuery(command=MongoCommands.FIND_ALL, - document=MongoDocuments.SHOUTS, - filters=MongoFilter('_id', list(set(shout_ids)), - MongoLogicalOperators.IN)), - as_cursor=False) - result = list() - - if fetch_senders: - user_ids = list(set([shout['user_id'] for shout in shouts])) - - users_from_shouts = cls.db_controller.exec_query(query=MongoQuery(command=MongoCommands.FIND_ALL, - document=MongoDocuments.USERS, - filters=MongoFilter('_id', user_ids, - MongoLogicalOperators.IN))) - - formatted_users = dict() - for users_from_shout in users_from_shouts: - user_id = users_from_shout.pop('_id', None) - formatted_users[user_id] = users_from_shout - - for shout in shouts: - matching_user = formatted_users.get(shout['user_id'], {}) - if not matching_user: - matching_user = create_from_pattern(UserPatterns.UNRECOGNIZED_USER) - - matching_user.pop('password', None) - matching_user.pop('is_tmp', None) - shout['message_id'] = shout['_id'] - shout_data = {**shout, **matching_user} - result.append(shout_data) - shouts = result - return shouts - - @classmethod - def get_translations(cls, translation_mapping: dict) -> Tuple[dict, dict]: - """ - Gets translation from db based on provided mapping - - :param translation_mapping: mapping of cid to desired translation language - - :return translations fetched from db - """ - populated_translations = {} - missing_translations = {} - for cid, cid_data in translation_mapping.items(): - lang = cid_data.get('lang', 'en') - shout_ids = cid_data.get('shouts', []) - conversation_data = cls.get_conversation_data(search_str=cid) - if not conversation_data: - LOG.error(f'Failed to fetch conversation data - {cid}') - continue - shout_data = cls.fetch_shout_data(conversation_data=conversation_data, - shout_ids=shout_ids, - fetch_senders=False) - shout_lang = 'en' - if len(shout_data) == 1: - shout_lang = shout_data[0].get('message_lang', 'en') - for shout in shout_data: - message_text = shout.get('message_text') - if shout_lang != 'en' and lang == 'en': - shout_text = message_text - else: - shout_text = shout.get('translations', {}).get(lang) - if shout_text and lang != 'en': - populated_translations.setdefault(cid, {}).setdefault('shouts', {})[shout['_id']] = shout_text - elif message_text: - missing_translations.setdefault(cid, {}).setdefault('shouts', {})[shout['_id']] = message_text - if missing_translations.get(cid): - missing_translations[cid]['lang'] = lang - missing_translations[cid]['source_lang'] = shout_lang - return populated_translations, missing_translations - - @classmethod - def save_translations(cls, translation_mapping: dict) -> Dict[str, List[str]]: - """ - Saves translations in DB - :param translation_mapping: mapping of cid to desired translation language - :returns dictionary containing updated shouts (those which were translated to English) - """ - updated_shouts = {} - for cid, shout_data in translation_mapping.items(): - translations = shout_data.get('shouts', {}) - bulk_update = [] - shouts = cls.db_controller.exec_query(query=MongoQuery(command=MongoCommands.FIND_ALL, - document=MongoDocuments.SHOUTS, - filters=MongoFilter('_id', list(translations), - MongoLogicalOperators.IN)), - as_cursor=False) - for shout_id, translation in translations.items(): - matching_instance = None - for shout in shouts: - if shout['_id'] == shout_id: - matching_instance = shout - break - if not matching_instance.get('translations'): - filter_expression = {'_id': shout_id} - cls.db_controller.exec_query(query=MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.SHOUTS, - filters=filter_expression, - data={'translations': {}}, - data_action='set')) - # English is the default language, so it is treated as message text - if shout_data.get('lang', 'en') == 'en': - updated_shouts.setdefault(cid, []).append(shout_id) - filter_expression = {'_id': shout_id} - update_expression = {'$set': {'message_lang': 'en'}} - cls.db_controller.exec_query(query={'document': 'shouts', - 'command': 'update', - 'data': (filter_expression, - update_expression,)}) - bulk_update_setter = {'message_text': translation, - 'message_lang': 'en'} - else: - bulk_update_setter = {f'translations.{shout_data["lang"]}': translation} - # TODO: make a convenience wrapper to make bulk insertion easier to follow - bulk_update.append(UpdateOne({'_id': shout_id}, - {'$set': bulk_update_setter})) - if len(bulk_update) > 0: - cls.db_controller.exec_query(query=MongoQuery(command=MongoCommands.BULK_WRITE, - document=MongoDocuments.SHOUTS, - data=bulk_update)) - return updated_shouts - - @classmethod - def get_user_preferences(cls, user_id): - """ Gets preferences of specified user, creates default if not exists """ - prefs = { - 'tts': {}, - 'chat_language_mapping': {} - } - if user_id: - user = cls.get_user(user_id=user_id) or {} - if user and not user.get('preferences'): - cls.db_controller.exec_query(MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.USERS, - filters=MongoFilter(key='_id', value=user_id), - data={'preferences': prefs}, - data_action='set')) - else: - prefs = user.get('preferences') - else: - LOG.warning('user_id is None') - return prefs - - @classmethod - def set_user_preferences(cls, user_id, preferences_mapping: dict): - """ Sets user preferences for specified user according to preferences mapping """ - if user_id: - try: - update_mapping = {f'preferences.{key}': val for key, val in preferences_mapping.items()} - cls.db_controller.exec_query(MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.USERS, - filters=MongoFilter('_id', user_id), - data=update_mapping, - data_action='set')) - except Exception as ex: - LOG.error(f'Failed to update preferences for user_id={user_id} - {ex}') - - @classmethod - def save_tts_response(cls, shout_id, audio_data: str, lang: str = 'en', gender: str = 'female') -> bool: - """ - Saves TTS Response under corresponding shout id - - :param shout_id: message id to consider - :param audio_data: base64 encoded audio data received - :param lang: language of speech (defaults to English) - :param gender: language gender (defaults to female) - - :return bool if saving was successful - """ - from chat_server.server_config import sftp_connector - - audio_file_name = f'{shout_id}_{lang}_{gender}.wav' - try: - sftp_connector.put_file_object(file_object=audio_data, save_to=f'audio/{audio_file_name}') - cls.db_controller.exec_query(query=MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.SHOUTS, - filters=MongoFilter('_id', shout_id), - data={f'audio.{lang}.{gender}': audio_file_name}, - data_action='set')) - operation_success = True - except Exception as ex: - LOG.error(f'Failed to save TTS response to db - {ex}') - operation_success = False - return operation_success - - @classmethod - def save_stt_response(cls, shout_id, message_text: str, lang: str = 'en'): - """ - Saves STT Response under corresponding shout id - - :param shout_id: message id to consider - :param message_text: STT result transcript - :param lang: language of speech (defaults to English) - """ - try: - cls.db_controller.exec_query(query=MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.SHOUTS, - filters=MongoFilter('_id', shout_id), - data={f'transcripts.{lang}': message_text}, - data_action='set')) - except Exception as ex: - LOG.error(f'Failed to save STT response to db - {ex}') - - @classmethod - def fetch_audio_data_from_message(cls, message_id: str) -> str: - """ - Fetches audio data from message if any - :param message_id: message id to fetch - """ - shout_data = cls.fetch_shouts(shout_ids=[message_id]) - if not shout_data: - LOG.warning('Requested shout does not exist') - elif shout_data[0].get('is_audio') != '1': - LOG.warning('Failed to fetch audio data from non-audio message') - else: - from chat_server.server_config import sftp_connector - file_location = f'audio/{shout_data[0]["message_text"]}' - LOG.info(f'Fetching existing file from: {file_location}') - fo = sftp_connector.get_file_object(file_location) - if fo.getbuffer().nbytes > 0: - return buffer_to_base64(fo) - else: - LOG.error(f'Empty buffer received while fetching audio of message id = {message_id}') - return '' diff --git a/chat_server/server_utils/dependencies.py b/chat_server/server_utils/dependencies.py new file mode 100644 index 00000000..553b1e8d --- /dev/null +++ b/chat_server/server_utils/dependencies.py @@ -0,0 +1,37 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from typing import Annotated + +from fastapi import Depends + +from chat_server.server_utils.auth import get_current_user_model +from chat_server.server_utils.models.users import CurrentUserModel + +CurrentUserDependency = Annotated[CurrentUserModel, Depends(get_current_user_model)] diff --git a/chat_server/server_utils/enums.py b/chat_server/server_utils/enums.py index a12e60ff..87ccddd9 100644 --- a/chat_server/server_utils/enums.py +++ b/chat_server/server_utils/enums.py @@ -26,10 +26,11 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from enum import Enum, IntEnum +from enum import Enum class DataSources(Enum): - """ Enumeration of supported data sources """ - SFTP = 'SFTP' - LOCAL = 'LOCAL' + """Enumeration of supported data sources""" + + SFTP = "SFTP" + LOCAL = "LOCAL" diff --git a/chat_server/server_utils/exceptions.py b/chat_server/server_utils/exceptions.py new file mode 100644 index 00000000..403e9b72 --- /dev/null +++ b/chat_server/server_utils/exceptions.py @@ -0,0 +1,55 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import http + +from utils.http_utils import respond + + +class KlatAPIException(Exception): + + HTTP_CODE = http.HTTPStatus.INTERNAL_SERVER_ERROR + MESSAGE = "Internal Server Error" + + def to_http_response(self): + return respond(msg=self.MESSAGE, status_code=self.HTTP_CODE.value) + + +class UserUnauthorizedException(KlatAPIException): + HTTP_CODE = http.HTTPStatus.FORBIDDEN + MESSAGE = "Requested user is not authorized to perform this action" + + +class ItemNotFoundException(KlatAPIException): + HTTP_CODE = http.HTTPStatus.NOT_FOUND + MESSAGE = "Requested item not found" + + +class DuplicatedItemException(KlatAPIException): + HTTP_CODE = http.HTTPStatus.CONFLICT + MESSAGE = "Requested item already exists" diff --git a/chat_server/server_utils/factory_utils.py b/chat_server/server_utils/factory_utils.py index 72304e71..6f1b47cf 100644 --- a/chat_server/server_utils/factory_utils.py +++ b/chat_server/server_utils/factory_utils.py @@ -26,12 +26,14 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + class Singleton(type): - """ Metaclass for Singleton Implementation""" + """Metaclass for Singleton Implementation""" + _instances = {} def __call__(cls, *args, **kwargs): - update = kwargs.pop('update', False) + update = kwargs.pop("update", False) if cls not in cls._instances or update: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] diff --git a/chat_server/server_utils/http_utils.py b/chat_server/server_utils/http_utils.py index 117dc91d..938002b7 100644 --- a/chat_server/server_utils/http_utils.py +++ b/chat_server/server_utils/http_utils.py @@ -30,7 +30,7 @@ from io import BytesIO import aiofiles -from fastapi import UploadFile +from fastapi import UploadFile, Request from starlette.responses import FileResponse, StreamingResponse from chat_server.server_config import app_config, sftp_connector @@ -40,68 +40,99 @@ from utils.logging_utils import LOG -def get_file_response(filename, location_prefix: str = "", media_type: str = None, - data_source: DataSources = DataSources.SFTP) -> FileResponse: +class KlatAPIResponse: + OK = respond("OK") + BAD_REQUEST = respond("Bad Request") + UNAUTHORIZED = respond("Unauthorized", status_code=401) + FORBIDDEN = respond("Permission Denied", status_code=403) + NOT_FOUND = respond("NOT_FOUND", status_code=404) + INTERNAL_SERVER_ERROR = respond("INTERNAL_SERVER_ERROR", status_code=500) + + +def get_file_response( + filename, + location_prefix: str = "", + media_type: str = None, + data_source: DataSources = DataSources.SFTP, +) -> FileResponse: """ - Gets starlette file response based on provided location + Gets starlette file response based on provided location - :param location_prefix: subdirectory for file to get - :param filename: name of the file to get - :param media_type: type of file to send - :param data_source: source of the data from DataSources + :param location_prefix: subdirectory for file to get + :param filename: name of the file to get + :param media_type: type of file to send + :param data_source: source of the data from DataSources - :returns FileResponse in case file is present under specified location + :returns FileResponse in case file is present under specified location """ # TODO: potentially support different ways to access files (e.g. local, S3, remote server, etc..) - LOG.debug(f'Getting file based on filename: {filename}, media type: {media_type}') + LOG.debug(f"Getting file based on filename: {filename}, media type: {media_type}") if data_source == DataSources.SFTP: - sftp_data = sftp_connector.get_file_object(get_from=f'{location_prefix}/{filename}') - file_response_args = dict(content=sftp_data,) + sftp_data = sftp_connector.get_file_object( + get_from=f"{location_prefix}/{filename}" + ) + file_response_args = dict( + content=sftp_data, + ) response_class = StreamingResponse elif data_source == DataSources.LOCAL: - path = os.path.join(app_config['FILE_STORING_LOCATION'], location_prefix, filename) - LOG.debug(f'path: {path}') + path = os.path.join( + app_config["FILE_STORING_LOCATION"], location_prefix, filename + ) + LOG.debug(f"path: {path}") if os.path.exists(os.path.expanduser(path)): - file_response_args = dict(path=path, - filename=filename) + file_response_args = dict(path=path, filename=filename) else: - LOG.error(f'{path} not found') + LOG.error(f"{path} not found") return respond("File not found", 404) response_class = FileResponse else: - LOG.error(f'Data source does not exists - {data_source}') + LOG.error(f"Data source does not exists - {data_source}") return respond("Unable to fetch relevant data source", 403) if media_type: - file_response_args['media_type'] = media_type + file_response_args["media_type"] = media_type return response_class(**file_response_args) -async def save_file(file: UploadFile, location_prefix: str = '', - data_source: DataSources = DataSources.SFTP) -> str: +async def save_file( + file: UploadFile, + location_prefix: str = "", + data_source: DataSources = DataSources.SFTP, +) -> str: """ - Saves file in the file system + Saves file in the file system - :param file: file to save - :param location_prefix: subdirectory for file to get - :param data_source: source of the data from DataSources + :param file: file to save + :param location_prefix: subdirectory for file to get + :param data_source: source of the data from DataSources - :returns generated location for the provided file + :returns generated location for the provided file """ new_name = f'{generate_uuid(length=12)}.{file.filename.split(".")[-1]}' if data_source == DataSources.LOCAL: - storing_path = os.path.expanduser(os.path.join(app_config['FILE_STORING_LOCATION'], location_prefix)) + storing_path = os.path.expanduser( + os.path.join(app_config["FILE_STORING_LOCATION"], location_prefix) + ) os.makedirs(storing_path, exist_ok=True) - async with aiofiles.open(os.path.join(storing_path, new_name), 'wb') as out_file: + async with aiofiles.open( + os.path.join(storing_path, new_name), "wb" + ) as out_file: content = file.file.read() # async read await out_file.write(content) elif data_source == DataSources.SFTP: content = BytesIO(file.file.read()) try: - sftp_connector.put_file_object(file_object=content, save_to=f'{location_prefix}/{new_name}') + sftp_connector.put_file_object( + file_object=content, save_to=f"{location_prefix}/{new_name}" + ) except Exception as ex: - LOG.error(f'failed to save file: {file.filename}- {ex}') - return respond('Failed to save attachment due to unexpected error', 422) + LOG.error(f"failed to save file: {file.filename}- {ex}") + return respond("Failed to save attachment due to unexpected error", 422) else: - LOG.error(f'Data source does not exists - {data_source}') + LOG.error(f"Data source does not exists - {data_source}") return respond(f"Unable to fetch relevant data source", 403) return new_name + + +def get_request_path_string(request: Request) -> str: + return f"[{request.method}] {request.url.path} " diff --git a/chat_server/server_utils/k8s_utils.py b/chat_server/server_utils/k8s_utils.py new file mode 100644 index 00000000..2302b51f --- /dev/null +++ b/chat_server/server_utils/k8s_utils.py @@ -0,0 +1,80 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import datetime +import os + +from kubernetes import client, config +from kubernetes.client.rest import ApiException + +from chat_server.server_config import k8s_config +from utils.logging_utils import LOG + +k8s_app_api = None +_k8s_default_namespace = "default" + +if _k8s_config_path := k8s_config.get("K8S_CONFIG_PATH"): + _k8s_default_namespace = ( + k8s_config.get("K8S_DEFAULT_NAMESPACE") or _k8s_default_namespace + ) + config.load_kube_config(_k8s_config_path) + + k8s_app_api = client.AppsV1Api() +else: + LOG.warning("K8S config is unset!") + + +def restart_deployment(deployment_name: str, namespace: str = _k8s_default_namespace): + """ + Restarts K8S deployment + :param deployment_name: name of the deployment to restart + :param namespace: name of the namespace + """ + if not k8s_app_api: + LOG.error( + f"Failed to restart {deployment_name=!r} ({namespace=!r}) - missing K8S configs" + ) + return -1 + now = datetime.datetime.utcnow() + now = str(now.isoformat() + "Z") + body = { + "spec": { + "template": { + "metadata": {"annotations": {"kubectl.kubernetes.io/restartedAt": now}} + } + } + } + try: + k8s_app_api.patch_namespaced_deployment( + deployment_name, namespace, body, pretty="true" + ) + except ApiException as e: + LOG.error( + "Exception when calling AppsV1Api->read_namespaced_deployment_status: %s\n" + % e + ) diff --git a/chat_server/server_utils/languages.py b/chat_server/server_utils/languages.py index a1297b28..cb7a7608 100644 --- a/chat_server/server_utils/languages.py +++ b/chat_server/server_utils/languages.py @@ -36,92 +36,95 @@ class LanguageSettings: - """ Language Settings controller""" + """Language Settings controller""" __supported_languages__ = {} __default_languages__ = { - 'en': dict(name='English', icon='us'), - 'es': dict(name='Español', icon='es'), + "en": dict(name="English", icon="us"), + "es": dict(name="Español", icon="es"), } __code_to_icon_mapping__ = { - 'en': 'us', - 'hi': 'in', - 'zh': 'cn', - 'cs': 'cz', - 'el': 'gr', - 'ja': 'jp', - 'ko': 'kr', - 'fa': 'ir', - 'uk': 'ua', - 'ar': 'sa', - 'da': 'dk', - 'he': 'il', - 'vi': 'vn', - 'ga': 'ie' + "en": "us", + "hi": "in", + "zh": "cn", + "cs": "cz", + "el": "gr", + "ja": "jp", + "ko": "kr", + "fa": "ir", + "uk": "ua", + "ar": "sa", + "da": "dk", + "he": "il", + "vi": "vn", + "ga": "ie", } - __included_language_codes = os.environ.get('INCLUDED_LANGUAGES', '') + __included_language_codes = os.environ.get("INCLUDED_LANGUAGES", "") if __included_language_codes: - __included_language_codes = __included_language_codes.split(',') + __included_language_codes = __included_language_codes.split(",") - __excluded_language_codes__ = ['ru', 'eo'] + __excluded_language_codes__ = ["ru", "eo"] - __neon_language_mapping__ = bidict({ - 'en': 'en-us' - }) + __neon_language_mapping__ = bidict({"en": "en-us"}) - __default_libre_url__ = 'https://libretranslate.com/' + __default_libre_url__ = "https://libretranslate.com/" @classmethod def init_supported_languages(cls): - """ Inits supported languages from system configuration""" + """Inits supported languages from system configuration""" from chat_server.server_config import app_config - for url in {app_config.get('LIBRE_TRANSLATE_URL', cls.__default_libre_url__), cls.__default_libre_url__}: + for url in { + app_config.get("LIBRE_TRANSLATE_URL", cls.__default_libre_url__), + cls.__default_libre_url__, + }: try: - res = requests.get(f'{url}/languages') + res = requests.get(f"{url}/languages") if res.ok: for item in res.json(): - code = item['code'] - if code not in cls.__excluded_language_codes__ \ - and (not cls.__included_language_codes or code in cls.__included_language_codes): + code = item["code"] + if code not in cls.__excluded_language_codes__ and ( + not cls.__included_language_codes + or code in cls.__included_language_codes + ): cls.__supported_languages__[code] = { - 'name': item['name'], - 'icon': cls.__code_to_icon_mapping__.get(code, code) + "name": item["name"], + "icon": cls.__code_to_icon_mapping__.get(code, code), } return 0 except Exception as ex: - LOG.error(f'Failed to get translations under URL - {url} (ex={ex})') + LOG.error(f"Failed to get translations under URL - {url} (ex={ex})") return -1 @classmethod def get(cls, lang) -> dict: - """ Gets properties based on provided language code""" + """Gets properties based on provided language code""" if not cls.__supported_languages__: status = cls.init_supported_languages() if status == -1: - LOG.warning('Rollback to default languages') + LOG.warning("Rollback to default languages") return cls.__default_languages__.get(lang, {}) return cls.__supported_languages__.get(lang, {}) @classmethod def list(cls) -> dict: - """ Lists supported languages """ + """Lists supported languages""" if not cls.__supported_languages__: status = cls.init_supported_languages() if status == -1: - LOG.warning('Rollback to default languages') + LOG.warning("Rollback to default languages") return copy.deepcopy(cls.__default_languages__) return copy.deepcopy(cls.__supported_languages__) @classmethod def to_neon_lang(cls, lang): - """ Maps provided language code to the Neon-supported language code """ + """Maps provided language code to the Neon-supported language code""" return cls.__neon_language_mapping__.get(lang, lang) @classmethod def to_system_lang(cls, neon_lang): - """ Maps provided Neon-supported language code to system language code """ + """Maps provided Neon-supported language code to system language code""" return cls.__neon_language_mapping__.inverse.get(neon_lang, neon_lang) diff --git a/chat_server/server_utils/middleware.py b/chat_server/server_utils/middleware.py new file mode 100644 index 00000000..e7eb39ba --- /dev/null +++ b/chat_server/server_utils/middleware.py @@ -0,0 +1,80 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import random +import string +import time +import traceback + +from starlette.requests import Request +from starlette.middleware.base import BaseHTTPMiddleware + +from chat_server.server_utils.exceptions import KlatAPIException +from chat_server.server_utils.http_utils import get_request_path_string, KlatAPIResponse +from utils.logging_utils import LOG + + +class KlatAPIExceptionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + try: + response = await call_next(request) + except KlatAPIException as exc: + path = get_request_path_string(request=request) + LOG.warning(f"Klat API exception occurred for {path = } msg={exc.MESSAGE}") + response = exc.to_http_response() + return response + + +class LogMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + path = get_request_path_string(request=request) + request_id = "".join( + random.choices(string.ascii_uppercase + string.digits, k=6) + ) + LOG.info(f"{request_id = } start at {path = }") + start_time = time.time() + try: + response = await call_next(request) + process_time = (time.time() - start_time) * 1000 + formatted_process_time = "{0:.2f}".format(process_time) + log_message = ( + f"{request_id = } " + f"completed_in={formatted_process_time}ms " + f"status_code={response.status_code}" + ) + LOG.info(log_message) + return response + except: + LOG.error(f"{path = }| traceback = {traceback.format_exc()}") + return KlatAPIResponse.INTERNAL_SERVER_ERROR + + +SUPPORTED_MIDDLEWARE = ( + KlatAPIExceptionMiddleware, + LogMiddleware, +) diff --git a/chat_server/server_utils/models/chats.py b/chat_server/server_utils/models/chats.py new file mode 100644 index 00000000..222455a0 --- /dev/null +++ b/chat_server/server_utils/models/chats.py @@ -0,0 +1,43 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from time import time + +from fastapi import Query, Path +from pydantic import BaseModel, Field + +from chat_server.constants.conversations import ConversationSkins + + +class GetConversationModel(BaseModel): + search_str: str = Field(Path(), examples=["1"]) + limit_chat_history: int = (Field(Query(default=100), examples=[100]),) + creation_time_from: str | None = Field(Query(default=None), examples=[int(time())]) + skin: str = Field( + Query(default=ConversationSkins.BASE), examples=[ConversationSkins.BASE] + ) diff --git a/chat_server/server_utils/models/configs.py b/chat_server/server_utils/models/configs.py new file mode 100644 index 00000000..1730974e --- /dev/null +++ b/chat_server/server_utils/models/configs.py @@ -0,0 +1,13 @@ +from fastapi import Path, Query +from pydantic import BaseModel, Field + + +class ConfigModel(BaseModel): + config_property: str = Field( + Path(title="Name of the config property"), examples=["supported_llms"] + ) + version: str = Field(Query(default="latest"), examples=["latest"]) + + +class SetConfigModel(ConfigModel): + data: dict = Field([{"records": [{"label": "Chat GPT", "value": "chatgpt"}]}]) diff --git a/chat_server/server_utils/models/personas.py b/chat_server/server_utils/models/personas.py new file mode 100644 index 00000000..a5d23bcb --- /dev/null +++ b/chat_server/server_utils/models/personas.py @@ -0,0 +1,79 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from fastapi import Query +from pydantic import BaseModel, Field, computed_field + + +class Persona(BaseModel): + persona_name: str = Field(examples=["doctor"]) + user_id: str | None = Field(default=None, examples=["test_user_id"]) + + @computed_field + @property + def _id(self) -> str: + persona_id = self.persona_name + if self.user_id: + persona_id += f"_{self.user_id}" + return persona_id + + @property + def persona_id(self): + return self._id + + +class AddPersonaModel(Persona): + supported_llms: list[str] = Field( + examples=[["chat_gpt", "llama", "fastchat"]], default=[] + ) + default_llm: str | None = Field(examples=["chat_gpt"], default=None) + description: str = Field(examples=["I am the doctor. I am helping people."]) + enabled: bool = False + + +class SetPersonaModel(Persona): + supported_llms: list[str] = Field( + examples=[["chat_gpt", "llama", "fastchat"]], default=[] + ) + default_llm: str | None = Field(examples=["chat_gpt"], default=None) + description: str = Field(examples=["I am the doctor. I am helping people."]) + + +class DeletePersonaModel(Persona): + persona_name: str = Field(Query(), examples=["doctor"]) + user_id: str | None = Field(Query(None), examples=["test_user_id"]) + + +class TogglePersonaStatusModel(Persona): + enabled: bool = Field(examples=[True, False], default=True) + + +class ListPersonasQueryModel(BaseModel): + llms: list[str] | None = Field(Query(default=None), examples=[["doctor"]]) + user_id: str | None = Field(Query(default=None), examples=["test_user_id"]) + only_enabled: bool = Field(Query(default=False), examples=[True, False]) diff --git a/chat_server/server_utils/models/users.py b/chat_server/server_utils/models/users.py new file mode 100644 index 00000000..e31934b0 --- /dev/null +++ b/chat_server/server_utils/models/users.py @@ -0,0 +1,44 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from pydantic import BaseModel, Field + + +class CurrentUserModel(BaseModel): + user_id: str = Field(default=None, examples=["test_user_id"], alias="_id") + nickname: str = Field(examples=["test_nickname"]) + first_name: str = Field(examples=["Test"]) + last_name: str = Field(examples=["Test"]) + preferences: dict | None = Field( + examples=[{"tts": {}, "chat_language_mapping": {}}], default=None + ) + avatar: str | None = Field(default=None) + full_nickname: str | None = Field(default=None) + is_bot: bool | None = Field(examples=[True, False], default=False) + is_tmp: bool | None = Field(examples=[False, True], default=True) + roles: list[str] | None = Field(examples=["admin", ""], default=[]) diff --git a/chat_server/server_utils/os_utils.py b/chat_server/server_utils/os_utils.py index 9a5291e8..c9f91962 100644 --- a/chat_server/server_utils/os_utils.py +++ b/chat_server/server_utils/os_utils.py @@ -31,7 +31,7 @@ def remove_if_exists(file_path): - """ Removes file if exists""" + """Removes file if exists""" try: os.remove(file_path) except OSError: diff --git a/chat_server/server_utils/prompt_utils.py b/chat_server/server_utils/prompt_utils.py deleted file mode 100644 index 5830c68e..00000000 --- a/chat_server/server_utils/prompt_utils.py +++ /dev/null @@ -1,103 +0,0 @@ -# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework -# All trademark and other rights reserved by their respective owners -# Copyright 2008-2022 Neongecko.com Inc. -# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, -# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo -# BSD-3 License -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# 1. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, -# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from enum import IntEnum - -from chat_server.server_config import db_controller -from utils.database_utils.mongo_utils import * -from utils.logging_utils import LOG - - -class PromptStates(IntEnum): - """ Prompt States """ - IDLE = 0 # No active prompt - RESP = 1 # Gathering responses to prompt - DISC = 2 # Discussing responses - VOTE = 3 # Voting on responses - PICK = 4 # Proctor will select response - WAIT = 5 # Bot is waiting for the proctor to ask them to respond (not participating) - - -def handle_prompt_message(message: dict) -> bool: - """ - Handles received prompt message - :param message: message dictionary received - :returns True if prompt message was handled, false otherwise - """ - try: - prompt_id = message.get('prompt_id') - prompt_state = PromptStates(int(message.get('promptState', PromptStates.IDLE.value))) - user_id = message['userID'] - message_id = message['messageID'] - ok = True - if prompt_id: - existing_prompt = db_controller.exec_query(MongoQuery(command=MongoCommands.FIND_ONE, - document=MongoDocuments.PROMPTS, - filters=MongoFilter(key='_id', value=prompt_id))) or {} - if existing_prompt and existing_prompt['is_completed'] == '0': - if user_id not in existing_prompt.get('data', {}).get('participating_subminds', []): - data_kwargs = { - 'data': {'data.participating_subminds': user_id}, - 'data_action': 'push' - } - db_controller.exec_query(MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.PROMPTS, - filters=MongoFilter(key='_id', value=prompt_id), - **data_kwargs)) - - prompt_state_mapping = { - # PromptStates.WAIT: {'key': 'participating_subminds', 'type': list}, - PromptStates.RESP: {'key': f'proposed_responses.{user_id}', 'type': dict, 'data': message_id}, - PromptStates.DISC: {'key': f'submind_opinions.{user_id}', 'type': dict, 'data': message_id}, - PromptStates.VOTE: {'key': f'votes.{user_id}', 'type': dict, 'data': message_id} - } - store_key_properties = prompt_state_mapping.get(prompt_state) - if not store_key_properties: - LOG.warning(f'Prompt State - {prompt_state.name} has no db store properties') - else: - store_key = store_key_properties['key'] - store_type = store_key_properties['type'] - store_data = store_key_properties['data'] - if user_id in list(existing_prompt.get('data', {}).get(store_key, {})): - LOG.error( - f'user_id={user_id} tried to duplicate data to prompt_id={prompt_id}, store_key={store_key}') - else: - data_kwargs = { - 'data': {f'data.{store_key}': store_data}, - 'data_action': 'push' if store_type == list else 'set' - } - db_controller.exec_query(MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.PROMPTS, - filters=MongoFilter(key='_id', value=prompt_id), - **data_kwargs)) - else: - ok = False - except Exception as ex: - LOG.error(f'Failed to handle prompt message - {message} ({ex})') - ok = False - return ok diff --git a/chat_server/server_utils/rmq_utils.py b/chat_server/server_utils/rmq_utils.py new file mode 100644 index 00000000..8dc88b1f --- /dev/null +++ b/chat_server/server_utils/rmq_utils.py @@ -0,0 +1,174 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2021 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import requests + +from urllib.parse import quote_plus +from requests.auth import HTTPBasicAuth + + +# TODO: use this package from neon_diana_utils once its dependencies won't cause conficts +class RabbitMQAPI: + def __init__(self, url: str, verify_ssl: bool = False): + """ + Creates an object used to interface with a RabbitMQ server + :param url: Management URL (usually IP:port) + """ + self._verify_ssl = verify_ssl + self.console_url = url + self._username = None + self._password = None + + def login(self, username: str, password: str): + """ + Sets internal username/password parameters used to generate HTTP auth + :param username: user to authenticate as + :param password: plaintext password to authenticate with + """ + self._username = username + self._password = password + # TODO: Check auth and return DM + + @property + def auth(self): + """ + HTTPBasicAuth object to include with requests. + """ + return HTTPBasicAuth(self._username, self._password) + + def add_vhost(self, vhost: str) -> bool: + """ + Add a vhost to the server + :param vhost: vhost to add + :return: True if request was successful + """ + status = requests.put( + f"{self.console_url}/api/vhosts/{quote_plus(vhost)}", + auth=self.auth, + verify=self._verify_ssl, + ) + return status + + def add_user(self, user: str, password: str, tags: str = "") -> bool: + """ + Add a user to the server + :param user: username to add + :param password: password for user + :param tags: comma-delimited list of tags to assign to new user + :return: True if request was successful + """ + tags = tags or "" + body = {"password": password, "tags": tags} + status = requests.put( + f"{self.console_url}/api/users/{quote_plus(user)}", + data=json.dumps(body), + auth=self.auth, + verify=self._verify_ssl, + ) + return status.ok + + def delete_user(self, user: str) -> bool: + """ + Delete a user from the server + :param user: username to remove + """ + status = requests.delete( + f"{self.console_url}/api/users/{quote_plus(user)}", + auth=self.auth, + verify=self._verify_ssl, + ) + return status.ok + + def configure_vhost_user_permissions( + self, + vhost: str, + user: str, + configure: str = ".*", + write: str = ".*", + read: str = ".*", + ) -> bool: + """ + Configure user's access to vhost. See RabbitMQ docs: + https://www.rabbitmq.com/access-control.html#authorisation + :param vhost: vhost to set/modify permissions for + :param user: user to set/modify permissions of + :param configure: regex configure permissions + :param write: regex write permissions + :param read: regex read permissions + :return: True if request was successful + """ + url = ( + f"{self.console_url}/api/permissions/{quote_plus(vhost)}/" + f"{quote_plus(user)}" + ) + body = {"configure": configure, "write": write, "read": read} + status = requests.put( + url, data=json.dumps(body), auth=self.auth, verify=self._verify_ssl + ) + return status.ok + + def get_definitions(self): + """ + Get the server definitions for RabbitMQ; these are used to persist + configuration between container restarts + """ + resp = requests.get( + f"{self.console_url}/api/definitions", + auth=self.auth, + verify=self._verify_ssl, + ) + data = json.loads(resp.content) + return data + + def create_default_users(self, users: list) -> dict: + """ + Creates the passed list of users with random passwords and returns a + dict of users to passwords + :param users: list of usernames to create + :return: Dict of created usernames and associated passwords + """ + import secrets + + credentials = dict() + for user in users: + passwd = secrets.token_urlsafe(32) + credentials[user] = passwd + self.add_user(user, passwd) + return credentials + + def configure_admin_account(self, username: str, password: str) -> bool: + """ + Configures an administrator with the passed credentials and removes + the default account + :param username: New administrator's username + :param password: New administrator's password + :return: True if action was successful + """ + create = self.add_user(username, password, "administrator") + self.login(username, password) + delete = self.delete_user("guest") + return create and delete diff --git a/chat_server/server_utils/sftp_utils.py b/chat_server/server_utils/sftp_utils.py index f06e908c..cb0f1799 100644 --- a/chat_server/server_utils/sftp_utils.py +++ b/chat_server/server_utils/sftp_utils.py @@ -31,11 +31,15 @@ def init_sftp_connector(config): - """ Initialise SFTP Connector based on provided configuration """ + """Initialise SFTP Connector based on provided configuration""" if config is None: - raise AssertionError('No SFTP Config Detected') - return NeonSFTPConnector(host=config.get('HOST', '127.0.0.1'), - username=config.get('USERNAME', 'root'), - passphrase=config.get('PASSWORD', ''), - port=int(config.get('PORT', 22)), - root_path=config.get('ROOT_PATH', '/').format(env=Configuration.KLAT_ENV.lower())) + raise AssertionError("No SFTP Config Detected") + return NeonSFTPConnector( + host=config.get("HOST", "127.0.0.1"), + username=config.get("USERNAME", "root"), + passphrase=config.get("PASSWORD", ""), + port=int(config.get("PORT", 22)), + root_path=config.get("ROOT_PATH", "/").format( + env=Configuration.KLAT_ENV.lower() + ), + ) diff --git a/chat_server/server_utils/user_utils.py b/chat_server/server_utils/user_utils.py index be097708..78361418 100644 --- a/chat_server/server_utils/user_utils.py +++ b/chat_server/server_utils/user_utils.py @@ -31,86 +31,30 @@ from chat_server.constants.users import UserPatterns from utils.common import get_hash, generate_uuid from utils.database_utils import DatabaseController -from utils.database_utils.mongo_utils import MongoQuery, MongoCommands, MongoDocuments, MongoFilter +from utils.database_utils.mongo_utils import ( + MongoQuery, + MongoCommands, + MongoDocuments, + MongoFilter, +) def create_from_pattern(source: UserPatterns, override_defaults: dict = None) -> dict: """ - Creates user record based on provided pattern from UserPatterns + Creates user record based on provided pattern from UserPatterns - :param source: source pattern from UserPatterns - :param override_defaults: to override default values (optional) - :returns user data populated with default values where necessary + :param source: source pattern from UserPatterns + :param override_defaults: to override default values (optional) + :returns user data populated with default values where necessary """ if not override_defaults: override_defaults = {} matching_data = {**copy.deepcopy(source.value), **override_defaults} - matching_data.setdefault('_id', generate_uuid(length=20)) - matching_data.setdefault('password', get_hash(generate_uuid())) - matching_data.setdefault('date_created', int(time())) - matching_data.setdefault('is_tmp', True) + matching_data.setdefault("_id", generate_uuid(length=20)) + matching_data.setdefault("password", get_hash(generate_uuid())) + matching_data.setdefault("date_created", int(time())) + matching_data.setdefault("is_tmp", True) return matching_data - - -def get_neon_data(db_controller: DatabaseController, skill_name: str = 'neon') -> dict: - """ - Gets a user profile for the user 'Neon' and adds it to the users db if not already present - - :param db_controller: db controller instance - :param skill_name: Neon Skill to consider (defaults to neon - Neon Assistant) - - :return Neon AI data - """ - neon_data = db_controller.exec_query({'command': 'find_one', 'document': 'users', - 'data': {'nickname': skill_name}}) - if not neon_data: - last_name = 'AI' if skill_name == 'neon' else skill_name.capitalize() - nickname = skill_name - neon_data = create_from_pattern(source=UserPatterns.NEON, override_defaults={'last_name': last_name, - 'nickname': nickname}) - db_controller.exec_query(MongoQuery(command=MongoCommands.INSERT_ONE, - document=MongoDocuments.USERS, - data=neon_data)) - return neon_data - - -def get_bot_data(db_controller: DatabaseController, nickname: str, context: dict = None) -> dict: - """ - Gets a user profile for the requested bot instance and adds it to the users db if not already present - - :param db_controller: db controller instance - :param nickname: nickname of the bot provided - :param context: context with additional bot information (optional) - - :return Matching bot data - """ - if not context: - context = {} - full_nickname = nickname - nickname = nickname.split('-')[0] - bot_data = db_controller.exec_query(MongoQuery(command=MongoCommands.FIND_ONE, - document=MongoDocuments.USERS, - filters=MongoFilter(key='nickname', value=nickname))) - if not bot_data: - bot_data = dict(_id=generate_uuid(length=20), - first_name=context.get('first_name', nickname.capitalize()), - last_name=context.get('last_name', ''), - avatar=context.get('avatar', ''), - password=get_hash(generate_uuid()), - nickname=nickname, - is_bot='1', - full_nickname=full_nickname, - date_created=int(time()), - is_tmp=False) - db_controller.exec_query(MongoQuery(command=MongoCommands.INSERT_ONE, - document=MongoDocuments.USERS, - data=bot_data)) - elif not bot_data.get('is_bot') == '1': - db_controller.exec_query(MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.USERS, - filters=MongoFilter('_id', bot_data['_id']), - data={'is_bot': '1'})) - return bot_data diff --git a/chat_server/services/popularity_counter.py b/chat_server/services/popularity_counter.py index 5a8a956c..77ec66bb 100644 --- a/chat_server/services/popularity_counter.py +++ b/chat_server/services/popularity_counter.py @@ -25,26 +25,31 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from collections import Counter from dataclasses import dataclass from time import time from typing import List -from utils.database_utils.mongo_utils import MongoQuery, MongoCommands, MongoDocuments, MongoFilter, \ - MongoLogicalOperators +from utils.database_utils.mongo_utils import ( + MongoFilter, + MongoLogicalOperators, +) +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.logging_utils import LOG @dataclass class ChatPopularityRecord: - """ Dataclass representing single chat popularity data """ + """Dataclass representing single chat popularity data""" + cid: str name: str popularity: int = 0 class PopularityCounter: - """ Handler for ordering chats popularity """ + """Handler for ordering chats popularity""" __DATA: List[ChatPopularityRecord] = [] # sorted popularity data __EXPIRATION_PERIOD = 3600 @@ -52,72 +57,101 @@ class PopularityCounter: @classmethod def get_data(cls): - """ Retrieves popularity data""" + """Retrieves popularity data""" ts = int(time()) if not cls.__DATA or ts - cls.last_updated_ts > cls.__EXPIRATION_PERIOD: cls.init_data() return cls.__DATA @classmethod - def add_new_chat(cls, cid, name, popularity: int = 0): - """ Adds new chat to the tracked chat popularity records """ - cls.__DATA.append(ChatPopularityRecord(cid=cid, - name=name, - popularity=popularity)) + def add_new_chat(cls, cid, popularity: int = 0): + """Adds new chat to the tracked chat popularity records""" + name = MongoDocumentsAPI.CHATS.get_item(item_id=cid).get("conversation_name") + cls.__DATA.append( + ChatPopularityRecord(cid=cid, name=name, popularity=popularity) + ) @classmethod def init_data(cls, actuality_days: int = 7): """ - Initialise items popularity from DB - Current implementation considers length of number of message container under given conversation + Initialise items popularity from DB + Current implementation considers length of number of message container under given conversation - :param actuality_days: number of days for message to affect the chat popularity + :param actuality_days: number of days for message to affect the chat popularity """ - from chat_server.server_utils.db_utils import DbUtils curr_time = int(time()) - chats = DbUtils.db_controller.exec_query(MongoQuery(command=MongoCommands.FIND_ALL, - document=MongoDocuments.CHATS, - filters=MongoFilter(key='is_private', - value=False)), - as_cursor=False) - relevant_shouts = DbUtils.db_controller.exec_query(MongoQuery(command=MongoCommands.FIND_ALL, - document=MongoDocuments.SHOUTS, - filters=MongoFilter(key='created_on', - logical_operator=MongoLogicalOperators.GTE, - value=curr_time - 3600 * 24 * actuality_days))) - relevant_shouts = set(x['_id'] for x in relevant_shouts) + oldest_timestamp = curr_time - 3600 * 24 * actuality_days + chats = MongoDocumentsAPI.CHATS.list_items( + filters=[ + MongoFilter( + key="last_shout_ts", + logical_operator=MongoLogicalOperators.GTE, + value=oldest_timestamp, + ) + ], + include_private=False, + result_as_cursor=False, + ) + relevant_shouts = MongoDocumentsAPI.SHOUTS.list_items( + filters=[ + MongoFilter( + key="created_on", + logical_operator=MongoLogicalOperators.GTE, + value=oldest_timestamp, + ), + MongoFilter( + key="cid", + value=[chat["_id"] for chat in chats], + logical_operator=MongoLogicalOperators.IN, + ), + ] + ) + cids_popularity_counter = Counter() + for shout in relevant_shouts: + cids_popularity_counter[str(shout["cid"])] += 1 formatted_chats = [] - for chat in chats: - chat_flow = set(chat.get('chat_flow', [])) - popularity = len(chat_flow.intersection(relevant_shouts)) - if chat['_id'] is not None: - formatted_chats.append(ChatPopularityRecord(cid=str(chat['_id']), - name=chat['conversation_name'], - popularity=popularity)) + for cid in cids_popularity_counter: + relevant_chat = [ + chat for chat in chats if str(chat.get("_id", "")) == str(cid) + ][0] + formatted_chats.append( + ChatPopularityRecord( + cid=cid, + name=relevant_chat["conversation_name"], + popularity=cids_popularity_counter[cid], + ) + ) cls.last_updated_ts = int(time()) cls.__DATA = sorted(formatted_chats, key=lambda x: x.popularity, reverse=True) @classmethod def increment_cid_popularity(cls, cid): - """ Increments popularity of specified conversation id """ + """Increments popularity of specified conversation id""" try: matching_item = [item for item in cls.get_data() if item.cid == cid][0] matching_item.popularity += 1 except IndexError: - LOG.error(f'No cid matching = {cid}') + LOG.debug(f"No cid matching = {cid}") + cls.add_new_chat(cid=cid, popularity=1) @classmethod def get_first_n_items(cls, search_str, exclude_items: list = None, limit: int = 10): """ - Returns first N items matching searched string + Returns first N items matching searched string - :param search_str: Substring to match - :param exclude_items: list of conversation ids to exclude from search - :param limit: number of the highest rated results to return + :param search_str: Substring to match + :param exclude_items: list of conversation ids to exclude from search + :param limit: number of the highest rated results to return """ if not exclude_items: exclude_items = [] - data = [{'_id': item.cid, 'conversation_name': item.name, 'popularity': item.popularity} - for item in cls.get_data() if search_str.lower() in item.name.lower() - and item.cid not in exclude_items] - return sorted(data, key=lambda item: item['popularity'], reverse=True)[:limit] + data = [ + { + "_id": item.cid, + "conversation_name": item.name, + "popularity": item.popularity, + } + for item in cls.get_data() + if search_str.lower() in item.name.lower() and item.cid not in exclude_items + ] + return sorted(data, key=lambda item: item["popularity"], reverse=True)[:limit] diff --git a/chat_server/sio.py b/chat_server/sio.py deleted file mode 100644 index 49b283ec..00000000 --- a/chat_server/sio.py +++ /dev/null @@ -1,616 +0,0 @@ -# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework -# All trademark and other rights reserved by their respective owners -# Copyright 2008-2022 Neongecko.com Inc. -# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, -# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo -# BSD-3 License -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# 1. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, -# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import json -import os -import socketio - -from functools import wraps -from time import time -from typing import List, Optional - -from cachetools import LRUCache -from utils.logging_utils import LOG - -from utils.common import generate_uuid, deep_merge, buffer_to_base64 -from chat_server.server_utils.auth import validate_session -from chat_server.server_utils.cache_utils import CacheFactory -from chat_server.server_utils.db_utils import DbUtils, MongoCommands, MongoDocuments, MongoQuery, MongoFilter -from chat_server.server_utils.prompt_utils import handle_prompt_message -from chat_server.server_utils.user_utils import get_neon_data, get_bot_data -from chat_server.server_utils.languages import LanguageSettings -from chat_server.server_config import db_controller, sftp_connector -from chat_server.services.popularity_counter import PopularityCounter - -sio = socketio.AsyncServer(cors_allowed_origins='*', async_mode='asgi') - - -def list_current_headers(sid: str) -> list: - return sio.environ.get(sio.manager.rooms['/'].get(sid, {}).get(sid), {}).get('asgi.scope', {}).get('headers', []) - - -def get_header(sid: str, match_str: str): - for header_tuple in list_current_headers(sid): - if header_tuple[0].decode() == match_str.lower(): - return header_tuple[1].decode() - - -def login_required(*outer_args, **outer_kwargs): - """ - Decorator that validates current authorization token - """ - - no_args = False - func = None - if len(outer_args) == 1 and not outer_kwargs and callable(outer_args[0]): - # Function was called with no arguments - no_args = True - func = outer_args[0] - - outer_kwargs.setdefault('tmp_allowed', True) - - def outer(func): - - @wraps(func) - async def wrapper(sid, *args, **kwargs): - if os.environ.get('DISABLE_AUTH_CHECK', '0') != '1': - auth_token = get_header(sid, 'session') - session_validation_output = (None, None,) - if auth_token: - session_validation_output = validate_session(auth_token, - check_tmp=not outer_kwargs['tmp_allowed'], - sio_request=True) - if session_validation_output[1] != 200: - return await sio.emit('auth_expired', data={}, to=sid) - return await func(sid, *args, **kwargs) - - return wrapper - - if no_args: - return outer(func) - else: - return outer - - -@sio.event -async def connect(sid, environ: dict, auth): - """ - SIO event fired on client connect - :param sid: client session id - :param environ: connection environment dict - :param auth: authorization method (None if was not provided) - """ - LOG.info(f'{sid} connected') - - -@sio.event -async def ping(sid, data): - """ - SIO event fired on client ping request - :param sid: client session id - :param data: user message data - """ - LOG.info(f'Received ping request from "{sid}"') - await sio.emit('pong', data={'msg': 'hello from sio server'}) - - -@sio.event -async def disconnect(sid): - """ - SIO event fired on client disconnect - - :param sid: client session id - """ - LOG.info(f'{sid} disconnected') - - -@sio.event -# @login_required -async def user_message(sid, data): - """ - SIO event fired on new user message in chat - :param sid: client session id - :param data: user message data - Example: - ``` - data = {'cid':'conversation id', - 'userID': 'emitted user id', - 'messageID': 'id of emitted message', - 'promptID': 'id of related prompt (optional)', - 'source': 'declared name of the source that shouted given user message' - 'messageText': 'content of the user message', - 'repliedMessage': 'id of replied message (optional)', - 'bot': 'if the message is from bot (defaults to False)', - 'lang': 'language of the message (defaults to "en")' - 'attachments': 'list of filenames that were send with message', - 'context': 'message context (optional)', - 'test': 'is test message (defaults to False)', - 'isAudio': '1 if current message is audio message 0 otherwise', - 'messageTTS': received tts mapping of type: {language: {gender: (audio data base64 encoded)}}, - 'isAnnouncement': if received message is the announcement, - 'timeCreated': 'timestamp on which message was created'} - ``` - """ - LOG.debug(f'Got new user message from {sid}: {data}') - try: - filter_expression = dict(_id=data['cid']) - cid_data = DbUtils.get_conversation_data(data['cid'], column_identifiers=['_id']) - if not cid_data: - msg = 'Shouting to non-existent conversation, skipping further processing' - await emit_error(sids=[sid], message=msg) - return - - LOG.info(f'Received user message data: {data}') - data['messageID'] = data.get('messageID') - if data['messageID']: - existing_shout = DbUtils.fetch_shouts(shout_ids=[data['messageID']], fetch_senders=False) - if existing_shout: - raise ValueError(f'messageID value="{data["messageID"]}" already exists') - else: - data['messageID'] = generate_uuid() - data['is_bot'] = data.pop('bot', '0') - if data['userID'].startswith('neon'): - neon_data = get_neon_data(db_controller=db_controller) - data['userID'] = neon_data['_id'] - elif data['is_bot'] == '1': - bot_data = get_bot_data(db_controller=db_controller, nickname=data['userID'], - context=data.get('context', None)) - data['userID'] = bot_data['_id'] - - is_audio = data.get('isAudio', '0') - - if is_audio != '1': - is_audio = '0' - - audio_path = f'{data["messageID"]}_audio.wav' - try: - if is_audio == '1': - message_text = data['messageText'].split(',')[-1] - sftp_connector.put_file_object(file_object=message_text, save_to=f'audio/{audio_path}') - # for audio messages "message_text" references the name of the audio stored - data['messageText'] = audio_path - except Exception as ex: - LOG.error(f'Failed to located file - {ex}') - return -1 - - is_announcement = data.get('isAnnouncement', '0') or '0' - - if is_announcement != '1': - is_announcement = '0' - - lang = data.get('lang', 'en') - data['prompt_id'] = data.pop('promptID', '') - - new_shout_data = {'_id': data['messageID'], - 'cid': data['cid'], - 'user_id': data['userID'], - 'prompt_id': data['prompt_id'], - 'message_text': data['messageText'], - 'message_lang': lang, - 'attachments': data.get('attachments', []), - 'replied_message': data.get('repliedMessage', ''), - 'is_audio': is_audio, - 'is_announcement': is_announcement, - 'is_bot': data['is_bot'], - 'translations': {}, - 'created_on': int(data.get('timeCreated', time()))} - - # in case message is received in some foreign language - - # message text is kept in that language unless English translation received - if lang != 'en': - new_shout_data['translations'][lang] = data['messageText'] - - db_controller.exec_query(MongoQuery(command=MongoCommands.INSERT_ONE, - document=MongoDocuments.SHOUTS, - data=new_shout_data)) - db_controller.exec_query(query=MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.CHATS, - filters=filter_expression, - data={'chat_flow': new_shout_data['_id']}, - data_action='push')) - if is_announcement == '0' and data['prompt_id']: - is_ok = handle_prompt_message(data) - if is_ok: - await sio.emit('new_prompt_message', data={'cid': data['cid'], - 'userID': data['userID'], - 'messageText': data['messageText'], - 'promptID': data['prompt_id'], - 'promptState': data['promptState']}) - - message_tts = data.get('messageTTS', {}) - for language, gender_mapping in message_tts.items(): - for gender, audio_data in gender_mapping.items(): - sftp_connector.put_file_object(file_object=audio_data, save_to=f'audio/{audio_path}') - DbUtils.save_tts_response(shout_id=data['messageID'], audio_file_name=audio_path, - lang=language, gender=gender) - - data['bound_service'] = cid_data.get('bound_service', '') - await sio.emit('new_message', data=data, skip_sid=[sid]) - PopularityCounter.increment_cid_popularity(new_shout_data['cid']) - except Exception as ex: - LOG.error(f'Exception on sio processing: {ex}') - await emit_error(sids=[sid], message=f'Unable to process request "user_message" with data: {data}') - - -@sio.event -# @login_required -async def new_prompt(sid, data): - """ - SIO event fired on new prompt data saving request - :param sid: client session id - :param data: user message data - Example: - ``` - data = {'cid':'conversation id', - 'promptID': 'id of related prompt', - 'context': 'message context (optional)', - 'timeCreated': 'timestamp on which message was created' - } - ``` - """ - prompt_id = data['prompt_id'] - cid = data['cid'] - prompt_text = data['prompt_text'] - created_on = int(data.get('created_on') or time()) - try: - formatted_data = {'_id': prompt_id, - 'cid': cid, - 'is_completed': '0', - 'data': {'prompt_text': prompt_text}, - 'created_on': created_on} - db_controller.exec_query(MongoQuery(command=MongoCommands.INSERT_ONE, - document=MongoDocuments.PROMPTS, - data=formatted_data)) - await sio.emit('new_prompt_created', data=formatted_data) - except Exception as ex: - LOG.error(f'Prompt "{prompt_id}" was not created due to exception - {ex}') - - -@sio.event -# @login_required -async def prompt_completed(sid, data): - """ - SIO event fired upon prompt completion - :param sid: client session id - :param data: user message data - """ - prompt_id = data['context']['prompt']['prompt_id'] - prompt_summary_keys = ['winner', 'votes_per_submind'] - prompt_summary_agg = {f'data.{k}': v for k, v in data['context'].items() if k in prompt_summary_keys} - prompt_summary_agg['is_completed'] = '1' - try: - db_controller.exec_query(MongoQuery(command=MongoCommands.UPDATE, - document=MongoDocuments.PROMPTS, - filters=MongoFilter(key='_id', value=prompt_id), - data=prompt_summary_agg, - data_action='set')) - formatted_data = {'winner': data['context'].get('winner', ''), 'prompt_id': prompt_id} - await sio.emit('set_prompt_completed', data=formatted_data) - except Exception as ex: - LOG.error(f'Prompt "{prompt_id}" was not updated due to exception - {ex}') - - -@sio.event -# @login_required -async def get_prompt_data(sid, data): - """ - SIO event fired getting prompt data request - :param sid: client session id - :param data: user message data - Example: - ``` - data = {'userID': 'emitted user id', - 'cid':'conversation id', - 'promptID': 'id of related prompt'} - ``` - """ - prompt_id = data.get('prompt_id') - _prompt_data = DbUtils.fetch_prompt_data(cid=data['cid'], - limit=data.get('limit', 5), - prompt_ids=[prompt_id], - fetch_user_data=True) - if prompt_id: - prompt_data = {'_id': _prompt_data[0]['_id'], - 'is_completed': _prompt_data[0].get('is_completed', '1'), - **_prompt_data[0].get('data')} - else: - prompt_data = [] - for item in _prompt_data: - prompt_data.append({'_id': item['_id'], 'created_on': item['created_on'], - 'is_completed': item.get('is_completed', '1'), **item['data']}) - result = dict(data=prompt_data, receiver=data['nick'], cid=data['cid'], request_id=data['request_id'], ) - await sio.emit('prompt_data', data=result) - - -@sio.event -# @login_required -async def request_translate(sid, data): - """ - Handles requesting for cid translation - :param sid: client session id - :param data: mapping of cid to desired translation language - """ - if not data: - LOG.warning('Missing request translate data, skipping...') - else: - input_type = data.get('inputType', 'incoming') - - populated_translations, missing_translations = DbUtils.get_translations( - translation_mapping=data.get('chat_mapping', {})) - if populated_translations and not missing_translations: - await sio.emit('translation_response', data={'translations': populated_translations, - 'input_type': input_type}, to=sid) - else: - LOG.info('Not every translation is contained in db, sending out request to Neon') - request_id = generate_uuid() - caching_instance = {'translations': populated_translations, 'sid': sid, - 'input_type': input_type} - CacheFactory.get('translation_cache', cache_type=LRUCache).put(key=request_id, value=caching_instance) - await sio.emit('request_neon_translations', data={'request_id': request_id, 'data': missing_translations}, ) - - -@sio.event -async def get_neon_translations(sid, data): - """ - Handles received translations from Neon Translation Service - :param sid: client session id - :param data: received translations data - Example of translations data: - ``` - data = { - 'request_id': (emitted request id), - 'translations':(dictionary containing mapping of shout id to translations) - } - ``` - """ - request_id = data.get('request_id') - if not request_id: - LOG.error('Missing "request id" in response dict') - else: - try: - cached_data = CacheFactory.get('translation_cache').get(key=request_id) - if not cached_data: - LOG.warning('Failed to get matching cached data') - return - sid = cached_data.get('sid') - input_type = cached_data.get('input_type') - updated_shouts = DbUtils.save_translations(data.get('translations', {})) - populated_translations = deep_merge(data.get('translations', {}), cached_data.get('translations', {})) - await sio.emit('translation_response', data={'translations': populated_translations, - 'input_type': input_type}, to=sid) - if updated_shouts: - send_dict = { - 'input_type': input_type, - 'translations': updated_shouts, - } - await sio.emit('updated_shouts', data=send_dict, skip_sid=[sid]) - except KeyError as err: - LOG.error(f'No translation cache detected under request_id={request_id} (err={err})') - - -@sio.event -# @login_required -async def request_tts(sid, data): - """ - Handles request to Neon TTS service - - :param sid: client session id - :param data: received tts request data - Example of tts request data: - ``` - data = { - 'message_id': (target message id), - 'message_text':(target message text), - 'lang': (target message lang) - } - ``` - """ - required_keys = ('cid', 'message_id', 'user_id',) - if not all(key in list(data) for key in required_keys): - LOG.error(f'Missing one of the required keys - {required_keys}') - else: - lang = data.get('lang', 'en') - message_id = data['message_id'] - user_id = data['user_id'] - cid = data['cid'] - matching_messages = DbUtils.fetch_shouts(shout_ids=[message_id], fetch_senders=False) - if not matching_messages: - LOG.error('Failed to request TTS - matching message not found') - else: - matching_message = matching_messages[0] - - # Trying to get existing audio data - preferred_gender = DbUtils.get_user_preferences(user_id=user_id).get('tts', {}).get(lang, {}).get('gender', - 'female') - existing_audio_file = matching_message.get('audio', {}).get(lang, {}).get(preferred_gender) - if not existing_audio_file: - LOG.info(f'File was not detected for cid={cid}, message_id={message_id}, lang={lang}') - message_text = matching_message.get('message_text') - formatted_data = { - 'cid': cid, - 'sid': sid, - 'message_id': message_id, - 'text': message_text, - 'lang': LanguageSettings.to_neon_lang(lang) - } - await sio.emit('get_tts', data=formatted_data) - else: - try: - file_location = f'audio/{existing_audio_file}' - LOG.info(f'Fetching existing file from: {file_location}') - fo = sftp_connector.get_file_object(file_location) - if fo.getbuffer().nbytes > 0: - LOG.info(f'File detected for cid={cid}, message_id={message_id}, lang={lang}') - audio_data = buffer_to_base64(fo) - response_data = { - 'cid': cid, - 'message_id': message_id, - 'lang': lang, - 'gender': preferred_gender, - 'audio_data': audio_data - } - await sio.emit('incoming_tts', data=response_data, to=sid) - else: - LOG.error(f'Empty file detected for cid={cid}, message_id={message_id}, lang={lang}') - except Exception as ex: - LOG.error(f'Failed to send TTS response - {ex}') - - -@sio.event -async def tts_response(sid, data): - """ Handle TTS Response from Observer """ - mq_context = data.get('context', {}) - cid = mq_context.get('cid') - message_id = mq_context.get('message_id') - sid = mq_context.get('sid') - lang = LanguageSettings.to_system_lang(data.get('lang', 'en-us')) - lang_gender = data.get('gender', 'undefined') - matching_shouts = DbUtils.fetch_shouts(shout_ids=[message_id], fetch_senders=False) - if not matching_shouts: - LOG.warning(f'Skipping TTS Response for message_id={message_id} - matching shout does not exist') - else: - audio_data = data.get('audio_data') - if not audio_data: - LOG.warning(f'Skipping TTS Response for message_id={message_id} - audio data is empty') - else: - is_ok = DbUtils.save_tts_response(shout_id=message_id, audio_data=audio_data, - lang=lang, gender=lang_gender) - if is_ok: - response_data = { - 'cid': cid, - 'message_id': message_id, - 'lang': lang, - 'gender': lang_gender, - 'audio_data': audio_data - } - await sio.emit('incoming_tts', data=response_data, to=sid) - else: - to = None - if sid: - to = [sid] - await emit_error(message='Failed to get TTS response', context={'message_id': message_id, - 'cid': cid}, sids=to) - - -@sio.event -async def stt_response(sid, data): - """ Handle STT Response from Observer """ - mq_context = data.get('context', {}) - message_id = mq_context.get('message_id') - matching_shouts = DbUtils.fetch_shouts(shout_ids=[message_id], fetch_senders=False) - if not matching_shouts: - LOG.warning(f'Skipping STT Response for message_id={message_id} - matching shout does not exist') - else: - try: - message_text = data.get('transcript') - lang = LanguageSettings.to_system_lang(data['lang']) - DbUtils.save_stt_response(shout_id=message_id, message_text=message_text, lang=lang) - sid = mq_context.get('sid') - cid = mq_context.get('cid') - response_data = { - 'cid': cid, - 'message_id': message_id, - 'lang': lang, - 'message_text': message_text - } - await sio.emit('incoming_stt', data=response_data, to=sid) - except Exception as ex: - LOG.error(f'Failed to save received transcript due to exception {ex}') - - -@sio.event -# @login_required -async def request_stt(sid, data): - """ - Handles request to Neon STT service - - :param sid: client session id - :param data: received tts request data - Example of tts request data: - ``` - data = { - 'cid': (target conversation id) - 'message_id': (target message id), - 'audio_data':(target audio data base64 encoded), - (optional) 'lang': (target message lang) - } - ``` - """ - required_keys = ('message_id',) - if not all(key in list(data) for key in required_keys): - LOG.error(f'Missing one of the required keys - {required_keys}') - else: - cid = data.get('cid', '') - message_id = data.get('message_id', '') - # TODO: process received language - lang = 'en' - # lang = data.get('lang', 'en') - existing_shouts = DbUtils.fetch_shouts(shout_ids=[message_id]) - if existing_shouts: - existing_transcript = existing_shouts[0].get('transcripts', {}).get(lang) - if existing_transcript: - response_data = { - 'cid': cid, - 'message_id': message_id, - 'lang': lang, - 'message_text': existing_transcript - } - return await sio.emit('incoming_stt', data=response_data, to=sid) - audio_data = data.get('audio_data') or DbUtils.fetch_audio_data_from_message(message_id) - if not audio_data: - LOG.error('Failed to fetch audio data') - else: - lang = LanguageSettings.to_neon_lang(lang) - formatted_data = { - 'cid': cid, - 'sid': sid, - 'message_id': message_id, - 'audio_data': audio_data, - 'lang': lang, - } - await sio.emit('get_stt', data=formatted_data) - - -async def emit_error(message: str, context: Optional[dict] = None, sids: Optional[List[str]] = None): - """ - Emits error message to provided sid - - :param message: message to emit - :param sids: client session ids (optional) - :param context: context to emit (optional) - """ - if not context: - context = {} - await sio.emit(context.pop('callback_event', 'klatchat_sio_error'), - data={'msg': message}, - to=sids) - - -async def emit_session_expired(sid: str): - """ Wrapper to emit session expired session event to desired client session """ - await emit_error(message='Session Expired', context={'callback_event': 'auth_expired'}, sids=[sid]) diff --git a/chat_server/sio/__init__.py b/chat_server/sio/__init__.py new file mode 100644 index 00000000..b8622a2d --- /dev/null +++ b/chat_server/sio/__init__.py @@ -0,0 +1,37 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Import used sio server instance and handlers here +from .server import sio +from .handlers import ( + session, + stt, + tts, + translation, + user_message as languages_blueprint, +) diff --git a/chat_server/sio/handlers/prompt.py b/chat_server/sio/handlers/prompt.py new file mode 100644 index 00000000..14c7c35d --- /dev/null +++ b/chat_server/sio/handlers/prompt.py @@ -0,0 +1,137 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from time import time + +from utils.database_utils.mongo_utils.queries import mongo_queries +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI +from utils.logging_utils import LOG +from ..server import sio + + +@sio.event +# @login_required +async def new_prompt(sid, data): + """ + SIO event fired on new prompt data saving request + :param sid: client session id + :param data: user message data + Example: + ``` + data = {'cid':'conversation id', + 'promptID': 'id of related prompt', + 'context': 'message context (optional)', + 'timeCreated': 'timestamp on which message was created' + } + ``` + """ + prompt_id = data["prompt_id"] + cid = data["cid"] + prompt_text = data["prompt_text"] + created_on = int(data.get("created_on") or time()) + try: + formatted_data = { + "_id": prompt_id, + "cid": cid, + "is_completed": "0", + "data": {"prompt_text": prompt_text}, + "created_on": created_on, + } + MongoDocumentsAPI.PROMPTS.add_item(data=formatted_data) + await sio.emit("new_prompt_created", data=formatted_data) + except Exception as ex: + LOG.error(f'Prompt "{prompt_id}" was not created due to exception - {ex}') + + +@sio.event +# @login_required +async def prompt_completed(sid, data): + """ + SIO event fired upon prompt completion + :param sid: client session id + :param data: user message data + """ + prompt_id = data["context"]["prompt"]["prompt_id"] + + LOG.info(f"setting {prompt_id = } as completed") + + MongoDocumentsAPI.PROMPTS.set_completed( + prompt_id=prompt_id, prompt_context=data["context"] + ) + formatted_data = { + "winner": data["context"].get("winner", ""), + "prompt_id": prompt_id, + } + await sio.emit("set_prompt_completed", data=formatted_data) + + +@sio.event +# @login_required +async def get_prompt_data(sid, data): + """ + SIO event fired getting prompt data request + :param sid: client session id + :param data: user message data + Example: + ``` + data = {'userID': 'emitted user id', + 'cid':'conversation id', + 'promptID': 'id of related prompt'} + ``` + """ + prompt_id = data.get("prompt_id") + _prompt_data = mongo_queries.fetch_prompt_data( + cid=data["cid"], + limit=data.get("limit", 5), + prompt_ids=[prompt_id], + fetch_user_data=True, + ) + if prompt_id: + prompt_data = { + "_id": _prompt_data[0]["_id"], + "is_completed": _prompt_data[0].get("is_completed", "1"), + **_prompt_data[0].get("data"), + } + else: + prompt_data = [] + for item in _prompt_data: + prompt_data.append( + { + "_id": item["_id"], + "created_on": item["created_on"], + "is_completed": item.get("is_completed", "1"), + **item["data"], + } + ) + result = dict( + data=prompt_data, + receiver=data["nick"], + cid=data["cid"], + request_id=data["request_id"], + ) + await sio.emit("prompt_data", data=result) diff --git a/chat_server/sio/handlers/session.py b/chat_server/sio/handlers/session.py new file mode 100644 index 00000000..69c64f85 --- /dev/null +++ b/chat_server/sio/handlers/session.py @@ -0,0 +1,62 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from utils.logging_utils import LOG +from ..server import sio + + +@sio.event +async def connect(sid, environ: dict, auth): + """ + SIO event fired on client connect + :param sid: client session id + :param environ: connection environment dict + :param auth: authorization method (None if was not provided) + """ + LOG.info(f"{sid} connected") + + +@sio.event +async def ping(sid, data): + """ + SIO event fired on client ping request + :param sid: client session id + :param data: user message data + """ + LOG.info(f'Received ping request from "{sid}"') + await sio.emit("pong", data={"msg": "hello from sio server"}) + + +@sio.event +async def disconnect(sid): + """ + SIO event fired on client disconnect + + :param sid: client session id + """ + LOG.info(f"{sid} disconnected") diff --git a/chat_server/sio/handlers/stt.py b/chat_server/sio/handlers/stt.py new file mode 100644 index 00000000..9287a950 --- /dev/null +++ b/chat_server/sio/handlers/stt.py @@ -0,0 +1,121 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI +from utils.logging_utils import LOG +from ..server import sio +from ..utils import emit_error +from ...server_utils.languages import LanguageSettings + + +@sio.event +async def stt_response(sid, data): + """Handle STT Response from Observer""" + mq_context = data.get("context", {}) + message_id = mq_context.get("message_id") + matching_shout = MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id) + if not matching_shout: + LOG.warning( + f"Skipping STT Response for message_id={message_id} - matching shout does not exist" + ) + else: + try: + message_text = data.get("transcript") + lang = LanguageSettings.to_system_lang(data["lang"]) + MongoDocumentsAPI.SHOUTS.save_stt_response( + shout_id=message_id, message_text=message_text, lang=lang + ) + sid = mq_context.get("sid") + cid = mq_context.get("cid") + response_data = { + "cid": cid, + "message_id": message_id, + "lang": lang, + "message_text": message_text, + } + await sio.emit("incoming_stt", data=response_data, to=sid) + except Exception as ex: + LOG.error(f"Failed to save received transcript due to exception {ex}") + + +@sio.event +# @login_required +async def request_stt(sid, data): + """ + Handles request to Neon STT service + + :param sid: client session id + :param data: received tts request data + Example of tts request data: + ``` + data = { + 'cid': (target conversation id) + 'message_id': (target message id), + 'audio_data':(target audio data base64 encoded), + (optional) 'lang': (target message lang) + } + ``` + """ + required_keys = ("message_id",) + if not all(key in list(data) for key in required_keys): + LOG.error(f"Missing one of the required keys - {required_keys}") + else: + cid = data.get("cid", "") + message_id = data.get("message_id", "") + # TODO: process received language + lang = "en" + # lang = data.get('lang', 'en') + if shout_data := MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id): + message_transcript = shout_data.get("transcripts", {}).get(lang) + if message_transcript: + response_data = { + "cid": cid, + "message_id": message_id, + "lang": lang, + "message_text": message_transcript, + } + return await sio.emit("incoming_stt", data=response_data, to=sid) + else: + err_msg = "Message transcript was missing" + LOG.error(err_msg) + return await emit_error(message=err_msg, sids=[sid]) + audio_data = data.get( + "audio_data" + ) or MongoDocumentsAPI.SHOUTS.fetch_audio_data(message_id=message_id) + if not audio_data: + LOG.error("Failed to fetch audio data") + else: + lang = LanguageSettings.to_neon_lang(lang) + formatted_data = { + "cid": cid, + "sid": sid, + "message_id": message_id, + "audio_data": audio_data, + "lang": lang, + } + await sio.emit("get_stt", data=formatted_data) diff --git a/chat_server/sio/handlers/translation.py b/chat_server/sio/handlers/translation.py new file mode 100644 index 00000000..e633171e --- /dev/null +++ b/chat_server/sio/handlers/translation.py @@ -0,0 +1,125 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from cachetools import LRUCache + +from utils.common import generate_uuid, deep_merge +from utils.database_utils.mongo_utils.queries import mongo_queries +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI +from utils.logging_utils import LOG +from ..server import sio +from ...server_utils.cache_utils import CacheFactory + + +@sio.event +# @login_required +async def request_translate(sid, data): + """ + Handles requesting for cid translation + :param sid: client session id + :param data: mapping of cid to desired translation language + """ + if not data: + LOG.warning("Missing request translate data, skipping...") + else: + input_type = data.get("inputType", "incoming") + + populated_translations, missing_translations = mongo_queries.get_translations( + translation_mapping=data.get("chat_mapping", {}) + ) + if populated_translations and not missing_translations: + await sio.emit( + "translation_response", + data={"translations": populated_translations, "input_type": input_type}, + to=sid, + ) + else: + LOG.info( + "Not every translation is contained in db, sending out request to Neon" + ) + request_id = generate_uuid() + caching_instance = { + "translations": populated_translations, + "sid": sid, + "input_type": input_type, + } + CacheFactory.get("translation_cache", cache_type=LRUCache)[ + request_id + ] = caching_instance + await sio.emit( + "request_neon_translations", + data={"request_id": request_id, "data": missing_translations}, + ) + + +@sio.event +async def get_neon_translations(sid, data): + """ + Handles received translations from Neon Translation Service + :param sid: client session id + :param data: received translations data + Example of translations data: + ``` + data = { + 'request_id': (emitted request id), + 'translations':(dictionary containing mapping of shout id to translations) + } + ``` + """ + request_id = data.get("request_id") + if not request_id: + LOG.error('Missing "request id" in response dict') + else: + try: + cached_data = CacheFactory.get("translation_cache").get(key=request_id) + if not cached_data: + LOG.warning("Failed to get matching cached data") + return + sid = cached_data.get("sid") + input_type = cached_data.get("input_type") + updated_shouts = MongoDocumentsAPI.SHOUTS.save_translations( + translation_mapping=data.get("translations", {}) + ) + populated_translations = deep_merge( + data.get("translations", {}), cached_data.get("translations", {}) + ) + await sio.emit( + "translation_response", + data={"translations": populated_translations, "input_type": input_type}, + to=sid, + ) + if updated_shouts: + send_dict = { + "input_type": input_type, + "translations": updated_shouts, + } + await sio.emit("updated_shouts", data=send_dict, skip_sid=[sid]) + except KeyError as err: + LOG.error( + f"No translation cache detected under request_id={request_id} (err={err})" + ) diff --git a/chat_server/sio/handlers/tts.py b/chat_server/sio/handlers/tts.py new file mode 100644 index 00000000..ece82ad4 --- /dev/null +++ b/chat_server/sio/handlers/tts.py @@ -0,0 +1,164 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from utils.common import buffer_to_base64 +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI +from utils.logging_utils import LOG +from ..server import sio +from ..utils import emit_error +from ...server_config import sftp_connector +from ...server_utils.languages import LanguageSettings + + +@sio.event +# @login_required +async def request_tts(sid, data): + """ + Handles request to Neon TTS service + + :param sid: client session id + :param data: received tts request data + Example of tts request data: + ``` + data = { + 'message_id': (target message id), + 'message_text':(target message text), + 'lang': (target message lang) + } + ``` + """ + required_keys = ( + "cid", + "message_id", + ) + if not all(key in list(data) for key in required_keys): + LOG.error(f"Missing one of the required keys - {required_keys}") + else: + lang = data.get("lang", "en") + message_id = data["message_id"] + cid = data["cid"] + matching_message = MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id) + if not matching_message: + LOG.error("Failed to request TTS - matching message not found") + else: + # TODO: support for multiple genders in TTS + # Trying to get existing audio data + # preferred_gender = ( + # MongoDocumentsAPI.USERS.get_preferences(user_id=user_id) + # .get("tts", {}) + # .get(lang, {}) + # .get("gender", "female") + # ) + preferred_gender = "female" + audio_file = ( + matching_message.get("audio", {}).get(lang, {}).get(preferred_gender) + ) + if not audio_file: + LOG.info( + f"File was not detected for cid={cid}, message_id={message_id}, lang={lang}" + ) + message_text = matching_message.get("message_text") + formatted_data = { + "cid": cid, + "sid": sid, + "message_id": message_id, + "text": message_text, + "lang": LanguageSettings.to_neon_lang(lang), + } + await sio.emit("get_tts", data=formatted_data) + else: + try: + file_location = f"audio/{audio_file}" + LOG.info(f"Fetching existing file from: {file_location}") + fo = sftp_connector.get_file_object(file_location) + if fo.getbuffer().nbytes > 0: + LOG.info( + f"File detected for cid={cid}, message_id={message_id}, lang={lang}" + ) + audio_data = buffer_to_base64(fo) + response_data = { + "cid": cid, + "message_id": message_id, + "lang": lang, + "gender": preferred_gender, + "audio_data": audio_data, + } + await sio.emit("incoming_tts", data=response_data, to=sid) + else: + LOG.error( + f"Empty file detected for cid={cid}, message_id={message_id}, lang={lang}" + ) + except Exception as ex: + LOG.error(f"Failed to send TTS response - {ex}") + + +@sio.event +async def tts_response(sid, data): + """Handle TTS Response from Observer""" + mq_context = data.get("context", {}) + cid = mq_context.get("cid") + message_id = mq_context.get("message_id") + sid = mq_context.get("sid") + lang = LanguageSettings.to_system_lang(data.get("lang", "en-us")) + lang_gender = data.get("gender", "undefined") + matching_shout = MongoDocumentsAPI.SHOUTS.get_item(item_id=message_id) + if not matching_shout: + LOG.warning( + f"Skipping TTS Response for message_id={message_id} - matching shout does not exist" + ) + else: + audio_data = data.get("audio_data") + if not audio_data: + LOG.warning( + f"Skipping TTS Response for message_id={message_id} - audio data is empty" + ) + else: + is_ok = MongoDocumentsAPI.SHOUTS.save_tts_response( + shout_id=message_id, + audio_data=audio_data, + lang=lang, + gender=lang_gender, + ) + if is_ok: + response_data = { + "cid": cid, + "message_id": message_id, + "lang": lang, + "gender": lang_gender, + "audio_data": audio_data, + } + await sio.emit("incoming_tts", data=response_data, to=sid) + else: + to = None + if sid: + to = [sid] + await emit_error( + message="Failed to get TTS response", + context={"message_id": message_id, "cid": cid}, + sids=to, + ) diff --git a/chat_server/sio/handlers/user_message.py b/chat_server/sio/handlers/user_message.py new file mode 100644 index 00000000..f0e63195 --- /dev/null +++ b/chat_server/sio/handlers/user_message.py @@ -0,0 +1,192 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from time import time + +from utils.common import generate_uuid +from utils.database_utils.mongo_utils.queries import mongo_queries +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI +from utils.logging_utils import LOG +from ..server import sio +from ..utils import emit_error +from ...server_config import sftp_connector +from ...services.popularity_counter import PopularityCounter + + +@sio.event +# @login_required +async def user_message(sid, data): + """ + SIO event fired on new user message in chat + :param sid: client session id + :param data: user message data + Example: + ``` + data = {'cid':'conversation id', + 'userID': 'emitted user id', + 'promptID': 'id of related prompt (optional)', + 'source': 'declared name of the source that shouted given user message' + 'messageText': 'content of the user message', + 'repliedMessage': 'id of replied message (optional)', + 'bot': 'if the message is from bot (defaults to False)', + 'lang': 'language of the message (defaults to "en")' + 'attachments': 'list of filenames that were send with message', + 'context': 'message context (optional)', + 'test': 'is test message (defaults to False)', + 'isAudio': '1 if current message is audio message 0 otherwise', + 'messageTTS': received tts mapping of type: {language: {gender: (audio data base64 encoded)}}, + 'isAnnouncement': if received message is the announcement, + 'timeCreated': 'timestamp on which message was created'} + ``` + """ + LOG.debug(f"Got new user message from {sid}: {data}") + try: + cid_data = MongoDocumentsAPI.CHATS.get_conversation_data( + search_str=data["cid"], + column_identifiers=["_id"], + ) + if not cid_data: + msg = "Shouting to non-existent conversation, skipping further processing" + await emit_error(sids=[sid], message=msg) + return + + LOG.info(f"Received user message data: {data}") + data["message_id"] = generate_uuid() + data["is_bot"] = data.pop("bot", "0") + if data["userID"].startswith("neon"): + neon_data = MongoDocumentsAPI.USERS.get_neon_data(skill_name="neon") + data["userID"] = neon_data["_id"] + elif data["is_bot"] == "1": + bot_data = MongoDocumentsAPI.USERS.get_bot_data( + user_id=data["userID"], context=data.get("context") + ) + data["userID"] = bot_data["_id"] + + is_audio = data.get("isAudio", "0") + + if is_audio != "1": + is_audio = "0" + + audio_path = f'{data["message_id"]}_audio.wav' + try: + if is_audio == "1": + message_text = data["messageText"].split(",")[-1] + sftp_connector.put_file_object( + file_object=message_text, save_to=f"audio/{audio_path}" + ) + # for audio messages "message_text" references the name of the audio stored + data["messageText"] = audio_path + except Exception as ex: + LOG.error(f"Failed to located file - {ex}") + return -1 + + is_announcement = data.get("isAnnouncement", "0") or "0" + + if is_announcement != "1": + is_announcement = "0" + + lang = data.get("lang", "en") + data["prompt_id"] = data.pop("promptID", "") + + new_shout_data = { + "_id": data["message_id"], + "cid": data["cid"], + "user_id": data["userID"], + "prompt_id": data["prompt_id"], + "message_text": data["messageText"], + "message_lang": lang, + "attachments": data.get("attachments", []), + "replied_message": data.get("repliedMessage", ""), + "is_audio": is_audio, + "is_announcement": is_announcement, + "is_bot": data["is_bot"], + "translations": {}, + "created_on": int(data.get("timeCreated", time())), + } + + # in case message is received in some foreign language - + # message text is kept in that language unless English translation received + if lang != "en": + new_shout_data["translations"][lang] = data["messageText"] + + mongo_queries.add_shout(data=new_shout_data) + if is_announcement == "0" and data.get("prompt_id"): + is_ok = MongoDocumentsAPI.PROMPTS.add_shout_to_prompt( + prompt_id=data["prompt_id"], + user_id=data["userID"], + message_id=data["message_id"], + prompt_state=data["promptState"], + ) + if is_ok: + await sio.emit( + "new_prompt_message", + data={ + "cid": data["cid"], + "userID": data["userID"], + "messageText": data["messageText"], + "promptID": data["prompt_id"], + "promptState": data["promptState"], + }, + ) + + message_tts = data.get("messageTTS", {}) + for language, gender_mapping in message_tts.items(): + for gender, audio_data in gender_mapping.items(): + MongoDocumentsAPI.SHOUTS.save_tts_response( + shout_id=data["message_id"], + audio_data=audio_data, + lang=language, + gender=gender, + ) + + data["bound_service"] = cid_data.get("bound_service", "") + await sio.emit("new_message", data=data, skip_sid=[sid]) + PopularityCounter.increment_cid_popularity(new_shout_data["cid"]) + except Exception as ex: + LOG.error(f"Exception on sio processing: {ex}") + await emit_error( + sids=[sid], + message=f'Unable to process request "user_message" with data: {data}', + ) + + +@sio.event +# @login_required +async def broadcast(sid, data): + """Forwards received broadcast message from client""" + # TODO: introduce certification mechanism to forward messages only from trusted entities + msg_type = data.pop("msg_type", None) + msg_receivers = data.pop("to", None) + if msg_type: + await sio.emit( + msg_type, + data=data, + to=msg_receivers, + ) + else: + LOG.error(f'data={data} skipped - no "msg_type" provided') diff --git a/chat_server/sio/server.py b/chat_server/sio/server.py new file mode 100644 index 00000000..f4a2b70b --- /dev/null +++ b/chat_server/sio/server.py @@ -0,0 +1,31 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import socketio + +sio = socketio.AsyncServer(cors_allowed_origins="*", async_mode="asgi") diff --git a/chat_server/sio/utils.py b/chat_server/sio/utils.py new file mode 100644 index 00000000..734bc1a9 --- /dev/null +++ b/chat_server/sio/utils.py @@ -0,0 +1,120 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import os +from functools import wraps +from typing import Optional, List + +from utils.logging_utils import LOG +from .server import sio +from ..server_utils.auth import validate_session + + +def list_current_headers(sid: str) -> list: + return ( + sio.environ.get(sio.manager.rooms["/"].get(sid, {}).get(sid), {}) + .get("asgi.scope", {}) + .get("headers", []) + ) + + +def get_header(sid: str, match_str: str): + for header_tuple in list_current_headers(sid): + if header_tuple[0].decode() == match_str.lower(): + return header_tuple[1].decode() + + +def login_required(*outer_args, **outer_kwargs): + """ + Decorator that validates current authorization token + """ + + no_args = False + func = None + if len(outer_args) == 1 and not outer_kwargs and callable(outer_args[0]): + # Function was called with no arguments + no_args = True + func = outer_args[0] + + outer_kwargs.setdefault("tmp_allowed", True) + + def outer(func): + @wraps(func) + async def wrapper(sid, *args, **kwargs): + if os.environ.get("DISABLE_AUTH_CHECK", "0") != "1": + auth_token = get_header(sid, "session") + session_validation_output = ( + None, + None, + ) + if auth_token: + session_validation_output = validate_session( + auth_token, + check_tmp=not outer_kwargs["tmp_allowed"], + sio_request=True, + ) + if session_validation_output[1] != 200: + return await sio.emit("auth_expired", data={}, to=sid) + return await func(sid, *args, **kwargs) + + return wrapper + + if no_args: + return outer(func) + else: + return outer + + +async def emit_error( + message: str, context: Optional[dict] = None, sids: Optional[List[str]] = None +): + """ + Emits error message to provided sid + + :param message: message to emit + :param sids: client session ids (optional) + :param context: context to emit (optional) + """ + if not context: + context = {} + LOG.error(message) + await sio.emit( + context.pop("callback_event", "klatchat_sio_error"), + data={"msg": message}, + to=sids, + ) + + +async def emit_session_expired(sid: str): + """Wrapper to emit session expired session event to desired client session""" + await emit_error( + message="Session Expired", + context={"callback_event": "auth_expired"}, + sids=[sid], + ) diff --git a/chat_server/tests/test_sio.py b/chat_server/tests/test_sio.py index 06856429..8ff5d39c 100644 --- a/chat_server/tests/test_sio.py +++ b/chat_server/tests/test_sio.py @@ -37,18 +37,23 @@ from chat_server.constants.users import ChatPatterns from chat_server.tests.beans.server import ASGITestServer -from chat_server.server_utils.auth import generate_uuid from chat_server.server_config import db_controller from utils.logging_utils import LOG +from utils.common import generate_uuid SERVER_ADDRESS = "http://127.0.0.1:8888" -TEST_CID = '-1' +TEST_CID = "-1" @pytest.fixture(scope="session") def create_server(): """Creates ASGI server for testing""" - config = Config('chat_server.tests.utils.app_utils:get_test_app', port=8888, log_level="info", factory=True) + config = Config( + "chat_server.tests.utils.app_utils:get_test_app", + port=8888, + log_level="info", + factory=True, + ) app_server = ASGITestServer(config=config) with app_server.run_in_thread(): yield @@ -63,106 +68,151 @@ class TestSIO(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - os.environ['DISABLE_AUTH_CHECK'] = '1' - matching_conversation = db_controller.exec_query(query={'command': 'find_one', - 'document': 'chats', - 'data': {'_id': TEST_CID}}) + from chat_server.server_config import database_config_path + + assert os.path.isfile(database_config_path) + os.environ["DISABLE_AUTH_CHECK"] = "1" + matching_conversation = db_controller.exec_query( + query={ + "command": "find_one", + "document": "chats", + "data": {"_id": TEST_CID}, + } + ) if not matching_conversation: - db_controller.exec_query(query={'document': 'chats', - 'command': 'insert_one', - 'data': ChatPatterns.TEST_CHAT.value}) + db_controller.exec_query( + query={ + "document": "chats", + "command": "insert_one", + "data": ChatPatterns.TEST_CHAT.value, + } + ) @classmethod def tearDownClass(cls) -> None: - db_controller.exec_query(query={'document': 'chats', - 'command': 'delete_one', - 'data': {'_id': ChatPatterns.TEST_CHAT.value}}) + db_controller.exec_query( + query={ + "document": "chats", + "command": "delete_one", + "data": {"_id": ChatPatterns.TEST_CHAT.value}, + } + ) def handle_pong(self, _data): """Handles pong from sio server""" - LOG.info('Received pong') + LOG.info("Received pong") self.pong_received = True self.pong_event.set() def setUp(self) -> None: self.sio = socketio.Client() self.sio.connect(url=SERVER_ADDRESS) - LOG.info(f'Socket IO client connected to {SERVER_ADDRESS}') + LOG.info(f"Socket IO client connected to {SERVER_ADDRESS}") def tearDown(self) -> None: if self.sio: self.sio.disconnect() - @pytest.mark.usefixtures('create_server') + @pytest.mark.usefixtures("create_server") def test_01_ping_server(self): - self.sio.on('pong', self.handle_pong) + self.sio.on("pong", self.handle_pong) self.pong_event = Event() time.sleep(5) - self.sio.emit('ping', data={'knock': 'knock'}) - LOG.info(f'Socket IO client connected to {SERVER_ADDRESS}') + self.sio.emit("ping", data={"knock": "knock"}) + LOG.info(f"Socket IO client connected to {SERVER_ADDRESS}") self.pong_event.wait(5) self.assertEqual(self.pong_received, True) - @pytest.mark.usefixtures('create_server') + @pytest.mark.usefixtures("create_server") def test_neon_message(self): - message_id = f'test_neon_{generate_uuid()}' - user_id = 'neon' - message_data = {'userID': 'neon', - 'messageID': message_id, - 'messageText': 'Neon Test 123', - 'bot': '0', - 'cid': '-1', - 'test': True, - 'timeCreated': int(time.time())} - self.sio.emit('user_message', data=message_data) + message_id = f"test_neon_{generate_uuid()}" + user_id = "neon" + message_data = { + "userID": "neon", + "messageText": "Neon Test 123", + "bot": "0", + "cid": "-1", + "test": True, + "timeCreated": int(time.time()), + } + self.sio.emit("user_message", data=message_data) time.sleep(2) - neon = db_controller.exec_query(query={'command': 'find_one', - 'document': 'users', - 'data': {'nickname': user_id}}) + neon = db_controller.exec_query( + query={ + "command": "find_one", + "document": "users", + "data": {"nickname": user_id}, + } + ) self.assertIsNotNone(neon) self.assertIsInstance(neon, dict) - shout = db_controller.exec_query(query={'command': 'find_one', - 'document': 'shouts', - 'data': {'_id': message_id}}) + shout = db_controller.exec_query( + query={ + "command": "find_one", + "document": "shouts", + "data": {"user_id": neon["_id"]}, + } + ) self.assertIsNotNone(shout) self.assertIsInstance(shout, dict) - db_controller.exec_query(query={'command': 'delete_one', - 'document': 'shouts', - 'data': {'_id': message_id}}) - - @pytest.mark.usefixtures('create_server') + db_controller.exec_query( + query={ + "command": "delete_many", + "document": "shouts", + "data": {"_id": neon["_id"]}, + } + ) + + @pytest.mark.usefixtures("create_server") def test_bot_message(self): - message_id = f'test_bot_message_{generate_uuid()}' - user_id = f'test_bot_{generate_uuid()}' - message_data = {'userID': user_id, - 'messageID': message_id, - 'messageText': 'Bot Test 123', - 'bot': '1', - 'cid': '-1', - 'context': dict(first_name='The', last_name='Bot'), - 'test': True, - 'timeCreated': int(time.time())} - self.sio.emit('user_message', data=message_data) + user_id = f"test_bot_{generate_uuid()}" + message_text = f"Bot Test {generate_uuid()}" + message_data = { + "userID": user_id, + "messageText": message_text, + "bot": "1", + "cid": "-1", + "context": dict(first_name="The", last_name="Bot"), + "test": True, + "timeCreated": int(time.time()), + } + self.sio.emit("user_message", data=message_data) time.sleep(2) - bot = db_controller.exec_query(query={'command': 'find_one', - 'document': 'users', - 'data': {'nickname': user_id}}) + bot = db_controller.exec_query( + query={ + "command": "find_one", + "document": "users", + "data": {"nickname": user_id}, + } + ) self.assertIsNotNone(bot) self.assertIsInstance(bot, dict) - self.assertTrue(bot['first_name'] == 'The') - self.assertTrue(bot['last_name'] == 'Bot') - - shout = db_controller.exec_query(query={'command': 'find_one', - 'document': 'shouts', - 'data': {'_id': message_id}}) + self.assertTrue(bot["first_name"] == "Bot") + self.assertTrue(bot["last_name"] == "Bot") + + shout = db_controller.exec_query( + query={ + "command": "find_one", + "document": "shouts", + "data": {"user_id": bot["_id"]}, + } + ) self.assertIsNotNone(shout) self.assertIsInstance(shout, dict) - db_controller.exec_query(query={'command': 'delete_one', - 'document': 'shouts', - 'data': {'_id': message_id}}) - db_controller.exec_query(query={'command': 'delete_one', - 'document': 'users', - 'data': {'nickname': user_id}}) + db_controller.exec_query( + query={ + "command": "delete_many", + "document": "shouts", + "data": {"user_id": bot["_id"]}, + } + ) + db_controller.exec_query( + query={ + "command": "delete_one", + "document": "users", + "data": {"nickname": user_id}, + } + ) diff --git a/chat_server/tests/utils/app_utils.py b/chat_server/tests/utils/app_utils.py index 8d6bf62f..d64e222d 100644 --- a/chat_server/tests/utils/app_utils.py +++ b/chat_server/tests/utils/app_utils.py @@ -31,11 +31,10 @@ from typing import Union from fastapi import FastAPI -from chat_server.app import create_app +from chat_server.sio import sio +from chat_server.wsgi import create_app def get_test_app() -> Union[FastAPI, socketio.ASGIApp]: """Returns test application instance""" - return create_app(testing_mode=True) - - + return create_app(testing_mode=True, sio_server=sio) diff --git a/chat_server/wsgi.py b/chat_server/wsgi.py index 70de229b..808e410d 100644 --- a/chat_server/wsgi.py +++ b/chat_server/wsgi.py @@ -26,5 +26,6 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from .app import create_app +from .sio import sio -app = create_app() +app = create_app(sio_server=sio) diff --git a/config.py b/config.py index 710fad2b..4ce20210 100644 --- a/config.py +++ b/config.py @@ -39,14 +39,14 @@ def load_config() -> dict: """ - Load and return a configuration object, + Load and return a configuration object, """ legacy_config_path = "/app/app/config.json" if isfile(legacy_config_path): LOG.warning(f"Deprecated configuration found at {legacy_config_path}") with open(legacy_config_path) as f: config = json.load(f) - LOG.debug(f'Loaded config - {config}') + LOG.debug(f"Loaded config - {config}") return config config = OVOSConfiguration() if not config: @@ -58,9 +58,9 @@ def load_config() -> dict: class Configuration: - """ Generic configuration module""" + """Generic configuration module""" - KLAT_ENV = os.environ.get('KLAT_ENV', 'DEV') + KLAT_ENV = os.environ.get("KLAT_ENV", "DEV") db_controllers = dict() def __init__(self, from_files: List[str]): @@ -71,28 +71,30 @@ def __init__(self, from_files: List[str]): @staticmethod def extract_config_from_path(file_path: str) -> dict: """ - Extracts configuration dictionary from desired file path + Extracts configuration dictionary from desired file path - :param file_path: desired file path + :param file_path: desired file path - :returns dictionary containing configs from target file, empty dict otherwise + :returns dictionary containing configs from target file, empty dict otherwise """ try: with open(os.path.expanduser(file_path)) as input_file: extraction_result = json.load(input_file) except Exception as ex: - LOG.error(f'Exception occurred while extracting data from {file_path}: {ex}') + LOG.error( + f"Exception occurred while extracting data from {file_path}: {ex}" + ) extraction_result = dict() # LOG.info(f'Extracted config: {extraction_result}') return extraction_result def add_new_config_properties(self, new_config_dict: dict, at_key: str = None): """ - Adds new configuration properties to existing configuration dict + Adds new configuration properties to existing configuration dict - :param new_config_dict: dictionary containing new configuration - :param at_key: the key at which to append new dictionary - (optional but setting that will reduce possible future key conflicts) + :param new_config_dict: dictionary containing new configuration + :param at_key: the key at which to append new dictionary + (optional but setting that will reduce possible future key conflicts) """ if at_key: self.config_data[at_key] = new_config_dict @@ -118,27 +120,29 @@ def config_data(self) -> dict: @config_data.setter def config_data(self, value): if not isinstance(value, dict): - raise TypeError(f'Type: {type(value)} not supported') + raise TypeError(f"Type: {type(value)} not supported") self._config_data = value def get_db_config_from_key(self, key: str): """Gets DB configuration by key""" - return self.config_data.get('DATABASE_CONFIG', {}).get(self.KLAT_ENV, {}).get(key, {}) - - def get_db_controller(self, name: str, - override: bool = False, - override_args: dict = None): + return ( + self.config_data.get("DATABASE_CONFIG", {}) + .get(self.KLAT_ENV, {}) + .get(key, {}) + ) + + def get_db_controller( + self, name: str, override: bool = False, override_args: dict = None + ): """ - Returns an new instance of Database Controller for specified dialect (creates new one if not present) + Returns an new instance of Database Controller for specified dialect (creates new one if not present) - :param name: db connection name from config - :param override: to override existing instance under :param dialect (defaults to False) - :param override_args: dict with arguments to override (optional) + :param name: db connection name from config + :param override: to override existing instance under :param dialect (defaults to False) + :param override_args: dict with arguments to override (optional) - :returns instance of Database Controller + :returns instance of Database Controller """ - from chat_server.server_utils.db_utils import DbUtils - db_controller = self.db_controllers.get(name, None) if not db_controller or override: db_config = self.get_db_config_from_key(key=name) @@ -147,12 +151,11 @@ def get_db_controller(self, name: str, override_args = {} db_config = {**db_config, **override_args} - dialect = db_config.pop('dialect', None) + dialect = db_config.pop("dialect", None) if dialect: from utils.database_utils import DatabaseController db_controller = DatabaseController(config_data=db_config) db_controller.attach_connector(dialect=dialect) db_controller.connect() - DbUtils.init(db_controller) return db_controller diff --git a/dockerfiles/Dockerfile.base b/dockerfiles/Dockerfile.base index e367e0cb..2802140f 100644 --- a/dockerfiles/Dockerfile.base +++ b/dockerfiles/Dockerfile.base @@ -1,4 +1,8 @@ -FROM python:3.8-slim-buster +FROM python:3.10-slim-buster + +ENV OVOS_CONFIG_BASE_FOLDER neon +ENV OVOS_CONFIG_FILENAME klat.yaml +ENV XDG_CONFIG_HOME /config COPY . /app/ diff --git a/dockerfiles/Dockerfile.server b/dockerfiles/Dockerfile.server index 3cc62cdd..a79a441e 100644 --- a/dockerfiles/Dockerfile.server +++ b/dockerfiles/Dockerfile.server @@ -1,6 +1,7 @@ FROM ghcr.io/neongeckocom/pyklatchat_base:dev COPY /utils /app +COPY version.py /app COPY . /app/chat_server/ ENV KLAT_ENV PROD diff --git a/migration_scripts/__main__.py b/migration_scripts/__main__.py index 7515fe26..89920d82 100644 --- a/migration_scripts/__main__.py +++ b/migration_scripts/__main__.py @@ -41,39 +41,51 @@ def main(migration_id: str = None, dump_dir=os.getcwd(), time_since: int = 1677829600): """ - Main migration scripts entry point + Main migration scripts entry point - :param migration_id: migration id to run - :param dump_dir: directory for dumping files to - :param time_since: timestamp since which to do a migration + :param migration_id: migration id to run + :param dump_dir: directory for dumping files to + :param time_since: timestamp since which to do a migration """ migration_id = migration_id or uuid.uuid4().hex - considered_path = os.path.join(dump_dir, 'passed_migrations', migration_id) + considered_path = os.path.join(dump_dir, "passed_migrations", migration_id) if not os.path.exists(considered_path): os.makedirs(considered_path, exist_ok=True) LOG.info(f'Initiating migration id: "{migration_id}"') - LOG.info(f'Considered time since: {time_since}') + LOG.info(f"Considered time since: {time_since}") - config_source_files = [os.environ.get('CONFIG_PATH', 'config.json'), os.environ.get('SSH_CONFIG', None)] + config_source_files = [ + os.environ.get("CONFIG_PATH", "config.json"), + os.environ.get("SSH_CONFIG", None), + ] configuration = Configuration(from_files=config_source_files) - mysql_connector, mongo_connector = setup_db_connectors(configuration=configuration, - old_db_key=os.environ.get('OLD_DB_KEY', None), - new_db_key=os.environ.get('NEW_DB_KEY', None)) - - LOG.info('Established connections with dbs') - - if all(os.path.exists(os.path.join(considered_path, f.value)) for f in (MigrationFiles.NICK_MAPPING, - MigrationFiles.CIDS, - MigrationFiles.NICKS)): - LOG.info('Skipping conversations migrations') - - with open(os.path.join(considered_path, MigrationFiles.NICK_MAPPING.value)) as f: + mysql_connector, mongo_connector = setup_db_connectors( + configuration=configuration, + old_db_key=os.environ.get("OLD_DB_KEY", None), + new_db_key=os.environ.get("NEW_DB_KEY", None), + ) + + LOG.info("Established connections with dbs") + + if all( + os.path.exists(os.path.join(considered_path, f.value)) + for f in ( + MigrationFiles.NICK_MAPPING, + MigrationFiles.CIDS, + MigrationFiles.NICKS, + ) + ): + LOG.info("Skipping conversations migrations") + + with open( + os.path.join(considered_path, MigrationFiles.NICK_MAPPING.value) + ) as f: nick_to_uuid_mapping = json.load(f) with open(os.path.join(considered_path, MigrationFiles.NICKS.value)) as f: @@ -82,37 +94,55 @@ def main(migration_id: str = None, dump_dir=os.getcwd(), time_since: int = 16778 with open(os.path.join(considered_path, MigrationFiles.CIDS.value)) as f: cids = [x.strip() for x in f.readlines()] else: - LOG.info('Starting conversations migration') - cids, nick_to_uuid_mapping, nicks_to_consider = migrate_conversations(old_db_controller=mysql_connector, - new_db_controller=mongo_connector, - time_since=time_since) - - with open(os.path.join(considered_path, MigrationFiles.NICK_MAPPING.value), 'w', encoding="utf-8") as f: + LOG.info("Starting conversations migration") + cids, nick_to_uuid_mapping, nicks_to_consider = migrate_conversations( + old_db_controller=mysql_connector, + new_db_controller=mongo_connector, + time_since=time_since, + ) + + with open( + os.path.join(considered_path, MigrationFiles.NICK_MAPPING.value), + "w", + encoding="utf-8", + ) as f: json.dump(nick_to_uuid_mapping, f) - LOG.info(f'Stored nicks mapping in {MigrationFiles.NICK_MAPPING.value}') + LOG.info(f"Stored nicks mapping in {MigrationFiles.NICK_MAPPING.value}") if nicks_to_consider: - with open(os.path.join(considered_path, MigrationFiles.NICKS.value), 'w', encoding="utf-8") as f: - nicks = [str(nick) + '\n' for nick in nicks_to_consider] + with open( + os.path.join(considered_path, MigrationFiles.NICKS.value), + "w", + encoding="utf-8", + ) as f: + nicks = [str(nick) + "\n" for nick in nicks_to_consider] f.writelines(nicks) - LOG.info(f'Stored nicks list in {MigrationFiles.NICKS.value}') + LOG.info(f"Stored nicks list in {MigrationFiles.NICKS.value}") if cids: - with open(os.path.join(considered_path, MigrationFiles.CIDS.value), 'w', encoding="utf-8") as f: - cids = [str(cid) + '\n' for cid in cids] + with open( + os.path.join(considered_path, MigrationFiles.CIDS.value), + "w", + encoding="utf-8", + ) as f: + cids = [str(cid) + "\n" for cid in cids] f.writelines(cids) - LOG.info(f'Stored cid list in {MigrationFiles.CIDS.value}') + LOG.info(f"Stored cid list in {MigrationFiles.CIDS.value}") - migrate_users(old_db_controller=mysql_connector, - new_db_controller=mongo_connector, - nick_to_uuid_mapping=nick_to_uuid_mapping, - nicks_to_consider=nicks_to_consider) + migrate_users( + old_db_controller=mysql_connector, + new_db_controller=mongo_connector, + nick_to_uuid_mapping=nick_to_uuid_mapping, + nicks_to_consider=nicks_to_consider, + ) - migrate_shouts(old_db_controller=mysql_connector, - new_db_controller=mongo_connector, - nick_to_uuid_mapping=nick_to_uuid_mapping, - from_cids=cids) + migrate_shouts( + old_db_controller=mysql_connector, + new_db_controller=mongo_connector, + nick_to_uuid_mapping=nick_to_uuid_mapping, + from_cids=cids, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/migration_scripts/constants/migration_constants.py b/migration_scripts/constants/migration_constants.py index 51b5d6ef..31fa1d21 100644 --- a/migration_scripts/constants/migration_constants.py +++ b/migration_scripts/constants/migration_constants.py @@ -32,6 +32,6 @@ class MigrationFiles(Enum): """Enum containing migration files""" - NICK_MAPPING = 'nick_mapping.json' - CIDS = 'cids.txt' - NICKS = 'nicks.txt' + NICK_MAPPING = "nick_mapping.json" + CIDS = "cids.txt" + NICKS = "nicks.txt" diff --git a/migration_scripts/conversations.py b/migration_scripts/conversations.py index 2f730535..98e7971f 100644 --- a/migration_scripts/conversations.py +++ b/migration_scripts/conversations.py @@ -29,70 +29,85 @@ from pymongo import ReplaceOne -from migration_scripts.utils.conversation_utils import clean_conversation_name, index_nicks +from migration_scripts.utils.conversation_utils import ( + clean_conversation_name, + index_nicks, +) from utils.logging_utils import LOG -def migrate_conversations(old_db_controller, new_db_controller, - time_since: int = 1577829600) -> Tuple[List[str], Dict[str, str], List[str]]: +def migrate_conversations( + old_db_controller, new_db_controller, time_since: int = 1577829600 +) -> Tuple[List[str], Dict[str, str], List[str]]: """ - Migrating conversations from old database to new one - :param old_db_controller: old database connector - :param new_db_controller: new database connector - :param time_since: timestamp for conversation activity + Migrating conversations from old database to new one + :param old_db_controller: old database connector + :param new_db_controller: new database connector + :param time_since: timestamp for conversation activity """ - LOG.info(f'Starting chats migration') + LOG.info(f"Starting chats migration") - get_cids_query = f""" + get_cids_query = f""" select * from shoutbox_conversations where updated>{time_since}; """ result = old_db_controller.exec_query(get_cids_query) - result_cids = [str(r['cid']) for r in result] + result_cids = [str(r["cid"]) for r in result] - existing_cids = list(new_db_controller.exec_query(query=dict(document='chats', command='find', data={ - '_id': {'$in': result_cids} - }))) + existing_cids = list( + new_db_controller.exec_query( + query=dict( + document="chats", command="find", data={"_id": {"$in": result_cids}} + ) + ) + ) - existing_cids = [r['_id'] for r in existing_cids] + existing_cids = [r["_id"] for r in existing_cids] - all_cids_in_scope = list(set(existing_cids+result_cids)) + all_cids_in_scope = list(set(existing_cids + result_cids)) - LOG.info(f'Found {len(existing_cids)} existing cids') + LOG.info(f"Found {len(existing_cids)} existing cids") if existing_cids: - result = list(filter(lambda x: str(x['cid']) not in existing_cids, result)) - - LOG.info(f'Received {len(result)} new cids') - - received_nicks = [record['creator'] for record in result if record['creator'] is not None] - - nicknames_mapping, nicks_to_consider = index_nicks(mongo_controller=new_db_controller, - received_nicks=received_nicks) - - LOG.debug(f'Records to process: {len(result)}') - - formed_result = [ReplaceOne({'_id': str(record['cid'])}, - { - '_id': str(record['cid']), - 'is_private': int(record['private']) == 1, - 'domain': record['domain'], - 'image': record['image_url'], - 'password': record['password'], - 'conversation_name': f"{clean_conversation_name(record['title'])}_{record['cid']}", - 'chat_flow': [], - 'creator': nicknames_mapping.get(record['creator'], record['creator']), - 'created_on': int(record['created']) - }, upsert=True) for record in result - ] + result = list(filter(lambda x: str(x["cid"]) not in existing_cids, result)) + + LOG.info(f"Received {len(result)} new cids") + + received_nicks = [ + record["creator"] for record in result if record["creator"] is not None + ] + + nicknames_mapping, nicks_to_consider = index_nicks( + mongo_controller=new_db_controller, received_nicks=received_nicks + ) + + LOG.debug(f"Records to process: {len(result)}") + + formed_result = [ + ReplaceOne( + {"_id": str(record["cid"])}, + { + "_id": str(record["cid"]), + "is_private": int(record["private"]) == 1, + "domain": record["domain"], + "image": record["image_url"], + "password": record["password"], + "conversation_name": f"{clean_conversation_name(record['title'])}_{record['cid']}", + "chat_flow": [], + "creator": nicknames_mapping.get(record["creator"], record["creator"]), + "created_on": int(record["created"]), + }, + upsert=True, + ) + for record in result + ] if len(formed_result) > 0: - - new_db_controller.exec_query(query=dict(document='chats', - command='bulk_write', - data=formed_result)) + new_db_controller.exec_query( + query=dict(document="chats", command="bulk_write", data=formed_result) + ) else: - LOG.info('All chats are already in new deb, skipping chat migration') + LOG.info("All chats are already in new deb, skipping chat migration") return all_cids_in_scope, nicknames_mapping, nicks_to_consider diff --git a/migration_scripts/shouts.py b/migration_scripts/shouts.py index 53b2dc0f..275a9f52 100644 --- a/migration_scripts/shouts.py +++ b/migration_scripts/shouts.py @@ -29,31 +29,48 @@ from pymongo import ReplaceOne, UpdateOne -from chat_server.server_utils.db_utils import DbUtils, MongoQuery, MongoCommands, MongoDocuments +from chat_server.server_utils.db_utils import ( + MongoQuery, + MongoCommands, + MongoDocuments, +) from migration_scripts.utils.shout_utils import prepare_nicks_for_sql from migration_scripts.utils.sql_utils import iterable_to_sql_array, sql_arr_is_null +from utils.database_utils.mongo_utils.queries.wrapper import MongoDocumentsAPI from utils.logging_utils import LOG -def migrate_shouts(old_db_controller, new_db_controller, nick_to_uuid_mapping: dict, from_cids: list): +def migrate_shouts( + old_db_controller, new_db_controller, nick_to_uuid_mapping: dict, from_cids: list +): """ - Migrating users from old database to new one - :param old_db_controller: old database connector - :param new_db_controller: new database connector - :param nick_to_uuid_mapping: mapping of nicks to uuid - :param from_cids: list of considered conversation ids + Migrating users from old database to new one + :param old_db_controller: old database connector + :param new_db_controller: new database connector + :param nick_to_uuid_mapping: mapping of nicks to uuid + :param from_cids: list of considered conversation ids """ - existing_shouts = list(new_db_controller.exec_query(query=dict(document='shouts', command='find', data={}))) + existing_shouts = list( + new_db_controller.exec_query( + query=dict(document="shouts", command="find", data={}) + ) + ) - nick_to_uuid_mapping = {k.strip().lower(): v for k, v in copy.deepcopy(nick_to_uuid_mapping).items() if k} + nick_to_uuid_mapping = { + k.strip().lower(): v + for k, v in copy.deepcopy(nick_to_uuid_mapping).items() + if k + } - LOG.info('Starting shouts migration') + LOG.info("Starting shouts migration") users = iterable_to_sql_array(prepare_nicks_for_sql(list(nick_to_uuid_mapping))) filter_str = f"WHERE nick IN {users} " - existing_shout_ids = iterable_to_sql_array([str(shout['_id']) for shout in list(existing_shouts)]) + existing_shout_ids = iterable_to_sql_array( + [str(shout["_id"]) for shout in list(existing_shouts)] + ) if not sql_arr_is_null(existing_shout_ids): filter_str += f"AND shout_id NOT IN {existing_shout_ids}" @@ -64,87 +81,102 @@ def migrate_shouts(old_db_controller, new_db_controller, nick_to_uuid_mapping: d get_shouts_query = f""" SELECT * FROM shoutbox {filter_str};""" result = old_db_controller.exec_query(get_shouts_query) - LOG.info(f'Received {len(list(result))} shouts') + LOG.info(f"Received {len(list(result))} shouts") formed_result = [] for record in result: - for k in list(record): if isinstance(record[k], (bytearray, bytes)): - record[k] = str(record[k].decode('utf-8')) - - formed_result.append(ReplaceOne({'_id': str(record['shout_id'])}, - { - '_id': str(record['shout_id']), - 'domain': record['domain'], - 'user_id': nick_to_uuid_mapping.get(record['nick'], 'undefined'), - 'created_on': int(record['created']), - 'message_text': record['shout'], - 'language': record['language'], - 'cid': str(record['cid']) - }, upsert=True)) + record[k] = str(record[k].decode("utf-8")) + + formed_result.append( + ReplaceOne( + {"_id": str(record["shout_id"])}, + { + "_id": str(record["shout_id"]), + "domain": record["domain"], + "user_id": nick_to_uuid_mapping.get(record["nick"], "undefined"), + "created_on": int(record["created"]), + "message_text": record["shout"], + "language": record["language"], + "cid": str(record["cid"]), + }, + upsert=True, + ) + ) if len(formed_result) > 0: + new_db_controller.exec_query( + query=dict(document="shouts", command="bulk_write", data=formed_result) + ) - new_db_controller.exec_query(query=dict(document='shouts', - command='bulk_write', - data=formed_result)) - - LOG.info('Starting inserting shouts in conversations') + LOG.info("Starting inserting shouts in conversations") for record in result: - try: - new_db_controller.exec_query(query=dict(document='chats', - command='update', - data=({'_id': record['cid']}, - {'$push': {'chat_flow': str(record['shout_id'])}}))) + new_db_controller.exec_query( + query=dict( + document="chats", + command="update", + data=( + {"_id": record["cid"]}, + {"$push": {"chat_flow": str(record["shout_id"])}}, + ), + ) + ) except Exception as ex: - LOG.error(f'Skipping processing of shout data "{record}" due to exception: {ex}') + LOG.error( + f'Skipping processing of shout data "{record}" due to exception: {ex}' + ) continue def remap_creation_timestamp(db_controller): - """ Remaps creation timestamp from millis to seconds """ - filter_stage = { - '$match': { - 'created_on': { - '$gte': 10 ** 12 - } - } - } + """Remaps creation timestamp from millis to seconds""" + filter_stage = {"$match": {"created_on": {"$gte": 10**12}}} bulk_update = [] - res = list(DbUtils.db_controller.connector.connection["shouts"].aggregate([filter_stage])) + res = list( + MongoDocumentsAPI.db_controller.connector.connection["shouts"].aggregate( + [filter_stage] + ) + ) for item in res: - bulk_update.append(UpdateOne({'_id': item['_id']}, - {'$set': {'created_on': item['created_on'] // 10 ** 3}})) - db_controller.exec_query(query=MongoQuery(command=MongoCommands.BULK_WRITE, - document=MongoDocuments.SHOUTS, - data=bulk_update)) + bulk_update.append( + UpdateOne( + {"_id": item["_id"]}, + {"$set": {"created_on": item["created_on"] // 10**3}}, + ) + ) + db_controller.exec_query( + query=MongoQuery( + command=MongoCommands.BULK_WRITE, + document=MongoDocuments.SHOUTS, + data=bulk_update, + ) + ) def set_cid_to_shouts(db_controller): - """ Sets correspondent cid to new shouts """ - conversion_stage = { - '$addFields': {'str_id': {'$toString': "$_id"}} - } - add_str_length = { - '$addFields': {'length': {'$strLenCP': "$str_id"}} - } - filter_stage = { - '$match': { - 'length': { - '$gt': 5 - } - } - } + """Sets correspondent cid to new shouts""" + conversion_stage = {"$addFields": {"str_id": {"$toString": "$_id"}}} + add_str_length = {"$addFields": {"length": {"$strLenCP": "$str_id"}}} + filter_stage = {"$match": {"length": {"$gt": 5}}} bulk_update = [] - res = list(db_controller.connector.connection["chats"].aggregate([conversion_stage, add_str_length, filter_stage])) + res = list( + db_controller.connector.connection["chats"].aggregate( + [conversion_stage, add_str_length, filter_stage] + ) + ) for item in res: - for shout in item.get('chat_flow', []): - bulk_update.append(UpdateOne({'_id': shout}, - {'$set': {'cid': item['_id']}})) - DbUtils.db_controller.exec_query(query=MongoQuery(command=MongoCommands.BULK_WRITE, - document=MongoDocuments.SHOUTS, - data=bulk_update)) + for shout in item.get("chat_flow", []): + bulk_update.append( + UpdateOne({"_id": shout}, {"$set": {"cid": item["_id"]}}) + ) + MongoDocumentsAPI.db_controller.exec_query( + query=MongoQuery( + command=MongoCommands.BULK_WRITE, + document=MongoDocuments.SHOUTS, + data=bulk_update, + ) + ) diff --git a/migration_scripts/users.py b/migration_scripts/users.py index 0c91d76b..84af83fd 100644 --- a/migration_scripts/users.py +++ b/migration_scripts/users.py @@ -34,43 +34,53 @@ from utils.logging_utils import LOG -def migrate_users(old_db_controller, new_db_controller, nick_to_uuid_mapping, nicks_to_consider): +def migrate_users( + old_db_controller, new_db_controller, nick_to_uuid_mapping, nicks_to_consider +): """ - Migrating users from old database to new one - :param old_db_controller: old database connector - :param new_db_controller: new database connector - :param nick_to_uuid_mapping: mapping of nicks to uuid - :param nicks_to_consider: list of nicknames to consider + Migrating users from old database to new one + :param old_db_controller: old database connector + :param new_db_controller: new database connector + :param nick_to_uuid_mapping: mapping of nicks to uuid + :param nicks_to_consider: list of nicknames to consider """ - LOG.info('Starting users migration') + LOG.info("Starting users migration") existing_nicks = get_existing_nicks_to_id(mongo_controller=new_db_controller) - nick_to_uuid_mapping = {k.strip().lower(): v for k, v in copy.deepcopy(nick_to_uuid_mapping).items() - if k not in list(existing_nicks)} + nick_to_uuid_mapping = { + k.strip().lower(): v + for k, v in copy.deepcopy(nick_to_uuid_mapping).items() + if k not in list(existing_nicks) + } - LOG.info(f'Nicks to consider: {nicks_to_consider}') + LOG.info(f"Nicks to consider: {nicks_to_consider}") - users = ', '.join(["'" + nick.replace("'", "") + "'" for nick in nicks_to_consider - if nick.strip().lower() not in list(existing_nicks)]) + users = ", ".join( + [ + "'" + nick.replace("'", "") + "'" + for nick in nicks_to_consider + if nick.strip().lower() not in list(existing_nicks) + ] + ) if len(nicks_to_consider) == 0: - LOG.info('All nicks are already in new db, skipping user migration') + LOG.info("All nicks are already in new db, skipping user migration") return - get_user_query = f""" SELECT color, nick, avatar_url, pass, - mail, login, timezone, logout, about_me, speech_rate, - speech_pitch, speech_voice, ai_speech_voice, stt_language, - tts_language, tts_voice_gender, tts_secondary_language, time_format, - unit_measure, date_format, location_city, location_state, - location_country, first_name, middle_name, last_name, preferred_name, - birthday, age, display_nick, utc, email_verified, phone_verified, - ignored_brands, favorite_brands, specially_requested, stt_region, - alt_languages, secondary_tts_gender, secondary_neon_voice, - username, phone, synonyms, full_name, share_my_recordings, - use_client_stt, show_recordings_from_others, speakers_on, - allow_audio_recording, volume, use_multi_line_shout + get_user_query = f""" SELECT color, nick, avatar_url, pass, + mail, login, timezone, logout, about_me, speech_rate, + speech_pitch, speech_voice, ai_speech_voice, stt_language, + tts_language, tts_voice_gender, tts_secondary_language, time_format, + unit_measure, date_format, location_city, location_state, + location_country, first_name, middle_name, last_name, preferred_name, + birthday, age, display_nick, utc, email_verified, phone_verified, + ignored_brands, favorite_brands, specially_requested, stt_region, + alt_languages, secondary_tts_gender, secondary_neon_voice, + username, phone, synonyms, full_name, share_my_recordings, + use_client_stt, show_recordings_from_others, speakers_on, + allow_audio_recording, volume, use_multi_line_shout FROM shoutbox_users WHERE nick IN ({users}); """ result = old_db_controller.exec_query(get_user_query) @@ -80,49 +90,61 @@ def migrate_users(old_db_controller, new_db_controller, nick_to_uuid_mapping, ni if isinstance(value, Decimal): record[key] = int(value) - formed_result = [ReplaceOne({'_id': nick_to_uuid_mapping[record['nick'].strip().lower()]}, - { - '_id': nick_to_uuid_mapping[record['nick'].strip().lower()], - 'first_name': record['first_name'], - 'last_name': record['last_name'], - 'avatar': record['avatar_url'], - 'nickname': record['nick'], - 'password': record['pass'], - 'about_me': record['about_me'], - 'date_created': int(record['login']), - 'email': record['mail'], - 'phone': record['phone'] - }, upsert=True) for record in result - ] - - new_db_controller.exec_query(query=dict(document='users', - command='bulk_write', - data=formed_result)) - - formed_result = [ReplaceOne({'_id': nick_to_uuid_mapping[record['nick'].strip().lower()]}, - { - '_id': nick_to_uuid_mapping[record['nick'].strip().lower()], - 'display_nick': record['display_nick'], - 'stt_language': record['stt_language'], - 'use_client_stt': record['use_client_stt'], - 'tts_language': record['tts_language'], - 'tts_voice_gender': record['tts_voice_gender'], - 'tts_secondary_language': record['tts_secondary_language'], - 'speech_rate': record['speech_rate'], - 'speech_pitch': record['speech_pitch'], - 'speech_voice': record['speech_voice'], - 'share_my_recordings': record['share_my_recordings'], - 'ai_speech_voice': record['ai_speech_voice'], - 'preferred_name': record['preferred_name'], - 'ignored_brands': record['ignored_brands'], - 'favorite_brands': record['favorite_brands'], - 'volume': record['volume'], - 'use_multi_line_shout': record['use_multi_line_shout'] - }, upsert=True) for record in result - ] - - new_db_controller.exec_query(query=dict(document='user_preferences', - command='bulk_write', - data=formed_result)) - - LOG.info(f'Received {len(list(result))} new users') + formed_result = [ + ReplaceOne( + {"_id": nick_to_uuid_mapping[record["nick"].strip().lower()]}, + { + "_id": nick_to_uuid_mapping[record["nick"].strip().lower()], + "first_name": record["first_name"], + "last_name": record["last_name"], + "avatar": record["avatar_url"], + "nickname": record["nick"], + "password": record["pass"], + "about_me": record["about_me"], + "date_created": int(record["login"]), + "email": record["mail"], + "phone": record["phone"], + }, + upsert=True, + ) + for record in result + ] + + new_db_controller.exec_query( + query=dict(document="users", command="bulk_write", data=formed_result) + ) + + formed_result = [ + ReplaceOne( + {"_id": nick_to_uuid_mapping[record["nick"].strip().lower()]}, + { + "_id": nick_to_uuid_mapping[record["nick"].strip().lower()], + "display_nick": record["display_nick"], + "stt_language": record["stt_language"], + "use_client_stt": record["use_client_stt"], + "tts_language": record["tts_language"], + "tts_voice_gender": record["tts_voice_gender"], + "tts_secondary_language": record["tts_secondary_language"], + "speech_rate": record["speech_rate"], + "speech_pitch": record["speech_pitch"], + "speech_voice": record["speech_voice"], + "share_my_recordings": record["share_my_recordings"], + "ai_speech_voice": record["ai_speech_voice"], + "preferred_name": record["preferred_name"], + "ignored_brands": record["ignored_brands"], + "favorite_brands": record["favorite_brands"], + "volume": record["volume"], + "use_multi_line_shout": record["use_multi_line_shout"], + }, + upsert=True, + ) + for record in result + ] + + new_db_controller.exec_query( + query=dict( + document="user_preferences", command="bulk_write", data=formed_result + ) + ) + + LOG.info(f"Received {len(list(result))} new users") diff --git a/migration_scripts/utils/__init__.py b/migration_scripts/utils/__init__.py index ae31e9cf..9a6b86c6 100644 --- a/migration_scripts/utils/__init__.py +++ b/migration_scripts/utils/__init__.py @@ -32,17 +32,20 @@ def setup_db_connectors(configuration: Configuration, old_db_key: str, new_db_key: str): """ - Migrating users from old database to new one - :param configuration: active configuration - :param old_db_key: old database key - :param new_db_key: new database key + Migrating users from old database to new one + :param configuration: active configuration + :param old_db_key: old database key + :param new_db_key: new database key """ - ssh_configs = configuration.config_data.get('SSH_CONFIG') - tunnel_connection = create_ssh_tunnel(server_address=ssh_configs['ADDRESS'], - username=ssh_configs['USER'], - password=ssh_configs['PASSWORD'], - remote_bind_address=('127.0.0.1', 3306)) - mysql_connector = configuration.get_db_controller(name=old_db_key, - override_args={'port': tunnel_connection.local_bind_address[1]}) + ssh_configs = configuration.config_data.get("SSH_CONFIG") + tunnel_connection = create_ssh_tunnel( + server_address=ssh_configs["ADDRESS"], + username=ssh_configs["USER"], + password=ssh_configs["PASSWORD"], + remote_bind_address=("127.0.0.1", 3306), + ) + mysql_connector = configuration.get_db_controller( + name=old_db_key, override_args={"port": tunnel_connection.local_bind_address[1]} + ) mongo_connector = configuration.get_db_controller(name=new_db_key) return mysql_connector, mongo_connector diff --git a/migration_scripts/utils/conversation_utils.py b/migration_scripts/utils/conversation_utils.py index 675f5876..cb110218 100644 --- a/migration_scripts/utils/conversation_utils.py +++ b/migration_scripts/utils/conversation_utils.py @@ -37,10 +37,10 @@ def index_nicks(mongo_controller, received_nicks: List[str]) -> Tuple[dict, List[str]]: """ - Assigns unique id to each nick that is not present in new db + Assigns unique id to each nick that is not present in new db - :param mongo_controller: controller to active mongo collection - :param received_nicks: received nicks from mysql controller + :param mongo_controller: controller to active mongo collection + :param received_nicks: received nicks from mysql controller """ # Excluding existing nicks from loop @@ -52,16 +52,16 @@ def index_nicks(mongo_controller, received_nicks: List[str]) -> Tuple[dict, List for nick in nicks_to_consider: nicks_mapping[nick] = uuid.uuid4().hex - LOG.info(f'Created nicks mapping for {len(list(nicks_mapping))} records') + LOG.info(f"Created nicks mapping for {len(list(nicks_mapping))} records") return nicks_mapping, nicks_to_consider def clean_conversation_name(conversation_title: str): """ - Cleans up conversation names excluding all the legacy special chars + Cleans up conversation names excluding all the legacy special chars - :param conversation_title: Conversation title to clean + :param conversation_title: Conversation title to clean """ regex = re.search("-\[(.*?)\](.*)$", conversation_title) if regex is not None: diff --git a/migration_scripts/utils/shout_utils.py b/migration_scripts/utils/shout_utils.py index f8d03fcb..0f3dcb44 100644 --- a/migration_scripts/utils/shout_utils.py +++ b/migration_scripts/utils/shout_utils.py @@ -31,9 +31,9 @@ def prepare_nicks_for_sql(nicks: List[str]) -> list: """ - Prepares nicks to be used in SQL query + Prepares nicks to be used in SQL query - :param nicks: list of nicks to be used + :param nicks: list of nicks to be used """ processed_nicks = nicks.copy() return [nick.replace("'", "") for nick in processed_nicks] diff --git a/migration_scripts/utils/sql_utils.py b/migration_scripts/utils/sql_utils.py index 5b895def..b7439ab5 100644 --- a/migration_scripts/utils/sql_utils.py +++ b/migration_scripts/utils/sql_utils.py @@ -31,9 +31,9 @@ def iterable_to_sql_array(i: Iterable) -> str: """Converts python iterable to SQL array""" - return f'({str(list(i))[1:-1]})' + return f"({str(list(i))[1:-1]})" def sql_arr_is_null(sql_arr: str): """Checks if SQL array is null""" - return sql_arr and sql_arr == '()' + return sql_arr and sql_arr == "()" diff --git a/requirements/legacy_migration_requirements.txt b/requirements/legacy_migration_requirements.txt new file mode 100644 index 00000000..03dcdfb7 --- /dev/null +++ b/requirements/legacy_migration_requirements.txt @@ -0,0 +1,3 @@ +# These requirements are needed to migrated db data from legacy Klatchat v1 +mysql-connector==2.2.9 +sshtunnel~=0.4 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 3c70c766..f15de5b3 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,22 +1,23 @@ -aiofiles==0.7.0 -bidict==0.21.2 -cachetools==5.1.0 -fastapi==0.65.1 -fastapi-socketio==0.0.9 -Jinja2==3.0.1 +aiofiles==23.2.1 +bidict==0.22.1 +cachetools==5.3.1 +fastapi==0.103.2 +fastapi-socketio==0.0.10 +httpx==0.25.0 # required by FastAPI +Jinja2==3.1.2 jsbeautifier==1.14.7 -mysql-connector==2.2.9 +kubernetes==28.1.0 neon-mq-connector~=0.7 neon-sftp~=0.1 ovos_config==0.0.10 ovos_utils==0.0.35 -pydantic==1.8.2 -PyJWT==2.1.0 -pymongo==3.12.0 -python-multipart==0.0.5 -python-socketio==5.7.1 -s3transfer==0.4.2 -sshtunnel~=0.4 -uvicorn==0.14.0 -websocket-client==0.54.0 -websockets==9.1 +pre-commit==3.4.0 +pydantic==2.4.2 +PyJWT==2.8.0 +pymongo==4.5.0 +python-multipart==0.0.6 +python-socketio==5.9.0 +starlette==0.27.0 +uvicorn==0.23.2 +websocket-client==1.6.3 +websockets==11.0.3 diff --git a/scripts/file_merger.py b/scripts/file_merger.py index ceed3adf..62581d70 100644 --- a/scripts/file_merger.py +++ b/scripts/file_merger.py @@ -44,65 +44,89 @@ class ParseKwargs(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, dict()) for value in values: - key, value = value.split('=') + key, value = value.split("=") value = eval(value) getattr(namespace, self.dest)[key] = value class FileMerger(FilesManipulator): """ - File Merger is a convenience class for merging files dependencies - into the single considering the order of insertion + File Merger is a convenience class for merging files dependencies + into the single considering the order of insertion """ - DEFAULT_FILE_EXTENSION = '.js' - - def __init__(self, - working_dir: str, - weighted_dirs: Dict[str, tuple], - weighted_files: Optional[Dict[str, tuple]] = None, - skip_files: Optional[List[str]] = None, - save_to: str = None, - beautify: bool = False): + + DEFAULT_FILE_EXTENSION = ".js" + + def __init__( + self, + working_dir: str, + weighted_dirs: Dict[str, tuple], + weighted_files: Optional[Dict[str, tuple]] = None, + skip_files: Optional[List[str]] = None, + save_to: str = None, + beautify: bool = False, + ): super().__init__(working_dir=working_dir, skip_files=skip_files) self.weighted_dirs = weighted_dirs or {} self.weighted_files = weighted_files or {} - self.save_to = save_to or f'output{self.DEFAULT_FILE_EXTENSION}' + self.save_to = save_to or f"output{self.DEFAULT_FILE_EXTENSION}" self.current_content = "" self.beautify = beautify @staticmethod def build_from_args(): """ - Parsing user-entered arguments - Currently accepts: - - weighted dirs: serialized dictionary of weighted dirs - - weighted_files: serialized dictionary of weighted files - - skip_files: list of files to skip + Parsing user-entered arguments + Currently accepts: + - weighted dirs: serialized dictionary of weighted dirs + - weighted_files: serialized dictionary of weighted files + - skip_files: list of files to skip """ parser = argparse.ArgumentParser() - parser.add_argument('-wdir', '--working_dir', metavar='.') - parser.add_argument('-dirs', '--weighted_dirs', nargs='*', action=ParseKwargs, - metavar='key1=["val1","val2"] ...') - parser.add_argument('-files', '--weighted_files', nargs='*', action=ParseKwargs, - metavar='key1=["val1","val2"] ...') - parser.add_argument('-skip', '--skip_files', nargs='+', help='list of filenames to skip') - parser.add_argument('-dest', '--save_to', type=str, help='name of destination file') - parser.add_argument('-b', '--beautify', type=str, help='"1" to beautify the output file (does not work for css)') + parser.add_argument("-wdir", "--working_dir", metavar=".") + parser.add_argument( + "-dirs", + "--weighted_dirs", + nargs="*", + action=ParseKwargs, + metavar='key1=["val1","val2"] ...', + ) + parser.add_argument( + "-files", + "--weighted_files", + nargs="*", + action=ParseKwargs, + metavar='key1=["val1","val2"] ...', + ) + parser.add_argument( + "-skip", "--skip_files", nargs="+", help="list of filenames to skip" + ) + parser.add_argument( + "-dest", "--save_to", type=str, help="name of destination file" + ) + parser.add_argument( + "-b", + "--beautify", + type=str, + help='"1" to beautify the output file (does not work for css)', + ) script_args = parser.parse_args() - return FileMerger(working_dir=script_args.working_dir, - weighted_dirs=script_args.weighted_dirs, - weighted_files=script_args.weighted_files, - skip_files=script_args.skip_files, - save_to=script_args.save_to, - beautify=script_args.beautify == '1') + return FileMerger( + working_dir=script_args.working_dir, + weighted_dirs=script_args.weighted_dirs, + weighted_files=script_args.weighted_files, + skip_files=script_args.skip_files, + save_to=script_args.save_to, + beautify=script_args.beautify == "1", + ) def get_content(self, from_file) -> str: """ - Gets content from file - :param from_file: file to get content from - :returns extracted content + Gets content from file + :param from_file: file to get content from + :returns extracted content """ with open(join(self.working_dir, from_file)) as f: lines = f.readlines() @@ -114,13 +138,16 @@ def on_valid_file(self, file_path): content = self.get_content(self.full_path(file_path)) if self.beautify: content = jsbeautifier.beautify(content) - self.current_content += '\n' + content + self.current_content += "\n" + content def run(self): """ - Runs merging based on provided attributes + Runs merging based on provided attributes """ - weights_list = [int(x) for x in list(set(list(self.weighted_dirs) + list(self.weighted_files)))] + weights_list = [ + int(x) + for x in list(set(list(self.weighted_dirs) + list(self.weighted_files))) + ] weights_list = sorted(weights_list, reverse=True) @@ -131,25 +158,25 @@ def run(self): content = self.get_content(file) if self.beautify: content = jsbeautifier.beautify(content) - self.current_content += '\n' + content + self.current_content += "\n" + content matching_dirs = self.weighted_dirs.get(str(weight), ()) for folder in matching_dirs: self.walk_tree(folder) - with open(os.path.join(self.working_dir, self.save_to), 'w') as f: + with open(os.path.join(self.working_dir, self.save_to), "w") as f: f.write(self.current_content) self.current_content = "" def merge_files_by_arguments(): """ - Executes files merging based on the provided CMD arguments + Executes files merging based on the provided CMD arguments - Example invocation for reference: - python build_widget.py --weighted_dirs 1=['js'] --weighted_files 0=['nano_builder.js'] --save_to output.js --skip_files meta.js + Example invocation for reference: + python build_widget.py --weighted_dirs 1=['js'] --weighted_files 0=['nano_builder.js'] --save_to output.js --skip_files meta.js """ file_merger = FileMerger.build_from_args() file_merger.run() -if __name__ == '__main__': +if __name__ == "__main__": merge_files_by_arguments() diff --git a/scripts/files_manipulator.py b/scripts/files_manipulator.py index fe00e885..6ecf0c4b 100644 --- a/scripts/files_manipulator.py +++ b/scripts/files_manipulator.py @@ -39,9 +39,11 @@ class FilesManipulator(ABC): - """ Base class to manipulate files under specified dir """ + """Base class to manipulate files under specified dir""" - def __init__(self, working_dir: str, skip_files: list = None, skip_dirs: list = None): + def __init__( + self, working_dir: str, skip_files: list = None, skip_dirs: list = None + ): self.working_dir = working_dir or os.getcwd() self.skip_files = skip_files or [] self.skip_dirs = skip_dirs or [] @@ -50,7 +52,7 @@ def __init__(self, working_dir: str, skip_files: list = None, skip_dirs: list = @abstractmethod def build_from_args(): """ - Building instances from CLI arguments + Building instances from CLI arguments """ pass @@ -58,27 +60,33 @@ def full_path(self, file_path): return os.path.join(self.working_dir, file_path) def is_valid_processing_file(self, file_path) -> bool: - """ Condition to validate if given file is appropriate for processing """ - return isfile(join(self.working_dir, file_path)) and os.path.split(file_path)[-1] not in self.skip_files + """Condition to validate if given file is appropriate for processing""" + return ( + isfile(join(self.working_dir, file_path)) + and os.path.split(file_path)[-1] not in self.skip_files + ) def on_valid_file(self, file_path): - """ Implement to handle valid files """ + """Implement to handle valid files""" pass def on_failed_file(self, file_path): - """ Implement to handle failed files """ + """Implement to handle failed files""" pass - def walk_tree(self, folder: str = ''): - """ Walks towards specified folder and processes files""" + def walk_tree(self, folder: str = ""): + """Walks towards specified folder and processes files""" if folder: target_folder = join(self.working_dir, folder) else: target_folder = self.working_dir - print(f'walking through folder: {target_folder}') + print(f"walking through folder: {target_folder}") for item in listdir(target_folder): - print(f'checking path: {item}') - if os.path.isdir(os.path.join(self.working_dir, item)) and item not in self.skip_dirs: + print(f"checking path: {item}") + if ( + os.path.isdir(os.path.join(self.working_dir, item)) + and item not in self.skip_dirs + ): self.walk_tree(item) else: file_path = os.path.join(folder, item) diff --git a/scripts/minifier.py b/scripts/minifier.py index e22c3435..7c2ebf4a 100644 --- a/scripts/minifier.py +++ b/scripts/minifier.py @@ -36,17 +36,19 @@ class FilesMinifier(FilesManipulator): - """ Intelligent minifier of frontend modules """ + """Intelligent minifier of frontend modules""" __css_lib_installed = False __js_lib_installed = False - def __init__(self, - working_dir: str, - processing_pattern: str, - skipping_pattern: str = '', - skip_files: list = None, - skip_dirs: list = None): + def __init__( + self, + working_dir: str, + processing_pattern: str, + skipping_pattern: str = "", + skip_files: list = None, + skip_dirs: list = None, + ): super().__init__(working_dir, skip_files=skip_files, skip_dirs=skip_dirs) self.processing_patter = re.compile(pattern=processing_pattern) self.skipping_pattern = re.compile(pattern=skipping_pattern) @@ -54,25 +56,49 @@ def __init__(self, @staticmethod def build_from_args(): parser = argparse.ArgumentParser() - parser.add_argument('-wdir', '--working_dir', metavar='.', default='.') - parser.add_argument('-ppattern', '--processing_pattern', help='regex string of files that must be processed') - parser.add_argument('-spattern', '--skipping_pattern', help='regex string of files that must be skipped') - parser.add_argument('-skip', '--skip_files', nargs='+', help='list of filenames to skip', default=None) - parser.add_argument('-dskip', '--skip_dirs', nargs='+', help='list of directories to skip', default=None) + parser.add_argument("-wdir", "--working_dir", metavar=".", default=".") + parser.add_argument( + "-ppattern", + "--processing_pattern", + help="regex string of files that must be processed", + ) + parser.add_argument( + "-spattern", + "--skipping_pattern", + help="regex string of files that must be skipped", + ) + parser.add_argument( + "-skip", + "--skip_files", + nargs="+", + help="list of filenames to skip", + default=None, + ) + parser.add_argument( + "-dskip", + "--skip_dirs", + nargs="+", + help="list of directories to skip", + default=None, + ) script_args = parser.parse_args() - return FilesMinifier(working_dir=script_args.working_dir, - processing_pattern=script_args.processing_pattern, - skipping_pattern=script_args.skipping_pattern, - skip_files=script_args.skip_files, - skip_dirs=script_args.skip_dirs) + return FilesMinifier( + working_dir=script_args.working_dir, + processing_pattern=script_args.processing_pattern, + skipping_pattern=script_args.skipping_pattern, + skip_files=script_args.skip_files, + skip_dirs=script_args.skip_dirs, + ) def is_valid_processing_file(self, file_path) -> bool: - return super().is_valid_processing_file(file_path) and self.processing_patter.match(file_path) + return super().is_valid_processing_file( + file_path + ) and self.processing_patter.match(file_path) def on_valid_file(self, file_path): - dest_path = os.path.join(self.working_dir, 'build') + dest_path = os.path.join(self.working_dir, "build") file_path_lst = os.path.normpath(file_path).split(os.sep) dest_path = os.path.join(dest_path, *file_path_lst[:-1]) if not os.path.exists(dest_path): @@ -83,18 +109,20 @@ def on_valid_file(self, file_path): shutil.copyfile(source_path, dest_path) else: # dest_path = dest_path.replace('.js', '.min.js') - if source_path.endswith('css'): + if source_path.endswith("css"): if not self.__css_lib_installed: - os.system('npm install uglifycss -g') + os.system("npm install uglifycss -g") self.__css_lib_installed = True - command = f'uglifycss --ugly-comments --output {dest_path} {source_path}' - elif source_path.endswith('js'): + command = ( + f"uglifycss --ugly-comments --output {dest_path} {source_path}" + ) + elif source_path.endswith("js"): if not self.__js_lib_installed: os.system(f"npm install uglify-js -g") self.__js_lib_installed = True command = f"uglifyjs --compress --mangle --output {dest_path} -- {source_path}" else: - print(f'{source_path} is skipped') + print(f"{source_path} is skipped") return os.system(command) @@ -103,6 +131,6 @@ def run(self): return self.walk_tree() -if __name__ == '__main__': +if __name__ == "__main__": instance = FilesMinifier.build_from_args() instance.run() diff --git a/services/klatchat_observer/__main__.py b/services/klatchat_observer/__main__.py index ebc2e6d9..338557e5 100644 --- a/services/klatchat_observer/__main__.py +++ b/services/klatchat_observer/__main__.py @@ -36,22 +36,30 @@ def main(config: Optional[dict] = None): - connector = ChatObserver(config=config, scan_neon_service=os.environ.get('SCAN_NEON_SERVICE', False)) + connector = ChatObserver( + config=config, scan_neon_service=os.environ.get("SCAN_NEON_SERVICE", False) + ) connector.run(run_sync=False) -if __name__ == '__main__': +if __name__ == "__main__": try: config_data = load_config() - if not config_data.get('MQ'): - LOG.warning('Failed to load MQ settings from OVOS config, legacy flow will be applied') - config_data = Configuration(from_files=[os.environ.get('KLATCHAT_OBSERVER_CONFIG', 'config.json')]).config_data + if not config_data.get("MQ"): + LOG.warning( + "Failed to load MQ settings from OVOS config, legacy flow will be applied" + ) + config_data = Configuration( + from_files=[os.environ.get("KLATCHAT_OBSERVER_CONFIG", "config.json")] + ).config_data except Exception as e: LOG.error(e) config_data = dict() - LOG.info(f'Starting Chat Observer Listener (pid: {os.getpid()})...') + LOG.info(f"Starting Chat Observer Listener (pid: {os.getpid()})...") try: main(config=config_data) except Exception as ex: - LOG.info(f'Chat Observer Execution Interrupted (pid: {os.getpid()}) due to exception: {ex}') + LOG.info( + f"Chat Observer Execution Interrupted (pid: {os.getpid()}) due to exception: {ex}" + ) sys.exit(-1) diff --git a/services/klatchat_observer/constants/neon_api_constants.py b/services/klatchat_observer/constants/neon_api_constants.py index 7f41e386..07482bd1 100644 --- a/services/klatchat_observer/constants/neon_api_constants.py +++ b/services/klatchat_observer/constants/neon_api_constants.py @@ -31,13 +31,13 @@ class NeonServices(Enum): - OWM = 'open_weather_map' - WOLFRAM = 'wolfram_alpha' - ALPHA_VANTAGE = 'alpha_vantage' + OWM = "open_weather_map" + WOLFRAM = "wolfram_alpha" + ALPHA_VANTAGE = "alpha_vantage" neon_service_tokens: Dict[NeonServices, List[str]] = { - NeonServices.OWM: ['lat', 'lng', 'lon'], - NeonServices.ALPHA_VANTAGE: ['symbol'], - NeonServices.WOLFRAM: ['query'] -} \ No newline at end of file + NeonServices.OWM: ["lat", "lng", "lon"], + NeonServices.ALPHA_VANTAGE: ["symbol"], + NeonServices.WOLFRAM: ["query"], +} diff --git a/services/klatchat_observer/controller.py b/services/klatchat_observer/controller.py index d012f485..0f8dd7ed 100644 --- a/services/klatchat_observer/controller.py +++ b/services/klatchat_observer/controller.py @@ -28,14 +28,21 @@ import json import re import time +import cachetools.func + from threading import Event, Timer +import requests import socketio from enum import Enum +from neon_mq_connector.utils import retry from neon_mq_connector.utils.rabbit_utils import create_mq_callback from neon_mq_connector.connector import MQConnector +from requests import Response + +from utils.exceptions import KlatAPIAuthorizationError from utils.logging_utils import LOG from version import __version__ @@ -43,226 +50,298 @@ class Recipients(Enum): """Enumeration of possible recipients""" - NEON = 'neon' - CHATBOT_CONTROLLER = 'chatbot_controller' - UNRESOLVED = 'unresolved' + + NEON = "neon" + CHATBOT_CONTROLLER = "chatbot_controller" + UNRESOLVED = "unresolved" class ChatObserver(MQConnector): """Observer of conversations states""" recipient_prefixes = { - Recipients.NEON: ['neon', '@neon'], - Recipients.UNRESOLVED: ['undefined'] + Recipients.NEON: ["neon", "@neon"], + Recipients.UNRESOLVED: ["undefined"], } vhosts = { - 'neon_api': '/neon_chat_api', - 'chatbots': '/chatbots', - 'translation': '/translation' + "neon_api": "/neon_chat_api", + "chatbots": "/chatbots", + "translation": "/translation", + "llm": "/llm", } - def __init__(self, config: dict, service_name: str = 'chat_observer', vhosts: dict = None, - scan_neon_service: bool = False): - super().__init__(config['MQ'], service_name) + def __init__( + self, + config: dict, + service_name: str = "chat_observer", + vhosts: dict = None, + scan_neon_service: bool = False, + ): + super().__init__(config["MQ"], service_name) if not vhosts: vhosts = {} self.vhosts = {**vhosts, **self.vhosts} self.__translation_requests = {} - self.__neon_service_id = '' + self.__neon_service_id = "" self.neon_detection_enabled = scan_neon_service self.neon_service_event = None self.last_neon_request: int = 0 self.neon_service_refresh_interval = 60 # seconds - self.mention_separator = ',' + self.mention_separator = "," self.recipient_to_handler_method = { Recipients.NEON: self.__handle_neon_recipient, - Recipients.CHATBOT_CONTROLLER: self.__handle_chatbot_recipient + Recipients.CHATBOT_CONTROLLER: self.__handle_chatbot_recipient, } - self._sio = None - self.sio_url = config['SIO_URL'] - try: - self.connect_sio() - except Exception as ex: - err = f'Failed to connect Socket IO at {self.sio_url} due to exception={str(ex)}, observing will not be run' - LOG.warning(err) - if not self.testing_mode: - raise ConnectionError(err) - self.register_consumer(name='neon_response', - vhost=self.get_vhost('neon_api'), - queue='neon_chat_api_response', - callback=self.handle_neon_response, - on_error=self.default_error_handler) - self.register_consumer(name='neon_response_error', - vhost=self.get_vhost('neon_api'), - queue='neon_chat_api_error', - callback=self.handle_neon_error, - on_error=self.default_error_handler) - self.register_consumer(name='neon_stt_response', - vhost=self.get_vhost('neon_api'), - queue='neon_stt_response', - callback=self.on_stt_response, - on_error=self.default_error_handler) - self.register_consumer(name='neon_tts_response', - vhost=self.get_vhost('neon_api'), - queue='neon_tts_response', - callback=self.on_tts_response, - on_error=self.default_error_handler) - self.register_consumer(name='submind_shout', - vhost=self.get_vhost('chatbots'), - queue='submind_response', - callback=self.handle_submind_shout, - on_error=self.default_error_handler) - self.register_consumer(name='save_prompt_data', - vhost=self.get_vhost('chatbots'), - queue='save_prompt', - callback=self.handle_saving_prompt_data, - on_error=self.default_error_handler) - self.register_consumer(name='new_prompt', - vhost=self.get_vhost('chatbots'), - queue='new_prompt', - callback=self.handle_new_prompt, - on_error=self.default_error_handler) - self.register_consumer(name='get_prompt_data', - vhost=self.get_vhost('chatbots'), - queue='get_prompt', - callback=self.handle_get_prompt, - on_error=self.default_error_handler) - self.register_consumer(name='get_neon_translation_response', - vhost=self.get_vhost('translation'), - queue='get_libre_translations', - callback=self.on_neon_translations_response, - on_error=self.default_error_handler) + self._sio: socketio.Client = None + self.sio_url = config["SIO_URL"] + self.server_url = self.sio_url + self._klat_session_token = None + self.klat_auth_credentials = config.get("KLAT_AUTH_CREDENTIALS", {}) + self.default_persona_llms = dict() + self.connect_sio() + self.register_consumer( + name="neon_response", + vhost=self.get_vhost("neon_api"), + queue="neon_chat_api_response", + callback=self.handle_neon_response, + on_error=self.default_error_handler, + ) + self.register_consumer( + name="neon_response_error", + vhost=self.get_vhost("neon_api"), + queue="neon_chat_api_error", + callback=self.handle_neon_error, + on_error=self.default_error_handler, + ) + self.register_consumer( + name="neon_stt_response", + vhost=self.get_vhost("neon_api"), + queue="neon_stt_response", + callback=self.on_stt_response, + on_error=self.default_error_handler, + ) + self.register_consumer( + name="neon_tts_response", + vhost=self.get_vhost("neon_api"), + queue="neon_tts_response", + callback=self.on_tts_response, + on_error=self.default_error_handler, + ) + self.register_consumer( + name="submind_shout", + vhost=self.get_vhost("chatbots"), + queue="submind_response", + callback=self.handle_submind_shout, + on_error=self.default_error_handler, + ) + self.register_consumer( + name="save_prompt_data", + vhost=self.get_vhost("chatbots"), + queue="save_prompt", + callback=self.handle_saving_prompt_data, + on_error=self.default_error_handler, + ) + self.register_consumer( + name="new_prompt", + vhost=self.get_vhost("chatbots"), + queue="new_prompt", + callback=self.handle_new_prompt, + on_error=self.default_error_handler, + ) + self.register_consumer( + name="get_prompt_data", + vhost=self.get_vhost("chatbots"), + queue="get_prompt", + callback=self.handle_get_prompt, + on_error=self.default_error_handler, + ) + self.register_consumer( + name="get_neon_translation_response", + vhost=self.get_vhost("translation"), + queue="get_libre_translations", + callback=self.on_neon_translations_response, + ) + self.register_consumer( + name="get_configured_personas", + vhost=self.get_vhost("llm"), + queue="get_configured_personas", + callback=self.on_get_configured_personas, + ) + self.register_subscriber( + name="subminds_state_receiver", + vhost=self.get_vhost("chatbots"), + exchange="subminds_state", + callback=self.on_subminds_state, + on_error=self.default_error_handler, + ) @classmethod def get_recipient_from_prefix(cls, message: str) -> dict: """ - Gets recipient from incoming message + Gets recipient from incoming message - :param message: incoming message - :returns extracted recipient + :param message: incoming message + :returns extracted recipient """ callback = dict(recipient=Recipients.UNRESOLVED, context={}) message_formatted = message.upper().strip() - if message_formatted.startswith('!PROMPT:'): - callback['recipient'] = Recipients.CHATBOT_CONTROLLER - callback['context'] = dict(requested_participants=['proctor']) - elif message_formatted.startswith('SHOW SCORE:'): - callback['recipient'] = Recipients.CHATBOT_CONTROLLER - callback['context'] = dict(requested_participants=['scorekeeper']) - elif any(message_formatted.startswith(command) for command in ('!START_AUTO_PROMPTS', '!STOP_AUTO_PROMPTS',)): - callback['recipient'] = Recipients.CHATBOT_CONTROLLER + if message_formatted.startswith("!PROMPT:"): + callback["recipient"] = Recipients.CHATBOT_CONTROLLER + callback["context"] = dict(requested_participants=["proctor"]) + elif message_formatted.startswith("SHOW SCORE:"): + callback["recipient"] = Recipients.CHATBOT_CONTROLLER + callback["context"] = dict(requested_participants=["scorekeeper"]) + elif any( + message_formatted.startswith(command) + for command in ( + "!START_AUTO_PROMPTS", + "!STOP_AUTO_PROMPTS", + ) + ): + callback["recipient"] = Recipients.CHATBOT_CONTROLLER else: for recipient, recipient_prefixes in cls.recipient_prefixes.items(): - if any(message_formatted.startswith(x.upper()) for x in recipient_prefixes): - callback['recipient'] = recipient + if any( + message_formatted.startswith(x.upper()) for x in recipient_prefixes + ): + callback["recipient"] = recipient break return callback - @staticmethod - def get_recipient_from_body(message: str) -> dict: + def get_recipient_from_body(self, message: str) -> dict: """ - Gets recipients from message body + Gets recipients from message body - :param message: user's message - :returns extracted recipient + :param message: user's message + :returns extracted recipient - Example: - >>> assert ChatObserver.get_recipient_from_body('@Proctor hello dsfdsfsfds @Prompter') == {'recipient': Recipients.CHATBOT_CONTROLLER, 'context': {'requested_participants': {'proctor', 'prompter'}} + Example: + >>> assert self.get_recipient_from_body('@Proctor hello dsfdsfsfds @Prompter') == {'recipient': Recipients.CHATBOT_CONTROLLER, 'context': {'requested_participants': {'proctor', 'prompter'}} """ - message = ' ' + message - bot_mentioning_regexp = r'[\s]+@[a-zA-Z]+[\w]+' + message = " " + message + bot_mentioning_regexp = r"[\s]+@[a-zA-Z]+[\w]+" bots = re.findall(bot_mentioning_regexp, message) - bots = set([bot.strip().replace('@', '').lower() for bot in bots]) + bots = set([bot.strip().replace("@", "").lower() for bot in bots]) if len(bots) > 0: recipient = Recipients.CHATBOT_CONTROLLER else: recipient = Recipients.UNRESOLVED - return {'recipient': recipient, 'context': {'requested_participants': bots}} + return { + "recipient": recipient, + "context": { + "requested_participants": [ + self.default_persona_llms.get(bot, bot) for bot in bots + ] + }, + } - @staticmethod - def get_recipient_from_bound_service(bound_service) -> dict: - """ Gets recipient in case bounded service is received in data """ + def get_recipient_from_bound_service(self, bound_service) -> dict: + """Gets recipient in case bounded service is received in data""" response = {} - if bound_service.startswith('chatbots'): - response = {'recipient': Recipients.CHATBOT_CONTROLLER, 'context': {'requested_participants': bound_service.split('.')[1].split(',')}} - elif bound_service.startswith('neon'): - service = '.'.join(bound_service.split('.')[1]) - if service == 'assistant': - response = {'recipient': Recipients.NEON, 'context': {}} + if bound_service.startswith("chatbots"): + bot = bound_service.split(".")[1].split(",") + response = { + "recipient": Recipients.CHATBOT_CONTROLLER, + "context": { + "requested_participants": self.default_persona_llms.get(bot, bot) + }, + } + elif bound_service.startswith("neon"): + service = ".".join(bound_service.split(".")[1]) + if service == "assistant": + response = {"recipient": Recipients.NEON, "context": {}} else: - response = {'recipient': Recipients.NEON, 'context': {'requested_service_name': service}} + response = { + "recipient": Recipients.NEON, + "context": {"requested_service_name": service}, + } return response def get_recipient_from_message(self, message: str) -> dict: """ - Gets recipient based on message + Gets recipient based on message - :param message: message text + :param message: message text - :returns Dictionary of type: {"recipient": (instance of Recipients), - "context": dictionary with supportive context} + :returns Dictionary of type: {"recipient": (instance of Recipients), + "context": dictionary with supportive context} """ # Parsing message prefix response_body = self.get_recipient_from_prefix(message=message) # Parsing message body - if response_body['recipient'] == Recipients.UNRESOLVED: + if response_body["recipient"] == Recipients.UNRESOLVED: response_body = self.get_recipient_from_body(message=message) return response_body @property def neon_service_id(self): """Gets neon service id / detects the one from synchronization loop if neon_detection enabled""" - if not self.__neon_service_id \ - or int(time.time()) - self.last_neon_request >= self.neon_service_refresh_interval \ - and self.neon_detection_enabled: + if ( + not self.__neon_service_id + or int(time.time()) - self.last_neon_request + >= self.neon_service_refresh_interval + and self.neon_detection_enabled + ): self.get_neon_service() return self.__neon_service_id def get_neon_service(self, wait_timeout: int = 10) -> None: """ - Scans neon service synchronization loop for neon service id + Scans neon service synchronization loop for neon service id """ self.neon_service_event = Event() - self.register_consumer(name='neon_service_sync_consumer', - callback=self.handle_neon_sync, - vhost=self.get_vhost('neon_api'), - on_error=self.default_error_handler, - queue='chat_api_proxy_sync') - sync_consumer = self.consumers['neon_service_sync_consumer'] + self.register_consumer( + name="neon_service_sync_consumer", + callback=self.handle_neon_sync, + vhost=self.get_vhost("neon_api"), + on_error=self.default_error_handler, + queue="chat_api_proxy_sync", + ) + sync_consumer = self.consumers["neon_service_sync_consumer"] sync_consumer.start() self.neon_service_event.wait(wait_timeout) - LOG.info('Joining sync consumer') + LOG.info("Joining sync consumer") sync_consumer.join() if not self.neon_service_event.is_set(): - LOG.warning(f'Failed to get neon_service in {wait_timeout} seconds') - self.__neon_service_id = '' + LOG.warning(f"Failed to get neon_service in {wait_timeout} seconds") + self.__neon_service_id = "" def register_sio_handlers(self): """Convenience method for setting up Socket IO listeners""" - self._sio.on('new_message', handler=self.handle_message) - self._sio.on('get_stt', handler=self.handle_get_stt) - self._sio.on('get_tts', handler=self.handle_get_tts) - self._sio.on('prompt_data', handler=self.forward_prompt_data) - self._sio.on('request_neon_translations', handler=self.request_neon_translations) - - def connect_sio(self, refresh=False): - """ - Method for establishing connection with Socket IO server - - :param refresh: To refresh an existing instance - """ - if not self._sio or refresh: - self._sio = socketio.Client() - self._sio.connect(url=self.sio_url) - self.register_sio_handlers() + self._sio.on("new_message", handler=self.handle_message) + self._sio.on("get_stt", handler=self.handle_get_stt) + self._sio.on("get_tts", handler=self.handle_get_tts) + self._sio.on("prompt_data", handler=self.forward_prompt_data) + self._sio.on( + "request_neon_translations", handler=self.request_neon_translations + ) + self._sio.on("ban_submind", handler=self.request_ban_submind) + self._sio.on( + "ban_submind_from_conversation", + handler=self.request_ban_submind_from_conversation, + ) + self._sio.on("revoke_submind_ban", handler=self.request_revoke_submind_ban) + self._sio.on( + "revoke_submind_ban_from_conversation", + handler=self.request_revoke_submind_ban_from_conversation, + ) + + @retry(use_self=True) + def connect_sio(self): + """ + Method for establishing connection with Socket IO server + """ + self._sio = socketio.Client() + self._sio.connect(url=self.sio_url, namespaces=["/"]) + self.register_sio_handlers() @property def sio(self): """ - Creates socket io client if none is present + Creates socket io client if none is present - :return: connected async socket io instance + :return: connected async socket io instance """ if not self._sio: self.connect_sio() @@ -274,322 +353,501 @@ def apply_testing_prefix(self, vhost): :param vhost: MQ virtual host to validate """ # TODO: implement this method in the base class - if self.testing_mode and self.testing_prefix not in vhost.split('_')[0]: - vhost = f'/{self.testing_prefix}_{vhost[1:]}' - if vhost.endswith('_'): + if self.testing_mode and self.testing_prefix not in vhost.split("_")[0]: + vhost = f"/{self.testing_prefix}_{vhost[1:]}" + if vhost.endswith("_"): vhost = vhost[:-1] return vhost def get_vhost(self, name: str): - """ Gets actual vhost based on provided string """ + """Gets actual vhost based on provided string""" if name not in list(self.vhosts): - LOG.error(f'Invalid vhost specified - {name}') + LOG.error(f"Invalid vhost specified - {name}") return name else: return self.apply_testing_prefix(vhost=self.vhosts.get(name)) @staticmethod def get_neon_request_structure(msg_data: dict): - """ Gets Neon API message structure based on received request skill type """ - request_skills = msg_data.get('request_skills', 'default').lower() - if request_skills == 'tts': - utterance = msg_data.pop('utterance', '') or msg_data.pop('text', '') + """Gets Neon API message structure based on received request skill type""" + requested_skill = msg_data.get("requested_skill", "recognizer").lower() + if requested_skill == "tts": + utterance = msg_data.pop("utterance", "") or msg_data.pop("text", "") request_dict = { - 'data': { - 'utterance': utterance, - 'text': utterance, + "data": { + "utterance": utterance, + "text": utterance, }, - 'context': { - 'sender_context': msg_data - } + "context": {"sender_context": msg_data}, } - elif request_skills == 'stt': + elif requested_skill == "stt": request_dict = { - 'data': { - 'audio_data': msg_data.pop('audio_data', msg_data['message_body']), + "data": { + "audio_data": msg_data.pop("audio_data", msg_data["message_body"]), } } else: request_dict = { - 'data': { - 'utterances': [msg_data['message_body']], + "data": { + "utterances": [msg_data["message_body"]], }, - 'context': { - 'sender_context': msg_data - } + "context": {"sender_context": msg_data}, } # TODO: any specific structure per wolfram/duckduckgo, etc... return request_dict def __handle_neon_recipient(self, recipient_data: dict, msg_data: dict): - msg_data.setdefault('message_body', msg_data.pop('messageText', '')) - msg_data.setdefault('message_id', msg_data.pop('messageID', '')) - recipient_data.setdefault('context', {}) + msg_data.setdefault("message_body", msg_data.pop("messageText", "")) + msg_data.setdefault("message_id", msg_data.pop("messageID", "")) + recipient_data.setdefault("context", {}) pattern = re.compile("Neon", re.IGNORECASE) - msg_data['message_body'] = pattern.sub("", msg_data['message_body'], 1).strip('<>@,.:|- ').capitalize() - msg_data['request_skills'] = recipient_data['context'].pop('service', 'default') + msg_data["message_body"] = ( + pattern.sub("", msg_data["message_body"], 1).strip("<>@,.:|- ").capitalize() + ) + msg_data.setdefault( + "requested_skill", recipient_data["context"].pop("service", "recognizer") + ) request_dict = self.get_neon_request_structure(msg_data) - request_dict['data']['lang'] = msg_data.get('lang', 'en-us') - request_dict['context'] = {**recipient_data.get('context', {}), - **{'source': 'mq_api', - 'message_id': msg_data.get('message_id'), - 'sid': msg_data.get('sid'), - 'cid': msg_data.get('cid'), - 'agent': msg_data.get('agent', f'pyklatchat v{__version__}'), - 'requested_service_name': recipient_data['context'].get('requested_service_name', ''), - 'request_skills': [msg_data.get('request_skills', 'recognizer').lower()], - 'username': msg_data.pop('nick', 'guest')}} - input_queue = 'neon_chat_api_request' + request_dict["data"]["lang"] = msg_data.get("lang", "en-us") + request_dict["context"] = { + **recipient_data.get("context", {}), + **{ + "source": "mq_api", + "message_id": msg_data.get("message_id"), + "sid": msg_data.get("sid"), + "cid": msg_data.get("cid"), + "agent": msg_data.get("agent", f"pyklatchat v{__version__}"), + "requested_service_name": recipient_data["context"].get( + "requested_service_name", "" + ), + "request_skills": [msg_data["requested_skill"].lower()], + "username": msg_data.pop("nick", "guest"), + }, + } + input_queue = "neon_chat_api_request" if self.neon_detection_enabled: neon_service_id = self.neon_service_id if neon_service_id: - input_queue = f'{input_queue}_{neon_service_id}' + input_queue = f"{input_queue}_{neon_service_id}" self.last_neon_request = int(time.time()) - self.send_message(request_data=request_dict, vhost=self.get_vhost('neon_api'), queue=input_queue) + self.send_message( + request_data=request_dict, + vhost=self.get_vhost("neon_api"), + queue=input_queue, + ) def __handle_chatbot_recipient(self, recipient_data: dict, msg_data: dict): - LOG.info(f'Emitting message to Chatbot Controller: {recipient_data}') - queue = 'external_shout' - msg_data['requested_participants'] = json.dumps(list(recipient_data.setdefault('context', {}) - .setdefault('requested_participants', []))) - self.send_message(request_data=msg_data, vhost=self.get_vhost('chatbots'), queue=queue, expiration=3000) + LOG.info(f"Emitting message to Chatbot Controller: {recipient_data}") + queue = "external_shout" + msg_data["requested_participants"] = json.dumps( + list( + recipient_data.setdefault("context", {}).setdefault( + "requested_participants", [] + ) + ) + ) + self.send_message( + request_data=msg_data, + vhost=self.get_vhost("chatbots"), + queue=queue, + expiration=3000, + ) def handle_get_stt(self, data): - """ Handler for get STT request from Socket IO channel """ - data['recipient'] = Recipients.NEON - data['skip_recipient_detection'] = True - data['request_skills'] = 'stt' + """Handler for get STT request from Socket IO channel""" + data["recipient"] = Recipients.NEON + data["skip_recipient_detection"] = True + data["requested_skill"] = "stt" self.handle_message(data=data) def handle_get_tts(self, data): - """ Handler for get TTS request from Socket IO channel """ - data['recipient'] = Recipients.NEON - data['skip_recipient_detection'] = True - data['request_skills'] = 'tts' + """Handler for get TTS request from Socket IO channel""" + data["recipient"] = Recipients.NEON + data["skip_recipient_detection"] = True + data["requested_skill"] = "tts" self.handle_message(data=data) def handle_message(self, data: dict): """ - Handles input requests from MQ to Neon API + Handles input requests from MQ to Neon API - :param data: Received user data + :param data: Received user data """ if data and isinstance(data, dict): recipient_data = {} - if not data.get('skip_recipient_detection'): - recipient_data = self.get_recipient_from_bound_service(data.get('bound_service', '')) or \ - self.get_recipient_from_message(message=data.get('messageText', - data.get('message_body'))) - recipient = recipient_data.get('recipient') or data.get('recipient') or Recipients.UNRESOLVED + if not data.get("skip_recipient_detection"): + recipient_data = self.get_recipient_from_bound_service( + data.get("bound_service", "") + ) or self.get_recipient_from_message( + message=data.get("messageText", data.get("message_body")) + ) + recipient = ( + recipient_data.get("recipient") + or data.get("recipient") + or Recipients.UNRESOLVED + ) handler_method = self.recipient_to_handler_method.get(recipient) if not handler_method: - LOG.warning(f'Failed to get handler message for recipient={recipient}') + LOG.warning(f"Failed to get handler message for recipient={recipient}") else: handler_method(recipient_data=recipient_data, msg_data=data) else: - raise TypeError(f'Malformed data received: {data}') + raise TypeError(f"Malformed data received: {data}") def forward_prompt_data(self, data: dict): """Forwards received prompt data to the destination observer""" - requested_nick = data.pop('receiver', None) + requested_nick = data.pop("receiver", None) if not requested_nick: - LOG.warning('Forwarding to unknown recipient, skipping') + LOG.warning("Forwarding to unknown recipient, skipping") return -1 - self.send_message(request_data=data, - vhost=self.get_vhost('chatbots'), - queue=f'{requested_nick}_prompt_data', - expiration=3000) + self.send_message( + request_data=data, + vhost=self.get_vhost("chatbots"), + queue=f"{requested_nick}_prompt_data", + expiration=3000, + ) def request_neon_translations(self, data: dict): - """ Requests translations from neon """ - request_id = data.pop('request_id', None) + """Requests translations from neon""" + request_id = data.pop("request_id", None) if request_id: - default_callback = { - 'data': { - 'request_id': request_id, - 'translations': {} - } + default_callback = {"data": {"request_id": request_id, "translations": {}}} + self.__translation_requests[request_id] = { + "void_callback_timer": Timer( + interval=2 * 60, + function=self.send_translation_response, + kwargs=default_callback, + ) } - self.__translation_requests[request_id] = {'void_callback_timer': Timer(interval=2 * 60, - function=self.send_translation_response, - kwargs=default_callback)} - self.__translation_requests[request_id]['void_callback_timer'].start() - self.send_message(request_data={'data': data['data'], - 'request_id': request_id}, - vhost=self.get_vhost('translation'), - queue='request_libre_translations', expiration=3000) + self.__translation_requests[request_id]["void_callback_timer"].start() + self.send_message( + request_data={"data": data["data"], "request_id": request_id}, + vhost=self.get_vhost("translation"), + queue="request_libre_translations", + expiration=3000, + ) @create_mq_callback() def on_neon_translations_response(self, body: dict): """ - Translations response from neon + Translations response from neon - :param body: request body (dict) + :param body: request body (dict) """ - self.send_translation_response(data={'request_id': body['request_id'], - 'translations': body['data']}) + self.send_translation_response( + data={"request_id": body["request_id"], "translations": body["data"]} + ) def send_translation_response(self, data: dict): """ - Sends translation response back to klatchat - :param data: translation data to send + Sends translation response back to klatchat + :param data: translation data to send """ - request_id = data.get('request_id', None) + request_id = data.get("request_id", None) if request_id and self.__translation_requests.pop(request_id, None): - self.sio.emit('get_neon_translations', data=data) + self.sio.emit("get_neon_translations", data=data) else: - LOG.warning(f'Neon translation response was not sent, ' - f'as request_id={request_id} was not found among translation requests') + LOG.warning( + f"Neon translation response was not sent, " + f"as request_id={request_id} was not found among translation requests" + ) @create_mq_callback() def handle_get_prompt(self, body: dict): """ - Handles get request for the prompt data + Handles get request for the prompt data - :param body: request body (dict) + :param body: request body (dict) """ - requested_nick = body.get('nick', None) + requested_nick = body.get("nick", None) if not requested_nick: - LOG.warning('Request from unknown sender, skipping') + LOG.warning("Request from unknown sender, skipping") return -1 - self.sio.emit('get_prompt_data', data=body) + self.sio.emit("get_prompt_data", data=body) @create_mq_callback() def handle_neon_sync(self, body: dict): """ - Handles input neon api sync requests from MQ + Handles input neon api sync requests from MQ - :param body: request body (dict) + :param body: request body (dict) """ - service_id = body.get('service_id', None) + service_id = body.get("service_id", None) if service_id: - LOG.info(f'Received neon service id: {service_id}') + LOG.info(f"Received neon service id: {service_id}") self.__neon_service_id = service_id self.neon_service_event.set() else: - LOG.error('No service id specified - neon api is not synchronized') + LOG.error("No service id specified - neon api is not synchronized") @create_mq_callback() def handle_neon_response(self, body: dict): """ - Handles responses from Neon API - :param body: request body (dict) + Handles responses from Neon API + :param body: request body (dict) """ try: - LOG.info(f'Received Neon Response: {body}') - msg_type = body['msg_type'] - data = body['data'] - context = body['context'] - neon_chat_skill_pattern = re.compile(r'chat.[a-z]([a-z]|[0-9]|_)*([a-z]|[0-9]).response', re.IGNORECASE) + LOG.info(f"Received Neon Response: {body}") + msg_type = body["msg_type"] + data = body["data"] + context = body["context"] + neon_chat_skill_pattern = re.compile( + r"chat.[a-z]([a-z]|[0-9]|_)*([a-z]|[0-9]).response", re.IGNORECASE + ) response_languages = [] if neon_chat_skill_pattern.match(msg_type): - response = data.get('response', 'No idea.') - service_name = msg_type.split('chat.')[1].split('.response')[0] - user_id = f'neon.{service_name}' + response = data.get("response", "No idea.") + service_name = msg_type.split("chat.")[1].split(".response")[0] + user_id = f"neon.{service_name}" else: - response_languages = list(data['responses']) - response = data['responses'][response_languages[0]]['sentence'] - user_id = 'neon' + response_languages = list(data["responses"]) + response = data["responses"][response_languages[0]]["sentence"] + user_id = "neon" # TODO: multilingual support send_data = { - 'cid': context['cid'], - 'userID': user_id, - 'repliedMessage': context.get('message_id', ''), - 'messageText': response, - 'messageTTS': {}, - 'source': 'klat_observer', - 'timeCreated': int(time.time()) + "cid": context["cid"], + "userID": user_id, + "repliedMessage": context.get("message_id", ""), + "messageText": response, + "messageTTS": {}, + "source": "klat_observer", + "timeCreated": int(time.time()), } - response_audio_genders = data.get('genders', []) + response_audio_genders = data.get("genders", []) for language in response_languages: for gender in response_audio_genders: try: - send_data['messageTTS'].setdefault(language, {})[gender] = \ - data['responses'][language]['audio'][gender] + send_data["messageTTS"].setdefault(language, {})[gender] = data[ + "responses" + ][language]["audio"][gender] except Exception as ex: - LOG.error(f'Failed to set messageTTS with language={language}, gender={gender} - {ex}') - self.sio.emit('user_message', data=send_data) + LOG.error( + f"Failed to set messageTTS with language={language}, gender={gender} - {ex}" + ) + self.sio.emit("user_message", data=send_data) except Exception as ex: - LOG.error(f'Failed to emit Neon Chat API response: {ex}') + LOG.error(f"Failed to emit Neon Chat API response: {ex}") @create_mq_callback() def handle_neon_error(self, body: dict): """ - Handles responses from Neon API + Handles responses from Neon API - :param body: request body (bytes) + :param body: request body (bytes) """ - LOG.error(f'Error response from Neon API: {body}') + LOG.error(f"Error response from Neon API: {body}") @create_mq_callback() def handle_submind_shout(self, body: dict): """ - Handles shouts from subminds outside the PyKlatchat + Handles shouts from subminds outside the PyKlatchat - :param body: request body (dict) + :param body: request body (dict) """ - response_required_keys = ('userID', 'cid', 'messageText',) + response_required_keys = ( + "userID", + "cid", + "messageText", + ) if all(required_key in list(body) for required_key in response_required_keys): - body.setdefault('timeCreated', int(time.time())) - body.setdefault('source', 'klat_observer') - self.sio.emit('user_message', data=body) + body.setdefault("timeCreated", int(time.time())) + body.setdefault("source", "klat_observer") + self.sio.emit("user_message", data=body) self.handle_message(data=body) else: - error_msg = f'Skipping received data {body} as it lacks one of the required keys: ' \ - f'({",".join(response_required_keys)})' + error_msg = ( + f"Skipping received data {body} as it lacks one of the required keys: " + f'({",".join(response_required_keys)})' + ) LOG.warning(error_msg) - self.send_message(request_data={'msg': error_msg}, vhost=self.get_vhost('chatbots'), - queue='chatbot_response_error', expiration=3000) + self.send_message( + request_data={"msg": error_msg}, + vhost=self.get_vhost("chatbots"), + queue="chatbot_response_error", + expiration=3000, + ) @create_mq_callback() def handle_new_prompt(self, body: dict): """ - Handles announcement of new prompt - :param body: new prompt body - :return: + Handles announcement of new prompt + :param body: new prompt body + :return: """ - response_required_keys = ('cid', 'prompt_id', 'prompt_text',) + response_required_keys = ( + "cid", + "prompt_id", + "prompt_text", + ) if all(required_key in list(body) for required_key in response_required_keys): - self.sio.emit('new_prompt', data=body) + self.sio.emit("new_prompt", data=body) else: - error_msg = f'Skipping received data {body} as it lacks one of the required keys: ' \ - f'({",".join(response_required_keys)})' + error_msg = ( + f"Skipping received data {body} as it lacks one of the required keys: " + f'({",".join(response_required_keys)})' + ) LOG.error(error_msg) - self.send_message(request_data={'msg': error_msg}, vhost=self.get_vhost('chatbots'), - queue='chatbot_response_error', expiration=3000) + self.send_message( + request_data={"msg": error_msg}, + vhost=self.get_vhost("chatbots"), + queue="chatbot_response_error", + expiration=3000, + ) @create_mq_callback() def handle_saving_prompt_data(self, body: dict): """ - Handles requests for saving prompt data + Handles requests for saving prompt data - :param body: request body (dict) + :param body: request body (dict) """ - response_required_keys = ('userID', 'cid', 'messageText', 'bot', 'timeCreated', 'context',) + response_required_keys = ( + "userID", + "cid", + "messageText", + "bot", + "timeCreated", + "context", + ) if all(required_key in list(body) for required_key in response_required_keys): - self.sio.emit('prompt_completed', data=body) + self.sio.emit("prompt_completed", data=body) else: - error_msg = f'Skipping received data {body} as it lacks one of the required keys: ' \ - f'({",".join(response_required_keys)})' + error_msg = ( + f"Skipping received data {body} as it lacks one of the required keys: " + f'({",".join(response_required_keys)})' + ) LOG.error(error_msg) - self.send_message(request_data={'msg': error_msg}, vhost=self.get_vhost('chatbots'), - queue='chatbot_response_error', expiration=3000) + self.send_message( + request_data={"msg": error_msg}, + vhost=self.get_vhost("chatbots"), + queue="chatbot_response_error", + expiration=3000, + ) @create_mq_callback() def on_stt_response(self, body: dict): - """ Handles receiving STT response """ - LOG.info(f'Received STT Response: {body}') - self.sio.emit('stt_response', data=body) + """Handles receiving STT response""" + LOG.debug(f"Received STT Response: {body}") + self.sio.emit("stt_response", data=body) @create_mq_callback() def on_tts_response(self, body: dict): - """ Handles receiving TTS response """ - LOG.info(f'Received TTS Response: {body}') - self.sio.emit('tts_response', data=body) + """Handles receiving TTS response""" + LOG.debug(f"Received TTS Response: {body}") + self.sio.emit("tts_response", data=body) + + @create_mq_callback() + def on_subminds_state(self, body: dict): + """Handles receiving subminds state message""" + LOG.debug(f"Received submind state: {body}") + body["msg_type"] = "subminds_state" + self.sio.emit("broadcast", data=body) + + @create_mq_callback() + def on_get_configured_personas(self, body: dict): + response_data = self._fetch_persona_api(user_id=body.get("user_id")) + response_data["items"] = [ + item + for item in response_data["items"] + if body["service_name"] in item["supported_llms"] + ] + response_data.setdefault("context", {}).setdefault("mq", {}).setdefault( + "message_id", body["message_id"] + ) + self.send_message( + request_data=response_data, + vhost=self.get_vhost("llm"), + queue=body["routing_key"], + expiration=5000, + ) + + @cachetools.func.ttl_cache(ttl=15) + def _fetch_persona_api(self, user_id: str) -> dict: + query_string = self._build_persona_api_query(user_id=user_id) + url = f"{self.server_url}/personas/list?{query_string}" + try: + response = self._fetch_klat_server(url=url) + data = response.json() + self._refresh_default_persona_llms(data=data) + except KlatAPIAuthorizationError: + LOG.error(f"Failed to fetch personas from {url = }") + data = {"items": []} + return data + + def _refresh_default_persona_llms(self, data): + for item in data["items"]: + if default_llm := item.get("default_llm"): + self.default_persona_llms[item["id"]] = item["id"] + "_" + default_llm + + def _build_persona_api_query(self, user_id: str) -> str: + url_query_params = f"only_enabled=true" + if user_id: + url_query_params += f"&user_id={user_id}" + return url_query_params + + def request_ban_submind(self, data: dict): + self.send_message( + request_data=data, + vhost=self.get_vhost("chatbots"), + queue="ban_submind", + expiration=3000, + ) + + def _fetch_klat_server(self, url: str) -> Response: + # only getter method is supported, for POST/PUT/DELETE operations using Socket IO is preferable channel + if self._klat_session_token: + response = self._send_get_request_to_klat(url=url) + if response.ok: + return response + elif response.status_code != 403: + raise KlatAPIAuthorizationError("Klat API unavailable") + self._login_to_klat_server() + return self._send_get_request_to_klat(url=url) + + def _send_get_request_to_klat(self, url: str) -> Response: + return requests.get( + url=url, headers={"Authorization": self._klat_session_token} + ) + + def _login_to_klat_server(self): + response = requests.post( + f"{self.server_url}/auth/login", data=self.klat_auth_credentials + ) + if response.ok: + self._klat_session_token = response.json()["token"] + else: + LOG.error( + f"Klat API authorization error: [{response.status_code}] {response.text}" + ) + raise KlatAPIAuthorizationError + + def request_ban_submind_from_conversation(self, data: dict): + self.send_message( + request_data=data, + vhost=self.get_vhost("chatbots"), + queue="ban_submind_from_conversation", + expiration=3000, + ) + + def request_revoke_submind_ban(self, data: dict): + self.send_message( + request_data=data, + vhost=self.get_vhost("chatbots"), + queue="revoke_submind_ban", + expiration=3000, + ) + + def request_revoke_submind_ban_from_conversation(self, data: dict): + self.send_message( + request_data=data, + vhost=self.get_vhost("chatbots"), + queue="revoke_submind_ban_from_conversation", + expiration=3000, + ) diff --git a/services/klatchat_observer/utils/neon_api_utils.py b/services/klatchat_observer/utils/neon_api_utils.py index 6703aad9..5d70b608 100644 --- a/services/klatchat_observer/utils/neon_api_utils.py +++ b/services/klatchat_observer/utils/neon_api_utils.py @@ -28,18 +28,20 @@ from ..constants.neon_api_constants import NeonServices, neon_service_tokens -def resolve_neon_service(message_data: dict, bypass_threshold: float = 0.5) -> NeonServices: +def resolve_neon_service( + message_data: dict, bypass_threshold: float = 0.5 +) -> NeonServices: """ - Resolves desired neon service based on the data content from message + Resolves desired neon service based on the data content from message - :param message_data: dictionary containing data for message - :param bypass_threshold: edge value to consider valid match + :param message_data: dictionary containing data for message + :param bypass_threshold: edge value to consider valid match - :returns neon service from NeonServices + :returns neon service from NeonServices """ # TODO: parse message text into lexemes for neon_service, tokens in neon_service_tokens.items(): - match_percentage = len(set(list(message_data)) & set(tokens))/len(tokens) + match_percentage = len(set(list(message_data)) & set(tokens)) / len(tokens) if match_percentage > bypass_threshold: return neon_service return NeonServices.WOLFRAM diff --git a/setup.py b/setup.py index 1843ec0f..a504aed0 100644 --- a/setup.py +++ b/setup.py @@ -31,15 +31,22 @@ def get_requirements(requirements_filename: str): - requirements_file = path.join(path.abspath(path.dirname(__file__)), "requirements", requirements_filename) - with open(requirements_file, 'r', encoding='utf-8') as r: + requirements_file = path.join( + path.abspath(path.dirname(__file__)), "requirements", requirements_filename + ) + with open(requirements_file, "r", encoding="utf-8") as r: requirements = r.readlines() - requirements = [r.strip() for r in requirements if r.strip() and not r.strip().startswith("#")] + requirements = [ + r.strip() for r in requirements if r.strip() and not r.strip().startswith("#") + ] for i in range(0, len(requirements)): r = requirements[i] if "@" in r: - parts = [p.lower() if p.strip().startswith("git+http") else p for p in r.split('@')] + parts = [ + p.lower() if p.strip().startswith("git+http") else p + for p in r.split("@") + ] r = "@".join(parts) if getenv("GITHUB_TOKEN"): if "github.com" in r: @@ -60,25 +67,25 @@ def get_requirements(requirements_filename: str): version = line.split("'")[1] setup( - name='pyklatchat', + name="pyklatchat", version=version, - description='Klatchat v2.0', - url='https://github.com/NeonGeckoCom/pyklatchat', - author='NeonGecko', - author_email='developers@neon.ai', - license='BSD-3', - packages=['chat_server', 'chat_client', 'services.klatchat_observer'], + description="Klatchat v2.0", + url="https://github.com/NeonGeckoCom/pyklatchat", + author="NeonGecko", + author_email="developers@neon.ai", + license="BSD-3", + packages=["chat_server", "chat_client", "services.klatchat_observer"], install_requires=get_requirements("requirements.txt"), zip_safe=True, classifiers=[ - 'Intended Audience :: Developers', - 'Programming Language :: Python :: 3.8', + "Intended Audience :: Developers", + "Programming Language :: Python :: 3.10", ], entry_points={ - 'console_scripts': [ - 'chat_server=chat_server.__main__:main', - 'chat_client=chat_client.__main__:main', - 'klatchat_observer=services.klatchat_observer.__main__:main' + "console_scripts": [ + "chat_server=chat_server.__main__:main", + "chat_client=chat_client.__main__:main", + "klatchat_observer=services.klatchat_observer.__main__:main", ] - } -) \ No newline at end of file + }, +) diff --git a/tests/mock.py b/tests/mock.py index 8578230a..4b3ee9c5 100644 --- a/tests/mock.py +++ b/tests/mock.py @@ -30,6 +30,5 @@ class MQConnectorChild(MQConnector): - - def __init__(self, config: dict = None, service_name: str = 'test'): + def __init__(self, config: dict = None, service_name: str = "test"): super().__init__(config=config, service_name=service_name) diff --git a/tests/test_db_utils.py b/tests/test_db_utils.py index f8af5bcb..9bd8494d 100644 --- a/tests/test_db_utils.py +++ b/tests/test_db_utils.py @@ -31,7 +31,11 @@ import unittest -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))) +sys.path.append( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + ) +) from config import Configuration from utils.connection_utils import create_ssh_tunnel @@ -40,67 +44,96 @@ class TestDBController(unittest.TestCase): - @classmethod def setUpClass(cls) -> None: - db_config_file_path = os.environ.get('DATABASE_CONFIG', '~/.local/share/neon/credentials.json') - ssh_config_file_path = os.environ.get('SSH_CONFIG', '~/.local/share/neon/credentials.json') + db_config_file_path = os.environ.get( + "DATABASE_CONFIG", "~/.local/share/neon/credentials.json" + ) + ssh_config_file_path = os.environ.get( + "SSH_CONFIG", "~/.local/share/neon/credentials.json" + ) - cls.configuration = Configuration(from_files=[db_config_file_path, ssh_config_file_path]) + cls.configuration = Configuration( + from_files=[db_config_file_path, ssh_config_file_path] + ) - @unittest.skip('legacy db is not supported') + @unittest.skip("legacy db is not supported") def test_simple_interaction_mysql(self): - ssh_configs = self.configuration.config_data.get('SSH_CONFIG', None) + ssh_configs = self.configuration.config_data.get("SSH_CONFIG", None) override_configs = dict() if ssh_configs: - tunnel_connection = create_ssh_tunnel(server_address=ssh_configs['ADDRESS'], - username=ssh_configs['USER'], - password=ssh_configs['PASSWORD'], - remote_bind_address=('127.0.0.1', 3306)) - override_configs = {'host': '127.0.0.1', - 'port': tunnel_connection.local_bind_address[1]} - self.db_controller = self.configuration.get_db_controller(name='klatchat_2222', - override_args=override_configs) + tunnel_connection = create_ssh_tunnel( + server_address=ssh_configs["ADDRESS"], + username=ssh_configs["USER"], + password=ssh_configs["PASSWORD"], + remote_bind_address=("127.0.0.1", 3306), + ) + override_configs = { + "host": "127.0.0.1", + "port": tunnel_connection.local_bind_address[1], + } + self.db_controller = self.configuration.get_db_controller( + name="klatchat_2222", override_args=override_configs + ) - simple_query = """SELECT name, created, last_updated_cid,value from shoutbox_cache;""" + simple_query = ( + """SELECT name, created, last_updated_cid,value from shoutbox_cache;""" + ) result = self.db_controller.exec_query(query=simple_query) self.assertIsNotNone(result) def test_simple_interaction_mongo(self): - self.db_controller = self.configuration.get_db_controller(name='pyklatchat_3333') + self.db_controller = self.configuration.get_db_controller( + name="pyklatchat_3333" + ) self.assertIsNotNone(self.db_controller) test_data = {"name": "John", "address": "Highway 37"} - self.db_controller.exec_query(query={'command': 'insert_one', - 'document': 'test', - 'data': test_data}) - inserted_data = self.db_controller.exec_query(query={'command': 'find_one', - 'document': 'test', - 'data': test_data}) - LOG.debug(f'Received inserted data: {inserted_data}') + self.db_controller.exec_query( + query={"command": "insert_one", "document": "test", "data": test_data} + ) + inserted_data = self.db_controller.exec_query( + query={"command": "find_one", "document": "test", "data": test_data} + ) + LOG.debug(f"Received inserted data: {inserted_data}") self.assertIsNotNone(inserted_data) self.assertIsInstance(inserted_data, dict) - self.db_controller.exec_query(query={'command': 'delete_many', - 'document': 'test', - 'data': test_data}) + self.db_controller.exec_query( + query={"command": "delete_many", "document": "test", "data": test_data} + ) def test_simple_interaction_mongo_new_design(self): - self.db_controller = self.configuration.get_db_controller(name='pyklatchat_3333') + self.db_controller = self.configuration.get_db_controller( + name="pyklatchat_3333" + ) self.assertIsNotNone(self.db_controller) test_data = {"name": "John", "address": "Highway 37"} - self.db_controller.exec_query(MongoQuery(command=MongoCommands.INSERT_ONE, - document=MongoDocuments.TEST, - data=test_data)) - inserted_data = self.db_controller.exec_query(MongoQuery(command=MongoCommands.FIND_ONE, - document=MongoDocuments.TEST, - filters=[MongoFilter(key='name', value='John'), - MongoFilter(key='address', - value='Highway 37')])) - LOG.debug(f'Received inserted data: {inserted_data}') + self.db_controller.exec_query( + MongoQuery( + command=MongoCommands.INSERT_ONE, + document=MongoDocuments.TEST, + data=test_data, + ) + ) + inserted_data = self.db_controller.exec_query( + MongoQuery( + command=MongoCommands.FIND_ONE, + document=MongoDocuments.TEST, + filters=[ + MongoFilter(key="name", value="John"), + MongoFilter(key="address", value="Highway 37"), + ], + ) + ) + LOG.debug(f"Received inserted data: {inserted_data}") self.assertIsNotNone(inserted_data) self.assertIsInstance(inserted_data, dict) - self.db_controller.exec_query(MongoQuery(command=MongoCommands.DELETE_MANY, - document=MongoDocuments.TEST, - filters=[MongoFilter(key='name', value='John'), - MongoFilter(key='address', value='Highway 37')])) - - + self.db_controller.exec_query( + MongoQuery( + command=MongoCommands.DELETE_MANY, + document=MongoDocuments.TEST, + filters=[ + MongoFilter(key="name", value="John"), + MongoFilter(key="address", value="Highway 37"), + ], + ) + ) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 00000000..718d1b00 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,27 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/utils/common.py b/utils/common.py index 0c24ea78..42a35c62 100644 --- a/utils/common.py +++ b/utils/common.py @@ -33,33 +33,33 @@ def generate_uuid(length=10) -> str: """ - Generates UUID string of desired length + Generates UUID string of desired length - :param length: length of the output UUID string + :param length: length of the output UUID string - :returns UUID string of the desired length + :returns UUID string of the desired length """ return uuid4().hex[:length] -def get_hash(input_str: str, encoding='utf-8', algo='sha512') -> str: +def get_hash(input_str: str, encoding="utf-8", algo="sha512") -> str: """ - Returns hashed version of input string corresponding to specified algorithm + Returns hashed version of input string corresponding to specified algorithm - :param input_str: input string to hash - :param encoding: encoding for string to be conformed to (defaults to UTF-8) - :param algo: hashing algorithm to use (defaults to SHA-512), - should correspond to hashlib hashing methods, - refer to: https://docs.python.org/3/library/hashlib.html + :param input_str: input string to hash + :param encoding: encoding for string to be conformed to (defaults to UTF-8) + :param algo: hashing algorithm to use (defaults to SHA-512), + should correspond to hashlib hashing methods, + refer to: https://docs.python.org/3/library/hashlib.html - :returns hashed string from the provided input + :returns hashed string from the provided input """ return getattr(hashlib, algo)(input_str.encode(encoding)).hexdigest() def get_version(from_path: str = None): """Gets version from provided path - :param from_path: path to get version from""" + :param from_path: path to get version from""" with open(from_path, "r", encoding="utf-8") as v: for line in v.readlines(): if line.startswith("__version__"): @@ -71,7 +71,7 @@ def get_version(from_path: str = None): def deep_merge(source: dict, destination: dict) -> dict: - """ Deeply merges source dict into destination """ + """Deeply merges source dict into destination""" for key, value in source.items(): if isinstance(value, dict): # get node or create one @@ -83,12 +83,12 @@ def deep_merge(source: dict, destination: dict) -> dict: return destination -def buffer_to_base64(b: BytesIO, encoding: str = 'utf-8') -> str: - """ Encodes buffered value to base64 string based on provided encoding""" +def buffer_to_base64(b: BytesIO, encoding: str = "utf-8") -> str: + """Encodes buffered value to base64 string based on provided encoding""" b.seek(0) return base64.b64encode(b.read()).decode(encoding) def base64_to_buffer(b64_encoded_string: str) -> BytesIO: - """ Decodes buffered value to base64 string based on provided encoding""" + """Decodes buffered value to base64 string based on provided encoding""" return BytesIO(base64.b64decode(b64_encoded_string)) diff --git a/utils/connection_utils.py b/utils/connection_utils.py index 907ca59f..abcbd88b 100644 --- a/utils/connection_utils.py +++ b/utils/connection_utils.py @@ -29,19 +29,23 @@ from sshtunnel import SSHTunnelForwarder -def create_ssh_tunnel(server_address: str, username: str, password: str = None, - private_key: str = None, - private_key_password: str = None, - remote_bind_address: tuple = ('127.0.0.1', 8080)) -> SSHTunnelForwarder: +def create_ssh_tunnel( + server_address: str, + username: str, + password: str = None, + private_key: str = None, + private_key_password: str = None, + remote_bind_address: tuple = ("127.0.0.1", 8080), +) -> SSHTunnelForwarder: """ - Creates tunneled SSH connection to dedicated address + Creates tunneled SSH connection to dedicated address - :param server_address: ssh server address - :param username: server username - :param password: server password (mutually exclusive with :param private_key) - :param private_key: private key to server (mutually exclusive with :param password) - :param private_key_password: private key password to server (optional) - :param remote_bind_address: remote address to bind to + :param server_address: ssh server address + :param username: server username + :param password: server password (mutually exclusive with :param private_key) + :param private_key: private key to server (mutually exclusive with :param password) + :param private_key_password: private key password to server (optional) + :param remote_bind_address: remote address to bind to """ server = SSHTunnelForwarder( server_address, @@ -49,7 +53,7 @@ def create_ssh_tunnel(server_address: str, username: str, password: str = None, ssh_password=password, ssh_pkey=private_key, ssh_private_key_password=private_key_password, - remote_bind_address=remote_bind_address + remote_bind_address=remote_bind_address, ) server.start() return server diff --git a/utils/database_utils/base_connector.py b/utils/database_utils/base_connector.py index 9b4f9369..607c11af 100644 --- a/utils/database_utils/base_connector.py +++ b/utils/database_utils/base_connector.py @@ -63,9 +63,11 @@ def abort_connection(self): pass @abstractmethod - def exec_raw_query(self, query: Union[str, dict], *args, **kwargs) -> Optional[Union[list, dict]]: + def exec_raw_query( + self, query: Union[str, dict], *args, **kwargs + ) -> Optional[Union[list, dict]]: """ - Executes raw query returns result if needed - :param query: query to execute + Executes raw query returns result if needed + :param query: query to execute """ pass diff --git a/utils/database_utils/db_controller.py b/utils/database_utils/db_controller.py index f3a67443..30307dbc 100644 --- a/utils/database_utils/db_controller.py +++ b/utils/database_utils/db_controller.py @@ -26,20 +26,23 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from utils.database_utils.mongodb_connector import MongoDBConnector -from utils.database_utils.mysql_connector import MySQLConnector - from utils.database_utils.base_connector import DatabaseConnector, DatabaseTypes from utils.logging_utils import LOG +try: + from utils.database_utils.mysql_connector import MySQLConnector +except ModuleNotFoundError: + LOG.info("MySQL dependency was not installed") + MySQLConnector = None + class DatabaseController: """ - Database Controller class acting as a single point of any incoming db connection - Allows to encapsulate particular database type / dialect with abstract database API + Database Controller class acting as a single point of any incoming db connection + Allows to encapsulate particular database type / dialect with abstract database API """ - database_class_mapping = {'mongo': MongoDBConnector, - 'mysql': MySQLConnector} + database_class_mapping = {"mongo": MongoDBConnector, "mysql": MySQLConnector} def __init__(self, config_data: dict): self._connector = None @@ -47,32 +50,34 @@ def __init__(self, config_data: dict): @property def connector(self) -> DatabaseConnector: - """ Database connector instance """ + """Database connector instance""" return self._connector @connector.setter def connector(self, val): if self._connector: - LOG.error('DB Connection is already established - detach connector first') + LOG.error("DB Connection is already established - detach connector first") else: self._connector = val def attach_connector(self, dialect: str): """ - Creates database connector instance base on the given class + Creates database connector instance base on the given class - :param dialect: name of the dialect to for connection + :param dialect: name of the dialect to for connection """ db_class = self.database_class_mapping.get(dialect) if not db_class: - raise AssertionError(f'Invalid dialect provided, supported are: {list(self.database_class_mapping)}') + raise AssertionError( + f"Invalid dialect provided, supported are: {list(self.database_class_mapping)}" + ) self.connector = db_class(config_data=self.config_data) def detach_connector(self, graceful_termination_func: callable = None): """ - Drops current database connector connection + Drops current database connector connection - :param graceful_termination_func: function causing graceful termination of connector instance (optional) + :param graceful_termination_func: function causing graceful termination of connector instance (optional) """ if graceful_termination_func: graceful_termination_func(self._connector) @@ -80,17 +85,17 @@ def detach_connector(self, graceful_termination_func: callable = None): self._connector = None def exec_query(self, query, *args, **kwargs): - """ Executes query on connector's database """ + """Executes query on connector's database""" return self.connector.exec_raw_query(query=query, *args, **kwargs) def connect(self): - """ Connects attached connector """ + """Connects attached connector""" self.connector.create_connection() def disconnect(self): - """ Disconnects attached connector """ + """Disconnects attached connector""" self.connector.abort_connection() def get_type(self) -> DatabaseTypes: - """ Gets type of Database connected to given controller """ + """Gets type of Database connected to given controller""" return self.connector.database_type diff --git a/utils/database_utils/mongo_utils/__init__.py b/utils/database_utils/mongo_utils/__init__.py index 7cba9abc..ca1fe54b 100644 --- a/utils/database_utils/mongo_utils/__init__.py +++ b/utils/database_utils/mongo_utils/__init__.py @@ -26,4 +26,10 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .structures import MongoFilter, MongoCommands, MongoQuery, MongoDocuments, MongoLogicalOperators +from .structures import ( + MongoFilter, + MongoCommands, + MongoQuery, + MongoDocuments, + MongoLogicalOperators, +) diff --git a/utils/database_utils/mongo_utils/queries/__init__.py b/utils/database_utils/mongo_utils/queries/__init__.py new file mode 100644 index 00000000..718d1b00 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/__init__.py @@ -0,0 +1,27 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/utils/database_utils/mongo_utils/queries/constants.py b/utils/database_utils/mongo_utils/queries/constants.py new file mode 100644 index 00000000..95ae07ba --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/constants.py @@ -0,0 +1,54 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from enum import Enum + + +class UserPatterns(Enum): + """Collection of user patterns used for commonly in conversations""" + + UNRECOGNIZED_USER = { + "first_name": "Deleted", + "last_name": "User", + "nickname": "deleted_user", + } + GUEST = {"first_name": "Klat", "last_name": "Guest"} + NEON = { + "first_name": "Neon", + "last_name": "AI", + "nickname": "neon", + "avatar": "neon.webp", + } + GUEST_NANO = {"first_name": "Nano", "last_name": "Guest", "tokens": []} + + +class ConversationSkins: + """List of supported conversation skins""" + + BASE = "base" + PROMPTS = "prompts" diff --git a/utils/database_utils/mongo_utils/queries/dao/__init__.py b/utils/database_utils/mongo_utils/queries/dao/__init__.py new file mode 100644 index 00000000..718d1b00 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/__init__.py @@ -0,0 +1,27 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/utils/database_utils/mongo_utils/queries/dao/abc.py b/utils/database_utils/mongo_utils/queries/dao/abc.py new file mode 100644 index 00000000..ded0bcf9 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/abc.py @@ -0,0 +1,224 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from abc import ABC, abstractmethod + +import pymongo +from neon_sftp import NeonSFTPConnector + +from utils.database_utils import DatabaseController +from utils.database_utils.mongo_utils import ( + MongoQuery, + MongoCommands, + MongoFilter, + MongoLogicalOperators, +) + + +class MongoDocumentDAO(ABC): + def __init__( + self, + db_controller: DatabaseController, + sftp_connector: NeonSFTPConnector = None, + ): + self.db_controller = db_controller + self.sftp_connector = sftp_connector + + @property + @abstractmethod + def document(self): + pass + + def list_contains( + self, + key: str = "_id", + source_set: list = None, + aggregate_result: bool = True, + *args, + **kwargs + ) -> dict[str, list] | list[str]: + """ + Lists items that are members of :param source_set under the :param key + :param key: attribute to query + :param source_set: collection of values to lookup + :param aggregate_result: to apply aggregation by key on result (defaults to True) + :return matching items + """ + items = {} + contains_filter = self._build_contains_filter(key=key, lookup_set=source_set) + if contains_filter: + filters = kwargs.pop("filters", []) + [contains_filter] + items = self.list_items(filters=filters, *args, **kwargs) + if aggregate_result: + items = self.aggregate_items_by_key(key=key, items=items) + return items + + def list_items( + self, + filters: list[MongoFilter] = None, + limit: int = None, + ordering_expression: dict[str, int] | None = None, + result_as_cursor: bool = True, + ) -> dict: + """ + Lists items under provided document belonging to source set of provided column values + + :param filters: filters to consider (optional) + :param limit: limit number of returned attributes (optional) + :param ordering_expression: items ordering expression (optional) + :param result_as_cursor: to return result as cursor (defaults to True) + :returns results of FIND operation over the desired document according to applied filters + """ + result_filters = {} + if limit: + result_filters["limit"] = limit + if ordering_expression: + result_filters["sort"] = [] + for attr, order in ordering_expression.items(): + if order == -1: + result_filters["sort"].append((attr, pymongo.DESCENDING)) + else: + result_filters["sort"].append((attr, pymongo.ASCENDING)) + items = self._execute_query( + command=MongoCommands.FIND_ALL, + filters=filters, + result_filters=result_filters, + result_as_cursor=result_as_cursor, + ) + return items + + def aggregate_items_by_key(self, key: str, items: list[dict]) -> dict: + """ + Aggregates list of dictionaries according to the provided key + :return dictionary mapping id -> list of matching items + """ + aggregated_data = {} + # TODO: consider Mongo DB aggregation API + for item in items: + items_key = item.pop(key, None) + if items_key: + aggregated_data.setdefault(items_key, []).append(item) + return aggregated_data + + def _build_list_items_filter( + self, key, lookup_set, additional_filters: list[MongoFilter] + ) -> list[MongoFilter] | None: + mongo_filters = additional_filters or [] + contains_filter = self._build_contains_filter(key=key, lookup_set=lookup_set) + if contains_filter: + mongo_filters.append(contains_filter) + return mongo_filters + + def _build_contains_filter(self, key, lookup_set) -> MongoFilter | None: + mongo_filter = None + if key and lookup_set: + lookup_set = list(set(lookup_set)) + mongo_filter = MongoFilter( + key=key, + value=lookup_set, + logical_operator=MongoLogicalOperators.IN, + ) + return mongo_filter + + def add_item(self, data: dict) -> bool: + """Inserts provided data into the object's document""" + return self._execute_query(command=MongoCommands.INSERT_ONE, data=data) + + def update_item( + self, filters: list[dict | MongoFilter], data: dict, data_action: str = "set" + ) -> bool: + """Updates provided data into the object's document""" + return self._execute_query( + command=MongoCommands.UPDATE_ONE, + filters=filters, + data=data, + data_action=data_action, + ) + + def update_items( + self, filters: list[dict | MongoFilter], data: dict, data_action: str = "set" + ) -> bool: + """Updates provided data into the object's documents""" + return self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=filters, + data=data, + data_action=data_action, + ) + + def get_item( + self, item_id: str = None, filters: list[dict | MongoFilter] = None + ) -> dict | None: + filters = self._build_item_selection_filters(item_id=item_id, filters=filters) + if not filters: + return + return self._execute_query(command=MongoCommands.FIND_ONE, filters=filters) + + def delete_item( + self, item_id: str = None, filters: list[dict | MongoFilter] = None + ) -> None: + filters = self._build_item_selection_filters(item_id=item_id, filters=filters) + if not filters: + raise + return self._execute_query(command=MongoCommands.DELETE_ONE, filters=filters) + + def _build_item_selection_filters( + self, item_id: str = None, filters: list[dict | MongoFilter] = None + ) -> list[dict | MongoFilter] | None: + if not filters: + filters = [] + if item_id: + if not isinstance(filters, list): + filters = [filters] + filters.append(MongoFilter(key="_id", value=item_id)) + return filters + + def _execute_query( + self, + command: MongoCommands, + filters: list[MongoFilter] = None, + data: dict = None, + data_action: str = "set", + result_filters: dict = None, + result_as_cursor: bool = True, + *args, + **kwargs + ): + return self.db_controller.exec_query( + MongoQuery( + command=command, + document=self.document, + filters=filters, + data=data, + data_action=data_action, + result_filters=result_filters, + ), + as_cursor=result_as_cursor, + *args, + **kwargs + ) diff --git a/utils/database_utils/mongo_utils/queries/dao/chats.py b/utils/database_utils/mongo_utils/queries/dao/chats.py new file mode 100644 index 00000000..12d69d4c --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/chats.py @@ -0,0 +1,124 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re +from typing import Union, List + +from bson import ObjectId + +from utils.database_utils.mongo_utils import ( + MongoDocuments, + MongoCommands, + MongoFilter, + MongoLogicalOperators, +) +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO +from utils.logging_utils import LOG + + +class ChatsDAO(MongoDocumentDAO): + @property + def document(self): + return MongoDocuments.CHATS + + def get_conversation_data( + self, + search_str: Union[list, str], + column_identifiers: List[str] = None, + limit: int = 1, + allow_regex_search: bool = False, + include_private: bool = False, + requested_user_id: str = None, + ) -> Union[None, dict]: + """ + Gets matching conversation data + :param search_str: search string to lookup + :param column_identifiers: desired column identifiers to look up + :param limit: limit found conversations + :param allow_regex_search: to allow search for matching entries that CONTAIN :param search_str + :param include_private: to include private conversations (defaults to False) + :param requested_user_id: id of the requested user (defaults to None) - used to find owned private conversations + """ + if isinstance(search_str, str): + search_str = [search_str] + if not column_identifiers: + column_identifiers = ["_id", "conversation_name"] + or_expression = [] + for _keyword in [item for item in search_str if item is not None]: + for identifier in column_identifiers: + if identifier == "_id" and isinstance(_keyword, str): + try: + or_expression.append({identifier: ObjectId(_keyword)}) + except Exception as ex: + LOG.debug(f"Failed to add {_keyword = }| {ex = }") + if allow_regex_search: + if not _keyword: + expression = ".*" + else: + expression = f".*{_keyword}.*" + _keyword = re.compile(expression, re.IGNORECASE) + or_expression.append({identifier: _keyword}) + + chats = self.list_items( + filters=[ + MongoFilter( + value=or_expression, logical_operator=MongoLogicalOperators.OR + ) + ], + limit=limit, + result_as_cursor=False, + include_private=include_private, + ) + for chat in chats: + chat["_id"] = str(chat["_id"]) + if chats and limit == 1: + chats = chats[0] + return chats + + def list_items( + self, + filters: list[MongoFilter] = None, + limit: int = None, + result_as_cursor: bool = True, + include_private: bool = False, + requested_user_id: str = None, + ) -> dict: + filters = filters or [] + if not include_private: + expression = {"is_private": False} + if requested_user_id: + expression["user_id"] = requested_user_id + expression = MongoFilter( + value=expression, logical_operator=MongoLogicalOperators.OR + ) + filters.append(expression) + return super().list_items( + filters=filters, + limit=limit, + result_as_cursor=result_as_cursor, + ) diff --git a/utils/database_utils/mongo_utils/queries/dao/configs.py b/utils/database_utils/mongo_utils/queries/dao/configs.py new file mode 100644 index 00000000..273d39f1 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/configs.py @@ -0,0 +1,57 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from chat_server.server_utils.exceptions import ItemNotFoundException +from utils.database_utils.mongo_utils import MongoDocuments, MongoFilter +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO +from utils.logging_utils import LOG + + +class ConfigsDAO(MongoDocumentDAO): + @property + def document(self): + return MongoDocuments.CONFIGS + + def get_by_name(self, config_name: str, version: str = "latest"): + filters = [ + MongoFilter(key="name", value=config_name), + MongoFilter(key="version", value=version), + ] + item = self.get_item(filters=filters) + if item: + return item.get("value") + else: + LOG.error(f"Failed to get config by {config_name = }, {version = }") + raise ItemNotFoundException + + def update_by_name(self, config_name: str, data: dict, version: str = "latest"): + filters = [ + MongoFilter(key="name", value=config_name), + MongoFilter(key="version", value=version), + ] + return self.update_item(filters=filters, data={"value": data}) diff --git a/utils/database_utils/mongo_utils/queries/dao/personas.py b/utils/database_utils/mongo_utils/queries/dao/personas.py new file mode 100644 index 00000000..c9c99628 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/personas.py @@ -0,0 +1,37 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from utils.database_utils.mongo_utils import ( + MongoDocuments, +) +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO + + +class PersonasDAO(MongoDocumentDAO): + @property + def document(self): + return MongoDocuments.PERSONAS diff --git a/utils/database_utils/mongo_utils/queries/dao/prompts.py b/utils/database_utils/mongo_utils/queries/dao/prompts.py new file mode 100644 index 00000000..f1ec6c55 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/prompts.py @@ -0,0 +1,192 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from enum import IntEnum +from typing import List + +import pymongo + +from utils.database_utils.mongo_utils import ( + MongoDocuments, + MongoCommands, + MongoFilter, + MongoLogicalOperators, +) +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO +from utils.logging_utils import LOG + + +class PromptStates(IntEnum): + """Prompt States""" + + IDLE = 0 # No active prompt + RESP = 1 # Gathering responses to prompt + DISC = 2 # Discussing responses + VOTE = 3 # Voting on responses + PICK = 4 # Proctor will select response + WAIT = ( + 5 # Bot is waiting for the proctor to ask them to respond (not participating) + ) + + +class PromptsDAO(MongoDocumentDAO): + @property + def document(self): + return MongoDocuments.PROMPTS + + def set_completed(self, prompt_id: str, prompt_context: dict): + prompt_summary_keys = ["winner", "votes_per_submind"] + prompt_summary_agg = { + f"data.{k}": v + for k, v in prompt_context.items() + if k in prompt_summary_keys + } + prompt_summary_agg["is_completed"] = "1" + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter(key="_id", value=prompt_id), + data=prompt_summary_agg, + ) + + def get_prompts( + self, + cid: str, + limit: int = 100, + id_from: str = None, + prompt_ids: List[str] = None, + created_from: int = None, + ) -> List[dict]: + """ + Fetches prompt data out of conversation data + + :param cid: target conversation id + :param limit: number of prompts to fetch + :param id_from: prompt id to start from + :param prompt_ids: prompt ids to fetch + :param fetch_user_data: to fetch user data in the + :param created_from: timestamp to filter messages from + + :returns list of matching prompt data along with matching messages and users + """ + filters = [MongoFilter("cid", cid)] + if id_from: + checkpoint_prompt = self._execute_query( + command=MongoCommands.FIND_ONE, + filters=MongoFilter("_id", id_from), + ) + if checkpoint_prompt: + filters.append( + MongoFilter( + "created_on", + checkpoint_prompt["created_on"], + MongoLogicalOperators.LT, + ) + ) + if prompt_ids: + if isinstance(prompt_ids, str): + prompt_ids = [prompt_ids] + filters.append(MongoFilter("_id", prompt_ids, MongoLogicalOperators.IN)) + if created_from: + filters.append( + MongoFilter("created_on", created_from, MongoLogicalOperators.GT) + ) + matching_prompts = self._execute_query( + command=MongoCommands.FIND_ALL, + filters=filters, + result_filters={ + "sort": [("created_on", pymongo.DESCENDING)], + "limit": limit, + }, + result_as_cursor=False, + ) + return matching_prompts + + def add_shout_to_prompt( + self, prompt_id: str, user_id: str, message_id: str, prompt_state: PromptStates + ) -> bool: + prompt = self.get_item(item_id=prompt_id) + if prompt and prompt["is_completed"] == "0": + if ( + user_id not in prompt.get("data", {}).get("participating_subminds", []) + and prompt_state == PromptStates.RESP + ): + self._add_participant(prompt_id=prompt_id, user_id=user_id) + prompt_state_structure = self._get_prompt_state_structure( + prompt_state=prompt_state, user_id=user_id, message_id=message_id + ) + if not prompt_state_structure: + LOG.warning( + f"Prompt State - {prompt_state.name} has no db store properties" + ) + else: + store_key = prompt_state_structure["key"] + store_type = prompt_state_structure["type"] + store_data = prompt_state_structure["data"] + if user_id in list(prompt.get("data", {}).get(store_key, {})): + LOG.error( + f"user_id={user_id} tried to duplicate data to prompt_id={prompt_id}, store_key={store_key}" + ) + else: + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter(key="_id", value=prompt_id), + data={f"data.{store_key}": store_data}, + data_action="push" if store_type == list else "set", + ) + return True + + def _add_participant(self, prompt_id: str, user_id: str): + return self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter(key="_id", value=prompt_id), + data={"data.participating_subminds": user_id}, + data_action="push", + ) + + @staticmethod + def _get_prompt_state_structure( + prompt_state: PromptStates, user_id: str, message_id: str + ): + prompt_state_mapping = { + # PromptStates.WAIT: {'key': 'participating_subminds', 'type': list}, + PromptStates.RESP: { + "key": f"proposed_responses.{user_id}", + "type": dict, + "data": message_id, + }, + PromptStates.DISC: { + "key": f"submind_opinions.{user_id}", + "type": dict, + "data": message_id, + }, + PromptStates.VOTE: { + "key": f"votes.{user_id}", + "type": dict, + "data": message_id, + }, + } + return prompt_state_mapping.get(prompt_state) diff --git a/utils/database_utils/mongo_utils/queries/dao/shouts.py b/utils/database_utils/mongo_utils/queries/dao/shouts.py new file mode 100644 index 00000000..66ac246c --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/shouts.py @@ -0,0 +1,198 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from typing import List, Dict + +from ovos_utils import LOG +from pymongo import UpdateOne + +from utils.common import buffer_to_base64 +from utils.database_utils.mongo_utils import ( + MongoDocuments, + MongoCommands, + MongoFilter, + MongoLogicalOperators, + MongoQuery, +) +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO + + +class ShoutsDAO(MongoDocumentDAO): + @property + def document(self): + return MongoDocuments.SHOUTS + + def fetch_shouts(self, shout_ids: List[str] = None) -> List[dict]: + """ + Fetches shout data from provided shouts list + :param shout_ids: list of shout ids to fetch + + :returns Data from requested shout ids along with matching user data + """ + return self.list_contains( + source_set=shout_ids, aggregate_result=False, result_as_cursor=False + ) + + def fetch_messages_from_prompt(self, prompt: dict): + """Fetches message ids detected in provided prompt""" + prompt_data = prompt["data"] + message_ids = [] + for column in ( + "proposed_responses", + "submind_opinions", + "votes", + ): + message_ids.extend(list(prompt_data.get(column, {}).values())) + return self.list_contains(source_set=message_ids) + + def fetch_audio_data(self, message_id: str) -> str | None: + """ + Fetches audio data from message + :param message_id: message id to fetch + :returns base64 encoded audio data if any + """ + shout_data = self.get_item(item_id=message_id) + if not shout_data: + LOG.warning("Requested shout does not exist") + elif shout_data.get("is_audio") != "1": + LOG.warning("Failed to fetch audio data from non-audio message") + else: + + file_location = f'audio/{shout_data["message_text"]}' + LOG.info(f"Fetching existing file from: {file_location}") + fo = self.sftp_connector.get_file_object(file_location) + if fo.getbuffer().nbytes > 0: + return buffer_to_base64(fo) + else: + LOG.error( + f"Empty buffer received while fetching audio of message id = {message_id}" + ) + return "" + + def save_translations(self, translation_mapping: dict) -> Dict[str, List[str]]: + """ + Saves translations in DB + :param translation_mapping: mapping of cid to desired translation language + :returns dictionary containing updated shouts (those which were translated to English) + """ + updated_shouts = {} + for cid, shout_data in translation_mapping.items(): + translations = shout_data.get("shouts", {}) + bulk_update = [] + shouts = self._execute_query( + command=MongoCommands.FIND_ALL, + filters=MongoFilter( + "_id", list(translations), MongoLogicalOperators.IN + ), + result_as_cursor=False, + ) + for shout_id, translation in translations.items(): + matching_instance = None + for shout in shouts: + if shout["_id"] == shout_id: + matching_instance = shout + break + filter_expression = MongoFilter("_id", shout_id) + if not matching_instance.get("translations"): + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=filter_expression, + data={"translations": {}}, + ) + # English is the default language, so it is treated as message text + if shout_data.get("lang", "en") == "en": + updated_shouts.setdefault(cid, []).append(shout_id) + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=filter_expression, + data={"message_lang": "en"}, + ) + bulk_update_setter = { + "message_text": translation, + "message_lang": "en", + } + else: + bulk_update_setter = { + f'translations.{shout_data["lang"]}': translation + } + # TODO: make a convenience wrapper to make bulk insertion easier to follow + bulk_update.append( + UpdateOne({"_id": shout_id}, {"$set": bulk_update_setter}) + ) + if bulk_update: + self._execute_query( + command=MongoCommands.BULK_WRITE, + data=bulk_update, + ) + return updated_shouts + + def save_tts_response( + self, shout_id, audio_data: str, lang: str = "en", gender: str = "female" + ) -> bool: + """ + Saves TTS Response under corresponding shout id + + :param shout_id: message id to consider + :param audio_data: base64 encoded audio data received + :param lang: language of speech (defaults to English) + :param gender: language gender (defaults to female) + + :return bool if saving was successful + """ + + audio_file_name = f"{shout_id}_{lang}_{gender}.wav" + try: + self.sftp_connector.put_file_object( + file_object=audio_data, save_to=f"audio/{audio_file_name}" + ) + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter("_id", shout_id), + data={f"audio.{lang}.{gender}": audio_file_name}, + ) + operation_success = True + except Exception as ex: + LOG.error(f"Failed to save TTS response to db - {ex}") + operation_success = False + return operation_success + + def save_stt_response(self, shout_id, message_text: str, lang: str = "en"): + """ + Saves STT Response under corresponding shout id + + :param shout_id: message id to consider + :param message_text: STT result transcript + :param lang: language of speech (defaults to English) + """ + try: + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter("_id", shout_id), + data={f"transcripts.{lang}": message_text}, + ) + except Exception as ex: + LOG.error(f"Failed to save STT response to db - {ex}") diff --git a/utils/database_utils/mongo_utils/queries/dao/users.py b/utils/database_utils/mongo_utils/queries/dao/users.py new file mode 100644 index 00000000..d96e07c1 --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/dao/users.py @@ -0,0 +1,206 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import copy +from time import time +from typing import Union + +from utils.common import generate_uuid, get_hash +from utils.logging_utils import LOG +from utils.database_utils.mongo_utils import ( + MongoCommands, + MongoDocuments, + MongoQuery, + MongoFilter, +) +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO +from utils.database_utils.mongo_utils.queries.constants import UserPatterns + + +class UsersDAO(MongoDocumentDAO): + + _default_user_preferences = {"tts": {}, "chat_language_mapping": {}} + + @property + def document(self): + return MongoDocuments.USERS + + def get_user(self, user_id=None, nickname=None) -> Union[dict, None]: + """ + Gets user data based on provided params + :param user_id: target user id + :param nickname: target user nickname + """ + if not (user_id or nickname): + LOG.warning("Neither user_id nor nickname was provided") + return + filter_data = {} + if user_id: + filter_data["_id"] = user_id + if nickname: + filter_data["nickname"] = nickname + user = self.get_item(filters=filter_data) + if user and not user.get("preferences"): + user["preferences"] = self._default_user_preferences + self.set_preferences( + user_id=user_id, preferences_mapping=user["preferences"] + ) + return user + + def fetch_users_from_prompt(self, prompt: dict) -> dict[str, list]: + """Fetches user ids detected in provided prompt""" + prompt_data = prompt["data"] + user_ids = prompt_data.get("participating_subminds", []) + return self.list_contains(source_set=user_ids) + + @staticmethod + def create_from_pattern( + source: UserPatterns, override_defaults: dict = None + ) -> dict: + """ + Creates user record based on provided pattern from UserPatterns + + :param source: source pattern from UserPatterns + :param override_defaults: to override default values (optional) + :returns user data populated with default values where necessary + """ + if not override_defaults: + override_defaults = {} + + matching_data = {**copy.deepcopy(source.value), **override_defaults} + + matching_data.setdefault("_id", generate_uuid(length=20)) + matching_data.setdefault("password", get_hash(generate_uuid())) + matching_data.setdefault("date_created", int(time())) + matching_data.setdefault("is_tmp", True) + + return matching_data + + def get_neon_data(self, skill_name: str = "neon") -> dict: + """ + Gets a user profile for the user 'Neon' and adds it to the users db if not already present + + :param db_controller: db controller instance + :param skill_name: Neon Skill to consider (defaults to neon - Neon Assistant) + + :return Neon AI data + """ + neon_data = self.get_user(nickname=skill_name) + if not neon_data: + neon_data = self._register_neon_skill_user(skill_name=skill_name) + return neon_data + + def _register_neon_skill_user(self, skill_name: str): + last_name = "AI" if skill_name == "neon" else skill_name.capitalize() + nickname = skill_name + neon_data = self.create_from_pattern( + source=UserPatterns.NEON, + override_defaults={"last_name": last_name, "nickname": nickname}, + ) + self.add_item(data=neon_data) + return neon_data + + def get_bot_data(self, user_id: str, context: dict = None) -> dict: + """ + Gets a user profile for the requested bot instance and adds it to the users db if not already present + + :param user_id: user id of the bot provided + :param context: context with additional bot information (optional) + + :return Matching bot data + """ + if not context: + context = {} + nickname = user_id.split("-")[0] + bot_data = self.get_user(nickname=nickname) + if not bot_data: + bot_data = self._create_bot(nickname=nickname, context=context) + elif not bot_data.get("is_bot") == "1": + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter("_id", bot_data["_id"]), + data={"is_bot": "1"}, + ) + return bot_data + + def _create_bot(self, nickname: str, context: dict) -> dict: + bot_data = dict( + _id=generate_uuid(length=20), + first_name="Bot", + last_name=context.get("last_name", nickname.capitalize()), + avatar=context.get("avatar", ""), + password=get_hash(generate_uuid()), + nickname=nickname, + is_bot="1", + full_nickname=nickname, # we treat each bot instance with equal nickname as same instance + date_created=int(time()), + is_tmp=False, + ) + self.add_item(data=bot_data) + return bot_data + + def set_preferences(self, user_id, preferences_mapping: dict): + """Sets user preferences for specified user according to preferences mapping""" + if user_id and preferences_mapping: + try: + update_mapping = { + f"preferences.{key}": val + for key, val in preferences_mapping.items() + } + self._execute_query( + command=MongoCommands.UPDATE_MANY, + filters=MongoFilter("_id", user_id), + data=update_mapping, + ) + except Exception as ex: + LOG.error(f"Failed to update preferences for user_id={user_id} - {ex}") + + def create_guest(self, nano_token: str = None) -> dict: + """ + Creates unauthorized user and sets its credentials to cookies + + :param nano_token: nano token to append to user on creation + + :returns: generated UserData + """ + + guest_nickname = f"guest_{generate_uuid(length=8)}" + + if nano_token: + new_user = self.create_from_pattern( + source=UserPatterns.GUEST_NANO, + override_defaults=dict(nickname=guest_nickname, tokens=[nano_token]), + ) + else: + new_user = self.create_from_pattern( + source=UserPatterns.GUEST, + override_defaults=dict(nickname=guest_nickname), + ) + # TODO: consider adding partial TTL index for guest users + # https://www.mongodb.com/docs/manual/core/index-ttl/ + self.add_item(data=new_user) + return new_user diff --git a/utils/database_utils/mongo_utils/queries/mongo_queries.py b/utils/database_utils/mongo_utils/queries/mongo_queries.py new file mode 100644 index 00000000..0cf94b9b --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/mongo_queries.py @@ -0,0 +1,240 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from time import time +from typing import List, Tuple + +from ..structures import MongoFilter +from .constants import UserPatterns, ConversationSkins +from .wrapper import MongoDocumentsAPI +from utils.logging_utils import LOG + + +def get_translations(translation_mapping: dict) -> Tuple[dict, dict]: + """ + Gets translation from db based on provided mapping + + :param translation_mapping: mapping of cid to desired translation language + + :return translations fetched from db + """ + populated_translations = {} + missing_translations = {} + for cid, cid_data in translation_mapping.items(): + lang = cid_data.get("lang", "en") + shout_ids = cid_data.get("shouts", []) + conversation_data = MongoDocumentsAPI.CHATS.get_conversation_data( + search_str=cid + ) + if not conversation_data: + LOG.error(f"Failed to fetch conversation data - {cid}") + continue + shout_data = fetch_shout_data( + conversation_data=conversation_data, + shout_ids=shout_ids, + fetch_senders=False, + ) + shout_lang = "en" + if len(shout_data) == 1: + shout_lang = shout_data[0].get("message_lang", "en") + for shout in shout_data: + message_text = shout.get("message_text") + if shout_lang != "en" and lang == "en": + shout_text = message_text + else: + shout_text = shout.get("translations", {}).get(lang) + if shout_text and lang != "en": + populated_translations.setdefault(cid, {}).setdefault("shouts", {})[ + shout["_id"] + ] = shout_text + elif message_text: + missing_translations.setdefault(cid, {}).setdefault("shouts", {})[ + shout["_id"] + ] = message_text + if missing_translations.get(cid): + missing_translations[cid]["lang"] = lang + missing_translations[cid]["source_lang"] = shout_lang + return populated_translations, missing_translations + + +def fetch_message_data( + skin: ConversationSkins, + conversation_data: dict, + limit: int = 100, + fetch_senders: bool = True, + creation_time_filter: MongoFilter = None, +) -> list[dict]: + """Fetches message data based on provided conversation skin""" + message_data = fetch_shout_data( + conversation_data=conversation_data, + fetch_senders=fetch_senders, + limit=limit, + creation_time_filter=creation_time_filter, + ) + for message in message_data: + message["message_type"] = "plain" + if skin == ConversationSkins.PROMPTS: + detected_prompts = list( + set(item.get("prompt_id") for item in message_data if item.get("prompt_id")) + ) + prompt_data = fetch_prompt_data( + cid=conversation_data["_id"], prompt_ids=detected_prompts + ) + if prompt_data: + detected_prompt_ids = [] + for prompt in prompt_data: + prompt["message_type"] = "prompt" + detected_prompt_ids.append(prompt["_id"]) + message_data = [ + message + for message in message_data + if message.get("prompt_id") not in detected_prompt_ids + ] + message_data.extend(prompt_data) + return sorted(message_data, key=lambda shout: int(shout["created_on"])) + + +def fetch_shout_data( + conversation_data: dict, + limit: int = 100, + fetch_senders: bool = True, + creation_time_filter: MongoFilter = None, + shout_ids: list = None, +): + query_filters = [MongoFilter(key="cid", value=conversation_data["_id"])] + if creation_time_filter: + query_filters.append(creation_time_filter) + if shout_ids: + shouts = MongoDocumentsAPI.SHOUTS.list_contains( + source_set=shout_ids, + aggregate_result=False, + result_as_cursor=False, + filters=query_filters, + limit=limit, + ordering_expression={"created_on": -1}, + ) + else: + shouts = MongoDocumentsAPI.SHOUTS.list_items( + filters=query_filters, + limit=limit, + ordering_expression={"created_on": -1}, + result_as_cursor=False, + ) + if shouts and fetch_senders: + shouts = _attach_senders_data(shouts=shouts) + return sorted(shouts, key=lambda user_shout: int(user_shout["created_on"])) + + +def _attach_senders_data(shouts: list[dict]): + result = list() + users_from_shouts = MongoDocumentsAPI.USERS.list_contains( + source_set=[shout["user_id"] for shout in shouts] + ) + for shout in shouts: + matching_user = users_from_shouts.get(shout["user_id"], {}) + if not matching_user: + matching_user = MongoDocumentsAPI.USERS.create_from_pattern( + UserPatterns.UNRECOGNIZED_USER + ) + else: + matching_user = matching_user[0] + matching_user.pop("password", None) + matching_user.pop("is_tmp", None) + shout["message_id"] = shout["_id"] + shout_data = {**shout, **matching_user} + result.append(shout_data) + return result + + +def fetch_prompt_data( + cid: str, + limit: int = 100, + id_from: str = None, + prompt_ids: List[str] = None, + fetch_user_data: bool = False, + created_from: int = None, +) -> List[dict]: + """ + Fetches prompt data out of conversation data + + :param cid: target conversation id + :param limit: number of prompts to fetch + :param id_from: prompt id to start from + :param prompt_ids: prompt ids to fetch + :param fetch_user_data: to fetch user data in the + :param created_from: timestamp to filter messages from + + :returns list of matching prompt data along with matching messages and users + """ + matching_prompts = MongoDocumentsAPI.PROMPTS.get_prompts( + cid=cid, + limit=limit, + id_from=id_from, + prompt_ids=prompt_ids, + created_from=created_from, + ) + for prompt in matching_prompts: + prompt["user_mapping"] = MongoDocumentsAPI.USERS.fetch_users_from_prompt(prompt) + prompt["message_mapping"] = MongoDocumentsAPI.SHOUTS.fetch_messages_from_prompt( + prompt + ) + if fetch_user_data: + for user in prompt.get("data", {}).get("participating_subminds", []): + try: + nick = prompt["user_mapping"][user][0]["nickname"] + except KeyError: + LOG.warning( + f'user_id - "{user}" was not detected setting it as nick' + ) + nick = user + for k in ( + "proposed_responses", + "submind_opinions", + "votes", + ): + msg_id = prompt["data"][k].pop(user, "") + if msg_id: + prompt["data"][k][nick] = ( + prompt["message_mapping"] + .get(msg_id, [{}])[0] + .get("message_text") + or msg_id + ) + prompt["data"]["participating_subminds"] = [ + prompt["user_mapping"][x][0]["nickname"] + for x in prompt["data"]["participating_subminds"] + ] + return sorted(matching_prompts, key=lambda _prompt: int(_prompt["created_on"])) + + +def add_shout(data: dict): + """Records shout data and pushes its id to the relevant conversation flow""" + MongoDocumentsAPI.SHOUTS.add_item(data=data) + MongoDocumentsAPI.CHATS.update_item( + filters=MongoFilter(key="_id", value=data["cid"]), + data={"last_shout_ts": int(time())}, + ) diff --git a/utils/database_utils/mongo_utils/queries/wrapper.py b/utils/database_utils/mongo_utils/queries/wrapper.py new file mode 100644 index 00000000..26a04edc --- /dev/null +++ b/utils/database_utils/mongo_utils/queries/wrapper.py @@ -0,0 +1,72 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# DAO Imports +from utils.database_utils.mongo_utils.queries.dao.abc import MongoDocumentDAO +from utils.database_utils.mongo_utils.queries.dao.configs import ConfigsDAO +from utils.database_utils.mongo_utils.queries.dao.users import UsersDAO +from utils.database_utils.mongo_utils.queries.dao.chats import ChatsDAO +from utils.database_utils.mongo_utils.queries.dao.shouts import ShoutsDAO +from utils.database_utils.mongo_utils.queries.dao.prompts import PromptsDAO +from utils.database_utils.mongo_utils.queries.dao.personas import PersonasDAO + + +class MongoDAOGateway(type): + def __getattribute__(self, name): + item = super().__getattribute__(name) + try: + if issubclass(item, MongoDocumentDAO): + item = item( + db_controller=self.db_controller, sftp_connector=self.sftp_connector + ) + except: + pass + return item + + +class MongoDocumentsAPI(metaclass=MongoDAOGateway): + """ + Wrapper for DB commands execution + If getting attribute is triggered, initialises relevant instance of DAO handler and returns it + """ + + db_controller = None + sftp_connector = None + + USERS = UsersDAO + CHATS = ChatsDAO + SHOUTS = ShoutsDAO + PROMPTS = PromptsDAO + PERSONAS = PersonasDAO + CONFIGS = ConfigsDAO + + @classmethod + def init(cls, db_controller, sftp_connector=None): + """Inits Singleton with specified database controller""" + cls.db_controller = db_controller + cls.sftp_connector = sftp_connector diff --git a/utils/database_utils/mongo_utils/structures.py b/utils/database_utils/mongo_utils/structures.py index a6dff7ed..b14bb1b5 100644 --- a/utils/database_utils/mongo_utils/structures.py +++ b/utils/database_utils/mongo_utils/structures.py @@ -33,82 +33,101 @@ class MongoCommands(Enum): - """Enumeration of possible commands supported by MongoDB API """ + """Enumeration of possible commands supported by MongoDB API""" + # Selection Operations - FIND = 'find' - FIND_ALL = 'find' - FIND_ONE = 'find_one' + FIND = "find" + FIND_ALL = "find" + FIND_ONE = "find_one" # Insertion operations ## Basic Insertion - INSERT = 'insert_many' - INSERT_ONE = 'insert_one' - INSERT_MANY = 'insert_many' + INSERT = "insert_many" + INSERT_ONE = "insert_one" + INSERT_MANY = "insert_many" ## Bulk Write - BULK_WRITE = 'bulk_write' + BULK_WRITE = "bulk_write" # Deletion Operations - DELETE = 'delete_many' - DELETE_ONE = 'delete_one' - DELETE_MANY = 'delete_many' + DELETE = "delete_many" + DELETE_ONE = "delete_one" + DELETE_MANY = "delete_many" # Update operation - UPDATE = 'update' + UPDATE = "update_many" + UPDATE_MANY = "update_many" + UPDATE_ONE = "update_one" class MongoDocuments(Enum): - """ Supported Mongo DB documents """ - USERS = 'users' - USER_PREFERENCES = 'user_preferences' - CHATS = 'chats' - SHOUTS = 'shouts' - PROMPTS = 'prompts' - TEST = 'test' + """Supported Mongo DB documents""" + + USERS = "users" + USER_PREFERENCES = "user_preferences" + CHATS = "chats" + SHOUTS = "shouts" + PROMPTS = "prompts" + PERSONAS = "personas" + CONFIGS = "configs" + TEST = "test" class MongoLogicalOperators(Enum): - """ Enumeration of supported logical operators""" - EQ = 'equal' - LT = 'lt' - LTE = 'lte' - GT = 'gt' - GTE = 'gte' - IN = 'in' - ALL = 'all' - ANY = 'any' - OR = 'or' - AND = 'and' + """Enumeration of supported logical operators""" + + EQ = "equal" + LT = "lt" + LTE = "lte" + GT = "gt" + GTE = "gte" + IN = "in" + ALL = "all" + OR = "or" + AND = "and" @dataclass class MongoFilter: - """ Class representing logical conditions supported by Mongo""" - key: str = '' + """Class representing logical conditions supported by Mongo""" + + key: str = "" value: Any = None logical_operator: MongoLogicalOperators = MongoLogicalOperators.EQ def to_dict(self): - """ Converts object to the dictionary """ + """Converts object to the dictionary""" if self.logical_operator.value == MongoLogicalOperators.EQ.value: return {self.key: self.value} - elif self.logical_operator.value in (MongoLogicalOperators.OR.value, MongoLogicalOperators.AND.value,): - return {f'${self.logical_operator.value}': self.value} + elif self.logical_operator.value in ( + MongoLogicalOperators.OR.value, + MongoLogicalOperators.AND.value, + ): + return {f"${self.logical_operator.value}": self.value} else: - return {self.key: {f'${self.logical_operator.value}': self.value}} + return {self.key: {f"${self.logical_operator.value}": self.value}} @dataclass class MongoQuery: - """ Object to represent Mongo Query data""" + """Object to represent Mongo Query data""" + command: MongoCommands document: MongoDocuments filters: List[Union[dict, MongoFilter]] = None data: dict = None - data_action: str = 'set' - result_filters: dict = None # To apply some filters on the resulting data e.g. limit or sort + data_action: str = "set" + result_filters: dict = ( + None # To apply some filters on the resulting data e.g. limit or sort + ) def build_filters(self): - """ Builds filters for Mongo Query """ + """Builds filters for Mongo Query""" res = {} if self.filters: - if any(isinstance(self.filters, _type) for _type in (MongoFilter, dict,)): + if any( + isinstance(self.filters, _type) + for _type in ( + MongoFilter, + dict, + ) + ): self.filters = [self.filters] for condition in self.filters: if isinstance(condition, MongoFilter): @@ -117,16 +136,22 @@ def build_filters(self): return res def build_setter(self) -> dict: - """ Builds setter for Mongo Query """ + """Builds setter for Mongo Query""" res = None - if self.command.value == MongoCommands.UPDATE.value: - res = {f'${self.data_action.lower()}': self.data} - elif self.command.value in (MongoCommands.INSERT_ONE.value, MongoCommands.BULK_WRITE.value,): + if self.command.value in ( + MongoCommands.UPDATE_MANY.value, + MongoCommands.UPDATE_ONE.value, + ): + res = {f"${self.data_action.lower()}": self.data} + elif self.command.value in ( + MongoCommands.INSERT_ONE.value, + MongoCommands.BULK_WRITE.value, + ): res = self.data return res def to_dict(self) -> dict: - """ Converts object to dictionary """ + """Converts object to dictionary""" data = list() filters = self.build_filters() if filters: @@ -135,10 +160,12 @@ def to_dict(self) -> dict: if setter: data.append(setter) data = tuple(data) - res = dict(document=self.document.value, - command=self.command.value, - data=data, - filters=self.result_filters) + res = dict( + document=self.document.value, + command=self.command.value, + data=data, + filters=self.result_filters, + ) if self.command.value == MongoCommands.INSERT_MANY.value: - res['documents'] = res.pop('data') + res["documents"] = res.pop("data") return res diff --git a/utils/database_utils/mongo_utils/user_utils.py b/utils/database_utils/mongo_utils/user_utils.py index a4d31f55..76b1bf5c 100644 --- a/utils/database_utils/mongo_utils/user_utils.py +++ b/utils/database_utils/mongo_utils/user_utils.py @@ -31,15 +31,18 @@ def get_existing_nicks_to_id(mongo_controller) -> dict: """ - Gets existing nicknames to id mapping from provided mongo db + Gets existing nicknames to id mapping from provided mongo db - :param mongo_controller: controller to active mongo collection + :param mongo_controller: controller to active mongo collection - :returns List of dict containing filtered items + :returns List of dict containing filtered items """ - retrieved_data = list(mongo_controller.exec_query(query=dict(document='users', command='find', data={}))) + retrieved_data = list( + mongo_controller.exec_query( + query=dict(document="users", command="find", data={}) + ) + ) - LOG.info(f'Retrieved {len(retrieved_data)} existing nicknames from new db') - - return {record['nickname']: record['_id'] for record in list(retrieved_data)} + LOG.info(f"Retrieved {len(retrieved_data)} existing nicknames from new db") + return {record["nickname"]: record["_id"] for record in list(retrieved_data)} diff --git a/utils/database_utils/mongodb_connector.py b/utils/database_utils/mongodb_connector.py index 6af83daa..d40cd5e8 100644 --- a/utils/database_utils/mongodb_connector.py +++ b/utils/database_utils/mongodb_connector.py @@ -25,57 +25,68 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + from typing import Optional, Union from pymongo import MongoClient from utils.database_utils.base_connector import DatabaseConnector, DatabaseTypes -from utils.database_utils.mongo_utils.structures import MongoQuery +from utils.database_utils.mongo_utils.structures import MongoQuery, MongoCommands +from utils.logging_utils import LOG class MongoDBConnector(DatabaseConnector): - """ Connector implementing interface for interaction with Mongo DB API """ + """Connector implementing interface for interaction with Mongo DB API""" - mongo_recognised_commands = ('insert_one', 'insert_many', 'delete_one', 'bulk_write', - 'delete_many', 'find', 'find_one', 'update') + mongo_recognised_commands = set(cmd.value for cmd in MongoCommands) @property def database_type(self) -> DatabaseTypes: return DatabaseTypes.NOSQL def create_connection(self): - database = self.config_data.pop('database') + database = self.config_data.pop("database") self._cnx = MongoClient(**self.config_data)[database] def abort_connection(self): self._cnx.close() - def exec_raw_query(self, query: Union[MongoQuery, dict], as_cursor: bool = True, *args, **kwargs) -> Optional[dict]: + def exec_raw_query( + self, query: Union[MongoQuery, dict], as_cursor: bool = True, *args, **kwargs + ) -> Optional[dict]: """ - Generic method for executing query over mongo db + Generic method for executing query over mongo db - :param query: dictionary with query instruction has to contain following parameters: - - "document": target document for query - - "command": member of the self.mongo_recognised_commands - - "data": query data, represented as a tuple of (List[dict] if bulk insert, dict otherwise) - - "filters": mapping of filters to apply in chain after the main command (e.g. limit or sort ) - :param as_cursor: to return query result as cursor + :param query: dictionary with query instruction has to contain following parameters: + - "document": target document for query + - "command": member of the self.mongo_recognised_commands + - "data": query data, represented as a tuple of (List[dict] if bulk insert, dict otherwise) + - "filters": mapping of filters to apply in chain after the main command (e.g. limit or sort ) + :param as_cursor: to return query result as cursor - :returns result of the query execution if any + :returns result of the query execution if any """ if isinstance(query, MongoQuery): query = query.to_dict() - received_command = query.get('command', 'find') + received_command = query.get("command", "find") if received_command not in self.mongo_recognised_commands: - raise NotImplementedError(f'Query command: {received_command} is not supported, ' - f'please use one of the following: ' - f'{self.mongo_recognised_commands}') - db_command = getattr(self.connection[query.get('document')], query.get('command')) - if not isinstance(query.get('data'), tuple): - # LOG.warning('Received wrong param type for query data, using default conversion to tuple') - query['data'] = (query.get('data', {}),) - query_output = db_command(*query.get('data'), *args, **kwargs) - if received_command == 'find': - filters = query.get('filters', {}) + raise NotImplementedError( + f"Query command: {received_command} is not supported, " + f"please use one of the following: " + f"{self.mongo_recognised_commands}" + ) + db_command = getattr(self.connection[query.get("document")], + received_command) + if not isinstance(query.get("data"), tuple): + LOG.debug(f'Casting data from {type(query["data"])} to tuple') + query["data"] = (query.get("data", {}),) + try: + query_output = db_command(*query.get("data"), *args, **kwargs) + except Exception as e: + LOG.error(f"Query failed: {query}|args={args}|kwargs={kwargs}") + raise e + + if received_command == "find": + filters = query.get("filters", {}) if filters: for name, value in filters.items(): query_output = getattr(query_output, name)(value) diff --git a/utils/database_utils/mysql_connector.py b/utils/database_utils/mysql_connector.py index a6e410b2..28334e3a 100644 --- a/utils/database_utils/mysql_connector.py +++ b/utils/database_utils/mysql_connector.py @@ -26,7 +26,7 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from mysql.connector import (connection) +from mysql.connector import connection from typing import Optional from utils.database_utils.base_connector import DatabaseConnector, DatabaseTypes @@ -46,13 +46,15 @@ def create_connection(self): def abort_connection(self): self._cnx.close() - def exec_raw_query(self, query: str, generator: bool = False, *args, **kwargs) -> Optional[list]: + def exec_raw_query( + self, query: str, generator: bool = False, *args, **kwargs + ) -> Optional[list]: """Executes raw string query and returns its results - :param query: valid SQL query string - :param generator: to return cursor as generator object (defaults to False) + :param query: valid SQL query string + :param generator: to return cursor as generator object (defaults to False) - :returns query result if any + :returns query result if any """ cursor = self.connection.cursor(dictionary=True) cursor.execute(query, *args, **kwargs) diff --git a/utils/exceptions.py b/utils/exceptions.py new file mode 100644 index 00000000..9f0a1c14 --- /dev/null +++ b/utils/exceptions.py @@ -0,0 +1,35 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2022 Neongecko.com Inc. +# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds, +# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo +# BSD-3 License +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class KlatAPIAuthorizationError(Exception): + MESSAGE = "Failed to authenticate in klat with provided credentials" + + def __init__(self, message=None): + message = self.MESSAGE if message is None else message + super().__init__(message) diff --git a/utils/http_utils.py b/utils/http_utils.py index 5e9e6a12..b9d7f562 100644 --- a/utils/http_utils.py +++ b/utils/http_utils.py @@ -30,11 +30,14 @@ def respond(msg: str, status_code: int = 200) -> JSONResponse: """ - Sending responds with unified pattern + Sending responds with unified pattern - :param msg: message to send - :param status_code: HTTP status code + :param msg: message to send + :param status_code: HTTP status code - :returns JSON response containing provided message + :returns JSON response containing provided message """ - return JSONResponse({'msg': msg}, status_code) + return JSONResponse({"msg": msg}, status_code) + + +response_ok = respond("OK") diff --git a/utils/logging_utils.py b/utils/logging_utils.py index 0e2da6c2..653b742b 100644 --- a/utils/logging_utils.py +++ b/utils/logging_utils.py @@ -26,10 +26,26 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import importlib import logging +import os + +from ovos_utils.log import LOG as ovos_default_logger combo_lock_logger = logging.getLogger("combo_lock") combo_lock_logger.disabled = True -LOG = getattr(importlib.import_module('ovos_utils'), 'LOG') + +def _init_app_logger(): + logger = ovos_default_logger + logger.name = os.environ.get("LOG_NAME", "klat_server_log") + logger.base_path = os.environ.get("LOG_BASE_PATH", ".") + logger.init( + config={ + "level": os.environ.get("LOG_LEVEL", "INFO"), + "path": os.environ.get("LOG_PATH", os.getcwd()), + } + ) + return logger + + +LOG = _init_app_logger() diff --git a/utils/template_utils.py b/utils/template_utils.py index d44dd211..79621c1a 100644 --- a/utils/template_utils.py +++ b/utils/template_utils.py @@ -32,19 +32,23 @@ from starlette.templating import Jinja2Templates -component_templates = Jinja2Templates(directory=os.environ.get('TEMPLATES_DIR', "chat_client/templates")) +component_templates = Jinja2Templates( + directory=os.environ.get("TEMPLATES_DIR", "chat_client/templates") +) def callback_template(request: Request, template_name: str, context: dict = None): """ - Returns template response based on provided params - :param request: FastAPI request object - :param template_name: name of template to render - :param context: supportive context to add + Returns template response based on provided params + :param request: FastAPI request object + :param template_name: name of template to render + :param context: supportive context to add """ if not context: context = {} - context['request'] = request + context["request"] = request # Preventing exiting to the source code files - template_name = template_name.replace('../', '').replace('.', '/') - return component_templates.TemplateResponse(f"components/{template_name}.html", context) + template_name = template_name.replace("../", "").replace(".", "/") + return component_templates.TemplateResponse( + f"components/{template_name}.html", context + ) diff --git a/version.py b/version.py index 2b6d144d..9d885af4 100644 --- a/version.py +++ b/version.py @@ -26,4 +26,4 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -__version__ = "0.3.3a9" +__version__ = "0.4.5a10" diff --git a/version_bump.py b/version_bump.py index cfb9b3b3..5c56dae6 100644 --- a/version_bump.py +++ b/version_bump.py @@ -38,9 +38,9 @@ version = line.split("'")[1] if "a" not in version: - parts = version.split('.') + parts = version.split(".") parts[-1] = str(int(parts[-1]) + 1) - version = '.'.join(parts) + version = ".".join(parts) version = f"{version}a0" else: post = version.split("a")[1] @@ -49,6 +49,6 @@ for line in fileinput.input(join(dirname(__file__), "version.py"), inplace=True): if line.startswith("__version__"): - print(f"__version__ = \"{version}\"") + print(f'__version__ = "{version}"') else: - print(line.rstrip('\n')) \ No newline at end of file + print(line.rstrip("\n"))