Skip to content

Commit

Permalink
feat: adjust /packages to QPPE
Browse files Browse the repository at this point in the history
  • Loading branch information
janbritz committed Aug 5, 2024
1 parent 2159c33 commit 918eb03
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 49 deletions.
32 changes: 22 additions & 10 deletions questionpy_server/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from typing import Annotated, Any

from pydantic import BaseModel, ByteSize, ConfigDict, Field, FilePath, HttpUrl
from pydantic import BaseModel, ByteSize, ConfigDict, Field, HttpUrl

Check failure on line 8 in questionpy_server/api/models.py

View workflow job for this annotation

GitHub Actions / ci / ruff-lint

Ruff (F401)

questionpy_server/api/models.py:8:62: F401 `pydantic.HttpUrl` imported but unused

from questionpy_common.api.attempt import AttemptModel
from questionpy_common.api.question import QuestionModel
Expand All @@ -16,10 +16,21 @@
class PackageInfo(BaseModel):
model_config = ConfigDict(use_enum_values=True)

package_hash: str
short_name: str
namespace: str
name: dict[str, str]
type: PackageType
author: str | None
url: str | None
languages: set[str] | None
description: dict[str, str] | None
icon: str | None
license: str | None
tags: set[str] | None


class PackageVersionSpecificInfo(BaseModel):
package_hash: str
version: Annotated[
str,
Field(
Expand All @@ -29,14 +40,15 @@ class PackageInfo(BaseModel):
r"(\+([0-9a-zA-Z-]+(\.[0-9a-zA-Z-]+)*))?$"
),
]
type: PackageType
author: str | None
url: HttpUrl | None
languages: list[str] | None
description: dict[str, str] | None
icon: FilePath | HttpUrl | None
license: str | None
tags: list[str] | None


class PackageVersionInfo(PackageInfo, PackageVersionSpecificInfo):
pass


class PackageVersionsInfo(BaseModel):
manifest: PackageInfo
versions: list[PackageVersionSpecificInfo]


class MainBaseModel(BaseModel):
Expand Down
6 changes: 2 additions & 4 deletions questionpy_server/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@
async def get_packages(request: web.Request) -> web.Response:
qpyserver: "QPyServer" = request.app["qpy_server_app"]

packages = qpyserver.package_collection.get_packages()
data = [package.get_info() for package in packages]

return json_response(data=data)
package_versions_infos = qpyserver.package_collection.get_package_versions_infos()
return json_response(data=package_versions_infos)


@routes.get(r"/packages/{package_hash:\w+}")
Expand Down
42 changes: 38 additions & 4 deletions questionpy_server/collector/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import overload

from questionpy_server import WorkerPool
from questionpy_server.api.models import PackageInfo, PackageVersionsInfo, PackageVersionSpecificInfo
from questionpy_server.collector.abc import BaseCollector
from questionpy_server.collector.local_collector import LocalCollector
from questionpy_server.collector.repo_collector import RepoCollector
Expand All @@ -30,6 +31,8 @@ def __init__(self, worker_pool: WorkerPool):
self._index_by_identifier: dict[str, dict[SemVer, Package]] = {}
"""dict[identifier, dict[version, Package]]"""

self._package_versions_infos: list[PackageVersionsInfo] | None = None

self._lock: Lock | None = None

def get_by_hash(self, package_hash: str) -> Package | None:
Expand Down Expand Up @@ -66,13 +69,38 @@ def get_by_identifier_and_version(self, identifier: str, version: SemVer) -> Pac
"""
return self._index_by_identifier.get(identifier, {}).get(version, None)

def get_packages(self) -> set[Package]:
"""Returns all packages in the index (excluding packages from LMSs).
def get_package_versions_infos(self) -> list[PackageVersionsInfo]:
"""Returns an overview of every package and its versions (excluding packages from LMSs).
TODO: optimize further?
Returns:
set of packages
list of PackageVersionsInfo
"""
return {package for packages in self._index_by_identifier.values() for package in packages.values()}
if self._package_versions_infos is not None:
return self._package_versions_infos

package_versions_infos = []

for package_versions in self._index_by_identifier.values():
versions = []
sorted_package_versions = sorted(package_versions, reverse=True)
for version in sorted_package_versions:
package_version = package_versions[version]
versions.append(PackageVersionSpecificInfo(package_hash=package_version.hash, version=str(version)))

# A package should always have at least one package version, we try-except just in case.
try:
latest_package_version = package_versions[sorted_package_versions[0]]
package_info = PackageInfo(**latest_package_version.manifest.model_dump())
except KeyError:
continue

package_versions_info = PackageVersionsInfo(manifest=package_info, versions=versions)
package_versions_infos.append(package_versions_info)

self._package_versions_infos = package_versions_infos
return self._package_versions_infos

@overload
async def register_package(
Expand Down Expand Up @@ -128,6 +156,9 @@ async def register_package(
else:
package_versions[package.manifest.version] = package

# Force recalculation of list[PackageVersionsInfo].
self._package_versions_infos = None

return package

async def unregister_package(self, package_hash: str, source: BaseCollector) -> None:
Expand Down Expand Up @@ -158,6 +189,9 @@ async def unregister_package(self, package_hash: str, source: BaseCollector) ->
if not package_versions:
self._index_by_identifier.pop(package.manifest.identifier, None)

# Force recalculation of list[PackageVersionsInfo].
self._package_versions_infos = None

if len(package.sources) == 0:
# Package has no more sources; remove it from the index.
self._index_by_hash.pop(package_hash, None)
9 changes: 5 additions & 4 deletions questionpy_server/collector/package_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import HttpUrl

from questionpy_server import WorkerPool
from questionpy_server.api.models import PackageVersionsInfo
from questionpy_server.cache import FileLimitLRU
from questionpy_server.collector.indexer import Indexer
from questionpy_server.collector.lms_collector import LMSCollector
Expand Down Expand Up @@ -121,10 +122,10 @@ def get_by_identifier_and_version(self, identifier: str, version: SemVer) -> "Pa

raise FileNotFoundError

def get_packages(self) -> set["Package"]:
"""Returns a set of all available packages.
def get_package_versions_infos(self) -> list[PackageVersionsInfo]:
"""Returns an overview of every package and its versions.
Returns:
set of packages
list of PackageVersionsInfo
"""
return self._indexer.get_packages()
return self._indexer.get_package_versions_infos()
4 changes: 2 additions & 2 deletions questionpy_server/factories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# (c) Technische Universität Berlin, innoCampus <[email protected]>

from .attempt import AttemptScoredFactory
from .package import PackageInfoFactory
from .package import PackageVersionInfoFactory
from .question_state import RequestBaseDataFactory

__all__ = ["AttemptScoredFactory", "PackageInfoFactory", "RequestBaseDataFactory"]
__all__ = ["AttemptScoredFactory", "PackageVersionInfoFactory", "RequestBaseDataFactory"]
6 changes: 3 additions & 3 deletions questionpy_server/factories/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from faker import Faker
from polyfactory.factories.pydantic_factory import ModelFactory

from questionpy_server.api.models import PackageInfo
from questionpy_server.api.models import PackageVersionInfo

languages = ["en", "de"]
fake = Faker()


class PackageInfoFactory(ModelFactory):
__model__ = PackageInfo
class PackageVersionInfoFactory(ModelFactory):
__model__ = PackageVersionInfo

@staticmethod
def author() -> str:
Expand Down
8 changes: 4 additions & 4 deletions questionpy_server/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import contextlib
from pathlib import Path

from questionpy_server.api.models import PackageInfo
from questionpy_server.api.models import PackageVersionInfo
from questionpy_server.collector.abc import BaseCollector
from questionpy_server.collector.lms_collector import LMSCollector
from questionpy_server.collector.local_collector import LocalCollector
Expand Down Expand Up @@ -99,7 +99,7 @@ class Package:

sources: PackageSources

_info: PackageInfo | None
_info: PackageVersionInfo | None
_path: Path | None

def __init__(
Expand Down Expand Up @@ -127,7 +127,7 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
return self.hash == other.hash

def get_info(self) -> PackageInfo:
def get_info(self) -> PackageVersionInfo:
"""Returns the package info.
Returns:
Expand All @@ -136,7 +136,7 @@ def get_info(self) -> PackageInfo:
if not self._info:
tmp = self.manifest.model_dump()
tmp["version"] = str(tmp["version"])
self._info = PackageInfo(**tmp, package_hash=self.hash)
self._info = PackageVersionInfo(**tmp, package_hash=self.hash)
return self._info

async def get_path(self) -> Path:
Expand Down
9 changes: 4 additions & 5 deletions questionpy_server/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from collections.abc import Sequence
from hashlib import sha256
from io import BytesIO
from json import JSONDecodeError, loads
from json import JSONDecodeError, dumps, loads
from typing import TYPE_CHECKING, Literal, NamedTuple, overload

from aiohttp import BodyPartReader
from aiohttp.abc import Request
from aiohttp.log import web_logger
from aiohttp.web_exceptions import HTTPBadRequest, HTTPRequestEntityTooLarge
from aiohttp.web_response import Response
from aiohttp.web_response import json_response as aiohttp_json_response
from pydantic import BaseModel, ValidationError
from pydantic_core import to_jsonable_python

from questionpy_common import constants
from questionpy_common.constants import KiB
Expand All @@ -36,10 +38,7 @@ def json_response(data: Sequence[BaseModel] | BaseModel, status: int = 200) -> R
Returns:
Response: A response object.
"""
if isinstance(data, Sequence):
json_list = f'[{",".join(x.json() for x in data)}]'
return Response(text=json_list, status=status, content_type="application/json")
return Response(text=data.model_dump_json(), status=status, content_type="application/json")
return aiohttp_json_response(data, status=status, dumps=lambda model: dumps(model, default=to_jsonable_python))


def create_model_from_json(json: object | str, param_class: type[M]) -> M:
Expand Down
74 changes: 69 additions & 5 deletions tests/questionpy_server/api/test_models.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,87 @@
# This file is part of the QuestionPy Server. (https://questionpy.org)
# The QuestionPy Server is free software released under terms of the MIT license. See LICENSE.md.
# (c) Technische Universität Berlin, innoCampus <[email protected]>

from hashlib import sha256
from io import BytesIO
from itertools import pairwise, starmap
from operator import ge
from unittest.mock import Mock

import pytest
from aiohttp import FormData
from aiohttp.pytest_plugin import AiohttpClient
from aiohttp.test_utils import TestClient
from pydantic import TypeAdapter

from questionpy_server.api.models import PackageInfo
from questionpy_server.api.models import PackageVersionInfo, PackageVersionsInfo
from questionpy_server.app import QPyServer
from questionpy_server.collector.local_collector import LocalCollector
from questionpy_server.utils.manifest import ComparableManifest
from tests.conftest import PACKAGE
from tests.test_data.factories import ManifestFactory


@pytest.mark.parametrize(
"packages",
[
# No packages.
{},
# One package.
{"ns1": {"0.1.0"}},
# Two packages.
{"ns1": {"0.1.0"}, "ns2": {"0.1.0"}},
# Multiple versions.
{"ns1": {"1.0.0", "0.0.1"}, "ns2": {"1.0.0", "0.1.0", "0.0.1"}},
# Multiple versions, unsorted.
{"ns1": {"0.0.1", "1.0.0"}, "ns2": {"0.1.0", "0.0.1", "1.0.0"}},
],
)
async def test_packages(qpy_server: QPyServer, aiohttp_client: AiohttpClient, packages: dict[str, set[str]]) -> None:
async def add_package_version(server: QPyServer, manifest: ComparableManifest) -> None:
package_hash = sha256((manifest.short_name + manifest.namespace + str(manifest.version)).encode()).hexdigest()
await server.package_collection._indexer.register_package(package_hash, manifest, Mock(spec=LocalCollector))

manifests: dict[str, dict[str, ComparableManifest]] = {}
for namespace, versions in packages.items():
for version in versions:
expected_manifest = ManifestFactory.build(namespace=namespace, short_name=namespace, version=version)
manifests.setdefault(namespace, {})[version] = expected_manifest
await add_package_version(qpy_server, expected_manifest)

async def test_packages(client: TestClient) -> None:
client = await aiohttp_client(qpy_server.web_app)
res = await client.request("GET", "/packages")

# Assert that a valid list of PackageVersionsInfo is returned.
assert res.status == 200
data = await res.json()
TypeAdapter(list[PackageInfo]).validate_python(data)
package_versions_infos: list[PackageVersionsInfo] = TypeAdapter(list[PackageVersionsInfo]).validate_python(data)

expected_package_count = len(packages)
assert len(package_versions_infos) == expected_package_count

if expected_package_count <= 0:
return

actual_namespaces = []

# Iterate over all actual packages.
for package_versions_info in package_versions_infos:
actual_package_info = package_versions_info.manifest
actual_versions = [version.version for version in package_versions_info.versions]
# Assert that each package version is available and in the correct order.
assert set(actual_versions) == packages[actual_package_info.namespace]
assert all(starmap(ge, pairwise(actual_versions))), "The package versions are not sorted in descending order."
# Assert that the actual package info is a subset of the manifest of the latest package version.
actual_package_info_items = actual_package_info.model_dump().items()
latest_manifest_items = manifests[actual_package_info.namespace][actual_versions[0]].model_dump().items()
assert actual_package_info_items <= latest_manifest_items, (
"Actual package info was not derived from the " "latest package version."
)

actual_namespaces.append(actual_package_info.namespace)

# Assert that every expected package is returned.
assert set(actual_namespaces) == packages.keys()


async def test_extract_info(client: TestClient) -> None:
Expand All @@ -29,7 +93,7 @@ async def test_extract_info(client: TestClient) -> None:

assert res.status == 201
data = await res.json()
PackageInfo.model_validate(data)
PackageVersionInfo.model_validate(data)


async def test_extract_info_faulty(client: TestClient) -> None:
Expand Down
Loading

0 comments on commit 918eb03

Please sign in to comment.