Skip to content

Commit

Permalink
set_handlers: models_to_fetch direct links support (#218)
Browse files Browse the repository at this point in the history
Closes: #217

---------

Signed-off-by: Alexander Piskun <[email protected]>
  • Loading branch information
bigcat88 authored Feb 5, 2024
1 parent 6d88c6e commit 526248a
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 20 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/analysis-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ jobs:
- nextcloud: "27.1.4"
python: "3.10"
php-version: "8.1"
timeout-minutes: 60

services:
mariadb:
Expand Down Expand Up @@ -209,6 +210,7 @@ jobs:
php-version: "8.2"
env:
NC_dbname: nextcloud_abz
timeout-minutes: 60

services:
postgres:
Expand Down Expand Up @@ -361,6 +363,7 @@ jobs:
needs: [analysis]
runs-on: ubuntu-22.04
name: stable27 • 🐘8.1 • 🐍3.11 • OCI
timeout-minutes: 60

services:
oracle:
Expand Down Expand Up @@ -483,6 +486,7 @@ jobs:
fail-fast: false
matrix:
nextcloud: [ 'stable27', 'stable28', 'master' ]
timeout-minutes: 60

services:
mariadb:
Expand Down Expand Up @@ -661,6 +665,7 @@ jobs:
nextcloud: [ 'stable27', 'stable28', 'master' ]
env:
NC_dbname: nextcloud_abz
timeout-minutes: 60

services:
postgres:
Expand Down Expand Up @@ -841,6 +846,7 @@ jobs:
nextcloud: [ 'stable26', 'stable27', 'stable28', 'master' ]
env:
NEXTCLOUD_URL: "http://localhost:8080/index.php"
timeout-minutes: 60

steps:
- name: Set up php
Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

All notable changes to this project will be documented in this file.

## [0.10.0 - 2024-02-0x]

### Added

- set_handlers: `models_to_fetch` can now accept direct links to a files to download. #217

### Changed

- adjusted code related to changes in AppAPI `2.0.3` #216

## [0.9.0 - 2024-01-25]

### Added
Expand Down
88 changes: 73 additions & 15 deletions nc_py_api/ex_app/integration_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""FastAPI directly related stuff."""

import asyncio
import builtins
import hashlib
import json
import os
import typing
from urllib.parse import urlparse

import httpx
from fastapi import (
BackgroundTasks,
Depends,
Expand All @@ -20,6 +24,7 @@
from .._misc import get_username_secret_from_headers
from ..nextcloud import AsyncNextcloudApp, NextcloudApp
from ..talk_bot import TalkBotMessage
from .defs import LogLvl
from .misc import persistent_storage


Expand Down Expand Up @@ -163,26 +168,79 @@ def __map_app_static_folders(fast_api_app: FastAPI):
fast_api_app.mount(f"/{mnt_dir}", staticfiles.StaticFiles(directory=mnt_dir_path), name=mnt_dir)


def __fetch_models_task(
nc: NextcloudApp,
models: dict[str, dict],
) -> None:
def __fetch_models_task(nc: NextcloudApp, models: dict[str, dict]) -> None:
if models:
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401

class TqdmProgress(tqdm):
def display(self, msg=None, pos=None):
nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100))
return super().display(msg, pos)

current_progress = 0
percent_for_each = min(int(100 / len(models)), 99)
for model in models:
workers = models[model].pop("max_workers", 2)
cache = models[model].pop("cache_dir", persistent_storage())
snapshot_download(model, tqdm_class=TqdmProgress, **models[model], max_workers=workers, cache_dir=cache)
if model.startswith(("http://", "https://")):
__fetch_model_as_file(current_progress, percent_for_each, nc, model, models[model])
else:
__fetch_model_as_snapshot(current_progress, percent_for_each, nc, model, models[model])
current_progress += percent_for_each
nc.set_init_status(100)


def __fetch_model_as_file(
current_progress: int, progress_for_task: int, nc: NextcloudApp, model_path: str, download_options: dict
) -> None:
result_path = download_options.pop("save_path", urlparse(model_path).path.split("/")[-1])
try:
with httpx.stream("GET", model_path, follow_redirects=True) as response:
if not response.is_success:
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' returned {response.status_code} status.")
return
downloaded_size = 0
linked_etag = ""
for each_history in response.history:
linked_etag = each_history.headers.get("X-Linked-ETag", "")
if linked_etag:
break
if not linked_etag:
linked_etag = response.headers.get("X-Linked-ETag", response.headers.get("ETag", ""))
total_size = int(response.headers.get("Content-Length"))
try:
existing_size = os.path.getsize(result_path)
except OSError:
existing_size = 0
if linked_etag and total_size == existing_size:
with builtins.open(result_path, "rb") as file:
sha256_hash = hashlib.sha256()
for byte_block in iter(lambda: file.read(4096), b""):
sha256_hash.update(byte_block)
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
nc.set_init_status(min(current_progress + progress_for_task, 99))
return

with builtins.open(result_path, "wb") as file:
last_progress = current_progress
for chunk in response.iter_bytes(5 * 1024 * 1024):
downloaded_size += file.write(chunk)
if total_size:
new_progress = min(current_progress + int(progress_for_task * downloaded_size / total_size), 99)
if new_progress != last_progress:
nc.set_init_status(new_progress)
last_progress = new_progress
except Exception as e: # noqa pylint: disable=broad-exception-caught
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' raised an exception: {e}")


def __fetch_model_as_snapshot(
current_progress: int, progress_for_task, nc: NextcloudApp, mode_name: str, download_options: dict
) -> None:
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401

class TqdmProgress(tqdm):
def display(self, msg=None, pos=None):
nc.set_init_status(min(current_progress + int(progress_for_task * self.n / self.total), 99))
return super().display(msg, pos)

workers = download_options.pop("max_workers", 2)
cache = download_options.pop("cache_dir", persistent_storage())
snapshot_download(mode_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache)


class AppAPIAuthMiddleware:
"""Pure ASGI AppAPIAuth Middleware."""

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ profile = "black"
master.py-version = "3.10"
master.extension-pkg-allow-list = ["pydantic"]
design.max-attributes = 8
design.max-locals = 16
design.max-locals = 20
design.max-branches = 16
design.max-returns = 8
design.max-args = 7
Expand Down
26 changes: 22 additions & 4 deletions tests/_install_init_handler_models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
from contextlib import asynccontextmanager
from pathlib import Path

from fastapi import FastAPI

from nc_py_api import NextcloudApp, ex_app

MODEL_NAME = "MBZUAI/LaMini-T5-61M"
INVALID_URL = "https://invalid_url"
MODEL_NAME1 = "MBZUAI/LaMini-T5-61M"
MODEL_NAME2 = "https://huggingface.co/MBZUAI/LaMini-T5-61M/resolve/main/pytorch_model.bin"
MODEL_NAME2_http = "http://huggingface.co/MBZUAI/LaMini-T5-61M/resolve/main/pytorch_model.bin"
INVALID_PATH = "https://huggingface.co/invalid_path"
SOME_FILE = "https://raw.githubusercontent.com/cloud-py-api/nc_py_api/main/README.md"


@asynccontextmanager
async def lifespan(_app: FastAPI):
ex_app.set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {}})
ex_app.set_handlers(
APP,
enabled_handler,
models_to_fetch={
INVALID_URL: {},
MODEL_NAME1: {},
MODEL_NAME2: {},
MODEL_NAME2_http: {},
INVALID_PATH: {},
SOME_FILE: {},
},
)
yield


Expand All @@ -19,9 +36,10 @@ async def lifespan(_app: FastAPI):
def enabled_handler(enabled: bool, _nc: NextcloudApp) -> str:
if enabled:
try:
assert ex_app.get_model_path(MODEL_NAME)
assert ex_app.get_model_path(MODEL_NAME1)
except Exception: # noqa
return "model not found"
return "model1 not found"
assert Path("pytorch_model.bin").is_file()
return ""


Expand Down

0 comments on commit 526248a

Please sign in to comment.