From d9d7159d6fc6d1982ca6f4067fb04c878dda3ceb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20G=C3=A9r=C3=B4me?= Date: Thu, 19 Sep 2024 14:44:53 +0200 Subject: [PATCH] feat(Storage): Filesystem-based storage backend for files * chore(docker): Add .ruff_cache & .venv to ignore folder for docker * chore(CI): Copy the .env.dist to .env for the tests * feat(Storage): Implement a filesystem storage to store the datasets files in a mounted directory * Fix datasets worker * Manage proxy settings in base.py settings and docker image build args (#810) This allows to reduce the number of env vars to pass and manage secured connections. Co-authored-by: Christophe Philemotte <1059797+toch@users.noreply.github.com> * fix: enforce file size to be an integer * chore: Adapt code based on Nazar's comments * chore: Do not silently handle the ValueError in case of wrong value --------- Co-authored-by: Christophe Philemotte Co-authored-by: Christophe Philemotte <1059797+toch@users.noreply.github.com> Co-authored-by: nazarfil --- .dockerignore | 3 +- .env.dist | 154 ++++- .github/workflows/build_docker_image.yml | 4 + .github/workflows/test.yml | 9 +- Dockerfile | 63 ++- config/settings/base.py | 242 ++++---- config/settings/dev.py | 5 - config/settings/local.py | 45 ++ config/settings/production.py | 39 ++ config/settings/test.py | 12 +- config/settings_test.py | 6 - config/urls.py | 1 + docker-compose.yaml | 98 +--- hexa/analytics/tests/test_analytics.py | 2 - hexa/core/views_utils.py | 2 +- hexa/databases/tests/test_schema.py | 2 - hexa/databases/tests/test_utils.py | 2 - hexa/datasets/api.py | 28 +- .../management/commands/prepare_datasets.py | 23 + .../0008_alter_datasetfilemetadata_options.py | 16 + hexa/datasets/models.py | 32 +- hexa/datasets/queue.py | 98 ++-- hexa/datasets/schema/mutations.py | 6 +- hexa/datasets/schema/types.py | 6 +- hexa/datasets/tests/test_metadata.py | 224 +++----- hexa/datasets/tests/test_models.py | 13 +- hexa/datasets/tests/test_schema.py | 28 +- hexa/datasets/tests/testutils.py | 2 - hexa/files/__init__.py | 6 + hexa/files/api.py | 16 - hexa/files/apps.py | 9 +- hexa/files/backends/__init__.py | 12 + hexa/files/{basefs.py => backends/base.py} | 63 ++- hexa/files/backends/dummy.py | 170 ++++++ hexa/files/backends/exceptions.py | 16 + hexa/files/backends/fs.py | 330 +++++++++++ hexa/files/{ => backends}/gcp.py | 158 +++--- hexa/files/graphql/schema.graphql | 1 + hexa/files/s3.py | 438 -------------- hexa/files/schema/mutations.py | 20 +- hexa/files/schema/types.py | 13 +- hexa/files/tests/backends/__init__.py | 0 hexa/files/tests/backends/test_fs.py | 271 +++++++++ hexa/files/tests/backends/test_gcp.py | 59 ++ hexa/files/tests/mocks/backend.py | 55 -- hexa/files/tests/mocks/bucket.py | 22 +- hexa/files/tests/mocks/client.py | 42 +- hexa/files/tests/mocks/mockgcp.py | 9 - hexa/files/tests/test_api.py | 534 ------------------ hexa/files/tests/test_schema.py | 5 - hexa/files/urls.py | 10 + hexa/files/views.py | 29 +- hexa/notebooks/tests/test_schema.py | 2 - .../management/commands/pipelines_runner.py | 16 +- hexa/pipelines/schema/mutations.py | 8 +- hexa/pipelines/schema/types.py | 7 +- .../test_schema/test_pipeline_versions.py | 2 - .../tests/test_schema/test_pipelines.py | 8 +- hexa/pipelines/tests/test_views.py | 2 - hexa/pipelines/utils.py | 2 +- hexa/user_management/tests/test_schema.py | 4 +- hexa/workspaces/models.py | 13 +- hexa/workspaces/tests/test_models.py | 23 +- .../tests/test_schema/test_workspace.py | 12 +- .../test_schema/test_workspace_connection.py | 3 - hexa/workspaces/tests/test_views.py | 38 +- hexa/workspaces/utils.py | 4 +- hexa/workspaces/views.py | 24 +- 68 files changed, 1827 insertions(+), 1794 deletions(-) create mode 100644 config/settings/local.py delete mode 100644 config/settings_test.py create mode 100644 hexa/datasets/management/commands/prepare_datasets.py create mode 100644 hexa/datasets/migrations/0008_alter_datasetfilemetadata_options.py delete mode 100644 hexa/files/api.py create mode 100644 hexa/files/backends/__init__.py rename hexa/files/{basefs.py => backends/base.py} (59%) create mode 100644 hexa/files/backends/dummy.py create mode 100644 hexa/files/backends/exceptions.py create mode 100644 hexa/files/backends/fs.py rename hexa/files/{ => backends}/gcp.py (74%) delete mode 100644 hexa/files/s3.py create mode 100644 hexa/files/tests/backends/__init__.py create mode 100644 hexa/files/tests/backends/test_fs.py create mode 100644 hexa/files/tests/backends/test_gcp.py delete mode 100644 hexa/files/tests/mocks/backend.py delete mode 100644 hexa/files/tests/mocks/mockgcp.py delete mode 100644 hexa/files/tests/test_api.py create mode 100644 hexa/files/urls.py diff --git a/.dockerignore b/.dockerignore index 0af4d38ed..f9b617deb 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,7 +3,8 @@ .idea .devcontainer .env - +.ruff_cache +.venv Dockerfile docker-compose.yml .dockerignore diff --git a/.env.dist b/.env.dist index 5107180b3..837158ae3 100644 --- a/.env.dist +++ b/.env.dist @@ -1,23 +1,149 @@ -DEBUG_LOGGING=false -DEBUG_TOOLBAR=false +# General +################### + +DEBUG=true -# Required to run with default storage mode GCP -GCS_SERVICE_ACCOUNT_KEY= +# Settings module for Django in dev +DJANGO_SETTINGS_MODULE=config.settings.dev -# To run it in AWS mode or in LocalHosting mode set the variable to s3 -# WORKSPACE_STORAGE_ENGINE=s3 -WORKSPACE_DATASETS_BUCKET= -WORKSPACE_STORAGE_ENGINE_AWS_ACCESS_KEY_ID= -WORKSPACE_STORAGE_ENGINE_AWS_SECRET_ACCESS_KEY= +# Django debugging settings +DEBUG_LOGGING=true +DEBUG_TOOLBAR=false -# Not required +# Encryption settings +SECRET_KEY="))dodw9%n)7q86l-q1by4e-2z#vonph50!%ep7_je)_=x0m2v-" +ENCRYPTION_KEY="oT7DKt8zf0vsnbBcJ0R36SHkBzbjF2agFIK3hSAVvko=" + +# Email settings EMAIL_HOST= EMAIL_PORT= EMAIL_HOST_USER= EMAIL_USE_TLS= EMAIL_HOST_PASSWORD= -# Required to run legacy part of the code +# Database settings for Django +DATABASE_HOST=db +DATABASE_PORT=5432 +DATABASE_NAME=hexa-app +DATABASE_USER=hexa-app +DATABASE_PASSWORD=hexa-app +# Database settings for Postgres +POSTGRES_DB=hexa-app +POSTGRES_USER=hexa-app +POSTGRES_PASSWORD=hexa-app + +# Networking +############ + +# To enable TLS/SSL directly on the app +# TLS="false" + +# The hostname on which the services are published / bound +BASE_HOSTNAME=localhost +# The port number to access the backend +BASE_PORT=8000 +# URL to use for the communication between pipelines, workers & the backend's API +# If not set, it falls back to BASE_HOSTNAME:BASE_PORT +INTERNAL_BASE_URL=http://app:8000 + +# NextJS Frontend +# If not set, it falls back to either PROXY_HOSTNAME_AND_PORT or +# BASE_HOSTNAME:FRONTEND_PORT +# NEW_FRONTEND_DOMAIN=http://localhost:3000 + +# Jupyter Hub +# If not set, it falls back to either PROXY_HOSTNAME_AND_PORT or +# BASE_HOSTNAME:JUPYTERHUB_PORT +# NOTEBOOKS_URL=http://localhost:8001 + +# The port number to access the frontend +FRONTEND_PORT=3000 +# The port number to access the Jupyter hub +JUPYTERHUB_PORT=8001 + +# I'd put that directly in the compose manifest file +OPENHEXA_BACKEND_URL=http://app:8000 + +# if it's behind a reverse proxy +# PROXY_HOSTNAME_AND_PORT=example.com +# If TLS/SSL is set up on a reverse proxy routing to the app +# TRUST_FORWARDED_PROTO="no" + +# MixPanel +########## + +# mixpanel analytics +MIXPANEL_TOKEN= + + +# Pipelines +############ + +DEFAULT_WORKSPACE_IMAGE=blsq/openhexa-base-environment:latest # Change this to the image of the workspace you want to use by default +PIPELINE_SCHEDULER_SPAWNER=docker # Change to kubernetes to use kubernetes spawner + +# Kubernetes resources settings (used only in kubernetes spawner mode +PIPELINE_DEFAULT_CONTAINER_CPU_LIMIT=2 +PIPELINE_DEFAULT_CONTAINER_MEMORY_LIMIT=4G +PIPELINE_DEFAULT_CONTAINER_CPU_REQUEST=0.05 +PIPELINE_DEFAULT_CONTAINER_MEMORY_REQUEST=100M + + +# Notebooks +############ + +NOTEBOOKS_HUB_URL=http://jupyterhub:8000/hub +HUB_API_TOKEN=cbb352d6a412e266d7494fb014dd699373645ec8d353e00c7aa9dc79ca87800d # Change this to the token of the jupyterhub service + +# Workspaces +############# + +# Workspaces' DB settings +WORKSPACES_DATABASE_HOST=db +WORKSPACES_DATABASE_PORT=5432 +WORKSPACES_DATABASE_ROLE=hexa-app +WORKSPACES_DATABASE_DEFAULT_DB=hexa-app +WORKSPACES_DATABASE_PASSWORD=hexa-app +WORKSPACES_DATABASE_PROXY_HOST=db + + +# Workspace storage options +# -------------------------- + +# Add a prefix to the bucket name (may be useful to separate dev and prod workspaces inside a shared Google Cloud Project) +WORKSPACE_BUCKET_PREFIX= + +# Local FS: Define the root location where the workspaces files will be stored +# Absolute path to the directory where the workspaces data will be stored +WORKSPACE_STORAGE_LOCATION=$WORKSPACE_STORAGE_LOCATION +# Uncomment to disable the check of the file size before uploading it to the workspace (only for local storage) +#DISABLE_UPLOAD_MAX_SIZE_CHECK=false + +## GCP: Mandatory to run with GCS +WORKSPACE_STORAGE_BACKEND_GCS_SERVICE_ACCOUNT_KEY= +# The region where the buckets will be created +# WORKSPACE_BUCKET_REGION= + +## AWS: To run it in AWS mode or in LocalHosting mode set the variable to s3 +WORKSPACE_STORAGE_BACKEND_AWS_ENDPOINT_URL= +WORKSPACE_STORAGE_BACKEND_AWS_PUBLIC_ENDPOINT_URL= +WORKSPACE_STORAGE_BACKEND_AWS_SECRET_ACCESS_KEY= +WORKSPACE_STORAGE_BACKEND_AWS_ACCESS_KEY_ID= +WORKSPACE_STORAGE_BACKEND_AWS_BUCKET_REGION= +# The region where the buckets will be created +# WORKSPACE_BUCKET_REGION= + +# Datasets +########### + +# Bucket to store datasets for all workspaces +WORKSPACE_DATASETS_BUCKET=hexa-datasets +WORKSPACE_DATASETS_FILE_SNAPSHOT_SIZE=50 + +# Legacy +######### + +# Required for the `connector_s3` django app to work AWS_USERNAME= AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= @@ -27,5 +153,7 @@ AWS_USER_ARN= AWS_APP_ROLE_ARN= AWS_PERMISSIONS_BOUNDARY_POLICY_ARN= -# mixpanel analytics -MIXPANEL_TOKEN= \ No newline at end of file +# Accessmod settings +ACCESSMOD_BUCKET_NAME=s3://hexa-demo-accessmod +ACCESSMOD_MANAGE_REQUESTS_URL=http://localhost:3000/admin/access-requests +ACCESSMOD_SET_PASSWORD_URL=http://localhost:3000/account/set-password diff --git a/.github/workflows/build_docker_image.yml b/.github/workflows/build_docker_image.yml index ab7d2eecf..c7fd3710f 100644 --- a/.github/workflows/build_docker_image.yml +++ b/.github/workflows/build_docker_image.yml @@ -71,6 +71,8 @@ jobs: context: . target: app file: Dockerfile + build-args: | + DJANGO_SETTINGS_MODULE=config.settings.dev cache-from: type=registry,ref=blsq/openhexa-app:buildcache cache-to: type=registry,ref=blsq/openhexa-app:buildcache,mode=max tags: | @@ -85,6 +87,8 @@ jobs: context: . target: app file: Dockerfile + build-args: | + DJANGO_SETTINGS_MODULE=config.settings.dev cache-from: type=registry,ref=blsq/openhexa-app:buildcache cache-to: type=registry,ref=blsq/openhexa-app:buildcache,mode=max tags: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d42f6b7c4..9adf5c8fb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,11 +50,14 @@ jobs: - name: Create docker network run: docker network create openhexa + + - name: Copy .env file + run: cp .env.dist .env - name: Build docker app image env: DOCKER_BUILDKIT: 1 - run: docker compose build + run: docker compose build --build-arg DJANGO_SETTINGS_MODULE="config.settings.dev" - name: Run Django tests run: docker compose run -e DEBUG=false app coveraged-test @@ -63,10 +66,12 @@ jobs: - name: Build and push (on main) uses: docker/build-push-action@v6 with: - push: true + push: ${{ github.event_name != 'pull_request' }} context: . target: app file: Dockerfile + build-args: | + DJANGO_SETTINGS_MODULE=config.settings.dev cache-from: type=registry,ref=blsq/openhexa-app:buildcache cache-to: type=registry,ref=blsq/openhexa-app:buildcache,mode=max tags: | diff --git a/Dockerfile b/Dockerfile index 305ac71fa..f08c0d206 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.12-slim as deps +FROM python:3.12-slim AS deps RUN \ --mount=type=cache,target=/var/cache/apt,sharing=locked \ @@ -8,33 +8,58 @@ RUN \ apt-get clean && \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -RUN pip install --upgrade pip - + # Set up work directory RUN mkdir /code WORKDIR /code -RUN \ - --mount=type=cache,target=/root/.cache \ - --mount=type=bind,source=requirements.txt,target=/code/requirements.txt \ - pip install setuptools==68.0.0 && pip install -r requirements.txt +# Upgrade pip and install requirements +RUN --mount=type=cache,target=/root/.cache \ + pip install --upgrade pip setuptools==68.0.0 + +# Install project dependencies from requirements.txt +COPY requirements.txt /code/ +RUN --mount=type=cache,target=/root/.cache \ + pip install -r requirements.txt && \ + apt-get remove -y build-essential && \ + apt-get autoremove -y + +# Copy the rest of the application COPY . /code/ -ENV SECRET_KEY="collectstatic" -ENV DJANGO_SETTINGS_MODULE config.settings.production +ARG DJANGO_SETTINGS_MODULE + +# Entry point +ARG WORKSPACE_STORAGE_LOCATION +ENV DJANGO_SETTINGS_MODULE=${DJANGO_SETTINGS_MODULE} +ENV WORKSPACE_STORAGE_LOCATION=${WORKSPACE_STORAGE_LOCATION} ENTRYPOINT ["/code/docker-entrypoint.sh"] CMD start -FROM deps as app -ENV DJANGO_SETTINGS_MODULE config.settings.production +FROM deps AS app +ARG DJANGO_SETTINGS_MODULE +ARG WORKSPACE_STORAGE_LOCATION +ENV DJANGO_SETTINGS_MODULE=${DJANGO_SETTINGS_MODULE} +ENV WORKSPACE_STORAGE_LOCATION=${WORKSPACE_STORAGE_LOCATION} RUN python manage.py collectstatic --noinput # Staged used to run the pipelines scheduler and runner -FROM app as pipelines -ENV DJANGO_SETTINGS_MODULE config.settings.production -RUN mkdir -m 0755 -p /etc/apt/keyrings -RUN curl -fsSL https://download.docker.com/linux/debian/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg -RUN echo \ - "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/debian \ - $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null -RUN apt-get update && apt-get install -y docker-ce-cli \ No newline at end of file +FROM app AS pipelines +ARG DJANGO_SETTINGS_MODULE +ARG WORKSPACE_STORAGE_LOCATION +ENV DJANGO_SETTINGS_MODULE=${DJANGO_SETTINGS_MODULE} +ENV WORKSPACE_STORAGE_LOCATION=${WORKSPACE_STORAGE_LOCATION} +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + gnupg && \ + mkdir -m 0755 -p /etc/apt/keyrings && \ + curl -fsSL https://download.docker.com/linux/debian/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ + echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/debian \ + $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null && \ + apt-get update && \ + apt-get install -y --no-install-recommends docker-ce-cli && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* diff --git a/config/settings/base.py b/config/settings/base.py index 5f35ce9ec..867a42423 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -11,7 +11,9 @@ """ import os +import re from pathlib import Path +from urllib.parse import urlparse from corsheaders.defaults import default_headers from django.utils.translation import gettext_lazy as _ @@ -31,10 +33,89 @@ # SECURITY WARNING: don't run with debug turned on in production! DEBUG = os.environ.get("DEBUG", "false") == "true" -ALLOWED_HOSTS = os.environ.get("ALLOWED_HOSTS", "").split(",") +# Trust the X_FORWARDED_PROTO header from the proxy or load balancer so Django is aware it is accessed by https +if "TRUST_FORWARDED_PROTO" in os.environ: + SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https") + +TLS = os.environ.get("TLS", "false") == "true" +SCHEME = "https" if TLS else "http" + +SECURE_REDIRECT_EXEMPT = [r"^ready$"] +if TLS or "TRUST_FORWARDED_PROTO" in os.environ: + SESSION_COOKIE_SECURE = True + CSRF_COOKIE_SECURE = True + SECURE_SSL_REDIRECT = TLS +else: + SESSION_COOKIE_SECURE = False + CSRF_COOKIE_SECURE = False + SECURE_SSL_REDIRECT = False + +BASE_HOSTNAME = os.environ.get("BASE_HOSTNAME", "localhost") +BASE_PORT = os.environ.get("BASE_PORT", 8000) + +# Needed so that external component know how to hit us back +# Do not add a trailing slash +BASE_URL = os.environ.get("BASE_URL", f"{SCHEME}://{BASE_HOSTNAME}:{BASE_PORT}") +INTERNAL_BASE_URL = os.environ.get("INTERNAL_BASE_URL", BASE_URL) + +ALLOWED_HOSTS = ( + os.environ.get("ADDITIONAL_ALLOWED_HOSTS").split(",") + if "ADDITIONAL_ALLOWED_HOSTS" in os.environ + else [] +) + +CORS_ALLOWED_ORIGINS = [] +CSRF_TRUSTED_ORIGINS = [] +if "PROXY_HOSTNAME_AND_PORT" in os.environ: + SCHEME = "https" if "TRUST_FORWARDED_PROTO" in os.environ else SCHEME + PROXY_URL = f'{SCHEME}://{os.environ.get("PROXY_HOSTNAME_AND_PORT")}' + NEW_FRONTEND_DOMAIN = os.environ.get( + "NEW_FRONTEND_DOMAIN", os.environ.get("PROXY_HOSTNAME_AND_PORT") + ) + if not re.match(r"^https?://", NEW_FRONTEND_DOMAIN, re.IGNORECASE): + # Add the scheme if it is missing + NEW_FRONTEND_DOMAIN = f"{SCHEME}://{NEW_FRONTEND_DOMAIN}" + NOTEBOOKS_URL = os.environ.get("NOTEBOOKS_URL", PROXY_URL) + CORS_ALLOWED_ORIGINS = [PROXY_URL] + CSRF_TRUSTED_ORIGINS = [PROXY_URL] + ALLOWED_HOSTS += [urlparse(NEW_FRONTEND_DOMAIN).netloc.split(":")[0]] +else: + NEW_FRONTEND_DOMAIN = os.environ.get( + "NEW_FRONTEND_DOMAIN", + f'{BASE_HOSTNAME}:{os.environ.get("FRONTEND_PORT", 3000)}', + ) + NOTEBOOKS_URL = os.environ.get( + "NOTEBOOKS_URL", + f'{SCHEME}://{BASE_HOSTNAME}:{os.environ.get("JUPYTERHUB_PORT", 8001)}', + ) + if not re.match(r"^https?://", NEW_FRONTEND_DOMAIN, re.IGNORECASE): + # Add the scheme if it is missing + NEW_FRONTEND_DOMAIN = f"{SCHEME}://{NEW_FRONTEND_DOMAIN}" + CORS_ALLOWED_ORIGINS = [NEW_FRONTEND_DOMAIN] + CSRF_TRUSTED_ORIGINS = [NEW_FRONTEND_DOMAIN] + ALLOWED_HOSTS += [BASE_HOSTNAME] + +# CORS (For GraphQL) +# https://github.com/adamchainz/django-cors-headers +if "CORS_ALLOWED_ORIGINS" in os.environ: + CORS_ALLOWED_ORIGINS += os.environ.get("CORS_ALLOWED_ORIGINS").split(",") + +CORS_URLS_REGEX = r"^[/graphql/(\w+\/)?|/analytics/track]$" +CORS_ALLOW_CREDENTIALS = True + +CORS_ALLOW_HEADERS = list(default_headers) + [ + "sentry-trace", +] -# Domain of the new frontend (it is used to redirect the user to new pages not implemented in Django) -NEW_FRONTEND_DOMAIN = os.environ.get("NEW_FRONTEND_DOMAIN") +# CSRF +if "CSRF_TRUSTED_ORIGINS" in os.environ: + CSRF_TRUSTED_ORIGINS += os.environ.get("CSRF_TRUSTED_ORIGINS").split(",") + +SESSION_COOKIE_DOMAIN = os.environ.get("SESSION_COOKIE_DOMAIN", None) +CSRF_COOKIE_DOMAIN = os.environ.get("CSRF_COOKIE_DOMAIN", None) +SECURE_HSTS_SECONDS = os.environ.get( + "SECURE_HSTS_SECONDS", 60 * 60 +) # TODO: increase to one year if ok # Application definition INSTALLED_APPS = [ @@ -131,7 +212,6 @@ } } - # Auth settings LOGIN_URL = "core:login" LOGOUT_REDIRECT_URL = "core:login" @@ -167,45 +247,9 @@ "hexa.user_management.backends.PermissionsBackend", ] - -# Additional security settings -SESSION_COOKIE_SECURE = os.environ.get("SESSION_COOKIE_SECURE", "true") != "false" -CSRF_COOKIE_SECURE = os.environ.get("CSRF_COOKIE_SECURE", "true") != "false" -SECURE_SSL_REDIRECT = os.environ.get("SECURE_SSL_REDIRECT", "true") != "false" -SECURE_REDIRECT_EXEMPT = [r"^ready$"] - -RAW_CSRF_TRUSTED_ORIGINS = os.environ.get("CSRF_TRUSTED_ORIGINS") -if RAW_CSRF_TRUSTED_ORIGINS is not None: - CSRF_TRUSTED_ORIGINS = RAW_CSRF_TRUSTED_ORIGINS.split(",") - -SESSION_COOKIE_DOMAIN = os.environ.get("SESSION_COOKIE_DOMAIN", None) -CSRF_COOKIE_DOMAIN = os.environ.get("CSRF_COOKIE_DOMAIN", None) -SECURE_HSTS_SECONDS = os.environ.get( - "SECURE_HSTS_SECONDS", 60 * 60 -) # TODO: increase to one year if ok - - # by default users need to login every 2 weeks -> update to 1 year SESSION_COOKIE_AGE = 365 * 24 * 3600 -# Trust the X_FORWARDED_PROTO header from the GCP load balancer so Django is aware it is accessed by https -if "TRUST_FORWARDED_PROTO" in os.environ: - SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https") - -# CORS (For GraphQL) -# https://github.com/adamchainz/django-cors-headers - -RAW_CORS_ALLOWED_ORIGINS = os.environ.get("CORS_ALLOWED_ORIGINS") -if RAW_CORS_ALLOWED_ORIGINS is not None: - CORS_ALLOWED_ORIGINS = RAW_CORS_ALLOWED_ORIGINS.split(",") - CORS_URLS_REGEX = r"^[/graphql/(\w+\/)?|/analytics/track]$" - CORS_ALLOW_CREDENTIALS = True - - -CORS_ALLOW_HEADERS = list(default_headers) + [ - "sentry-trace", -] - # Internationalization # https://docs.djangoproject.com/en/4.0/topics/i18n/ @@ -230,13 +274,23 @@ STATIC_URL = "/static/" STATIC_ROOT = BASE_DIR / "static" STATICFILES_DIRS = [BASE_DIR / "hexa" / "static"] +MEDIA_ROOT = BASE_DIR / "static" / "uploads" # Whitenoise # http://whitenoise.evans.io/en/stable/django.html#add-compression-and-caching-support STATICFILES_STORAGE = "whitenoise.storage.CompressedManifestStaticFilesStorage" +if os.environ.get("DEBUG_TOOLBAR", "false") == "true": + INSTALLED_APPS.append("debug_toolbar") # noqa: F405 + MIDDLEWARE.append("debug_toolbar.middleware.DebugToolbarMiddleware") # noqa: F405 + # Django Debug Toolbar specifically ask for INTERNAL_IPS to be set + INTERNAL_IPS = ["127.0.0.1"] + + DEBUG_TOOLBAR_CONFIG = { + "SHOW_TOOLBAR_CALLBACK": lambda request: request.user.is_staff, + } + # Notebooks component -NOTEBOOKS_URL = os.environ.get("NOTEBOOKS_URL", "http://localhost:8001") NOTEBOOKS_HUB_URL = os.environ.get("NOTEBOOKS_HUB_URL", "http://jupyterhub:8000/hub") HUB_API_TOKEN = os.environ.get("HUB_API_TOKEN", "") @@ -256,11 +310,16 @@ "loggers": { "": { "handlers": ["console"], - "level": "INFO", + "level": "DEBUG", }, }, } +# Disabling the check on the size of the request body when using the file system storage backend +# This is needed to allow the upload of large files when they are not stored by an external storage backend +if os.environ.get("DISABLE_UPLOAD_MAX_SIZE_CHECK", "false") == "true": + DATA_UPLOAD_MAX_MEMORY_SIZE = None + # Email settings EMAIL_HOST = os.environ.get("EMAIL_HOST") EMAIL_PORT = os.environ.get("EMAIL_PORT") @@ -279,47 +338,16 @@ # Sync settings: sync datasource with a worker (good for scaling) or in the web serv (good for dev) EXTERNAL_ASYNC_REFRESH = os.environ.get("EXTERNAL_ASYNC_REFRESH") == "true" - -if os.environ.get("DEBUG_TOOLBAR", "false") == "true": - INSTALLED_APPS.append("debug_toolbar") - MIDDLEWARE.append("debug_toolbar.middleware.DebugToolbarMiddleware") - # Django Debug Toolbar specifically ask for INTERNAL_IPS to be set - INTERNAL_IPS = ["127.0.0.1"] - - DEBUG_TOOLBAR_CONFIG = { - "SHOW_TOOLBAR_CALLBACK": lambda request: request.user.is_staff, - } - -if os.environ.get("STORAGE", "local") == "google-cloud": - # activate google cloud storage, used for dashboard screenshot, ... - # user generated content - DEFAULT_FILE_STORAGE = "storages.backends.gcloud.GoogleCloudStorage" - GS_BUCKET_NAME = os.environ.get("STORAGE_BUCKET") - GS_FILE_OVERWRITE = False -else: - MEDIA_ROOT = BASE_DIR / "static" / "uploads" - -# Accessmod settings -ACCESSMOD_BUCKET_NAME = os.environ.get("ACCESSMOD_BUCKET_NAME") -ACCESSMOD_MANAGE_REQUESTS_URL = os.environ.get("ACCESSMOD_MANAGE_REQUESTS_URL") -ACCESSMOD_SET_PASSWORD_URL = os.environ.get("ACCESSMOD_SET_PASSWORD_URL") - -# Specific settings for airflow plugins +## Specific settings for airflow plugins # number of second of airflow dag reloading setting AIRFLOW_SYNC_WAIT = 61 -GCS_TOKEN_LIFETIME = os.environ.get("GCS_TOKEN_LIFETIME") - -# Needed so that external component know how to hit us back -# Do not add a trailing slash -BASE_URL = os.environ.get("BASE_URL", "http://localhost:8000") - +GCS_TOKEN_LIFETIME = os.environ.get("GCS_TOKEN_LIFETIME", 3600) # Pipeline settings -PIPELINE_SCHEDULER_SPAWNER = os.environ.get("PIPELINE_SCHEDULER_SPAWNER", "kubernetes") -PIPELINE_API_URL = os.environ.get("PIPELINE_API_URL", BASE_URL) +PIPELINE_SCHEDULER_SPAWNER = os.environ.get("PIPELINE_SCHEDULER_SPAWNER", "docker") DEFAULT_WORKSPACE_IMAGE = os.environ.get( - "DEFAULT_WORKSPACE_IMAGE", "blsq/openhexa-blsq-environment:latest" + "DEFAULT_WORKSPACE_IMAGE", "blsq/openhexa-base-environment:latest" ) PIPELINE_DEFAULT_CONTAINER_CPU_LIMIT = os.environ.get( "PIPELINE_DEFAULT_CONTAINER_CPU_LIMIT", "2" @@ -349,47 +377,47 @@ WORKSPACES_DATABASE_DEFAULT_DB = os.environ.get("WORKSPACES_DATABASE_DEFAULT_DB") WORKSPACES_DATABASE_PROXY_HOST = os.environ.get("WORKSPACES_DATABASE_PROXY_HOST") +# Datasets config +WORKSPACE_DATASETS_BUCKET = os.environ.get("WORKSPACE_DATASETS_BUCKET", "hexa-datasets") +WORKSPACE_DATASETS_FILE_SNAPSHOT_SIZE = int( + os.environ.get("WORKSPACE_DATASETS_FILE_SNAPSHOT_SIZE", 50) +) # Filesystem configuration -WORKSPACE_BUCKET_PREFIX = os.environ.get("WORKSPACE_BUCKET_PREFIX", "hexa-") +WORKSPACE_STORAGE_LOCATION = os.environ.get("WORKSPACE_STORAGE_LOCATION") +WORKSPACE_STORAGE_BACKEND = { + "engine": "hexa.files.backends.fs.FileSystemStorage", + "options": { + "data_dir": "/data", + "ext_bind_path": WORKSPACE_STORAGE_LOCATION, + "file_permissions_mode": 0o777, + "directory_permissions_mode": 0o777, + }, +} + WORKSPACE_BUCKET_REGION = os.environ.get("WORKSPACE_BUCKET_REGION", "europe-west1") -WORKSPACE_STORAGE_ENGINE = os.environ.get("WORKSPACE_STORAGE_ENGINE", "gcp") -WORKSPACE_BUCKET_VERSIONING_ENABLED = ( - os.environ.get("WORKSPACE_BUCKET_VERSIONING_ENABLED", "false") == "true" -) +WORKSPACE_BUCKET_PREFIX = os.environ.get("WORKSPACE_BUCKET_PREFIX", "") +WORKSPACE_BUCKET_VERSIONING_ENABLED = False -WORKSPACE_STORAGE_ENGINE_AWS_ENDPOINT_URL = os.environ.get( - "WORKSPACE_STORAGE_ENGINE_AWS_ENDPOINT_URL" +### AWS S3 Settings if using AWS S3 as a storage backend ### +WORKSPACE_STORAGE_BACKEND_AWS_ENDPOINT_URL = os.environ.get( + "WORKSPACE_STORAGE_BACKEND_AWS_ENDPOINT_URL" ) - # This is the endpoint URL used when generating presigned URLs called by the client since the client # does not have access to storage engine in local mode (http://minio:9000) -WORKSPACE_STORAGE_ENGINE_AWS_PUBLIC_ENDPOINT_URL = os.environ.get( - "WORKSPACE_STORAGE_ENGINE_AWS_PUBLIC_ENDPOINT_URL" +WORKSPACE_STORAGE_BACKEND_AWS_PUBLIC_ENDPOINT_URL = os.environ.get( + "WORKSPACE_STORAGE_BACKEND_AWS_PUBLIC_ENDPOINT_URL" ) -WORKSPACE_STORAGE_ENGINE_AWS_ACCESS_KEY_ID = os.environ.get( - "WORKSPACE_STORAGE_ENGINE_AWS_ACCESS_KEY_ID" +WORKSPACE_STORAGE_BACKEND_AWS_ACCESS_KEY_ID = os.environ.get( + "WORKSPACE_STORAGE_BACKEND_AWS_ACCESS_KEY_ID" ) -WORKSPACE_STORAGE_ENGINE_AWS_SECRET_ACCESS_KEY = os.environ.get( - "WORKSPACE_STORAGE_ENGINE_AWS_SECRET_ACCESS_KEY" +WORKSPACE_STORAGE_BACKEND_AWS_SECRET_ACCESS_KEY = os.environ.get( + "WORKSPACE_STORAGE_BACKEND_AWS_SECRET_ACCESS_KEY" ) -WORKSPACE_STORAGE_ENGINE_AWS_BUCKET_REGION = os.environ.get( - "WORKSPACE_STORAGE_ENGINE_AWS_BUCKET_REGION" +WORKSPACE_STORAGE_BACKEND_AWS_BUCKET_REGION = os.environ.get( + "WORKSPACE_STORAGE_BACKEND_AWS_BUCKET_REGION" ) -# Datasets config -WORKSPACE_DATASETS_BUCKET = os.environ.get("WORKSPACE_DATASETS_BUCKET") -WORKSPACE_DATASETS_FILE_SNAPSHOT_SIZE = os.environ.get( - "WORKSPACE_DATASETS_FILE_SNAPSHOT_SIZE", 50 -) - -# Base64 encoded service account key -# To generate a service account key, follow the instructions here: -# import base64 -# import json -# base64.b64encode(json.dumps(service_account_key_content).encode("utf-8")) -GCS_SERVICE_ACCOUNT_KEY = os.environ.get("GCS_SERVICE_ACCOUNT_KEY", "") - -# S3 settings +# S3 settings (Used by OpenHEXA Legacy) AWS_USERNAME = os.environ.get("AWS_USERNAME", "") AWS_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID", "") AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", "") diff --git a/config/settings/dev.py b/config/settings/dev.py index 48747fdb5..ba4021905 100644 --- a/config/settings/dev.py +++ b/config/settings/dev.py @@ -1,6 +1,5 @@ from .base import * # noqa: F403, F401 -DEBUG = True SESSION_COOKIE_SECURE = False CSRF_COOKIE_SECURE = False SECURE_SSL_REDIRECT = False @@ -9,7 +8,3 @@ CORS_URLS_REGEX = r"^/graphql/(\w+\/)?$" CORS_ALLOW_CREDENTIALS = True CSRF_TRUSTED_ORIGINS = ["http://localhost:3000"] -PIPELINE_SCHEDULER_SPAWNER = "docker" -SECRET_KEY = "))dodw9%n)7q86l-q1by4e-2z#vonph50!%ep7_je)_=x0m2v-" -ENCRYPTION_KEY = "oT7DKt8zf0vsnbBcJ0R36SHkBzbjF2agFIK3hSAVvko=" -GCS_TOKEN_LIFETIME = 3600 diff --git a/config/settings/local.py b/config/settings/local.py new file mode 100644 index 000000000..3d569bf04 --- /dev/null +++ b/config/settings/local.py @@ -0,0 +1,45 @@ +import os + +from .base import * # noqa: F403, F401 + +LOGGING = { + "version": 1, + "disable_existing_loggers": True, + "formatters": {}, + "handlers": { + "console": { + "class": "logging.StreamHandler", + }, + }, + "loggers": { + "django.security.DisallowedHost": { + "level": "CRITICAL", + "propagate": True, + }, + "django": { + "level": "INFO", + "propagate": True, + }, + "gunicorn": { + "level": "INFO", + "propagate": True, + }, + "": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, +} + + +# Filesystem configuration +WORKSPACE_STORAGE_BACKEND = { + "engine": "hexa.files.backends.fs.FileSystemStorage", + "options": { + "data_dir": "/data", + "ext_bind_path": os.environ.get("WORKSPACE_STORAGE_LOCATION"), + "file_permissions_mode": 0o777, + "directory_permissions_mode": 0o777, + }, +} diff --git a/config/settings/production.py b/config/settings/production.py index 136ff7995..f363557e7 100644 --- a/config/settings/production.py +++ b/config/settings/production.py @@ -3,6 +3,14 @@ from .base import * # noqa: F403, F401 +if os.environ.get("STORAGE", "local") == "google-cloud": + # activate google cloud storage, used for dashboard screenshot, ... + # user generated content + DEFAULT_FILE_STORAGE = "storages.backends.gcloud.GoogleCloudStorage" + GS_BUCKET_NAME = os.environ.get("STORAGE_BUCKET") + GS_FILE_OVERWRITE = False + MEDIA_ROOT = None + SENTRY_DSN = os.environ.get("SENTRY_DSN") if SENTRY_DSN: @@ -76,3 +84,34 @@ def sentry_tracer_sampler(sampling_context): send_default_pii=True, environment=os.environ.get("SENTRY_ENVIRONMENT"), ) + + +PIPELINE_SCHEDULER_SPAWNER = "kubernetes" + +# GCP Settings if using GCS as a storage backend ### +# Base64 encoded service account key +# To generate a service account key, follow the instructions here: +# import base64 +# import json +# WORKSPACE_STORAGE_BACKEND_GCS_SERVICE_ACCOUNT_KEY = base64.b64encode(json.dumps(service_account_key_content).encode("utf-8")) +WORKSPACE_STORAGE_BACKEND = { + "engine": "hexa.files.backends.gcp.GoogleCloudStorage", + "options": { + "service_account_key": os.environ.get( + "WORKSPACE_STORAGE_BACKEND_GCS_SERVICE_ACCOUNT_KEY" + ), + "region": os.environ.get("WORKSPACE_BUCKET_REGION", "europe-west1"), + "enable_versioning": True, + }, +} + + +# Trust the X_FORWARDED_PROTO header from the GCP load balancer so Django is aware it is accessed by https +if "TRUST_FORWARDED_PROTO" in os.environ: + SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https") + + +# Legacy +ACCESSMOD_BUCKET_NAME = os.environ.get("ACCESSMOD_BUCKET_NAME") +ACCESSMOD_MANAGE_REQUESTS_URL = os.environ.get("ACCESSMOD_MANAGE_REQUESTS_URL") +ACCESSMOD_SET_PASSWORD_URL = os.environ.get("ACCESSMOD_SET_PASSWORD_URL") diff --git a/config/settings/test.py b/config/settings/test.py index bc53f7e53..05066d43e 100644 --- a/config/settings/test.py +++ b/config/settings/test.py @@ -1,10 +1,16 @@ from .dev import * # noqa: F403, F401 -from .dev import INSTALLED_APPS DEBUG = False -WORKSPACE_STORAGE_ENGINE = "gcp" # Custom test runner TEST_RUNNER = "hexa.core.test.runner.DiscoverRunner" +WORKSPACE_STORAGE_BACKEND = {"engine": "hexa.files.backends.dummy.DummyStorageClient"} -INSTALLED_APPS += ["hexa", "hexa.core.tests.soft_delete"] +if "hexa.plugins.connector_accessmod" in INSTALLED_APPS: # noqa: F405 + # Accessmod settings + ACCESSMOD_BUCKET_NAME = "s3://hexa-demo-accessmod" + ACCESSMOD_MANAGE_REQUESTS_URL = "http://localhost:3000/admin/access-requests" + ACCESSMOD_SET_PASSWORD_URL = "http://localhost:3000/account/set-password" + +NEW_FRONTEND_DOMAIN = "http://localhost:3000" +NOTEBOOKS_URL = "http://localhost:8001" diff --git a/config/settings_test.py b/config/settings_test.py deleted file mode 100644 index 38506f100..000000000 --- a/config/settings_test.py +++ /dev/null @@ -1,6 +0,0 @@ -from .settings import * # noqa - -# since most existing test assumed that gcp was used in conjunction of the mocked storage -# we make sure the .env doesn't interfere with the test and enforce gcp by default in the tests - -WORKSPACE_STORAGE_ENGINE = "gcp" diff --git a/config/urls.py b/config/urls.py index 596db6fe8..8caa0b953 100644 --- a/config/urls.py +++ b/config/urls.py @@ -37,6 +37,7 @@ path("notebooks/", include("hexa.notebooks.urls", namespace="notebooks")), path("pipelines/", include("hexa.pipelines.urls", namespace="pipelines")), path("workspaces/", include("hexa.workspaces.urls", namespace="workspaces")), + path("files/", include("hexa.files.urls", namespace="files")), path("analytics/", include("hexa.analytics.urls", namespace="analytics")), # Order matters, we override the default logout view defined later # We do this to logout the user from jupyterhub at the end of the openhexa diff --git a/docker-compose.yaml b/docker-compose.yaml index f7d2e3757..9fae8b360 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,78 +1,34 @@ # Defines a service that can be reused multiple times later x-app: &common + platform: linux/amd64 build: context: . dockerfile: Dockerfile target: app - platform: linux/amd64 + args: + - DJANGO_SETTINGS_MODULE=${DJANGO_SETTINGS_MODULE} + - WORKSPACE_STORAGE_LOCATION=${WORKSPACE_STORAGE_LOCATION} + env_file: + - path: ./.env + required: true + environment: + - ADDITIONAL_ALLOWED_HOSTS=app,frontend networks: - openhexa - environment: - - DEBUG=true - - DJANGO_SETTINGS_MODULE=config.settings.dev - - DATABASE_HOST=db - - DATABASE_PORT=5432 - - DATABASE_NAME=hexa-app - - DATABASE_USER=hexa-app - - DATABASE_PASSWORD=hexa-app - - ACCESSMOD_BUCKET_NAME=s3://hexa-demo-accessmod - - ACCESSMOD_MANAGE_REQUESTS_URL=http://localhost:3000/admin/access-requests - - ACCESSMOD_SET_PASSWORD_URL=http://localhost:3000/account/set-password - - NEW_FRONTEND_DOMAIN=localhost:3000 - - PIPELINE_SCHEDULER_SPAWNER=docker - - PIPELINE_API_URL=http://app:8000 - - DEFAULT_WORKSPACE_IMAGE - - MIXPANEL_TOKEN - # The following variables are optional and can be set in your local .env file for testing purposes - - DEBUG_LOGGING - - DEBUG_TOOLBAR - - EMAIL_HOST - - EMAIL_PORT - - EMAIL_HOST_USER - - EMAIL_USE_TLS - - EMAIL_HOST_PASSWORD - - WORKSPACES_DATABASE_ROLE=hexa-app - - WORKSPACES_DATABASE_PASSWORD=hexa-app - - WORKSPACES_DATABASE_HOST=db - - WORKSPACES_DATABASE_PORT=5432 - - WORKSPACES_DATABASE_DEFAULT_DB=hexa-app - - WORKSPACES_DATABASE_PROXY_HOST=db - - WORKSPACE_DATASETS_BUCKET - - WORKSPACE_STORAGE_ENGINE=gcp - - WORKSPACE_BUCKET_VERSIONING_ENABLED - - WORKSPACE_STORAGE_ENGINE_AWS_ENDPOINT_URL - - WORKSPACE_STORAGE_ENGINE_AWS_PUBLIC_ENDPOINT_URL - - WORKSPACE_STORAGE_ENGINE_AWS_ACCESS_KEY_ID - - WORKSPACE_STORAGE_ENGINE_AWS_SECRET_ACCESS_KEY - - WORKSPACE_STORAGE_ENGINE_AWS_BUCKET_REGION - - NOTEBOOKS_HUB_URL=http://jupyterhub:8000/hub - - HUB_API_TOKEN=cbb352d6a412e266d7494fb014dd699373645ec8d353e00c7aa9dc79ca87800d - - GCS_SERVICE_ACCOUNT_KEY - - WORKSPACE_BUCKET_PREFIX=hexa-test- - - WORKSPACE_BUCKET_REGION=europe-west1 - - AWS_USERNAME - - AWS_ACCESS_KEY_ID - - AWS_SECRET_ACCESS_KEY - - AWS_ENDPOINT_URL - - AWS_DEFAULT_REGION - - AWS_USER_ARN - - AWS_APP_ROLE_ARN - - AWS_PERMISSIONS_BOUNDARY_POLICY_ARN volumes: # only used for Github Codespaces - "${LOCAL_WORKSPACE_FOLDER:-.}:/code" - + - "${WORKSPACE_STORAGE_LOCATION:-/data/openhexa}:/data" services: db: image: postgis/postgis:12-3.2 + env_file: + - path: ./.env + required: true networks: - openhexa volumes: - pgdata:/var/lib/postgresql/data - environment: - - POSTGRES_DB=hexa-app - - POSTGRES_USER=hexa-app - - POSTGRES_PASSWORD=hexa-app ports: - "5434:5432" @@ -97,6 +53,7 @@ services: depends_on: - db + # This service is only used for the connector_accessmod app. dataworker: <<: *common command: "manage validate_fileset_worker" @@ -108,14 +65,15 @@ services: frontend: image: "blsq/openhexa-frontend:${FRONTEND_VERSION:-main}" + env_file: + - path: ./.env + required: true platform: linux/amd64 networks: - openhexa container_name: frontend ports: - - "3000:3000" - environment: - - OPENHEXA_BACKEND_URL=http://app:8000 + - "${FRONTEND_PORT:-3000}:3000" profiles: - frontend restart: unless-stopped @@ -128,6 +86,8 @@ services: context: . dockerfile: Dockerfile target: pipelines + args: + - DJANGO_SETTINGS_MODULE=${DJANGO_SETTINGS_MODULE} command: "manage pipelines_runner" restart: unless-stopped profiles: @@ -144,6 +104,8 @@ services: context: . dockerfile: Dockerfile target: pipelines + args: + - DJANGO_SETTINGS_MODULE=${DJANGO_SETTINGS_MODULE} command: "manage pipelines_scheduler" restart: unless-stopped profiles: @@ -151,22 +113,6 @@ services: depends_on: - db - minio: - image: quay.io/minio/minio - command: server --address 0.0.0.0:9000 --console-address ":9001" /data - volumes: - - minio_data:/data - profiles: - - minio - ports: - - '9000:9000' - - '9001:9001' - networks: - - openhexa - environment: - - MINIO_ACCESS_KEY=${WORKSPACE_STORAGE_ENGINE_AWS_ACCESS_KEY_ID} - - MINIO_SECRET_KEY=${WORKSPACE_STORAGE_ENGINE_AWS_SECRET_ACCESS_KEY} - networks: openhexa: diff --git a/hexa/analytics/tests/test_analytics.py b/hexa/analytics/tests/test_analytics.py index 0fe375801..8e67696d0 100644 --- a/hexa/analytics/tests/test_analytics.py +++ b/hexa/analytics/tests/test_analytics.py @@ -4,7 +4,6 @@ from hexa.analytics.api import set_user_properties, track from hexa.core.test import TestCase -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage from hexa.pipelines.models import Pipeline, PipelineRunTrigger from hexa.user_management.models import User from hexa.workspaces.models import Workspace @@ -12,7 +11,6 @@ class AnalyticsTest(TestCase): @classmethod - @mock_gcp_storage def setUpTestData(cls): cls.USER: User = User.objects.create_user( "user@bluesquarehub.com", "user", analytics_enabled=True, is_superuser=True diff --git a/hexa/core/views_utils.py b/hexa/core/views_utils.py index 63bcac403..6f457bcf2 100644 --- a/hexa/core/views_utils.py +++ b/hexa/core/views_utils.py @@ -9,7 +9,7 @@ def redirect_to_new_frontend(request: HttpRequest, *_, **__): if settings.NEW_FRONTEND_DOMAIN is not None: return redirect( request.build_absolute_uri( - f"//{settings.NEW_FRONTEND_DOMAIN}{request.get_full_path()}" + f"{settings.NEW_FRONTEND_DOMAIN}{request.get_full_path()}" ) ) raise Http404("Page not found") diff --git a/hexa/databases/tests/test_schema.py b/hexa/databases/tests/test_schema.py index 7e7aeded9..3b770981a 100644 --- a/hexa/databases/tests/test_schema.py +++ b/hexa/databases/tests/test_schema.py @@ -5,7 +5,6 @@ from hexa.core.test import GraphQLTestCase from hexa.databases.utils import TableRowsPage -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage from hexa.plugins.connector_postgresql.models import Database from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import ( @@ -19,7 +18,6 @@ class DatabaseTest(GraphQLTestCase): USER_SABRINA = None @classmethod - @mock_gcp_storage def setUpTestData(cls): cls.USER_SABRINA = User.objects.create_user( "sabrina@bluesquarehub.com", "standardpassword" diff --git a/hexa/databases/tests/test_utils.py b/hexa/databases/tests/test_utils.py index a8268ec9a..d8177ab50 100644 --- a/hexa/databases/tests/test_utils.py +++ b/hexa/databases/tests/test_utils.py @@ -13,7 +13,6 @@ get_table_definition, get_table_rows, ) -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage from hexa.plugins.connector_postgresql.models import Database from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import Workspace @@ -45,7 +44,6 @@ class DatabaseUtilsTest(TestCase): USER_SABRINA = None @classmethod - @mock_gcp_storage def setUpTestData(cls): cls.DB1 = Database.objects.create( hostname="host", username="user", password="pwd", database="db1" diff --git a/hexa/datasets/api.py b/hexa/datasets/api.py index 38684eb1e..e4445fac4 100644 --- a/hexa/datasets/api.py +++ b/hexa/datasets/api.py @@ -1,29 +1,29 @@ from django.conf import settings -from google.api_core import exceptions -from hexa.files import basefs -from hexa.files.api import get_storage +from hexa.files import storage -def generate_upload_url(uri, content_type): - return get_storage().generate_upload_url( +def generate_upload_url(file_uri, content_type: str, host: str | None = None): + return storage.generate_upload_url( settings.WORKSPACE_DATASETS_BUCKET, - uri, - content_type, + file_uri, + content_type=content_type, + host=host, raise_if_exists=True, ) -def generate_download_url(file): - return get_storage().generate_download_url( - settings.WORKSPACE_DATASETS_BUCKET, file.uri, force_attachment=True +def generate_download_url(version_file, host: str | None = None): + return storage.generate_download_url( + settings.WORKSPACE_DATASETS_BUCKET, + version_file.uri, + force_attachment=True, + host=host, ) def get_blob(uri): try: - return get_storage().get_bucket_object(settings.WORKSPACE_DATASETS_BUCKET, uri) - except exceptions.NotFound: - return None - except basefs.NotFound: + return storage.get_bucket_object(settings.WORKSPACE_DATASETS_BUCKET, uri) + except storage.exceptions.NotFound: return None diff --git a/hexa/datasets/management/commands/prepare_datasets.py b/hexa/datasets/management/commands/prepare_datasets.py new file mode 100644 index 000000000..281a3dbb6 --- /dev/null +++ b/hexa/datasets/management/commands/prepare_datasets.py @@ -0,0 +1,23 @@ +from django.conf import settings +from django.core.management.base import BaseCommand + +from hexa.files import storage + + +class Command(BaseCommand): + help = "Creates the datasets bucket if it does not exist" + + def handle(self, *args, **options): + if storage.bucket_exists(settings.WORKSPACE_DATASETS_BUCKET): + self.stdout.write( + self.style.SUCCESS( + f"Bucket '{settings.WORKSPACE_DATASETS_BUCKET}' already exists" + ) + ) + else: + storage.create_bucket(settings.WORKSPACE_DATASETS_BUCKET) + self.stdout.write( + self.style.SUCCESS( + f"Bucket '{settings.WORKSPACE_DATASETS_BUCKET}' created" + ) + ) diff --git a/hexa/datasets/migrations/0008_alter_datasetfilemetadata_options.py b/hexa/datasets/migrations/0008_alter_datasetfilemetadata_options.py new file mode 100644 index 000000000..10807191b --- /dev/null +++ b/hexa/datasets/migrations/0008_alter_datasetfilemetadata_options.py @@ -0,0 +1,16 @@ +# Generated by Django 5.0.8 on 2024-09-13 09:52 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("datasets", "0007_alter_dataset_slug_and_more"), + ] + + operations = [ + migrations.AlterModelOptions( + name="datasetfilemetadata", + options={"ordering": ["-created_at"]}, + ), + ] diff --git a/hexa/datasets/models.py b/hexa/datasets/models.py index 7bf6dcf99..ea02a894f 100644 --- a/hexa/datasets/models.py +++ b/hexa/datasets/models.py @@ -1,3 +1,4 @@ +import logging import secrets from django.contrib.auth.models import AnonymousUser @@ -11,6 +12,8 @@ from hexa.core.models.base import Base, BaseQuerySet from hexa.user_management.models import User +logger = logging.getLogger(__name__) + def create_dataset_slug(name: str, workspace): suffix = "" @@ -266,8 +269,30 @@ def filename(self): return self.uri.split("/")[-1] @property - def latest_metadata(self): - return self.metadata_entries.order_by("-created_at").first() + def sample_entry(self): + entry = self.metadata_entries.first() + if entry is None: + logger.info("No sample found for file %s, generating one", self.uri) + self.generate_sample() + return entry + + @property + def full_uri(self): + return self.dataset_version.get_full_uri(self.uri) + + def generate_sample(self): + from hexa.datasets.queue import dataset_file_metadata_queue, is_sample_supported + + if not is_sample_supported(self.filename): + logger.info("Sample generation not supported for file %s", self.uri) + return + logger.info("Generating sample for file %s", self.uri) + dataset_file_metadata_queue.enqueue( + "generate_file_metadata", + { + "file_id": str(self.id), + }, + ) class Meta: ordering = ["uri"] @@ -299,6 +324,9 @@ class DatasetFileMetadata(Base): related_name="metadata_entries", ) + class Meta: + ordering = ["-created_at"] + class DatasetLinkQuerySet(BaseQuerySet): def filter_for_user(self, user: AnonymousUser | User): diff --git a/hexa/datasets/queue.py b/hexa/datasets/queue.py index 6885f8799..9d66ee844 100644 --- a/hexa/datasets/queue.py +++ b/hexa/datasets/queue.py @@ -1,10 +1,7 @@ -import json from logging import getLogger import pandas as pd from django.conf import settings -from django.core.exceptions import ObjectDoesNotExist, ValidationError -from django.db import DatabaseError, IntegrityError from dpq.queue import AtLeastOnceQueue from hexa.core import mimetypes @@ -17,8 +14,10 @@ logger = getLogger(__name__) +SAMPLING_SEED = 22 -def is_supported_mimetype(filename: str) -> bool: + +def is_sample_supported(filename: str) -> bool: supported_mimetypes = [ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "application/vnd.ms-excel", @@ -27,17 +26,20 @@ def is_supported_mimetype(filename: str) -> bool: ] supported_extensions = ["parquet"] suffix = filename.split(".")[-1] - mime_type, encoding = mimetypes.guess_type(filename, strict=False) + mime_type, _ = mimetypes.guess_type(filename, strict=False) return mime_type in supported_mimetypes or suffix in supported_extensions -def download_file_as_dataframe( - dataset_version_file: DatasetVersionFile, -) -> pd.DataFrame | None: - mime_type, encoding = mimetypes.guess_type( - dataset_version_file.filename, strict=False - ) - download_url = generate_download_url(dataset_version_file) +def get_df(dataset_version_file: DatasetVersionFile) -> pd.DataFrame: + mime_type, _ = mimetypes.guess_type(dataset_version_file.filename, strict=False) + try: + download_url = generate_download_url( + dataset_version_file, host=settings.INTERNAL_BASE_URL + ) + except Exception as e: + logger.error(e) + raise e + if mime_type == "text/csv": return pd.read_csv(download_url) elif ( @@ -50,59 +52,38 @@ def download_file_as_dataframe( or dataset_version_file.filename.split(".")[-1] == "parquet" ): return pd.read_parquet(download_url) + else: + raise ValueError(f"Unsupported file format: {dataset_version_file.filename}") -def generate_dataset_file_sample_task( - queue: AtLeastOnceQueue, job: DatasetFileMetadataJob -): - dataset_version_file_id = job.args["file_id"] - try: - dataset_version_file = DatasetVersionFile.objects.get( - id=dataset_version_file_id - ) - except ObjectDoesNotExist as e: - logger.error( - f"DatasetVersionFile with id {dataset_version_file_id} does not exist: {e}" - ) - return - - if not is_supported_mimetype(dataset_version_file.filename): - logger.info(f"Unsupported file format: {dataset_version_file.filename}") - return +def generate_sample(version_file: DatasetVersionFile) -> DatasetFileMetadata: + if not is_sample_supported(version_file.filename): + raise ValueError(f"Unsupported file format: {version_file.filename}") - logger.info(f"Creating dataset sample for version file {dataset_version_file.id}") - try: - dataset_file_metadata = DatasetFileMetadata.objects.create( - dataset_version_file=dataset_version_file, - status=DatasetFileMetadata.STATUS_PROCESSING, - ) - except (IntegrityError, DatabaseError, ValidationError) as e: - logger.error(f"Error creating DatasetFileMetadata: {e}") - return + logger.info(f"Creating dataset sample for version file {version_file.id}") + dataset_file_metadata = DatasetFileMetadata.objects.create( + dataset_version_file=version_file, + status=DatasetFileMetadata.STATUS_PROCESSING, + ) try: - file_content = download_file_as_dataframe(dataset_version_file) - if not file_content.empty: - random_seed = 22 - file_sample = file_content.sample( + df = get_df(version_file) + if df.empty is False: + sample = df.sample( settings.WORKSPACE_DATASETS_FILE_SNAPSHOT_SIZE, - random_state=random_seed, + random_state=SAMPLING_SEED, replace=True, ) - dataset_file_metadata.sample = file_sample.to_json(orient="records") - else: - dataset_file_metadata.sample = json.dumps([]) - logger.info(f"Dataset sample saved for file {dataset_version_file_id}") + dataset_file_metadata.sample = sample.to_dict(orient="records") dataset_file_metadata.status = DatasetFileMetadata.STATUS_FINISHED - dataset_file_metadata.save() - logger.info(f"Dataset sample created for file {dataset_version_file_id}") + logger.info(f"Sample saved for file {version_file.id}") except Exception as e: + logger.exception(f"Sample creation failed for file {version_file.id}: {e}") dataset_file_metadata.status = DatasetFileMetadata.STATUS_FAILED dataset_file_metadata.status_reason = str(e) + finally: dataset_file_metadata.save() - logger.exception( - f"Dataset file sample creation failed for file {dataset_version_file_id}: {e}" - ) + return dataset_file_metadata class DatasetsFileMetadataQueue(AtLeastOnceQueue): @@ -111,16 +92,9 @@ class DatasetsFileMetadataQueue(AtLeastOnceQueue): dataset_file_metadata_queue = DatasetsFileMetadataQueue( tasks={ - "generate_file_metadata": generate_dataset_file_sample_task, + "generate_file_metadata": lambda _, job: generate_sample( + DatasetVersionFile.objects.get(id=job.args["file_id"]) + ), }, notify_channel="dataset_file_metadata_queue", ) - - -def load_file_metadata(file_id: str): - dataset_file_metadata_queue.enqueue( - "generate_file_metadata", - { - "file_id": str(file_id), - }, - ) diff --git a/hexa/datasets/schema/mutations.py b/hexa/datasets/schema/mutations.py index 3ed6c0bf4..dcb4bd5c0 100644 --- a/hexa/datasets/schema/mutations.py +++ b/hexa/datasets/schema/mutations.py @@ -3,12 +3,12 @@ from django.db import IntegrityError, transaction from hexa.analytics.api import track +from hexa.files import storage from hexa.pipelines.authentication import PipelineRunUser from hexa.workspaces.models import Workspace from ..api import generate_download_url, generate_upload_url, get_blob from ..models import Dataset, DatasetLink, DatasetVersion, DatasetVersionFile -from ..queue import load_file_metadata mutations = MutationType() @@ -249,7 +249,7 @@ def resolve_create_version_file(_, info, **kwargs): content_type=mutation_input["contentType"], ) - load_file_metadata(file_id=file.id) + file.generate_sample() return { "success": True, "errors": [], @@ -284,7 +284,7 @@ def resolve_version_file_download(_, info, **kwargs): return {"success": False, "errors": ["FILE_NOT_UPLOADED"]} return {"success": True, "errors": [], "download_url": download_url} - except DatasetVersionFile.DoesNotExist: + except (DatasetVersionFile.DoesNotExist, storage.exceptions.NotFound): return {"success": False, "errors": ["FILE_NOT_FOUND"]} except PermissionDenied: return {"success": False, "errors": ["PERMISSION_DENIED"]} diff --git a/hexa/datasets/schema/types.py b/hexa/datasets/schema/types.py index 491f42f33..2f59cceb5 100644 --- a/hexa/datasets/schema/types.py +++ b/hexa/datasets/schema/types.py @@ -13,7 +13,7 @@ DatasetVersion, DatasetVersionFile, ) -from hexa.files.basefs import BucketObjectAlreadyExists +from hexa.files import storage from hexa.workspaces.models import Workspace from hexa.workspaces.schema.types import workspace_object, workspace_permissions @@ -214,7 +214,7 @@ def resolve_upload_url(obj, info, **kwargs): file = obj["file"] upload_url = generate_upload_url(file.uri, file.content_type) return upload_url - except BucketObjectAlreadyExists as exc: + except storage.exceptions.AlreadyExists as exc: logging.error(f"Upload URL generation failed: {exc.message}") return None @@ -222,7 +222,7 @@ def resolve_upload_url(obj, info, **kwargs): @dataset_version_file_object.field("fileMetadata") def resolve_version_file_metadata(obj: DatasetVersionFile, info, **kwargs): try: - return obj.latest_metadata + return obj.sample_entry except DatasetFileMetadata.DoesNotExist: logging.error(f"No metadata found for file {obj.filename} with id {obj.id}") return None diff --git a/hexa/datasets/tests/test_metadata.py b/hexa/datasets/tests/test_metadata.py index 9a27ea2f4..3228762d3 100644 --- a/hexa/datasets/tests/test_metadata.py +++ b/hexa/datasets/tests/test_metadata.py @@ -1,169 +1,123 @@ import os -from unittest import mock +from unittest.mock import patch -from pandas.errors import ParserError +from django.test import override_settings from hexa.core.test import TestCase -from hexa.datasets.models import DatasetFileMetadata -from hexa.datasets.queue import generate_dataset_file_sample_task -from hexa.files.api import get_storage +from hexa.datasets.models import Dataset, DatasetFileMetadata, DatasetVersionFile +from hexa.datasets.queue import generate_sample +from hexa.files import storage +from hexa.user_management.models import User +from hexa.workspaces.models import Workspace class TestCreateDatasetFileMetadataTask(TestCase): - @mock.patch("hexa.datasets.queue.DatasetVersionFile.objects.get") - @mock.patch("hexa.datasets.queue.DatasetFileMetadata.objects.create") - @mock.patch("hexa.datasets.queue.generate_download_url") - def test_create_dataset_file_metadata_task_success( + @classmethod + def setUpTestData(cls): + storage.reset() + cls.USER_SERENA = User.objects.create_user( + "serena@bluesquarehub.com", "serena's password", is_superuser=True + ) + cls.WORKSPACE = Workspace.objects.create_if_has_perm( + cls.USER_SERENA, name="My Workspace", description="Test workspace" + ) + + cls.DATASET = Dataset.objects.create_if_has_perm( + cls.USER_SERENA, + cls.WORKSPACE, + name="Dataset", + description="Dataset's description", + ) + cls.DATASET_VERSION = cls.DATASET.create_version( + principal=cls.USER_SERENA, name="v1" + ) + + @override_settings(WORKSPACE_DATASETS_FILE_SNAPSHOT_SIZE=3) + def test_generate_sample( self, - mock_generate_download_url, - mock_DatasetFileMetadata_create, - mock_DatasetVersionFile_get, ): - test_cases = [ + CASES = [ + # It fails because the file is empty (no columns to parse) + ( + "example_empty_file.csv", + DatasetFileMetadata.STATUS_FAILED, + [], + "No columns to parse from file", + ), ( "example_names.csv", DatasetFileMetadata.STATUS_FINISHED, - '[{"name":"Jack","surname":"Howard"},{"name":"Olivia","surname":"Brown"},{"name":"Lily","surname":"Evan', + [ + {"name": "Jack", "surname": "Howard"}, + {"name": "Olivia", "surname": "Brown"}, + {"name": "Lily", "surname": "Evans"}, + ], + None, ), + # The CSV only contains 2 lines so it's going to add existing lines to achieve the desired sample size ( "example_names_2_lines.csv", DatasetFileMetadata.STATUS_FINISHED, - '[{"name":"Liam","surname":"Smith"},{"name":"Joe","surname":"Doe"},{"name":"Joe","surname":"Doe"},{"nam', + [ + {"name": "Liam", "surname": "Smith"}, + {"name": "Joe", "surname": "Doe"}, + {"name": "Joe", "surname": "Doe"}, + ], + None, ), ( "example_names_0_lines.csv", DatasetFileMetadata.STATUS_FINISHED, - "[]", + [], + None, ), ( "example_names.parquet", DatasetFileMetadata.STATUS_FINISHED, - '[{"name":"Jack","surname":"Howard"},{"name":"Olivia","surname":"Brown"},{"name":"Lily","surname":"Evan', + [ + {"name": "Jack", "surname": "Howard"}, + {"name": "Olivia", "surname": "Brown"}, + {"name": "Lily", "surname": "Evans"}, + ], + None, ), ( "example_names.xlsx", DatasetFileMetadata.STATUS_FINISHED, - '[{"name":"Jack","surname":"Howard"},{"name":"Olivia","surname":"Brown"},{"name":"Lily","surname":"Evan', + [ + {"name": "Jack", "surname": "Howard"}, + {"name": "Olivia", "surname": "Brown"}, + {"name": "Lily", "surname": "Evans"}, + ], + None, ), ] - for filename, expected_status, expected_content in test_cases: - with self.subTest(filename=filename): - dataset_version_file = mock.Mock() - dataset_version_file.id = 1 - dataset_version_file.filename = f"{filename}" - mock_DatasetVersionFile_get.return_value = dataset_version_file - - dataset_file_metadata = mock.Mock() - mock_DatasetFileMetadata_create.return_value = dataset_file_metadata - + for ( + fixture_name, + expected_status, + expected_sample, + expected_status_reason, + ) in CASES: + with self.subTest(fixture_name=fixture_name): fixture_file_path = os.path.join( - os.path.dirname(__file__), f"./fixtures/{filename}" + os.path.dirname(__file__), f"./fixtures/{fixture_name}" ) - mock_generate_download_url.return_value = fixture_file_path - - job = mock.Mock() - job.args = {"file_id": dataset_version_file.id} - - generate_dataset_file_sample_task(mock.Mock(), job) - - mock_generate_download_url.assert_called_once_with(dataset_version_file) - mock_DatasetVersionFile_get.assert_called_once_with( - id=dataset_version_file.id - ) - mock_DatasetFileMetadata_create.assert_called_once_with( - dataset_version_file=dataset_version_file, - status=DatasetFileMetadata.STATUS_PROCESSING, - ) - dataset_file_metadata.save.assert_called() - self.assertEqual(dataset_file_metadata.status, expected_status) - self.assertEqual( - dataset_file_metadata.sample[0 : len(expected_content)], - expected_content, + version_file = DatasetVersionFile.objects.create_if_has_perm( + self.USER_SERENA, + self.DATASET_VERSION, + uri=fixture_file_path, + content_type="application/octect-stream", ) - mock_generate_download_url.reset_mock() - mock_DatasetVersionFile_get.reset_mock() - mock_DatasetFileMetadata_create.reset_mock() - dataset_file_metadata.save.reset_mock() - - @mock.patch("hexa.datasets.queue.DatasetVersionFile.objects.get") - @mock.patch("hexa.datasets.queue.DatasetFileMetadata.objects.create") - @mock.patch("hexa.datasets.queue.generate_download_url") - def test_create_dataset_file_metadata_task_failure( - self, - mock_generate_download_url, - mock_DatasetFileMetadata_create, - mock_DatasetVersionFile_get, - ): - test_cases = [ - (get_storage().exceptions.NotFound, DatasetFileMetadata.STATUS_FAILED), - (ValueError, DatasetFileMetadata.STATUS_FAILED), - (ParserError, DatasetFileMetadata.STATUS_FAILED), - ] - for exception, expected_status in test_cases: - with self.subTest(exception=exception): - dataset_version_file = mock.Mock() - dataset_version_file.id = 1 - dataset_version_file.filename = "example_names.csv" - mock_DatasetVersionFile_get.return_value = dataset_version_file - - dataset_file_metadata = mock.Mock() - mock_DatasetFileMetadata_create.return_value = dataset_file_metadata - - mock_generate_download_url.side_effect = exception - - job = mock.Mock() - job.args = {"file_id": dataset_version_file.id} - generate_dataset_file_sample_task(mock.Mock(), job) - - mock_DatasetVersionFile_get.assert_called_with( - id=dataset_version_file.id - ) - dataset_file_metadata.save.assert_called() - self.assertEqual(dataset_file_metadata.status, expected_status) - - mock_generate_download_url.reset_mock() - mock_DatasetVersionFile_get.reset_mock() - mock_DatasetFileMetadata_create.reset_mock() - dataset_file_metadata.save.reset_mock() - - @mock.patch("hexa.datasets.queue.DatasetVersionFile.objects.get") - @mock.patch("hexa.datasets.queue.DatasetFileMetadata.objects.create") - @mock.patch("hexa.datasets.queue.generate_download_url") - def test_create_dataset_file_metadata_task_failure_empty_file( - self, - mock_generate_download_url, - mock_DatasetFileMetadata_create, - mock_DatasetVersionFile_get, - ): - dataset_version_file = mock.Mock() - dataset_version_file.id = 1 - dataset_version_file.filename = "example_empty_file.csv" - mock_DatasetVersionFile_get.return_value = dataset_version_file - - dataset_file_metadata = mock.Mock() - mock_DatasetFileMetadata_create.return_value = dataset_file_metadata - - fixture_file_path = os.path.join( - os.path.dirname(__file__), "./fixtures/example_empty_file.csv" - ) - mock_generate_download_url.return_value = fixture_file_path - - job = mock.Mock() - job.args = {"file_id": dataset_version_file.id} - - generate_dataset_file_sample_task(mock.Mock(), job) - - mock_generate_download_url.assert_called_once_with(dataset_version_file) - mock_DatasetVersionFile_get.assert_called_once_with(id=dataset_version_file.id) - mock_DatasetFileMetadata_create.assert_called_once_with( - dataset_version_file=dataset_version_file, - status=DatasetFileMetadata.STATUS_PROCESSING, - ) - dataset_file_metadata.save.assert_called() - self.assertEqual( - dataset_file_metadata.status, DatasetFileMetadata.STATUS_FAILED - ) - self.assertEqual( - dataset_file_metadata.status_reason, "No columns to parse from file" - ) + with patch( + "hexa.datasets.queue.generate_download_url" + ) as mock_generate_download_url: + mock_generate_download_url.return_value = fixture_file_path + sample_entry = generate_sample(version_file) + self.assertEqual(sample_entry.status, expected_status) + self.assertEqual(sample_entry.sample, expected_sample) + + if expected_status_reason: + self.assertEqual( + sample_entry.status_reason, expected_status_reason + ) diff --git a/hexa/datasets/tests/test_models.py b/hexa/datasets/tests/test_models.py index a7820792d..4571988de 100644 --- a/hexa/datasets/tests/test_models.py +++ b/hexa/datasets/tests/test_models.py @@ -7,8 +7,7 @@ from hexa.core.test import TestCase from hexa.datasets.models import Dataset, DatasetVersion, DatasetVersionFile -from hexa.files.api import get_storage -from hexa.files.tests.mocks.mockgcp import backend +from hexa.files import storage from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import ( Workspace, @@ -23,9 +22,8 @@ class BaseTestMixin: WORKSPACE = None @classmethod - @backend.mock_storage def setUpTestData(cls): - backend.reset() + storage.reset() cls.USER_SERENA = User.objects.create_user( "serena@bluesquarehub.com", "serena's password", @@ -180,7 +178,6 @@ class DatasetVersionTest(BaseTestMixin, TestCase): DATASET = None @classmethod - @backend.mock_storage def setUpTestData(cls): BaseTestMixin.setUpTestData() cls.DATASET = Dataset.objects.create_if_has_perm( @@ -190,9 +187,8 @@ def setUpTestData(cls): description="Description of dataset", ) - get_storage().create_bucket(settings.WORKSPACE_DATASETS_BUCKET) + storage.create_bucket(settings.WORKSPACE_DATASETS_BUCKET) - @backend.mock_storage def test_create_dataset_version( self, name="Dataset's version", description="Version's description" ): @@ -246,10 +242,9 @@ def test_get_file_by_name(self): @override_settings(WORKSPACE_DATASETS_BUCKET="hexa-datasets") class DatasetLinkTest(BaseTestMixin, TestCase): @classmethod - @backend.mock_storage def setUpTestData(cls): BaseTestMixin.setUpTestData() - get_storage().create_bucket(settings.WORKSPACE_DATASETS_BUCKET) + storage.create_bucket(settings.WORKSPACE_DATASETS_BUCKET) cls.DATASET = Dataset.objects.create_if_has_perm( cls.USER_ADMIN, diff --git a/hexa/datasets/tests/test_schema.py b/hexa/datasets/tests/test_schema.py index c03ef401d..161f0ea96 100644 --- a/hexa/datasets/tests/test_schema.py +++ b/hexa/datasets/tests/test_schema.py @@ -1,11 +1,11 @@ import json +from io import BytesIO from django.conf import settings from django.db import IntegrityError from hexa.core.test import GraphQLTestCase -from hexa.files.api import get_storage -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage +from hexa.files import storage from hexa.user_management.models import User from hexa.workspaces.models import WorkspaceMembershipRole @@ -326,9 +326,9 @@ def test_workspace_datasets(self): class DatasetVersionTest(GraphQLTestCase, DatasetTestMixin): @classmethod - @mock_gcp_storage def setUpTestData(cls): - get_storage().create_bucket(settings.WORKSPACE_DATASETS_BUCKET) + storage.reset() + storage.create_bucket(settings.WORKSPACE_DATASETS_BUCKET) def test_create_dataset_version(self): superuser = self.create_user("superuser@blsq.com", is_superuser=True) @@ -379,7 +379,6 @@ def test_create_duplicate(self): with self.assertRaises(IntegrityError): dataset.create_version(principal=superuser, name="Version 1") - @mock_gcp_storage def test_generate_upload_url(self): superuser = self.create_user("superuser@blsq.com", is_superuser=True) workspace = self.create_workspace( @@ -413,7 +412,7 @@ def test_generate_upload_url(self): self.assertEqual( r["data"]["generateDatasetUploadUrl"], { - "uploadUrl": f"http://signed-url/{str(dataset.id)}/{str(dataset_version.id)}/uri_file.csv", + "uploadUrl": f"http://mockstorage.com/{settings.WORKSPACE_DATASETS_BUCKET}/{str(dataset.id)}/{str(dataset_version.id)}/uri_file.csv/upload", "success": True, "errors": [], }, @@ -505,7 +504,6 @@ def test_get_file_metadata(self): r["data"], ) - @mock_gcp_storage def test_prepare_version_file_download(self): serena = self.create_user("sereba@blsq.org", is_superuser=True) workspace = self.create_workspace( @@ -535,6 +533,11 @@ def test_prepare_version_file_download(self): version_file = self.create_dataset_version_file( olivia, dataset_version=dataset_version ) + storage.save_object( + settings.WORKSPACE_DATASETS_BUCKET, + version_file.uri, + BytesIO(b"some content"), + ) for user in [robert, olivia]: self.client.force_login(user) @@ -554,17 +557,15 @@ def test_prepare_version_file_download(self): } }, ) - self.assertEqual( { "success": True, "errors": [], - "downloadUrl": "http://signed-url/some-uri.csv", + "downloadUrl": f"http://mockstorage.com/{settings.WORKSPACE_DATASETS_BUCKET}/{version_file.uri}", }, r["data"]["prepareVersionFileDownload"], ) - @mock_gcp_storage def test_prepare_version_file_download_linked_dataset(self): serena = self.create_user("sereba@blsq.org", is_superuser=True) src_workspace = self.create_workspace( @@ -580,6 +581,11 @@ def test_prepare_version_file_download_linked_dataset(self): version_file = self.create_dataset_version_file( serena, dataset_version=dataset_version ) + storage.save_object( + settings.WORKSPACE_DATASETS_BUCKET, + version_file.uri, + BytesIO(b"some content"), + ) tgt_workspace = self.create_workspace( serena, "Target Workspace", "Test workspace" @@ -639,7 +645,7 @@ def test_prepare_version_file_download_linked_dataset(self): { "success": True, "errors": [], - "downloadUrl": "http://signed-url/some-uri.csv", + "downloadUrl": f"http://mockstorage.com/{settings.WORKSPACE_DATASETS_BUCKET}/{version_file.uri}", }, r["data"]["prepareVersionFileDownload"], ) diff --git a/hexa/datasets/tests/testutils.py b/hexa/datasets/tests/testutils.py index cb50dc21d..83338bca4 100644 --- a/hexa/datasets/tests/testutils.py +++ b/hexa/datasets/tests/testutils.py @@ -1,4 +1,3 @@ -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import ( Workspace, @@ -22,7 +21,6 @@ def create_feature_flag(*, code: str, user: User): feature, _ = Feature.objects.get_or_create(code=code) FeatureFlag.objects.create(feature=feature, user=user) - @mock_gcp_storage def create_workspace(self, principal: User, name, description, *args, **kwargs): workspace = Workspace.objects.create_if_has_perm( principal=principal, name=name, description=description, *args, **kwargs diff --git a/hexa/files/__init__.py b/hexa/files/__init__.py index e69de29bb..ce363e35b 100644 --- a/hexa/files/__init__.py +++ b/hexa/files/__init__.py @@ -0,0 +1,6 @@ +from django.utils.functional import SimpleLazyObject + +from .backends import get_storage_backend +from .backends.base import Storage + +storage: Storage = SimpleLazyObject(get_storage_backend) diff --git a/hexa/files/api.py b/hexa/files/api.py deleted file mode 100644 index d77ef752f..000000000 --- a/hexa/files/api.py +++ /dev/null @@ -1,16 +0,0 @@ -from django.conf import settings - -from .basefs import NotFound -from .gcp import GCPClient -from .s3 import S3Client - - -def get_storage(mode=settings.WORKSPACE_STORAGE_ENGINE): - if mode == "gcp": - return GCPClient() - if mode == "s3": - return S3Client() - raise Exception(f"unsupported filesystem {mode}") - - -NotFound = NotFound diff --git a/hexa/files/apps.py b/hexa/files/apps.py index 0f6862c89..94c9af6a3 100644 --- a/hexa/files/apps.py +++ b/hexa/files/apps.py @@ -1,6 +1,11 @@ -from django.apps import AppConfig +from hexa.app import CoreAppConfig -class FilesConfig(AppConfig): +class FilesConfig(CoreAppConfig): default_auto_field = "django.db.models.BigAutoField" name = "hexa.files" + + ANONYMOUS_URLS = [ + "files:upload_file", + "files:download_file", + ] diff --git a/hexa/files/backends/__init__.py b/hexa/files/backends/__init__.py new file mode 100644 index 000000000..e4dd91872 --- /dev/null +++ b/hexa/files/backends/__init__.py @@ -0,0 +1,12 @@ +from django.conf import settings +from django.utils.module_loading import import_string + + +def get_storage_backend(): + try: + backend_class = import_string(settings.WORKSPACE_STORAGE_BACKEND["engine"]) + return backend_class(**settings.WORKSPACE_STORAGE_BACKEND.get("options", {})) + except ImportError as e: + raise ImportError( + f"Could not import storage backend '{settings.WORKSPACE_STORAGE_BACKEND['engine']}'." + ) from e diff --git a/hexa/files/basefs.py b/hexa/files/backends/base.py similarity index 59% rename from hexa/files/basefs.py rename to hexa/files/backends/base.py index b35e8134d..fc9882220 100644 --- a/hexa/files/basefs.py +++ b/hexa/files/backends/base.py @@ -1,20 +1,15 @@ +import io import os import typing from abc import ABC, abstractmethod from dataclasses import dataclass from os.path import dirname, isfile, join - -class NotFound(Exception): - pass +from .exceptions import AlreadyExists, NotFound, SuspiciousFileOperation -class BucketObjectAlreadyExists(Exception): - def __init__(self, target_key): - self.message = ( - f"File already exists. Choose a different object key {target_key}." - ) - super().__init__(self.message) +class BadRequest(Exception): + pass @dataclass @@ -29,29 +24,59 @@ def load_bucket_sample_data_with(bucket_name: str, client_storage): """ Init bucket with default content """ - static_files_dir = join(dirname(__file__), "static") + static_files_dir = join(dirname(__file__), "..", "static") files = [ f for f in os.listdir(static_files_dir) if isfile(join(static_files_dir, f)) ] for file in files: - client_storage.upload_object(bucket_name, file, join(static_files_dir, file)) + client_storage.save_object( + bucket_name, file, open(join(static_files_dir, file), "rb") + ) -class BaseClient(ABC): +@dataclass +class StorageObject: + name: str + key: str + path: str + updated: str + type: str + size: int = 0 + content_type: str = None + + +class Storage(ABC): + storage_type = None + class exceptions: + BadRequest = BadRequest NotFound = NotFound + AlreadyExists = AlreadyExists + SuspiciousFileOperation = SuspiciousFileOperation + + @abstractmethod + def __init__(self, *args, **kwargs): + pass + + @abstractmethod + def bucket_exists(self, bucket_name: str): + pass @abstractmethod def create_bucket(self, bucket_name: str, *args, **kwargs): pass @abstractmethod - def delete_bucket(self, bucket_name: str, fully: bool = False): + def delete_object(self, bucket_name: str, object_key: str): + pass + + @abstractmethod + def delete_bucket(self, bucket_name: str, force: bool = False): pass @abstractmethod - def upload_object(self, bucket_name: str, file_name: str, source: str): + def save_object(self, bucket_name: str, file_path: str, file: io.BufferedReader): pass @abstractmethod @@ -60,7 +85,7 @@ def create_bucket_folder(self, bucket_name: str, folder_key: str): @abstractmethod def generate_download_url( - self, bucket_name: str, target_key: str, force_attachment=False + self, bucket_name: str, target_key: str, force_attachment=False, *args, **kwargs ): pass @@ -80,10 +105,6 @@ def list_bucket_objects( ): pass - @abstractmethod - def get_short_lived_downscoped_access_token(self, bucket_name): - pass - @abstractmethod def generate_upload_url( self, @@ -91,9 +112,11 @@ def generate_upload_url( target_key: str, content_type: str, raise_if_exists=False, + *args, + **kwargs, ): pass @abstractmethod - def get_token_as_env_variables(self, token): + def get_bucket_mount_config(self, bucket_name): pass diff --git a/hexa/files/backends/dummy.py b/hexa/files/backends/dummy.py new file mode 100644 index 000000000..3bac61cde --- /dev/null +++ b/hexa/files/backends/dummy.py @@ -0,0 +1,170 @@ +import io + +from .base import ObjectsPage, Storage, StorageObject + +dummy_buckets = {} + + +class DummyStorageClient(Storage): + storage_type = "dummy" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def buckets(self): + return dummy_buckets + + def reset(self): + dummy_buckets.clear() + + def bucket_exists(self, bucket_name: str): + # Mock checking if the bucket exists + return bucket_name in dummy_buckets + + def create_bucket(self, bucket_name: str, *args, **kwargs): + # Mock bucket creation + if bucket_name in dummy_buckets: + raise self.exceptions.AlreadyExists( + f"Bucket '{bucket_name}' already exists." + ) + dummy_buckets[bucket_name] = {} + return bucket_name + + def delete_object(self, bucket_name: str, object_key: str): + # Mock object deletion + if ( + bucket_name not in dummy_buckets + or object_key not in dummy_buckets[bucket_name] + ): + raise self.exceptions.NotFound( + f"Object '{object_key}' not found in bucket '{bucket_name}'." + ) + del dummy_buckets[bucket_name][object_key] + + def delete_bucket(self, bucket_name: str, force: bool = False): + # Mock bucket deletion + if bucket_name not in dummy_buckets: + raise self.exceptions.NotFound(f"Bucket '{bucket_name}' not found.") + if force is False and dummy_buckets[bucket_name]: + raise self.exceptions.BadRequest(f"Bucket '{bucket_name}' is not empty.") + del dummy_buckets[bucket_name] + + def save_object(self, bucket_name: str, file_path: str, file: io.BufferedReader): + # Mock saving an object in a bucket + if bucket_name not in dummy_buckets: + raise self.exceptions.NotFound(f"Bucket '{bucket_name}' not found.") + dummy_buckets[bucket_name][file_path] = file.read() + + def create_bucket_folder(self, bucket_name: str, folder_key: str): + # Mock creating a folder in a bucket + if bucket_name not in dummy_buckets: + raise self.exceptions.NotFound(f"Bucket '{bucket_name}' not found.") + if folder_key in dummy_buckets[bucket_name]: + raise self.exceptions.AlreadyExists( + f"Folder '{folder_key}' already exists." + ) + dummy_buckets[bucket_name][folder_key] = {} + + def generate_download_url( + self, bucket_name: str, target_key: str, force_attachment=False, *args, **kwargs + ): + # Mock generating a download URL + if ( + bucket_name not in dummy_buckets + or target_key not in dummy_buckets[bucket_name] + ): + raise self.exceptions.NotFound( + f"Object '{target_key}' not found in bucket '{bucket_name}'." + ) + return f"http://mockstorage.com/{bucket_name}/{target_key}" + + def _to_storage_object(self, bucket_name: str, object_key: str): + if not self.bucket_exists(bucket_name): + raise self.exceptions.NotFound(f"Bucket '{bucket_name}' not found.") + if object_key not in dummy_buckets[bucket_name]: + raise self.exceptions.NotFound( + f"Object '{object_key}' not found in bucket '{bucket_name}'." + ) + obj = dummy_buckets[bucket_name][object_key] + if obj == {}: + return StorageObject( + name=object_key.split("/")[-1], + key=object_key, + path=f"{bucket_name}/{object_key}", + updated="", + type="directory", + ) + else: + return StorageObject( + name=object_key.split("/")[-1], + key=object_key, + path=f"{bucket_name}/{object_key}", + updated="", + type="file", + size=len(obj), + ) + + def get_bucket_object(self, bucket_name: str, object_key: str): + # Mock retrieving an object from a bucket + if ( + bucket_name not in dummy_buckets + or object_key not in dummy_buckets[bucket_name] + ): + raise self.exceptions.NotFound( + f"Object '{object_key}' not found in bucket '{bucket_name}'." + ) + return self.to_storage_object(bucket_name, object_key) + + def list_bucket_objects( + self, + bucket_name, + prefix=None, + page: int = 1, + per_page=30, + query=None, + ignore_hidden_files=True, + ): + # Mock listing objects in a bucket + if bucket_name not in dummy_buckets: + raise self.exceptions.NotFound(f"Bucket '{bucket_name}' not found.") + object_keys = [] + for key in dummy_buckets[bucket_name].keys(): + if key.startswith(prefix or "") and (not query or query in key): + object_keys.append(key) + + start = (page - 1) * per_page + end = start + per_page + objects = [ + self._to_storage_object(bucket_name, key) for key in object_keys[start:end] + ] + return ObjectsPage( + items=objects, + page_number=page, + has_previous_page=page > 1, + has_next_page=end < len(objects), + ) + + def generate_upload_url( + self, + bucket_name: str, + target_key: str, + content_type: str, + raise_if_exists=False, + *args, + **kwargs, + ): + # Mock generating an upload URL + if bucket_name not in dummy_buckets: + raise self.exceptions.NotFound(f"Bucket '{bucket_name}' not found.") + if raise_if_exists and target_key in dummy_buckets[bucket_name]: + raise self.exceptions.AlreadyExists( + f"Object '{target_key}' already exists." + ) + return f"http://mockstorage.com/{bucket_name}/{target_key}/upload" + + def get_bucket_mount_config(self, bucket_name): + # Mock retrieving bucket mount config + if bucket_name not in dummy_buckets: + raise self.exceptions.NotFound(f"Bucket '{bucket_name}' not found.") + return {} diff --git a/hexa/files/backends/exceptions.py b/hexa/files/backends/exceptions.py new file mode 100644 index 000000000..a07993608 --- /dev/null +++ b/hexa/files/backends/exceptions.py @@ -0,0 +1,16 @@ +from django.core.exceptions import SuspiciousFileOperation + + +class StorageException(Exception): + pass + + +class NotFound(StorageException): + pass + + +class AlreadyExists(StorageException): + pass + + +__all__ = ["StorageException", "NotFound", "AlreadyExists", "SuspiciousFileOperation"] diff --git a/hexa/files/backends/fs.py b/hexa/files/backends/fs.py new file mode 100644 index 000000000..ab09d7f73 --- /dev/null +++ b/hexa/files/backends/fs.py @@ -0,0 +1,330 @@ +import io +import os +import shutil +from datetime import datetime +from mimetypes import guess_type +from pathlib import Path + +from django.conf import settings +from django.core.files import locks +from django.core.signing import BadSignature, TimestampSigner +from django.urls import reverse +from django.utils._os import safe_join as django_safe_join +from django.utils.encoding import force_bytes, force_str +from django.utils.http import urlsafe_base64_decode, urlsafe_base64_encode +from django.utils.text import get_valid_filename + +from .base import ObjectsPage, Storage, StorageObject, load_bucket_sample_data_with + + +def safe_join(base, *paths): + """ + A version of django.utils._os.safe_join that returns a Path object. + """ + return Path(django_safe_join(base, *paths)) + + +class FileSystemStorage(Storage): + storage_type = "local" + + def __init__( + self, + data_dir: str, + ext_bind_path: str = None, + file_permissions_mode: int | None = None, + directory_permissions_mode: int | None = None, + ): + """Initialises the FileSystemStorage backend. + + Args: + data_dir (str): Directory where the data will be stored. + ext_bing_path (str, optional): When running in a docker container, it represents the path to the data_dir from context of the docker engine. Defaults to None. + file_permissions_mode (int, optional): File permissions mode. Defaults to None. + directory_permissions_mode (int, optional): Directory permissions mode. Defaults to None. + """ + self.data_dir = Path(data_dir) + self.ext_bind_path = Path(ext_bind_path) if ext_bind_path is not None else None + self.file_permissions_mode = file_permissions_mode + self.directory_permissions_mode = directory_permissions_mode + self._token_max_age = 60 * 60 # 1 hour + + def _ensure_location_group_id(self, full_path): + if os.name == "posix": + file_gid = os.stat(full_path).st_gid + data_dir_gid = os.stat(self.data_dir).st_gid + if file_gid != data_dir_gid: + try: + os.chown(full_path, uid=-1, gid=data_dir_gid) + except PermissionError: + pass + + def load_bucket_sample_data(self, bucket_name: str): + load_bucket_sample_data_with(bucket_name, self) + + def bucket_exists(self, bucket_name: str): + return self.exists(bucket_name) + + def is_directory_empty(self, bucket_name: str, *paths): + return not bool(next(os.scandir(self.path(bucket_name, *paths)), None)) + + def exists(self, name): + try: + exists = os.path.lexists(self.path(name)) + return exists + except self.exceptions.SuspiciousFileOperation: + raise + + def path(self, *paths): + return safe_join(self.data_dir, *paths) + + def size(self, name): + return os.path.getsize(name) + + def _get_payload_from_token(self, token): + try: + signer = TimestampSigner() + decoded_token = force_str(urlsafe_base64_decode(token)) + payload = signer.unsign_object(decoded_token, max_age=self._token_max_age) + return payload + except (UnicodeDecodeError, BadSignature, ValueError): + raise self.exceptions.BadRequest("Invalid token") + + def _create_token_for_payload(self, payload: dict): + signer = TimestampSigner() + signed_payload = signer.sign_object(payload, compress=True) + return urlsafe_base64_encode(force_bytes(signed_payload)) + + def to_storage_object(self, bucket_name: str, object_key: Path): + full_path = self.path(bucket_name, object_key) + if not self.exists(full_path): + raise self.exceptions.NotFound(f"Object {object_key} not found") + if full_path.is_file(): + return StorageObject( + name=object_key.name, + key=object_key, + updated=datetime.fromtimestamp( + os.path.getmtime(str(full_path)) + ).isoformat(), + size=self.size(full_path), + path=Path(bucket_name) / object_key, + type="file", + content_type=guess_type(full_path)[0] or "application/octet-stream", + ) + else: + return StorageObject( + name=object_key.name, + key=object_key, + updated=datetime.fromtimestamp(os.path.getmtime(str(full_path))), + path=Path(bucket_name) / object_key, + type="directory", + ) + + def get_valid_filepath(self, path: str | Path): + """Returns a path where all the directories and the filename are valid. + + Args: + path (str|Path): A path + """ + return "/".join(get_valid_filename(part) for part in str(path).split("/")) + + def create_directory(self, directory_path: str): + if self.exists(directory_path): + raise self.exceptions.AlreadyExists("Directory already exists") + directory = self.path(directory_path) + + try: + if self.directory_permissions_mode is not None: + # Set the umask because os.makedirs() doesn't apply the "mode" + # argument to intermediate-level directories. + old_umask = os.umask(0o777 & ~self.directory_permissions_mode) + try: + os.makedirs( + directory, self.directory_permissions_mode, exist_ok=True + ) + finally: + os.umask(old_umask) + else: + os.makedirs(directory, exist_ok=True) + except FileExistsError: + raise FileExistsError("%s exists and is not a directory." % directory) + + def create_bucket(self, bucket_name: str, *args, **kwargs): + if "/" in bucket_name: + raise self.exceptions.SuspiciousFileOperation( + "Bucket name cannot contain '/'" + ) + valid_bucket_name = get_valid_filename(bucket_name) + self.create_directory(valid_bucket_name) + return valid_bucket_name + + def delete_bucket(self, bucket_name: str, force: bool = False): + if not self.exists(bucket_name): + raise self.exceptions.NotFound(f"Bucket {bucket_name} not found") + if not force and not self.is_directory_empty(bucket_name): + raise self.exceptions.BadRequest("Bucket is not empty") + return shutil.rmtree(self.path(bucket_name)) + + def get_bucket_object_by_token(self, token: str): + payload = self._get_payload_from_token(token) + return self.get_bucket_object(payload["bucket_name"], payload["file_path"]) + + def save_object_by_token(self, token: str, file: io.BufferedReader): + payload = self._get_payload_from_token(token) + return self.save_object(payload["bucket_name"], payload["file_path"], file) + + def save_object( + self, bucket_name: str, file_path: str, file: io.BufferedReader | bytes + ): + full_path = self.path(bucket_name, file_path) + + # Create any intermediate directories that do not exist. + if not self.exists(full_path.parent): + self.create_directory(full_path.parent) + + f = open(full_path, "wb") + locks.lock(f, locks.LOCK_EX) + try: + if isinstance(file, bytes): + f.write(file) + else: + f.write(file.read()) + finally: + locks.unlock(f) + f.close() + + if self.file_permissions_mode is not None: + os.chmod(full_path, self.file_permissions_mode) + + # Ensure the moved file has the same gid as the storage root. + self._ensure_location_group_id(full_path) + + def create_bucket_folder(self, bucket_name: str, folder_key: str): + if not self.exists(bucket_name): + raise self.exceptions.NotFound(f"Bucket {bucket_name} not found") + folder_key = self.get_valid_filepath(folder_key) + self.create_directory(f"{bucket_name}/{folder_key}") + return self.get_bucket_object(bucket_name, folder_key) + + def get_bucket_object(self, bucket_name: str, object_key: str): + if not self.exists(bucket_name): + raise self.exceptions.NotFound(f"Bucket {bucket_name} not found") + full_path = self.path(bucket_name, object_key) + if not self.exists(full_path): + raise self.exceptions.NotFound(f"Object {object_key} not found") + return self.to_storage_object(bucket_name, Path(object_key)) + + def list_bucket_objects( + self, + bucket_name, + prefix="", + page: int = 1, + per_page=30, + query: str = None, + ignore_hidden_files=True, + ): + if prefix is None: + prefix = "" + full_path = self.path(bucket_name, prefix) + if not os.path.exists(full_path): + raise self.exceptions.NotFound(f"Bucket {bucket_name} not found") + if not os.path.isdir(full_path): + raise self.exceptions.NotFound(f"Bucket {bucket_name} is not a directory") + + def does_object_match(name): + if ignore_hidden_files and name.startswith("."): + return False + if query: + return query.lower() in name.lower() + return True + + objects = [] + root, dirs, files = next(os.walk(full_path)) + for dir in dirs: + if does_object_match(dir) is False: + continue + dir_key = Path(prefix) / dir + objects.append(self.to_storage_object(bucket_name, dir_key)) + for file in files: + if does_object_match(file) is False: + continue + object_key = Path(prefix) / file + objects.append(self.to_storage_object(bucket_name, object_key)) + + offset = (page - 1) * per_page + return ObjectsPage( + items=objects[offset : offset + per_page], + page_number=page, + has_previous_page=page > 1, + has_next_page=len(objects) > page * per_page, + ) + + def delete_object(self, bucket_name: str, object_key: str): + full_path = self.path(bucket_name, object_key) + if not self.exists(full_path): + raise self.exceptions.NotFound(f"Object {object_key} not found") + obj = self.get_bucket_object(bucket_name, object_key) + if obj.type == "directory": + shutil.rmtree(full_path) + else: + os.remove(full_path) + + def generate_upload_url( + self, + bucket_name: str, + target_key: str, + raise_if_exists=False, + host: str | None = None, + *args, + **kwargs, + ): + if not self.exists(bucket_name): + raise self.exceptions.NotFound(f"Bucket {bucket_name} not found") + full_path = self.path(bucket_name, target_key) + if self.exists(full_path) and raise_if_exists: + raise self.exceptions.AlreadyExists(f"Object {target_key} already exist") + + token = self._create_token_for_payload( + {"bucket_name": bucket_name, "file_path": target_key} + ) + internal_url = reverse("files:upload_file", args=(token,)) + if host is None: + host = settings.NEW_FRONTEND_DOMAIN + + return f"{host}{internal_url}" + + def generate_download_url( + self, + bucket_name: str, + target_key: str, + force_attachment=False, + host: str | None = None, + *args, + **kwargs, + ): + if not self.exists(bucket_name): + raise self.exceptions.NotFound(f"Bucket {bucket_name} not found") + full_path = self.path(bucket_name, target_key) + if not self.exists(full_path): + raise self.exceptions.NotFound(f"Object {target_key} not found") + + token = self._create_token_for_payload( + {"bucket_name": bucket_name, "file_path": target_key} + ) + endpoint = reverse("files:download_file", args=(token,)) + if force_attachment: + endpoint += "?attachment=true" + + if host is None: + host = settings.BASE_URL + + return f"{host}{endpoint}" + + def get_bucket_mount_config(self, bucket_name): + return { + "WORKSPACE_STORAGE_MOUNT_PATH": str( + safe_join( + self.ext_bind_path if self.ext_bind_path else self.data_dir, + bucket_name, + ) + ), + } diff --git a/hexa/files/gcp.py b/hexa/files/backends/gcp.py similarity index 74% rename from hexa/files/gcp.py rename to hexa/files/backends/gcp.py index a3199db6c..72bb3bc8d 100644 --- a/hexa/files/gcp.py +++ b/hexa/files/backends/gcp.py @@ -1,33 +1,28 @@ import base64 +import io import json import requests from django.conf import settings from django.core.exceptions import ValidationError from google.cloud import storage -from google.cloud.exceptions import Conflict +from google.cloud.exceptions import Conflict, NotFound from google.cloud.iam_credentials_v1 import IAMCredentialsClient from google.cloud.storage.blob import Blob from google.oauth2 import service_account from google.protobuf import duration_pb2 -from .basefs import ( - BaseClient, - BucketObjectAlreadyExists, - NotFound, - ObjectsPage, - load_bucket_sample_data_with, -) +from .base import ObjectsPage, Storage, StorageObject, load_bucket_sample_data_with -def get_credentials(): - decoded_creds = base64.b64decode(settings.GCS_SERVICE_ACCOUNT_KEY) +def get_credentials(service_account_key): + decoded_creds = base64.b64decode(service_account_key) json_creds = json.loads(decoded_creds, strict=False) return service_account.Credentials.from_service_account_info(json_creds) -def get_storage_client(): - credentials = get_credentials() +def get_storage_client(service_account_key): + credentials = get_credentials(service_account_key) return storage.Client(credentials=credentials) @@ -35,26 +30,25 @@ def _is_dir(blob): return blob.size == 0 and blob.name.endswith("/") -def _blob_to_dict(blob: Blob): - return { - "name": blob.name.split("/")[-2] if _is_dir(blob) else blob.name.split("/")[-1], - "key": blob.name, - "path": "/".join([blob.bucket.name, blob.name]), - "content_type": blob.content_type, - "updated": blob.updated, - "size": blob.size, - "type": "directory" if _is_dir(blob) else "file", - } +def _blob_to_obj(blob: Blob): + return StorageObject( + name=blob.name.split("/")[-2] if _is_dir(blob) else blob.name.split("/")[-1], + key=blob.name, + path="/".join([blob.bucket.name, blob.name]), + content_type=blob.content_type, + updated=blob.updated, + size=blob.size, + type="directory" if _is_dir(blob) else "file", + ) -def _prefix_to_dict(bucket_name, name: str): - return { - "name": name.split("/")[-2], - "key": name, - "path": "/".join([bucket_name, name]), - "size": 0, - "type": "directory", - } +def _prefix_to_obj(bucket_name, name: str): + return StorageObject( + name=name.split("/")[-2], + key=name, + path="/".join([bucket_name, name]), + type="directory", + ) def iter_request_results(bucket_name, request): @@ -69,14 +63,14 @@ def iter_request_results(bucket_name, request): current_page = next(pages) for prefix in sorted(prefixes): - yield _prefix_to_dict(bucket_name, prefix) + yield _prefix_to_obj(bucket_name, prefix) while True: for blob in current_page: if not _is_dir(blob): # We ignore objects that are directories (object with a size = 0 and ending with a /) # because they are already listed in the prefixes - yield _blob_to_dict(blob) + yield _blob_to_obj(blob) try: current_page = next(pages) except StopIteration: @@ -89,17 +83,39 @@ def ensure_is_folder(object_key: str): return object_key -class GCPClient(BaseClient): - def create_bucket(self, bucket_name: str, labels: dict = None, *args, **kwargs): - client = get_storage_client() +class GoogleCloudStorage(Storage): + storage_type = "gcp" + _client = None + + def __init__(self, service_account_key: str, region: str, enable_versioning=False): + super().__init__() + self._service_account_key = service_account_key + self.region = region + self.enable_versioning = enable_versioning + + @property + def client(self): + if self._client is None: + self._client = get_storage_client(self._service_account_key) + return self._client + + def bucket_exists(self, bucket_name: str): try: - bucket = client.create_bucket( - bucket_name, location=settings.WORKSPACE_BUCKET_REGION + self.client.get_bucket(bucket_name) + return True + except NotFound: + return False + + def create_bucket(self, bucket_name: str, labels: dict = None, *args, **kwargs): + if self.bucket_exists(bucket_name): + raise self.exceptions.AlreadyExists( + f"GCS: Bucket {bucket_name} already exists!" ) + try: + bucket = self.client.create_bucket(bucket_name, location=self.region) bucket.storage_class = "STANDARD" # Default storage class - if settings.WORKSPACE_BUCKET_VERSIONING_ENABLED: - bucket.versioning_enabled = True + bucket.versioning_enabled = self.enable_versioning # Set lifecycle rules # 1. Transition to "Nearline" Storage: Objects that haven't been accessed for 30 days can be moved to "Nearline" storage, which is cost-effective for data accessed less than once a month. @@ -147,32 +163,29 @@ def create_bucket(self, bucket_name: str, labels: dict = None, *args, **kwargs): ] bucket.patch() - return bucket + return bucket.name except Conflict: raise ValidationError(f"GCS: Bucket {bucket_name} already exists!") - def upload_object(self, bucket_name: str, file_name: str, source: str): - client = get_storage_client() - bucket = client.bucket(bucket_name) - blob = bucket.blob(file_name) - blob.upload_from_filename(source) + def save_object(self, bucket_name: str, file_path: str, file: io.BufferedReader): + bucket = self.client.bucket(bucket_name) + blob = bucket.blob(file_path) + blob.upload_from_file(file) def create_bucket_folder(self, bucket_name: str, folder_key: str): - client = get_storage_client() - bucket = client.get_bucket(bucket_name) + bucket = self.client.get_bucket(bucket_name) object = bucket.blob(ensure_is_folder(folder_key)) object.upload_from_string( "", content_type="application/x-www-form-urlencoded;charset=UTF-8" ) - return _blob_to_dict(object) + return _blob_to_obj(object) def generate_download_url( - self, bucket_name: str, target_key: str, force_attachment=False + self, bucket_name: str, target_key: str, force_attachment=False, *args, **kwargs ): - client = get_storage_client() - gcs_bucket = client.get_bucket(bucket_name) - blob: Blob = gcs_bucket.get_blob(target_key) + gcs_bucket = self.client.get_bucket(bucket_name) + blob = gcs_bucket.get_blob(target_key) if blob is None: return None @@ -193,24 +206,24 @@ def generate_upload_url( target_key: str, content_type: str = None, raise_if_exists: bool = False, + *args, + **kwargs, ): - client = get_storage_client() - gcs_bucket = client.get_bucket(bucket_name) + gcs_bucket = self.client.get_bucket(bucket_name) if raise_if_exists and gcs_bucket.get_blob(target_key) is not None: - raise BucketObjectAlreadyExists(target_key) + raise self.exceptions.AlreadyExists(target_key) blob = gcs_bucket.blob(target_key) return blob.generate_signed_url( expiration=3600, version="v4", method="PUT", content_type=content_type ) def get_bucket_object(self, bucket_name: str, object_key: str): - client = get_storage_client() - bucket = client.get_bucket(bucket_name) + bucket = self.client.get_bucket(bucket_name) object = bucket.get_blob(object_key) if object is None: - raise NotFound("Object not found") + raise self.exceptions.NotFound("Object not found") - return _blob_to_dict(object) + return _blob_to_obj(object) def list_bucket_objects( self, @@ -233,9 +246,7 @@ def list_bucket_objects( ignore_hidden_files (bool, optional): Returns the hidden files and directories if `False`. Defaults to True. """ - client = get_storage_client() - - request = client.list_blobs( + request = self.client.list_blobs( bucket_name, prefix=prefix, # We take twice the number of items to be sure to have enough @@ -250,11 +261,11 @@ def list_bucket_objects( objects = [] def is_object_match_query(obj): - if ignore_hidden_files and obj["name"].startswith("."): + if ignore_hidden_files and obj.name.startswith("."): return False if not query: return True - return query.lower() in obj["name"].lower() + return query.lower() in obj.name.lower() iterator = iter_request_results(bucket_name, request) while True: @@ -288,7 +299,7 @@ def get_short_lived_downscoped_access_token(self, bucket_name): token_lifetime = 3600 if settings.GCS_TOKEN_LIFETIME is not None: token_lifetime = int(settings.GCS_TOKEN_LIFETIME) - source_credentials = get_credentials() + source_credentials = get_credentials(self._service_account_key) iam_credentials = IAMCredentialsClient(credentials=source_credentials) iam_token = iam_credentials.generate_access_token( @@ -323,27 +334,28 @@ def get_short_lived_downscoped_access_token(self, bucket_name): }, ) payload = response.json() - return [payload["access_token"], payload["expires_in"], "gcp"] + return payload["access_token"], payload["expires_in"] - def delete_bucket(self, bucket_name: str, fully: bool = False): - return get_storage_client().delete_bucket(bucket_name) + def delete_bucket(self, bucket_name: str, force: bool = False): + return self.client.delete_bucket(bucket_name) def delete_object(self, bucket_name: str, file_name: str): - client = get_storage_client() - bucket = client.get_bucket(bucket_name) + bucket = self.client.get_bucket(bucket_name) blob = bucket.get_blob(file_name) + if blob is None: + raise self.exceptions.NotFound("Object not found") if _is_dir(blob): blobs = list(bucket.list_blobs(prefix=file_name)) bucket.delete_blobs(blobs) else: bucket.delete_blob(file_name) - return def load_bucket_sample_data(self, bucket_name: str): return load_bucket_sample_data_with(bucket_name, self) - def get_token_as_env_variables(self, token): + def get_bucket_mount_config(self, bucket_name): + token, _ = self.get_short_lived_downscoped_access_token(bucket_name) return { - "GCS_TOKEN": token, # FIXME: Once we have deployed the new openhexa-bslq-environment image and upgraded the openhexa-app, we can remove this line + "WORKSPACE_STORAGE_ENGINE_GCP_BUCKET_NAME": bucket_name, "WORKSPACE_STORAGE_ENGINE_GCP_ACCESS_TOKEN": token, } diff --git a/hexa/files/graphql/schema.graphql b/hexa/files/graphql/schema.graphql index a3e814895..ee910af88 100644 --- a/hexa/files/graphql/schema.graphql +++ b/hexa/files/graphql/schema.graphql @@ -131,6 +131,7 @@ Errors that can occur when creating a folder in a workspace's bucket. """ enum CreateBucketFolderError { ALREADY_EXISTS + NOT_FOUND PERMISSION_DENIED } diff --git a/hexa/files/s3.py b/hexa/files/s3.py deleted file mode 100644 index ba1ff676e..000000000 --- a/hexa/files/s3.py +++ /dev/null @@ -1,438 +0,0 @@ -import base64 -import json - -import boto3 -from django.conf import settings -from django.core.exceptions import ValidationError - -from .basefs import ( - BaseClient, - BucketObjectAlreadyExists, - NotFound, - ObjectsPage, - load_bucket_sample_data_with, -) - -default_region = "eu-central-1" - - -def get_storage_client(type="s3", endpoint_url=None): - """Type is the boto client type s3 by default but can be sts or other client api""" - if endpoint_url is None: - endpoint_url = settings.WORKSPACE_STORAGE_ENGINE_AWS_ENDPOINT_URL - - s3 = boto3.client( - type, - endpoint_url=endpoint_url, - aws_access_key_id=settings.WORKSPACE_STORAGE_ENGINE_AWS_ACCESS_KEY_ID, - aws_secret_access_key=settings.WORKSPACE_STORAGE_ENGINE_AWS_SECRET_ACCESS_KEY, - region_name=settings.WORKSPACE_STORAGE_ENGINE_AWS_BUCKET_REGION, - ) - return s3 - - -def _is_dir(blob): - return blob["Size"] == 0 and blob["Key"].endswith("/") - - -def _is_dir_object(blob, object_key): - return blob["ContentLength"] == 0 and object_key.endswith("/") - - -def _blob_to_dict(blob, bucket_name): - name = blob["Key"] - return { - "name": name.split("/")[-2] if _is_dir(blob) else name.split("/")[-1], - "key": name, - "path": "/".join([bucket_name, name]), - "content_type": blob.get("ContentType"), - "updated": blob["LastModified"], - "size": blob["Size"], - "type": "directory" if _is_dir(blob) else "file", - } - - -def _object_to_dict(blob, bucket_name, object_key): - name = object_key - return { - "name": ( - name.split("/")[-2] - if _is_dir_object(blob, object_key) - else name.split("/")[-1] - ), - "key": name, - "path": "/".join([bucket_name, name]), - "content_type": blob.get("ContentType"), - "updated": blob["LastModified"], - "size": blob["ContentLength"], - "type": "directory" if _is_dir_object(blob, object_key) else "file", - } - - -def _prefix_to_dict(bucket_name: str, name: str): - return { - "name": name.split("/")[-2], - "key": name, - "path": "/".join([bucket_name, name]), - "size": 0, - "type": "directory", - } - - -# allows to keep the test compatible between gcp and s3 -# for the fixture, the tests creates blobs -class S3BucketWrapper: - def __init__(self, bucket_name) -> None: - self.bucket_name = bucket_name - self.name = bucket_name # keep backward compat with gcp - - def blob(self, file_name, size=None, content_type="text/plain"): - get_storage_client().put_object( - Body="file_name", - Bucket=self.bucket_name, - Key=file_name, - ContentType=content_type, - ) - - -def ensure_is_folder(object_key: str): - if object_key.endswith("/") is False: - return object_key + "/" - return object_key - - -def _get_bucket_object(bucket_name: str, object_key: str): - client = get_storage_client() - try: - object = client.head_object(Bucket=bucket_name, Key=object_key) - - except Exception as e: - if "the HeadObject operation: Not Found" in str(e): - raise NotFound(f"{bucket_name} {object_key} not found") - # else just throw the initial error - raise e - - return _object_to_dict(object, bucket_name=bucket_name, object_key=object_key) - - -class S3Client(BaseClient): - def create_bucket(self, bucket_name: str, *args, **kwargs): - s3 = get_storage_client() - try: - s3.create_bucket( - Bucket=bucket_name, - CreateBucketConfiguration={"LocationConstraint": default_region}, - ) - # Define the configuration rules - cors_configuration = { - "CORSRules": [ - { - "AllowedHeaders": ["*"], - "AllowedMethods": ["GET", "PUT"], - "AllowedOrigins": ["*"], - "ExposeHeaders": ["*"], - "MaxAgeSeconds": 3000, - } - ] - } - - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration=cors_configuration) - return S3BucketWrapper(bucket_name) - except s3.exceptions.ClientError as exc: - # https://github.com/VeemsHQ/veems/blob/3e2e75c3407bc1f98395fe94c0e03367a82852c9/veems/media/upload_manager.py#L51C1-L51C1 - - if "NotImplemented" in str(exc): - from warnings import warn - - warn( - "Put bucket CORS failed. " - "Only if using Minio S3 backend is this okay, " - "otherwise investigate. %s", - DeprecationWarning, - stacklevel=2, - ) - else: - if "BucketAlreadyOwnedByYou" in str(exc): - raise ValidationError(f"{bucket_name} already exist") - - raise exc - - except s3.exceptions.BucketAlreadyOwnedByYou: - # https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html - raise ValidationError(f"{bucket_name} already exist") - - return S3BucketWrapper(bucket_name) - - def upload_object(self, bucket_name: str, file_name: str, source: str): - return get_storage_client().upload_file(source, bucket_name, file_name) - - def create_bucket_folder(self, bucket_name: str, folder_key: str): - s3 = get_storage_client() - - object = { - "Body": "", - "Bucket": bucket_name, - "Key": ensure_is_folder(folder_key), - "ContentType": "application/x-www-form-urlencoded;charset=UTF-8", - } - s3.put_object(**object) - - final_object = s3.get_object(Bucket=bucket_name, Key=object["Key"]) - - # the get_object isn't the same payload as the list :( - final_object["Key"] = object["Key"] - final_object["Size"] = 0 - return _blob_to_dict(final_object, bucket_name) - - def generate_download_url( - self, bucket_name: str, target_key: str, force_attachment=False - ): - url = self.generate_presigned_url( - "get_object", - Params={"Bucket": bucket_name, "Key": target_key}, - ExpiresIn=3600, - ) - return url - - def generate_presigned_url( - self, ClientMethod: str, Params=None, ExpiresIn=3600, HttpMethod=None - ): - # Since this URL will be used by the client, we need to use the client endpoint - s3_client = get_storage_client( - endpoint_url=settings.WORKSPACE_STORAGE_ENGINE_AWS_PUBLIC_ENDPOINT_URL - ) - url = s3_client.generate_presigned_url( - ClientMethod=ClientMethod, - Params=Params, - ExpiresIn=ExpiresIn, # URL expiration time in seconds - HttpMethod=HttpMethod, - ) - return url - - def generate_upload_url( - self, - bucket_name: str, - target_key: str, - content_type: str = None, - raise_if_exists: bool = False, - ): - s3_client = get_storage_client() - - if raise_if_exists: - try: - s3_client.head_object(Bucket=bucket_name, Key=target_key) - raise BucketObjectAlreadyExists(target_key) - except s3_client.exceptions.ClientError as e: - if e.response["Error"]["Code"] != "404": - # don't hide non "not found errors" - raise e - - params = { - "Bucket": bucket_name, - "Key": target_key, - } - - if content_type: - params["ContentType"] = content_type - - url = self.generate_presigned_url( - ClientMethod="put_object", - Params=params, - ExpiresIn=3600, # URL expiration time in seconds - ) - - return url - - def get_bucket_object(self, bucket_name: str, object_key: str): - return _get_bucket_object(bucket_name, object_key) - - def list_bucket_objects( - self, - bucket_name, - prefix=None, - page: int = 1, - per_page=30, - query=None, - ignore_hidden_files=True, - ): - prefix = prefix or "" - - start_offset = (page - 1) * per_page - end_offset = page * per_page - paginator = get_storage_client().get_paginator("list_objects_v2") - - pages = paginator.paginate( - Bucket=bucket_name, - Delimiter="/", - Prefix=prefix, - ) - - def is_object_match_query(obj): - if ignore_hidden_files and obj["name"].startswith("."): - return False - if not query: - return True - return query.lower() in obj["name"].lower() - - pageIndex = 0 - for response in pages: - pageIndex = pageIndex + 1 - - prefixes = ( - sorted([x["Prefix"] for x in response["CommonPrefixes"]]) - if "CommonPrefixes" in response - else [] - ) - objects = [] - - for current_prefix in prefixes: - res = _prefix_to_dict(bucket_name, current_prefix) - if is_object_match_query(res): - objects.append(res) - - files = response.get("Contents", []) - - for file in files: - if _is_dir(file): - # We ignore objects that are directories (object with a size = 0 and ending with a /) - # because they are already listed in the prefixes - continue - - res = _blob_to_dict(file, bucket_name) - - if res["key"] == prefix and prefix.endswith("/"): - continue - - if is_object_match_query(res): - objects.append(res) - - sorted(objects, key=lambda x: x["key"]) - items = objects[start_offset:end_offset] - - return ObjectsPage( - items=items, - page_number=page, - has_previous_page=page > 1, - has_next_page=len(objects) > (page * per_page), - ) - - # TODO handle read-only mode. - def get_short_lived_downscoped_access_token(self, bucket_name): - # highly inspired by https://gist.github.com/manics/305f4cc56d0ac6431893cde17b1ba8c4 - - token_lifetime = 3600 - if settings.GCS_TOKEN_LIFETIME is not None: - token_lifetime = int(settings.GCS_TOKEN_LIFETIME) - - sts_service = get_storage_client("sts") - prefix = "*" - - # Access policies - # https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html#policies_session - # https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html - # https://aws.amazon.com/premiumsupport/knowledge-center/s3-folder-user-access/ - - # TODO adapt/verify that the policy matches the gcp token one - policy = { - "Version": "2012-10-17", - "Statement": [ - { - "Sid": "ListObjectsInBucket", - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": [f"arn:aws:s3:::{bucket_name}"], - "Condition": {"StringLike": {"s3:prefix": [prefix]}}, - }, - { - "Sid": "ManageObjectsInBucket", - "Effect": "Allow", - "Action": "s3:*", - "Resource": [f"arn:aws:s3:::{bucket_name}/{prefix}"], - }, - ], - } - - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts.html#STS.Client.assume_role - # TODO what if we really want a real s3 see *ignored* ? - # - we probably need the roleArn and RoleSessionName ? - # - is this AWS_USER_ARN and AWS_APP_ROLE_ARN ? - response = sts_service.assume_role( - RoleArn=settings.AWS_APP_ROLE_ARN or "arn:x:ignored:by:minio:", - RoleSessionName="ignored-by-minio", - Policy=json.dumps(policy), - DurationSeconds=token_lifetime, - ) - - return [ - { - "endpoint_url": sts_service.__dict__["meta"].__dict__["_endpoint_url"], - "aws_access_key_id": response["Credentials"]["AccessKeyId"], - "aws_secret_access_key": response["Credentials"]["SecretAccessKey"], - "aws_session_token": response["Credentials"]["SessionToken"], - }, - response["Credentials"]["Expiration"], - "s3", - ] - - def delete_bucket(self, bucket_name: str, fully: bool = False): - s3 = get_storage_client() - try: - response = s3.list_objects_v2( - Bucket=bucket_name, - ) - - if fully: - while response["KeyCount"] > 0: - # print("Deleting %d objects from bucket %s" % (len(response["Contents"]), bucket_name)) - response = s3.delete_objects( - Bucket=bucket_name, - Delete={ - "Objects": [ - {"Key": obj["Key"]} for obj in response["Contents"] - ] - }, - ) - response = s3.list_objects_v2( - Bucket=bucket_name, - ) - return s3.delete_bucket(Bucket=bucket_name) - except s3.exceptions.NoSuchBucket: - return - except s3.exceptions.ClientError as exc: - if "InvalidBucketName" in str(exc): - return - - raise exc - - def delete_object(self, bucket_name: str, file_name: str): - client = get_storage_client() - blob = _get_bucket_object(bucket_name=bucket_name, object_key=file_name) - if blob["type"] == "directory": - blobs = self.list_bucket_objects(bucket_name=bucket_name, prefix=file_name) - client.delete_objects( - Bucket=bucket_name, - Delete={"Objects": [{"Key": b["key"]} for b in blobs]}, - ) - else: - client.delete_object(Bucket=bucket_name, Key=file_name) - return - - def load_bucket_sample_data(self, bucket_name: str): - return load_bucket_sample_data_with(bucket_name, self) - - def get_token_as_env_variables(self, token): - # the fuse config - json_config = { - "AWS_ENDPOINT": token["endpoint_url"], - "AWS_ACCESS_KEY_ID": token["aws_access_key_id"], - "AWS_SECRET_ACCESS_KEY": token["aws_secret_access_key"], - "AWS_SESSION_TOKEN": token["aws_session_token"], - "AWS_DEFAULT_REGION": token.get("default_region", ""), - "AWS_ENDPOINT_URL": token["endpoint_url"], - } - - return { - "WORKSPACE_STORAGE_ENGINE_S3_FUSE_CONFIG": base64.b64encode( - json.dumps(json_config).encode() - ).decode(), - } diff --git a/hexa/files/schema/mutations.py b/hexa/files/schema/mutations.py index af94190ac..38d8bcbae 100644 --- a/hexa/files/schema/mutations.py +++ b/hexa/files/schema/mutations.py @@ -1,7 +1,7 @@ from ariadne import MutationType from hexa.analytics.api import track -from hexa.files.api import get_storage +from hexa.files import storage from hexa.workspaces.models import Workspace mutations = MutationType() @@ -18,9 +18,9 @@ def resolve_delete_bucket_object(_, info, **kwargs): if not request.user.has_perm("files.delete_object", workspace): return {"success": False, "errors": ["PERMISSION_DENIED"]} - get_storage().delete_object(workspace.bucket_name, mutation_input["objectKey"]) + storage.delete_object(workspace.bucket_name, mutation_input["objectKey"]) return {"success": True, "errors": []} - except (get_storage().exceptions.NotFound, Workspace.DoesNotExist): + except (storage.exceptions.NotFound, Workspace.DoesNotExist): return {"success": False, "errors": ["NOT_FOUND"]} @@ -36,7 +36,7 @@ def resolve_prepare_download_object(_, info, **kwargs): if not request.user.has_perm("files.download_object", workspace): return {"success": False, "errors": ["PERMISSION_DENIED"]} object_key = mutation_input["objectKey"] - download_url = get_storage().generate_download_url( + download_url = storage.generate_download_url( workspace.bucket_name, object_key, force_attachment=True ) track( @@ -46,7 +46,7 @@ def resolve_prepare_download_object(_, info, **kwargs): ) return {"success": True, "download_url": download_url, "errors": []} - except (get_storage().exceptions.NotFound, Workspace.DoesNotExist): + except (storage.exceptions.NotFound, Workspace.DoesNotExist): return {"success": False, "errors": ["NOT_FOUND"]} @@ -62,12 +62,12 @@ def resolve_prepare_upload_object(_, info, **kwargs): if not request.user.has_perm("files.create_object", workspace): return {"success": False, "errors": ["PERMISSION_DENIED"]} object_key = mutation_input["objectKey"] - upload_url = get_storage().generate_upload_url( + upload_url = storage.generate_upload_url( workspace.bucket_name, object_key, mutation_input.get("contentType") ) return {"success": True, "upload_url": upload_url, "errors": []} - except (get_storage().exceptions.NotFound, Workspace.DoesNotExist): + except (storage.exceptions.NotFound, Workspace.DoesNotExist): return {"success": False, "errors": ["NOT_FOUND"]} @@ -83,12 +83,10 @@ def resolve_create_bucket_folder(_, info, **kwargs): if not request.user.has_perm("files.create_object", workspace): return {"success": False, "errors": ["PERMISSION_DENIED"]} folder_key = mutation_input["folderKey"] - folder_object = get_storage().create_bucket_folder( - workspace.bucket_name, folder_key - ) + folder_object = storage.create_bucket_folder(workspace.bucket_name, folder_key) return {"success": True, "folder": folder_object, "errors": []} - except (get_storage().exceptions.NotFound, Workspace.DoesNotExist): + except (storage.exceptions.NotFound, Workspace.DoesNotExist): return {"success": False, "errors": ["NOT_FOUND"]} diff --git a/hexa/files/schema/types.py b/hexa/files/schema/types.py index 9a978d52e..acbbfa1f4 100644 --- a/hexa/files/schema/types.py +++ b/hexa/files/schema/types.py @@ -2,7 +2,8 @@ from django.core.exceptions import ImproperlyConfigured from django.http import HttpRequest -from hexa.files.api import NotFound, get_storage +from hexa.files import storage +from hexa.files.backends.base import StorageObject from hexa.workspaces.models import Workspace from hexa.workspaces.schema.types import workspace_object, workspace_permissions @@ -53,7 +54,7 @@ def resolve_bucket_objects( ): if workspace.bucket_name is None: raise ImproperlyConfigured("Workspace does not have a bucket") - page = get_storage().list_bucket_objects( + page = storage.list_bucket_objects( workspace.bucket_name, prefix=prefix, page=page, @@ -70,8 +71,8 @@ def resolve_bucket_object(workspace, info, key, **kwargs): if workspace.bucket_name is None: raise ImproperlyConfigured("Workspace does not have a bucket") try: - return get_storage().get_bucket_object(workspace.bucket_name, key) - except NotFound: + return storage.get_bucket_object(workspace.bucket_name, key) + except storage.exceptions.NotFound: return None @@ -79,8 +80,8 @@ def resolve_bucket_object(workspace, info, key, **kwargs): @bucket_object_object.field("type") -def resolve_object_type(obj, info): - return obj["type"].upper() +def resolve_object_type(obj: StorageObject, info): + return obj.type.upper() bindables = [bucket_object, bucket_object_object] diff --git a/hexa/files/tests/backends/__init__.py b/hexa/files/tests/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/hexa/files/tests/backends/test_fs.py b/hexa/files/tests/backends/test_fs.py new file mode 100644 index 000000000..1752c3ec7 --- /dev/null +++ b/hexa/files/tests/backends/test_fs.py @@ -0,0 +1,271 @@ +import os +import shutil +from io import BytesIO +from pathlib import Path +from tempfile import mkdtemp +from unittest.mock import patch + +from django.core.files.uploadedfile import SimpleUploadedFile +from django.test import override_settings + +from hexa.core.test import TestCase +from hexa.files.backends.fs import FileSystemStorage + + +class FileSystemStorageTest(TestCase): + storage = None + + def setUp(self): + super().setUp() + self.data_directory = Path(mkdtemp()) + self.storage = FileSystemStorage( + data_dir=self.data_directory, + ) + + def tearDown(self) -> None: + super().tearDown() + shutil.rmtree(self.data_directory) + + def test_create_bucket(self): + self.storage.create_bucket("test") + self.assertTrue(self.storage.exists("test")) + self.assertTrue(os.path.lexists(self.storage.path("test"))) + self.assertEqual(self.storage.path("test"), self.data_directory / "test") + + def test_suspicious_create_bucket(self): + for path in ["../test", "/test", "../../test", "dir/subdir"]: + with self.assertRaises(self.storage.exceptions.SuspiciousFileOperation): + self.storage.create_bucket(path) + + def test_path(self): + self.assertEqual( + self.storage.path("my-dir/my-subdir/my-file.png"), + self.data_directory / "my-dir/my-subdir/my-file.png", + ) + self.assertEqual( + self.storage.path("my-dir/../my-dir/my-subdir/my-file.png"), + self.data_directory / "my-dir/my-subdir/my-file.png", + ) + + with self.assertRaises(self.storage.exceptions.SuspiciousFileOperation): + self.storage.path("../my-dir/my-subdir/my-file.png") + + def test_create_bucket_folder(self): + self.storage.create_bucket("my-bucket") + self.storage.create_bucket("my-second-bucket") + self.storage.create_bucket_folder("my-bucket", "my-dir") + + self.assertTrue(self.storage.exists("my-bucket/my-dir")) + self.assertTrue(os.path.lexists(self.storage.path("my-bucket/my-dir"))) + + self.assertTrue(self.storage.exists("my-second-bucket")) + self.assertFalse(self.storage.exists("my-second-bucket/my-dir")) + + with self.assertRaises(self.storage.exceptions.SuspiciousFileOperation): + self.storage.create_bucket_folder( + "my-bucket", "../my-second-bucket/my-second-dir" + ) + self.assertFalse( + os.path.lexists(self.storage.path("my-second-bucket/my-second-dir")) + ) + self.assertFalse( + os.path.lexists(self.data_directory / "my-bucket/my-second-dir") + ) + + def test_valid_filenames(self): + self.storage.create_bucket("default-bucket") + dir_obj = self.storage.create_bucket_folder("default-bucket", "éà_?_d 1") + self.assertEqual(dir_obj.name, "éà__d_1") + + def test_save_object(self): + self.storage.create_bucket("default-bucket") + self.storage.save_object("default-bucket", "file.txt", b"Hello, world!") + + self.assertTrue((self.data_directory / "default-bucket/file.txt").exists()) + self.assertEqual( + open(self.data_directory / "default-bucket/file.txt").read(), + "Hello, world!", + ) + + def test_deep_save_object(self): + self.storage.create_bucket("default-bucket") + self.storage.save_object( + "default-bucket", "dir1/dir2/file.txt", b"Hello, world!" + ) + self.assertTrue( + (self.data_directory / "default-bucket/dir1/dir2/file.txt").exists() + ) + + self.storage.save_object("default-bucket", "dir1/file2.txt", b"Hello, world!") + self.assertTrue( + (self.data_directory / "default-bucket/dir1/file2.txt").exists() + ) + + def test_overwrite_object(self): + self.storage.create_bucket("default-bucket") + self.storage.save_object("default-bucket", "file.txt", b"Hello, world!") + + self.assertTrue((self.data_directory / "default-bucket/file.txt").exists()) + self.assertEqual( + open(self.data_directory / "default-bucket/file.txt").read(), + "Hello, world!", + ) + + # Overwrite the file + self.storage.save_object("default-bucket", "file.txt", b"OVERWRITTEN") + self.assertEqual( + open(self.data_directory / "default-bucket/file.txt").read(), + "OVERWRITTEN", + ) + + def test_list_bucket_objects(self): + self.storage.create_bucket("default-bucket") + for i in range(100): + self.storage.save_object( + "default-bucket", f"file-{i}.txt", b"Hello, world!" + ) + res = self.storage.list_bucket_objects("default-bucket", page=1, per_page=5) + self.assertEqual(len(res.items), 5) + self.assertEqual(res.has_next_page, True) + self.assertEqual(res.has_previous_page, False) + self.assertEqual(res.page_number, 1) + + first_item = res.items[0] + self.assertEqual(first_item.type, "file") + self.assertEqual(first_item.size, 13) + self.assertEqual(first_item.content_type, "text/plain") + res = self.storage.list_bucket_objects("default-bucket", page=2, per_page=100) + self.assertEqual(len(res.items), 0) + self.assertEqual(res.has_next_page, False) + self.assertEqual(res.has_previous_page, True) + + def test_list_bucket_objects_with_query(self): + self.storage.create_bucket("default-bucket") + for i in range(100): + self.storage.save_object( + "default-bucket", f"file-{i}.txt", b"Hello, world!" + ) + self.storage.save_object("default-bucket", "found.txt", b"Hello, world!") + + res_found = self.storage.list_bucket_objects( + "default-bucket", query="found", per_page=100 + ) + + self.assertEqual(len(res_found.items), 1) + self.assertEqual(res_found.items[0].name, "found.txt") + + def test_list_bucket_objects_with_prefix_and_query(self): + self.storage.create_bucket("default-bucket") + self.storage.save_object("default-bucket", "prefix/found.txt", b"Hello, world!") + self.storage.save_object("default-bucket", "tada/found-2.txt", b"Hello, world!") + + self.assertEqual( + 0, + len( + self.storage.list_bucket_objects("default-bucket", query="found").items + ), + ) + + res = self.storage.list_bucket_objects( + "default-bucket", prefix="prefix", query="found" + ) + self.assertEqual(len(res.items), 1) + self.assertEqual(res.items[0].name, "found.txt") + + res = self.storage.list_bucket_objects( + "default-bucket", prefix="tada", query="found" + ) + self.assertEqual(len(res.items), 1) + self.assertEqual(res.items[0].name, "found-2.txt") + + @override_settings(NEW_FRONTEND_DOMAIN="http://localhost") + def test_generate_upload_url(self): + self.storage.create_bucket("default-bucket") + url = self.storage.generate_upload_url( + "default-bucket", "file.txt", content_type="text/plain" + ) + self.assertTrue(url.startswith("http://localhost/files/up/")) + token = url.split("/")[-2] + + self.assertEqual( + self.storage._get_payload_from_token(token), + {"bucket_name": "default-bucket", "file_path": "file.txt"}, + ) + + def test_upload_file(self): + self.storage.create_bucket("default-bucket") + + file_data = BytesIO(b"This is a test file2.") + uploaded_file = SimpleUploadedFile( + "test_file.txt", file_data.getvalue(), content_type="text/plain" + ) + + url = self.storage.generate_upload_url( + "default-bucket", "test_file.txt", "text/plain" + ) + with patch("hexa.files.views.storage", self.storage): + resp = self.client.post( + url, + data={"file": uploaded_file}, + format="multipart", + HTTP_X_METHOD_OVERRIDE="PUT", + ) + self.assertEqual(resp.status_code, 201) + + def test_upload_file_expired_token(self): + self.storage.create_bucket("default-bucket") + + file_data = BytesIO(b"This is a test file.") + uploaded_file = SimpleUploadedFile( + "test_file.txt", file_data.getvalue(), content_type="text/plain" + ) + + # Expire the token + self.storage._token_max_age = 0 + + url = self.storage.generate_upload_url( + "default-bucket", "test_file.txt", "text/plain" + ) + with patch("hexa.files.views.storage", self.storage): + resp = self.client.post( + url, + data={"file": uploaded_file}, + format="multipart", + HTTP_X_METHOD_OVERRIDE="PUT", + ) + self.assertEqual(resp.status_code, 400) + + # Put back a greater expiration time + self.storage._token_max_age = 60 * 60 + with patch("hexa.files.views.storage", self.storage): + resp = self.client.post( + url, + data={"file": uploaded_file}, + format="multipart", + HTTP_X_METHOD_OVERRIDE="PUT", + ) + self.assertEqual(resp.status_code, 201) + + def test_get_mount_config(self): + self.assertEqual( + self.storage.get_bucket_mount_config("my_bucket"), + {"WORKSPACE_STORAGE_MOUNT_PATH": str(self.data_directory / "my_bucket")}, + ) + + def test_delete_bucket(self): + self.storage.create_bucket("default-bucket") + self.assertTrue((self.data_directory / "default-bucket").exists()) + self.storage.delete_bucket( + "default-bucket", + ) + self.assertFalse(self.storage.exists("default-bucket")) + + def test_delete_bucket_not_empty(self): + self.storage.create_bucket("default-bucket") + self.storage.save_object("default-bucket", "file.txt", b"Hello, world!") + with self.assertRaises(self.storage.exceptions.BadRequest): + self.storage.delete_bucket("default-bucket") + self.assertTrue(self.storage.exists("default-bucket")) + + self.storage.delete_bucket("default-bucket", force=True) + self.assertFalse(self.storage.exists("default-bucket")) diff --git a/hexa/files/tests/backends/test_gcp.py b/hexa/files/tests/backends/test_gcp.py new file mode 100644 index 000000000..fb6552884 --- /dev/null +++ b/hexa/files/tests/backends/test_gcp.py @@ -0,0 +1,59 @@ +from unittest.mock import patch + +from google.cloud.exceptions import NotFound + +from hexa.core.test import TestCase +from hexa.files.backends.gcp import GoogleCloudStorage +from hexa.files.tests.mocks.client import MockClient + + +class GoogleCloudStorageTest(TestCase): + storage = None + + def setUp(self): + super().setUp() + self.get_storage_client_patch = patch( + "hexa.files.backends.gcp.get_storage_client" + ) + self.mock_get_client = self.get_storage_client_patch.start() + self.storage_client = MockClient() + self.mock_get_client.return_value = self.storage_client + self.storage = GoogleCloudStorage( + service_account_key="service_account_key", region="europe-west1" + ) + + self.addCleanup(self.get_storage_client_patch.stop) + + def test_mock_client(self): + self.assertIsInstance(self.storage.client, MockClient) + with self.assertRaises(NotFound): + self.storage.client.get_bucket("test-bucket") + + def test_create_bucket(self): + self.storage.create_bucket("my-bucket") + self.assertTrue(self.storage.bucket_exists("my-bucket")) + + def test_create_bucket_already_exists(self): + self.storage.create_bucket("my-bucket") + with self.assertRaises(self.storage.exceptions.AlreadyExists): + self.storage.create_bucket("my-bucket") + + def test_list_bucket_objects(self): + bucket_name = self.storage.create_bucket("my-bucket") + bucket = self.storage_client.get_bucket(bucket_name) + bucket.blob("test.txt").upload_from_string(b"test") + items = self.storage.list_bucket_objects("my-bucket").items + self.assertEqual(len(items), 1) + first = items[0] + self.assertEqual(first.name, "test.txt") + + def test_delete_object(self): + bucket_name = self.storage.create_bucket("my-bucket") + bucket = self.storage.client.get_bucket(bucket_name) + bucket.blob("my_blob.txt") + + self.assertTrue(self.storage.get_bucket_object("my-bucket", "my_blob.txt")) + self.storage.delete_object("my-bucket", "my_blob.txt") + + with self.assertRaises(self.storage.exceptions.NotFound): + self.storage.get_bucket_object("my-bucket", "my_blob.txt") diff --git a/hexa/files/tests/mocks/backend.py b/hexa/files/tests/mocks/backend.py deleted file mode 100644 index 9d029e3a8..000000000 --- a/hexa/files/tests/mocks/backend.py +++ /dev/null @@ -1,55 +0,0 @@ -import functools -import uuid -from unittest.mock import MagicMock, patch - - -class BucketAlreadyOwnedByYou(Exception): - pass - - -class StorageBackend: - def __init__(self, project=None): - if project is None: - project = "test-project-" + str(uuid.uuid1()) - self.project = project - self.buckets = {} - - def reset(self): - self.buckets = {} - - def delete_bucket(self, bucket_name): - if bucket_name in self.buckets: - del self.buckets[bucket_name] - - def mock_storage(self, func): - from .client import MockClient - - def create_mock_client(*args, **kwargs): - client = MockClient(backend=self, *args, **kwargs) - return client - - def wrapper(*args, **kwargs): - with patch("hexa.files.gcp.get_storage_client", create_mock_client): - return func(*args, **kwargs) - - functools.update_wrapper(wrapper, func) - wrapper.__wrapped__ = func - return wrapper - - def mock_s3_storage(self, func): - from .client import MockClient - - def create_mock_client(*args, **kwargs): - client = MockClient(backend=self, *args, **kwargs) - client.exceptions = MagicMock() - client.exceptions.BucketAlreadyOwnedByYou = BucketAlreadyOwnedByYou - client.exceptions.NoSuchBucket = BucketAlreadyOwnedByYou - return client - - def wrapper(*args, **kwargs): - with patch("hexa.files.s3.get_storage_client", create_mock_client): - return func(*args, **kwargs) - - functools.update_wrapper(wrapper, func) - wrapper.__wrapped__ = func - return wrapper diff --git a/hexa/files/tests/mocks/bucket.py b/hexa/files/tests/mocks/bucket.py index e42358874..380c87235 100644 --- a/hexa/files/tests/mocks/bucket.py +++ b/hexa/files/tests/mocks/bucket.py @@ -1,6 +1,6 @@ from google.cloud.storage._helpers import _validate_name -from hexa.files.api import NotFound +from hexa.files.backends.exceptions import NotFound from .blob import MockBlob @@ -33,14 +33,12 @@ def exists(self, client=None): def list_blobs(self, *args, **kwargs): return self.client.list_blobs(self, *args, **kwargs) - def get_blob(self, blob_name, *args, **kwargs): - if any( - filename in blob_name - for filename in ["test", "demo", "mock", "data", "some-uri"] - ): - return MockBlob(blob_name, self) - else: - return None + def get_blob(self, name, *args, **kwargs): + existing = [b for b in self._blobs if b.name == name] + if len(existing) == 0: + raise NotFound("key not found") + + return existing[0] def blob(self, *args, **kwargs): b = MockBlob(*args, bucket=self, **kwargs) @@ -50,9 +48,9 @@ def blob(self, *args, **kwargs): def patch(self): pass - def delete_blob(self, key): - existing = [b for b in self._blobs if b.name == key] + def delete_blob(self, name): + existing = [b for b in self._blobs if b.name == name] if len(existing) == 0: raise NotFound("key not found") - self._blobs = [b for b in self._blobs if b.name != key] + self._blobs = [b for b in self._blobs if b.name != name] diff --git a/hexa/files/tests/mocks/client.py b/hexa/files/tests/mocks/client.py index 10df5d7d9..998485fda 100644 --- a/hexa/files/tests/mocks/client.py +++ b/hexa/files/tests/mocks/client.py @@ -112,7 +112,6 @@ def _has_next_page(self): class MockClient: def __init__( self, - backend, credentials=None, _http=None, client_info=None, @@ -120,25 +119,12 @@ def __init__( *args, **kwargs, ): - self.backend = backend - self.project = backend.project + self.buckets = {} self.credentials = credentials self._http = _http self.client_info = client_info self.client_options = client_options - @classmethod - def create_anonymous_client(cls): - raise NotImplementedError - - @property - def _connection(self): - raise NotImplementedError - - @_connection.setter - def _connection(self, value): - raise NotImplementedError - def _push_batch(self, batch): raise NotImplementedError @@ -149,7 +135,7 @@ def _bucket_arg_to_bucket(self, bucket_or_name): if isinstance(bucket_or_name, MockBucket): bucket = bucket_or_name else: - bucket = self.backend.buckets.get(bucket_or_name) + bucket = self.buckets.get(bucket_or_name) return bucket @property @@ -163,7 +149,12 @@ def bucket(self, bucket_name, user_project=None): return MockBucket(client=self, name=bucket_name, user_project=user_project) def delete_bucket(self, bucket_name): - self.backend.delete_bucket(bucket_name) + if bucket_name in self.buckets: + del self.buckets[bucket_name] + else: + raise NotFound( + f"404 DELETE https://storage.googleapis.com/storage/v1/b/{bucket_name}" + ) def batch(self): raise NotImplementedError @@ -171,13 +162,12 @@ def batch(self): def get_bucket(self, bucket_or_name): bucket = self._bucket_arg_to_bucket(bucket_or_name) - # TODO: Use bucket.reload(client=self) when MockBucket class is implemented - if bucket.name in self.backend.buckets.keys(): - return self.backend.buckets[bucket.name] - else: + if bucket is None or bucket.name not in self.buckets.keys(): raise NotFound( - f"404 GET https://storage.googleapis.com/storage/v1/b/{bucket.name}?projection=noAcl" + f"404 GET https://storage.googleapis.com/storage/v1/b/{bucket_or_name}?projection=noAcl" ) + else: + return self.buckets[bucket.name] def lookup_bucket(self, bucket_name): try: @@ -197,14 +187,14 @@ def create_bucket( labels=labels, ) - if bucket.name in self.backend.buckets.keys(): + if bucket.name in self.buckets.keys(): raise Conflict( "409 POST https://storage.googleapis.com/storage/v1/b?project={}: You already own this bucket. Please select another name.".format( self.project ) ) else: - self.backend.buckets[bucket.name] = bucket + self.buckets[bucket.name] = bucket return bucket def download_blob_to_file(self, blob_or_uri, file_obj, start=None, end=None): @@ -256,9 +246,9 @@ def list_buckets( **kwargs, ): if isinstance(max_results, int): - buckets = list(self.backend.buckets.values())[:max_results] + buckets = list(self.buckets.values())[:max_results] else: - buckets = list(self.backend.buckets.values()) + buckets = list(self.buckets.values()) if isinstance(prefix, str): buckets = [bucket for bucket in buckets if bucket.name.startswith(prefix)] diff --git a/hexa/files/tests/mocks/mockgcp.py b/hexa/files/tests/mocks/mockgcp.py deleted file mode 100644 index 720bb6a27..000000000 --- a/hexa/files/tests/mocks/mockgcp.py +++ /dev/null @@ -1,9 +0,0 @@ -from .backend import StorageBackend - - -def create_storage_mock(project=None): - return StorageBackend(project=project) - - -backend = StorageBackend() -mock_gcp_storage = backend.mock_storage diff --git a/hexa/files/tests/test_api.py b/hexa/files/tests/test_api.py deleted file mode 100644 index 28123be0d..000000000 --- a/hexa/files/tests/test_api.py +++ /dev/null @@ -1,534 +0,0 @@ -from unittest.mock import patch - -import boto3 -import botocore -from django.core.exceptions import ValidationError -from django.test import override_settings -from moto import mock_aws - -from hexa.core.test import TestCase - -from ..api import NotFound, get_storage -from ..basefs import BucketObjectAlreadyExists -from .mocks.mockgcp import backend - - -class APITestCase: - def get_client(self): - return get_storage(self.get_type()) - - def to_keys(self, page): - return [x["key"] for x in page.items] - - def setUp(self): - if self.get_type() == "gcp": - from .mocks.client import MockClient - - def create_mock_client(*args, **kwargs): - return MockClient(backend=backend, *args, **kwargs) - - patcher = patch("hexa.files.gcp.get_storage_client", create_mock_client) - self.mock_backend = patcher.start() - self.addCleanup(patcher.stop) - - if self.get_type() == "s3": - mock = mock_aws() - - self.mock_backend = mock.start() - self.addCleanup(mock.stop) - - backend.reset() - # since I call a real minio, I delete the content and bucket upfront - buckets = [ - "my_bucket", - "my-bucket", - "test-bucket", - "empty-bucket", - "not-empty-bucket", - "bucket", - ] - for bucket_name in buckets: - self.get_client().delete_bucket(bucket_name=bucket_name, fully=True) - - def test_create_bucket(self): - self.assertEqual(backend.buckets, {}) - self.get_client().create_bucket("test-bucket") - self.assertEqual(self.get_client().list_bucket_objects("test-bucket").items, []) - - def test_create_same_bucket(self): - self.assertEqual(backend.buckets, {}) - self.get_client().create_bucket("test-bucket") - with self.assertRaises(ValidationError): - self.get_client().create_bucket("test-bucket") - - def test_list_blobs_empty(self): - bucket = self.get_client().create_bucket("empty-bucket") - self.assertEqual(self.get_client().list_bucket_objects(bucket.name).items, []) - - def test_list_blobs(self): - bucket = self.get_client().create_bucket("not-empty-bucket") - bucket.blob( - "test.txt", - size=123, - content_type="text/plain", - ) - bucket.blob( - "readme.md", - size=2103, - content_type="text/plain", - ) - bucket.blob( - "other_file.md", - size=2102, - content_type="text/plain", - ) - bucket.blob("folder/", size=0) - bucket.blob( - "folder/readme.md", - size=1, - content_type="text/plain", - ) - self.assertEqual( - self.to_keys( - self.get_client().list_bucket_objects(bucket.name, page=1, per_page=2) - ), - [ - "folder/", - "other_file.md", - ], - ) - - def test_list_blobs_with_query(self): - bucket = self.get_client().create_bucket("not-empty-bucket") - bucket.blob( - "test.txt", - size=123, - content_type="text/plain", - ) - bucket.blob( - "readme.md", - size=2103, - content_type="text/plain", - ) - bucket.blob( - "file.md", - size=1, - content_type="text/plain", - ) - bucket.blob( - "other_file.md", - size=2102, - content_type="text/plain", - ) - bucket.blob("folder/", size=0) - bucket.blob( - "folder/readme.md", - size=1, - content_type="text/plain", - ) - self.assertEqual( - [ - x["key"] - for x in self.get_client() - .list_bucket_objects(bucket.name, page=1, per_page=10, query="readme") - .items - ], - [ - "readme.md", - ], - ) - - self.assertEqual( - [ - x["key"] - for x in self.get_client() - .list_bucket_objects(bucket.name, page=1, per_page=10, query="file") - .items - ], - ["file.md", "other_file.md"], - ) - self.assertEqual( - [ - x["key"] - for x in self.get_client() - .list_bucket_objects(bucket.name, page=2, per_page=10, query="file") - .items - ], - [], - ) - - def test_list_hide_hidden_files(self): - bucket = self.get_client().create_bucket("bucket") - bucket.blob( - "test.txt", - size=123, - content_type="text/plain", - ) - bucket.blob( - ".gitconfig", - size=2103, - content_type="text/plain", - ) - bucket.blob( - ".gitignore", - size=2102, - content_type="text/plain", - ) - bucket.blob(".git/", size=0) - bucket.blob(".git/config", size=1, content_type="text/plain") - - self.assertEqual( - self.to_keys( - self.get_client().list_bucket_objects(bucket.name, page=1, per_page=10) - ), - [ - "test.txt", - ], - ) - - self.assertEqual( - self.to_keys( - self.get_client().list_bucket_objects( - bucket.name, page=1, per_page=10, ignore_hidden_files=False - ) - ), - [".git/", ".gitconfig", ".gitignore", "test.txt"], - ) - - def test_list_blobs_with_prefix(self): - bucket = self.get_client().create_bucket("bucket") - bucket.blob( - "test.txt", - size=123, - content_type="text/plain", - ) - bucket.blob( - "dir/", - size=0, - ) - bucket.blob( - "dir/readme.md", - size=2102, - content_type="text/plain", - ) - bucket.blob("dir/b/", size=0) - bucket.blob("dir/b/image.jpg", size=1, content_type="image/jpeg") - bucket.blob("other_dir/", size=0) - - self.assertEqual( - self.to_keys( - self.get_client().list_bucket_objects( - bucket.name, page=1, per_page=10, prefix="dir/" - ) - ), - [ - "dir/b/", - "dir/readme.md", - ], - ) - - def test_list_blobs_hidden_files(self): - bucket = self.get_client().create_bucket("bucket") - bucket.blob( - "dir/", - size=0, - ) - bucket.blob( - "dir/readme.md", - size=2102, - content_type="text/plain", - ) - bucket.blob( - "dir/.checkpoint.ipynb", - size=2102, - content_type="text/plain", - ) - bucket.blob("dir/.b/", size=0) - bucket.blob("dir/.b/image.jpg", size=1, content_type="image/jpeg") - - self.assertEqual( - self.to_keys( - self.get_client().list_bucket_objects( - bucket.name, - page=1, - per_page=10, - prefix="dir/", - ignore_hidden_files=True, - ) - ), - [ - "dir/readme.md", - ], - ) - - def test_list_blobs_pagination(self): - bucket = self.get_client().create_bucket("my-bucket") - for i in range(0, 12): - bucket.blob(f"test_{i}.txt", size=(123 * i), content_type="text/plain") - - res = self.get_client().list_bucket_objects(bucket.name, page=1, per_page=10) - - self.assertEqual( - self.to_keys(res), - [ - "test_0.txt", - "test_1.txt", - "test_10.txt", - "test_11.txt", - "test_2.txt", - "test_3.txt", - "test_4.txt", - "test_5.txt", - "test_6.txt", - "test_7.txt", - ], - ) - - self.assertTrue(res.has_next_page) - self.assertFalse(res.has_previous_page) - self.assertEqual(res.page_number, 1) - - res = self.get_client().list_bucket_objects(bucket.name, page=1, per_page=20) - self.assertEqual( - self.to_keys(res), - [ - "test_0.txt", - "test_1.txt", - "test_10.txt", - "test_11.txt", - "test_2.txt", - "test_3.txt", - "test_4.txt", - "test_5.txt", - "test_6.txt", - "test_7.txt", - "test_8.txt", - "test_9.txt", - ], - ) - self.assertFalse(res.has_next_page) - - res = self.get_client().list_bucket_objects(bucket.name, page=2, per_page=10) - self.assertEqual(self.to_keys(res), ["test_8.txt", "test_9.txt"]) - self.assertFalse(res.has_next_page) - self.assertTrue(res.has_previous_page) - self.assertEqual(res.page_number, 2) - - res = self.get_client().list_bucket_objects(bucket.name, page=2, per_page=5) - self.assertEqual( - self.to_keys(res), - ["test_3.txt", "test_4.txt", "test_5.txt", "test_6.txt", "test_7.txt"], - ) - - self.assertTrue(res.has_next_page) - self.assertTrue(res.has_previous_page) - self.assertEqual(res.page_number, 2) - - def test_delete_object_working(self): - bucket = self.get_client().create_bucket("bucket") - bucket.blob( - "test.txt", - size=123, - content_type="text/plain", - ) - res = self.get_client().list_bucket_objects("bucket") - self.assertEqual(self.to_keys(res), ["test.txt"]) - - self.get_client().delete_object(bucket_name=bucket.name, file_name="test.txt") - res = self.get_client().list_bucket_objects("bucket") - - self.assertEqual(self.to_keys(res), []) - - def test_delete_object_non_existing(self): - bucket = self.get_client().create_bucket("bucket") - with self.assertRaises(NotFound): - self.get_client().delete_object( - bucket_name=bucket.name, file_name="test.txt" - ) - - def test_generate_download_url(self): - self.get_client().create_bucket("bucket") - url = self.get_client().generate_download_url("bucket", "demo.txt") - assert "demo.txt" in url, f"Expected to be in '{url}'" - - def test_generate_upload_url(self): - self.get_client().create_bucket("bucket") - url = self.get_client().generate_upload_url("bucket", "demo.txt") - assert "demo.txt" in url, f"Expected to be in '{url}'" - - def test_generate_upload_url_raise_existing(self): - bucket = self.get_client().create_bucket("bucket") - bucket.blob( - "demo.txt", - size=123, - content_type="text/plain", - ) - with self.assertRaises(BucketObjectAlreadyExists): - self.get_client().generate_upload_url( - bucket_name="bucket", target_key="demo.txt", raise_if_exists=True - ) - - -class OnlyGCP: - @override_settings(WORKSPACE_BUCKET_VERSIONING_ENABLED="true") - def test_create_bucket_configuration(self): - bucket = self.get_client().create_bucket("bucket") - - self.assertEqual(bucket.versioning_enabled, True) - self.assertEqual( - bucket.lifecycle_rules, - [ - { - "action": {"type": "SetStorageClass", "storageClass": "NEARLINE"}, - "condition": {"age": 30}, - }, - { - "action": {"type": "SetStorageClass", "storageClass": "COLDLINE"}, - "condition": {"age": 90}, - }, - { - "action": {"type": "SetStorageClass", "storageClass": "ARCHIVE"}, - "condition": {"age": 365}, - }, - { - "action": {"type": "Delete"}, - "condition": {"isLive": False, "numNewerVersions": 3}, - }, - ], - ) - self.assertEqual(bucket.storage_class, "STANDARD") - - def test_create_bucket_labels(self): - bucket = self.get_client().create_bucket("bucket", labels={"key": "value"}) - self.assertEqual(bucket.labels, {"key": "value"}) - - -class OnlyS3: - def test_generate_upload_url_raise_existing_dont_raise(self): - self.get_client().delete_bucket("bucket") - self.get_client().create_bucket("bucket") - url = self.get_client().generate_upload_url( - bucket_name="bucket", target_key="demo.txt", raise_if_exists=True - ) - - assert "demo.txt" in url, f"Expected to be in '{url}'" - - def test_load_bucket_sample_data(self): - self.get_client().create_bucket("bucket") - self.get_client().load_bucket_sample_data(bucket_name="bucket") - res = self.get_client().list_bucket_objects("bucket") - - self.assertEqual( - self.to_keys(res), ["README.MD", "covid_data.csv", "demo.ipynb"] - ) - - def test_create_bucket_folder(self): - self.get_client().create_bucket("bucket") - self.assertEqual(self.get_client().list_bucket_objects("bucket").items, []) - self.get_client().create_bucket_folder(bucket_name="bucket", folder_key="demo") - self.assertEqual( - self.to_keys(self.get_client().list_bucket_objects("bucket")), - ["demo/"], - ) - - def test_generate_client_upload_url(self): - self.get_client().create_bucket("bucket") - - url = self.get_client().generate_upload_url( - bucket_name="bucket", target_key="demo.txt" - ) - self.assertFalse(url.startswith("https://custom-s3.local")) - - with override_settings( - WORKSPACE_STORAGE_ENGINE_AWS_PUBLIC_ENDPOINT_URL="https://custom-s3.local" - ): - url = self.get_client().generate_upload_url( - bucket_name="bucket", target_key="demo.txt" - ) - self.assertTrue(url.startswith("https://custom-s3.local")) - - def test_generate_client_download_url(self): - bucket = self.get_client().create_bucket("bucket") - bucket.blob("demo.txt", size=123, content_type="text/plain") - - url = self.get_client().generate_download_url("bucket", "demo.txt") - self.assertFalse(url.startswith("https://custom-s3.local")) - - with override_settings( - WORKSPACE_STORAGE_ENGINE_AWS_PUBLIC_ENDPOINT_URL="https://custom-s3.local" - ): - url = self.get_client().generate_download_url("bucket", "demo.txt") - self.assertTrue(url.startswith("https://custom-s3.local")) - - -class OnlyOnline: - def test_short_lived_downscoped_access_token(self): - # TODO make that test work for gcp and s3 - bucket = self.get_client().create_bucket("bucket") - for i in range(0, 2): - bucket.blob( - f"test_{i}.txt", - size=123 * i, - content_type="text/plain", - ) - - bucket = self.get_client().create_bucket("test-bucket") - - for i in range(0, 2): - bucket.blob( - f"test_{i}.txt", - size=123 * i, - content_type="text/plain", - ) - - token = self.get_client().get_short_lived_downscoped_access_token("bucket") - - env_vars = self.get_client().get_token_as_env_variables(token[0]) - - if self.get_type() == "s3": - self.assertEqual( - list(env_vars.keys()), - [ - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_ENDPOINT_URL", - "AWS_SESSION_TOKEN", - "AWS_S3_FUSE_CONFIG", - ], - ) - - # create a s3 client with the downscoped token - s3 = boto3.client("s3", **token[0]) - - objects = s3.list_objects(Bucket="bucket") - print(objects) - self.assertEqual( - [x["Key"] for x in objects["Contents"]], - ["test_0.txt", "test_1.txt"], - ) - # TODO unified exception ? - with self.assertRaisesMessage( - botocore.exceptions.ClientError, - "An error occurred (AccessDenied) when calling the ListObjects operation: Access Denied.", - ): - # should blow up not allowed on that bucket - objects = s3.list_objects(Bucket="test-bucket") - - with self.assertRaisesMessage( - botocore.exceptions.ClientError, - "An error occurred (AccessDenied) when calling the CreateBucket operation: Access Denied.", - ): - # should blow up not allowed to create new bucket - s3.create_bucket(Bucket="not-empty-bucket") - - -# MOTO the lib to mock s3 doesn't work when you set and endpoint url -@override_settings(WORKSPACE_STORAGE_ENGINE_AWS_ENDPOINT_URL=None) -class APIS3TestCase(APITestCase, OnlyS3, TestCase): - def get_type(self): - return "s3" - - -class APIGcpTestCase(APITestCase, OnlyGCP, TestCase): - def get_type(self): - return "gcp" diff --git a/hexa/files/tests/test_schema.py b/hexa/files/tests/test_schema.py index a671c5ee0..4622c1553 100644 --- a/hexa/files/tests/test_schema.py +++ b/hexa/files/tests/test_schema.py @@ -2,17 +2,13 @@ from hexa.user_management.models import Feature, User from hexa.workspaces.models import Workspace -from .mocks.mockgcp import backend, mock_gcp_storage - class FilesTest(GraphQLTestCase): USER_WORKSPACE_ADMIN = None WORKSPACE = None @classmethod - @mock_gcp_storage def setUpTestData(cls): - backend.reset() Feature.objects.create(code="workspaces", force_activate=True) cls.USER_WORKSPACE_ADMIN = User.objects.create_user( @@ -33,7 +29,6 @@ def setUpTestData(cls): countries=[{"code": "AD"}], ) - @mock_gcp_storage def test_workspace_objects_authorized(self): self.client.force_login(self.USER_WORKSPACE_ADMIN) diff --git a/hexa/files/urls.py b/hexa/files/urls.py new file mode 100644 index 000000000..aebbe1e86 --- /dev/null +++ b/hexa/files/urls.py @@ -0,0 +1,10 @@ +from django.urls import path + +from . import views + +app_name = "files" + +urlpatterns = [ + path("up//", views.upload_file, name="upload_file"), + path("dl//", views.download_file, name="download_file"), +] diff --git a/hexa/files/views.py b/hexa/files/views.py index 60f00ef0e..c07b2b733 100644 --- a/hexa/files/views.py +++ b/hexa/files/views.py @@ -1 +1,28 @@ -# Create your views here. +from django.http import FileResponse, HttpRequest, HttpResponse, HttpResponseBadRequest +from django.views.decorators.csrf import csrf_exempt +from django.views.decorators.http import require_http_methods + +from hexa.files import storage + + +def download_file(request: HttpRequest, token: str) -> HttpResponse: + if hasattr(storage, "get_bucket_object_by_token") is False: + return HttpResponseBadRequest("Storage does not support token-based access") + object = storage.get_bucket_object_by_token(token) + full_path = storage.path(object.path) + as_attachment = request.GET.get("attachment", "false") + return FileResponse(open(full_path, "rb"), as_attachment=as_attachment == "true") + + +@require_http_methods(["POST", "PUT"]) +@csrf_exempt +def upload_file(request: HttpRequest, token: str) -> HttpResponse: + if hasattr(storage, "save_object_by_token") is False: + return HttpResponseBadRequest("Storage does not support token-based access") + try: + storage.save_object_by_token(token, request.body) + return HttpResponse(status=201) + except storage.exceptions.AlreadyExists: + return HttpResponseBadRequest("Object already exists") + except storage.exceptions.BadRequest: + return HttpResponseBadRequest("Invalid token") diff --git a/hexa/notebooks/tests/test_schema.py b/hexa/notebooks/tests/test_schema.py index 91cd24e95..382ac15e5 100644 --- a/hexa/notebooks/tests/test_schema.py +++ b/hexa/notebooks/tests/test_schema.py @@ -2,7 +2,6 @@ from django.conf import settings from hexa.core.test import GraphQLTestCase -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import ( Workspace, @@ -13,7 +12,6 @@ class NotebooksTest(GraphQLTestCase): @classmethod - @mock_gcp_storage def setUpTestData(cls): cls.USER_SABRINA = User.objects.create_user( "sabrina@bluesquarehub.com", diff --git a/hexa/pipelines/management/commands/pipelines_runner.py b/hexa/pipelines/management/commands/pipelines_runner.py index 4b82bfd6c..01a76507f 100644 --- a/hexa/pipelines/management/commands/pipelines_runner.py +++ b/hexa/pipelines/management/commands/pipelines_runner.py @@ -14,6 +14,7 @@ from django.core.signing import Signer from django.utils import timezone +from hexa.files import storage from hexa.pipelines.models import PipelineRun, PipelineRunState, PipelineType from hexa.pipelines.utils import generate_pipeline_container_name, mail_run_recipients @@ -244,6 +245,18 @@ def run_pipeline_docker(run: PipelineRun, image: str, env_vars: dict): "HEXA_RUN_ID": str(run.id), } ) + volumes = None + if storage.storage_type == "local": + # FIXME Get this from the storage directly + workspace_folder = os.path.join( + settings.WORKSPACE_STORAGE_LOCATION, run.pipeline.workspace.bucket_name + ) + volumes = { + workspace_folder: { + "bind": "/home/hexa/workspace", + "mode": "rw", + } + } container = docker_client.containers.run( image=image, command=cmd, @@ -251,6 +264,7 @@ def run_pipeline_docker(run: PipelineRun, image: str, env_vars: dict): network="openhexa", platform="linux/amd64", environment=env_vars, + volumes=volumes, detach=True, ) logger.debug("Container %s started", container.id) @@ -295,7 +309,7 @@ def run_pipeline(run: PipelineRun): run.refresh_from_db() env_vars = { - "HEXA_SERVER_URL": f"{settings.PIPELINE_API_URL}", + "HEXA_SERVER_URL": f"{settings.INTERNAL_BASE_URL}", "HEXA_TOKEN": Signer().sign_object(run.access_token), "HEXA_WORKSPACE": run.pipeline.workspace.slug, "HEXA_RUN_ID": str(run.id), diff --git a/hexa/pipelines/schema/mutations.py b/hexa/pipelines/schema/mutations.py index 02e860e3d..0244244d9 100644 --- a/hexa/pipelines/schema/mutations.py +++ b/hexa/pipelines/schema/mutations.py @@ -8,7 +8,7 @@ from hexa.analytics.api import track from hexa.databases.utils import get_table_definition -from hexa.files.api import NotFound, get_storage +from hexa.files import storage from hexa.pipelines.authentication import PipelineRunUser from hexa.pipelines.models import ( InvalidTimeoutValueError, @@ -28,7 +28,7 @@ # ease mocking def get_bucket_object(bucket_name, file): - return get_storage().get_bucket_object(bucket_name, file) + return storage.get_bucket_object(bucket_name, file) @pipelines_mutations.field("createPipeline") @@ -73,7 +73,7 @@ def resolve_create_pipeline(_, info, **kwargs): event_properties, ) - except NotFound: + except storage.exceptions.NotFound: return {"success": False, "errors": ["FILE_NOT_FOUND"]} except IntegrityError: @@ -466,7 +466,7 @@ def resolve_add_pipeline_output(_, info, **kwargs): workspace.bucket_name, input["uri"][len(f"gs://{workspace.bucket_name}/") :], ) - except NotFound: + except storage.exceptions.NotFound: return {"success": False, "errors": ["FILE_NOT_FOUND"]} elif input.get("type") == "db" and not get_table_definition( workspace, input.get("name") diff --git a/hexa/pipelines/schema/types.py b/hexa/pipelines/schema/types.py index 7102ee07b..e925afd07 100644 --- a/hexa/pipelines/schema/types.py +++ b/hexa/pipelines/schema/types.py @@ -6,8 +6,7 @@ from hexa.core.graphql import result_page from hexa.databases.utils import get_table_definition -from hexa.files.api import get_storage -from hexa.files.basefs import NotFound +from hexa.files import storage from hexa.pipelines.models import Pipeline, PipelineRun, PipelineVersion from hexa.workspaces.models import Workspace from hexa.workspaces.schema.types import workspace_permissions @@ -30,7 +29,7 @@ def get_bucket_object(bucket_name, file): - return get_storage().get_bucket_object(bucket_name, file) + return storage.get_bucket_object(bucket_name, file) @workspace_permissions.field("createPipeline") @@ -255,7 +254,7 @@ def resolve_pipeline_run_outputs(run: PipelineRun, info, **kwargs): ) else: result.append(output) - except NotFound: + except storage.exceptions.NotFound: # File object might be deleted continue except Exception as e: diff --git a/hexa/pipelines/tests/test_schema/test_pipeline_versions.py b/hexa/pipelines/tests/test_schema/test_pipeline_versions.py index 1b77f164c..73fddc434 100644 --- a/hexa/pipelines/tests/test_schema/test_pipeline_versions.py +++ b/hexa/pipelines/tests/test_schema/test_pipeline_versions.py @@ -1,7 +1,6 @@ from unittest.mock import patch from hexa.core.test import GraphQLTestCase -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage from hexa.pipelines.models import Pipeline from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import ( @@ -18,7 +17,6 @@ class PipelineVersionsTest(GraphQLTestCase): PIPELINE = None @classmethod - @mock_gcp_storage def setUpTestData(cls): cls.USER_ROOT = User.objects.create_user( "root@bluesquarehub.com", diff --git a/hexa/pipelines/tests/test_schema/test_pipelines.py b/hexa/pipelines/tests/test_schema/test_pipelines.py index fef3c0b99..7896305c6 100644 --- a/hexa/pipelines/tests/test_schema/test_pipelines.py +++ b/hexa/pipelines/tests/test_schema/test_pipelines.py @@ -10,8 +10,7 @@ from django.core.signing import Signer from hexa.core.test import GraphQLTestCase -from hexa.files.api import NotFound -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage +from hexa.files.backends.exceptions import NotFound from hexa.pipelines.models import ( Pipeline, PipelineRecipient, @@ -42,7 +41,6 @@ class PipelinesV2Test(GraphQLTestCase): WS2 = None @classmethod - @mock_gcp_storage def setUpTestData(cls): cls.USER_ROOT = User.objects.create_user( "root@bluesquarehub.com", @@ -1423,7 +1421,7 @@ def test_mail_run_recipients_manual_trigger(self): ) self.assertListEqual([self.USER_ROOT.email], mail.outbox[0].recipients()) self.assertTrue( - f"https://{settings.NEW_FRONTEND_DOMAIN}/workspaces/{pipeline.workspace.slug}/pipelines/{pipeline.code}/runs/{run.id}" + f"{settings.NEW_FRONTEND_DOMAIN}/workspaces/{pipeline.workspace.slug}/pipelines/{pipeline.code}/runs/{run.id}" in mail.outbox[0].body ) @@ -1458,7 +1456,7 @@ def test_mail_run_recipients_scheduled_trigger(self): any(self.USER_ROOT.email in email.recipients() for email in mail.outbox) ) self.assertTrue( - f"https://{settings.NEW_FRONTEND_DOMAIN}/workspaces/{pipeline.workspace.slug}/pipelines/{pipeline.code}/runs/{run.id}" + f"{settings.NEW_FRONTEND_DOMAIN}/workspaces/{pipeline.workspace.slug}/pipelines/{pipeline.code}/runs/{run.id}" in mail.outbox[0].body ) diff --git a/hexa/pipelines/tests/test_views.py b/hexa/pipelines/tests/test_views.py index 1be644f06..4f89705c7 100644 --- a/hexa/pipelines/tests/test_views.py +++ b/hexa/pipelines/tests/test_views.py @@ -8,7 +8,6 @@ from django.urls import reverse from hexa.core.test import TestCase -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage from hexa.pipelines.models import Pipeline, PipelineRunTrigger, PipelineType from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import ( @@ -20,7 +19,6 @@ class ViewsTest(TestCase): @classmethod - @mock_gcp_storage def setUpTestData(cls): cls.WORKSPACE_FEATURE = Feature.objects.create(code="workspaces") cls.USER_JANE = User.objects.create_user( diff --git a/hexa/pipelines/utils.py b/hexa/pipelines/utils.py index f24ca6a21..e56fdb8af 100644 --- a/hexa/pipelines/utils.py +++ b/hexa/pipelines/utils.py @@ -35,7 +35,7 @@ def mail_run_recipients(run: PipelineRun): if run.duration is not None else datetime.timedelta(seconds=0) ), - "run_url": f"https://{settings.NEW_FRONTEND_DOMAIN}/workspaces/{workspace_slug}/pipelines/{run.pipeline.code}/runs/{run.id}", + "run_url": f"{settings.NEW_FRONTEND_DOMAIN}/workspaces/{workspace_slug}/pipelines/{run.pipeline.code}/runs/{run.id}", }, recipient_list=[recipient.email], ) diff --git a/hexa/user_management/tests/test_schema.py b/hexa/user_management/tests/test_schema.py index 097753590..9d8d0447e 100644 --- a/hexa/user_management/tests/test_schema.py +++ b/hexa/user_management/tests/test_schema.py @@ -542,7 +542,9 @@ def test_reset_password(self): ) self.assertEqual(len(mail.outbox), 1) - self.assertEqual(mail.outbox[0].subject, "Password reset on localhost:3000") + self.assertEqual( + mail.outbox[0].subject, "Password reset on http://localhost:3000" + ) def test_reset_password_wrong_email(self): r = self.run_query( diff --git a/hexa/workspaces/models.py b/hexa/workspaces/models.py index 0208d0fd7..03a0fd754 100644 --- a/hexa/workspaces/models.py +++ b/hexa/workspaces/models.py @@ -30,7 +30,7 @@ update_database_password, ) from hexa.datasets.models import Dataset -from hexa.files.api import get_storage +from hexa.files import storage from hexa.user_management.models import User @@ -57,7 +57,7 @@ def generate_database_name(): # ease patching def load_bucket_sample_data(bucket_name): - get_storage().load_bucket_sample_data(bucket_name) + storage.load_bucket_sample_data(bucket_name) validate_workspace_slug = RegexValidator( @@ -69,7 +69,6 @@ def load_bucket_sample_data(bucket_name): def create_workspace_bucket(workspace_slug: str): - storage = get_storage() while True: suffix = get_random_string( 4, allowed_chars=string.ascii_lowercase + string.digits @@ -78,7 +77,7 @@ def create_workspace_bucket(workspace_slug: str): # Bucket names must be unique across all of Google Cloud, so we add a suffix to the workspace slug # When separated by a dot, each segment can be up to 63 characters long return storage.create_bucket( - f"{(settings.WORKSPACE_BUCKET_PREFIX + workspace_slug)[:63]}-{suffix}", + f"{(settings.WORKSPACE_BUCKET_PREFIX + workspace_slug)[:63]}.{suffix}", labels={"hexa-workspace": workspace_slug}, ) except ValidationError: @@ -117,12 +116,12 @@ def create_if_has_perm( create_kwargs["db_name"] = db_name create_database(db_name, db_password) - bucket = create_workspace_bucket(slug) - create_kwargs["bucket_name"] = bucket.name + bucket_name = create_workspace_bucket(slug) + create_kwargs["bucket_name"] = bucket_name if load_sample_data: load_database_sample_data(db_name) - load_bucket_sample_data(bucket.name) + load_bucket_sample_data(bucket_name) workspace = self.create(**create_kwargs) diff --git a/hexa/workspaces/tests/test_models.py b/hexa/workspaces/tests/test_models.py index 56b9c91df..1b23345d2 100644 --- a/hexa/workspaces/tests/test_models.py +++ b/hexa/workspaces/tests/test_models.py @@ -5,7 +5,7 @@ from django.core.exceptions import ObjectDoesNotExist, PermissionDenied from hexa.core.test import TestCase -from hexa.files.tests.mocks.mockgcp import backend +from hexa.files import storage from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import ( Connection, @@ -23,9 +23,7 @@ class WorkspaceTest(TestCase): USER_JULIA = None @classmethod - @backend.mock_storage def setUpTestData(cls): - backend.reset() cls.USER_SERENA = User.objects.create_user( "serena@bluesquarehub.com", "serena's password", @@ -42,7 +40,10 @@ def setUpTestData(cls): user=cls.USER_JULIA, ) - @backend.mock_storage + def setUp(self) -> None: + storage.reset() + return super().setUp() + def test_create_workspace_regular_user(self): with self.assertRaises(PermissionDenied): workspace = Workspace.objects.create_if_has_perm( @@ -51,9 +52,8 @@ def test_create_workspace_regular_user(self): description="This is test for creating workspace", ) workspace.save() - self.assertTrue("hexa-test-senegal-workspace" in backend.buckets) + self.assertTrue("hexa-test-senegal-workspace" in storage.buckets) - @backend.mock_storage def test_create_workspace_no_slug(self): with patch("secrets.token_hex", lambda _: "mock"), patch( "hexa.workspaces.models.create_database" @@ -66,7 +66,6 @@ def test_create_workspace_no_slug(self): self.assertEqual(workspace.slug, "this-is-a-very-long-workspace-name") self.assertTrue(len(workspace.slug) <= 63) - @backend.mock_storage def test_create_workspace_with_underscore(self): with patch("secrets.token_hex", lambda _: "mock"), patch( "hexa.workspaces.models.create_database" @@ -76,10 +75,10 @@ def test_create_workspace_with_underscore(self): name="Worksp?ace_wi😱th_und~ersc!/ore", description="Description", ) + self.assertEqual(workspace.slug, "worksp-ace-with-und-ersc-ore") - self.assertTrue(workspace.bucket_name in backend.buckets) + self.assertTrue(storage.bucket_exists(workspace.bucket_name)) - @backend.mock_storage def test_create_workspace_with_random_characters(self): with patch("secrets.token_hex", lambda _: "mock"), patch( "hexa.workspaces.models.create_database" @@ -92,7 +91,6 @@ def test_create_workspace_with_random_characters(self): self.assertEqual(workspace.slug, "1workspace-with-random-char") self.assertEqual(16, len(workspace.db_name)) - @backend.mock_storage def test_create_workspace_admin_user(self): with patch("hexa.workspaces.models.create_database"), patch( "hexa.workspaces.models.load_database_sample_data" @@ -104,7 +102,6 @@ def test_create_workspace_admin_user(self): ) self.assertEqual(1, Workspace.objects.all().count()) - @backend.mock_storage def test_get_workspace_by_id(self): with patch("hexa.workspaces.models.create_database"), patch( "hexa.workspaces.models.load_database_sample_data" @@ -120,7 +117,6 @@ def test_get_workspace_by_id_failed(self): with self.assertRaises(ObjectDoesNotExist): Workspace.objects.get(pk="7bf4c750-f74b-4ed6-b7f7-b23e4cac4e2c") - @backend.mock_storage def test_create_workspaces_same_name(self): with patch("hexa.workspaces.models.create_database"), patch( "hexa.workspaces.models.load_database_sample_data" @@ -140,7 +136,6 @@ def test_create_workspaces_same_name(self): self.assertEqual(workspace_2.slug, "my-workspace-mock") - @backend.mock_storage def test_add_member(self): with patch("hexa.workspaces.models.create_database"), patch( "hexa.workspaces.models.load_database_sample_data" @@ -167,7 +162,6 @@ def test_add_member(self): ).notebooks_server_hash, ) - @backend.mock_storage def test_add_external_user(self): with patch("hexa.workspaces.models.create_database"), patch( "hexa.workspaces.models.load_database_sample_data" @@ -194,7 +188,6 @@ class ConnectionTest(TestCase): USER_ADMIN = None @classmethod - @backend.mock_storage def setUpTestData(cls): cls.USER_SERENA = User.objects.create_user( "serena@bluesquarehub.com", diff --git a/hexa/workspaces/tests/test_schema/test_workspace.py b/hexa/workspaces/tests/test_schema/test_workspace.py index 1ff42c7aa..c456b467b 100644 --- a/hexa/workspaces/tests/test_schema/test_workspace.py +++ b/hexa/workspaces/tests/test_schema/test_workspace.py @@ -9,7 +9,6 @@ from hexa.core.test import GraphQLTestCase from hexa.databases.utils import TableNotFound -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import ( Workspace, @@ -29,7 +28,6 @@ class WorkspaceTest(GraphQLTestCase): WORKSPACE = None @classmethod - @mock_gcp_storage def setUpTestData(cls): cls.USER_SABRINA = User.objects.create_user( "sabrina@bluesquarehub.com", @@ -124,7 +122,6 @@ def setUpTestData(cls): status=WorkspaceInvitationStatus.ACCEPTED, ) - @mock_gcp_storage def test_create_workspace_denied(self): self.client.force_login(self.USER_SABRINA) r = self.run_query( @@ -152,7 +149,6 @@ def test_create_workspace_denied(self): r["data"]["createWorkspace"], ) - @mock_gcp_storage def test_create_workspace_if_feature_flag_enabled(self): self.client.force_login(self.USER_JOE) r = self.run_query( @@ -217,7 +213,6 @@ def test_create_workspace_if_feature_flag_enabled(self): r["data"]["createWorkspace"], ) - @mock_gcp_storage def test_create_workspace_with_demo_data(self): with patch("hexa.workspaces.models.create_database"), patch( "hexa.workspaces.models.load_database_sample_data" @@ -264,7 +259,6 @@ def test_create_workspace_with_demo_data(self): self.assertTrue(mocked_load_bucket_sample.called) self.assertTrue(mocked_load_database_sample.called) - @mock_gcp_storage def test_create_workspace_without_demo_data(self): with patch("hexa.workspaces.models.create_database"), patch( "hexa.workspaces.models.load_database_sample_data" @@ -310,7 +304,6 @@ def test_create_workspace_without_demo_data(self): self.assertFalse(mocked_load_bucket_sample.called) self.assertFalse(mocked_load_database_sample.called) - @mock_gcp_storage def test_create_workspace_with_country(self): with patch("hexa.workspaces.models.create_database"), patch( "hexa.workspaces.models.load_database_sample_data" @@ -354,7 +347,6 @@ def test_create_workspace_with_country(self): r["data"]["createWorkspace"], ) - @mock_gcp_storage def test_get_workspace_not_member(self): self.client.force_login(self.USER_SABRINA) r = self.run_query( @@ -996,7 +988,7 @@ def test_invite_workspace_member_external_user(self): ) self.assertListEqual([user_email], mail.outbox[0].recipients()) self.assertIn( - f"https://{settings.NEW_FRONTEND_DOMAIN}/register?{urlencode({'email': user_email, 'token': encoded})}", + f"{settings.NEW_FRONTEND_DOMAIN}/register?{urlencode({'email': user_email, 'token': encoded})}", mail.outbox[0].body, ) @@ -1579,7 +1571,7 @@ def test_resend_workspace_member_invitation(self): ) self.assertListEqual([user_email], mail.outbox[0].recipients()) self.assertIn( - f"https://{settings.NEW_FRONTEND_DOMAIN}/register?{urlencode({'email': user_email, 'token': encoded})}", + f"{settings.NEW_FRONTEND_DOMAIN}/register?{urlencode({'email': user_email, 'token': encoded})}", mail.outbox[0].body, ) diff --git a/hexa/workspaces/tests/test_schema/test_workspace_connection.py b/hexa/workspaces/tests/test_schema/test_workspace_connection.py index ab20c6d6d..5b6c0d4d0 100644 --- a/hexa/workspaces/tests/test_schema/test_workspace_connection.py +++ b/hexa/workspaces/tests/test_schema/test_workspace_connection.py @@ -1,5 +1,4 @@ from hexa.core.test import GraphQLTestCase -from hexa.files.tests.mocks.mockgcp import backend from hexa.user_management.models import User from hexa.workspaces.models import ( Connection, @@ -15,9 +14,7 @@ class ConnectionTest(GraphQLTestCase): USER_ADMIN = None @classmethod - @backend.mock_storage def setUpTestData(cls): - backend.reset() cls.USER_SABRINA = User.objects.create_user( "sabrina@bluesquarehub.com", "standardpassword", diff --git a/hexa/workspaces/tests/test_views.py b/hexa/workspaces/tests/test_views.py index c96547ddc..9280fe74c 100644 --- a/hexa/workspaces/tests/test_views.py +++ b/hexa/workspaces/tests/test_views.py @@ -1,11 +1,8 @@ -from unittest.mock import patch - from django.core.signing import Signer from django.urls import reverse from hexa.core.test import TestCase from hexa.databases.api import get_db_server_credentials -from hexa.files.tests.mocks.mockgcp import mock_gcp_storage from hexa.pipelines.models import Pipeline, PipelineRunTrigger from hexa.user_management.models import Feature, FeatureFlag, User from hexa.workspaces.models import ( @@ -17,7 +14,6 @@ class ViewsTest(TestCase): @classmethod - @mock_gcp_storage def setUpTestData(cls): cls.USER_JANE = User.objects.create_user( "jane@bluesquarehub.com", @@ -95,13 +91,7 @@ def test_workspace_credentials_401(self): ) self.assertEqual(response.status_code, 401) - @patch( - "hexa.workspaces.views.get_short_lived_downscoped_access_token", - return_value=("gcs-token", 3600, "gcp"), - ) - def test_workspace_credentials_200( - self, mock_get_short_lived_downscoped_access_token - ): + def test_workspace_credentials_200(self): self.client.force_login(self.USER_JULIA) response = self.client.post( reverse("workspaces:credentials"), @@ -114,6 +104,7 @@ def test_workspace_credentials_200( response_data = response.json() self.assertEqual(response.status_code, 200) + self.maxDiff = None self.assertEqual( response_data["env"], { @@ -124,9 +115,7 @@ def test_workspace_credentials_200( "WORKSPACE_DATABASE_USERNAME": self.WORKSPACE.db_name, "WORKSPACE_DATABASE_PASSWORD": self.WORKSPACE.db_password, "WORKSPACE_DATABASE_URL": self.WORKSPACE.db_url, - "WORKSPACE_STORAGE_ENGINE": "gcp", - "WORKSPACE_STORAGE_ENGINE_GCP_ACCESS_TOKEN": "gcs-token", - "GCS_TOKEN": "gcs-token", + "WORKSPACE_STORAGE_ENGINE": "dummy", "HEXA_TOKEN": Signer().sign_object( self.WORKSPACE_MEMBERSHIP_JULIA.access_token ), @@ -139,12 +128,8 @@ def test_workspace_credentials_200( ).notebooks_server_hash, ) - @patch( - "hexa.workspaces.views.get_short_lived_downscoped_access_token", - return_value=("gcs-token", 3600, "gcp"), - ) def test_pipeline_invalid_credentials_404( - self, mock_get_short_lived_downscoped_access_token + self, ): run = self.PIPELINE.run( self.USER_JULIA, self.PIPELINE.last_version, PipelineRunTrigger.MANUAL, {} @@ -159,17 +144,13 @@ def test_pipeline_invalid_credentials_404( self.assertEqual(response.status_code, 404) - @patch( - "hexa.workspaces.views.get_short_lived_downscoped_access_token", - return_value=("gcs-token", 3600, "gcp"), - ) - def test_pipeline_credentials_200( - self, mock_get_short_lived_downscoped_access_token - ): + def test_pipeline_credentials_200(self): run = self.PIPELINE.run( self.USER_JULIA, self.PIPELINE.last_version, PipelineRunTrigger.MANUAL, {} ) + token = Signer().sign_object(run.access_token) + response = self.client.post( reverse("workspaces:credentials"), data={"workspace": self.WORKSPACE.slug}, @@ -180,6 +161,7 @@ def test_pipeline_credentials_200( response_data = response.json() self.assertEqual(response.status_code, 200) + self.maxDiff = None self.assertEqual( response_data["env"], { @@ -190,9 +172,7 @@ def test_pipeline_credentials_200( "WORKSPACE_DATABASE_USERNAME": self.WORKSPACE.db_name, "WORKSPACE_DATABASE_PASSWORD": self.WORKSPACE.db_password, "WORKSPACE_DATABASE_URL": self.WORKSPACE.db_url, - "WORKSPACE_STORAGE_ENGINE": "gcp", - "WORKSPACE_STORAGE_ENGINE_GCP_ACCESS_TOKEN": "gcs-token", - "GCS_TOKEN": "gcs-token", + "WORKSPACE_STORAGE_ENGINE": "dummy", "HEXA_TOKEN": token, }, ) diff --git a/hexa/workspaces/utils.py b/hexa/workspaces/utils.py index e2f2b9ab8..d41539a32 100644 --- a/hexa/workspaces/utils.py +++ b/hexa/workspaces/utils.py @@ -20,12 +20,12 @@ def send_workspace_invitation_email( title = gettext_lazy( f"You've been added to the workspace {invitation.workspace.name}" ) - action_url = f"https://{settings.NEW_FRONTEND_DOMAIN}/user/account" + action_url = f"{settings.NEW_FRONTEND_DOMAIN}/user/account" else: title = gettext_lazy( f"You've been invited to join the workspace {invitation.workspace.name} on OpenHEXA" ) - action_url = f"https://{settings.NEW_FRONTEND_DOMAIN}/register?{urlencode({'email': invitation.email, 'token': token})}" + action_url = f"{settings.NEW_FRONTEND_DOMAIN}/register?{urlencode({'email': invitation.email, 'token': token})}" send_mail( title=title, diff --git a/hexa/workspaces/views.py b/hexa/workspaces/views.py index 258bce689..6c90d5e21 100644 --- a/hexa/workspaces/views.py +++ b/hexa/workspaces/views.py @@ -5,23 +5,13 @@ from django.views.decorators.http import require_POST from hexa.databases.api import get_db_server_credentials -from hexa.files.api import get_storage +from hexa.files import storage from hexa.pipelines.models import PipelineRun from hexa.workspaces.models import Workspace, WorkspaceMembership # ease patching -def get_short_lived_downscoped_access_token(bucket_name): - return get_storage().get_short_lived_downscoped_access_token( - bucket_name=bucket_name - ) - - -def get_token_as_env_variables(token): - return get_storage().get_token_as_env_variables(token) - - @require_POST @csrf_exempt def credentials(request: HttpRequest, workspace_slug: str = None) -> HttpResponse: @@ -90,7 +80,9 @@ def credentials(request: HttpRequest, workspace_slug: str = None) -> HttpRespons ) # Populate the environment variables with the connections of the workspace - env = {} + env = { + "WORKSPACE_BUCKET_NAME": workspace.bucket_name, + } # Database credentials db_credentials = get_db_server_credentials() @@ -106,14 +98,10 @@ def credentials(request: HttpRequest, workspace_slug: str = None) -> HttpRespons ) # Bucket credentials - token, _expires_in, engine_key = get_short_lived_downscoped_access_token( - workspace.bucket_name - ) - env.update(get_token_as_env_variables(token)) env.update( { - "WORKSPACE_STORAGE_ENGINE": engine_key, - "WORKSPACE_BUCKET_NAME": workspace.bucket_name, + "WORKSPACE_STORAGE_ENGINE": storage.storage_type, + **storage.get_bucket_mount_config(workspace.bucket_name), } )