Skip to content

Commit

Permalink
🐛 when updating comp_task table always give actual wallet info (#4955)
Browse files Browse the repository at this point in the history
  • Loading branch information
matusdrobuliak66 authored Oct 31, 2023
1 parent 809a4d6 commit e04a211
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from aiohttp import web
from models_library.projects import ProjectID
from models_library.users import UserID
from models_library.wallets import ZERO_CREDITS, WalletID, WalletInfo
from pydantic import parse_obj_as

from ..application_settings import get_settings
from ..products.api import Product
from ..projects import api as projects_api
from ..users import preferences_api as user_preferences_api
from ..users.exceptions import UserDefaultWalletNotFoundError
from ..wallets import api as wallets_api
from ..wallets.errors import WalletNotEnoughCreditsError


async def get_wallet_info(
app: web.Application,
*,
product: Product,
user_id: UserID,
project_id: ProjectID,
product_name: str,
) -> WalletInfo | None:
app_settings = get_settings(app)
if not (
product.is_payment_enabled and app_settings.WEBSERVER_CREDIT_COMPUTATION_ENABLED
):
return None
project_wallet = await projects_api.get_project_wallet(app, project_id=project_id)
if project_wallet is None:
user_default_wallet_preference = await user_preferences_api.get_frontend_user_preference(
app,
user_id=user_id,
product_name=product_name,
preference_class=user_preferences_api.PreferredWalletIdFrontendUserPreference,
)
if user_default_wallet_preference is None:
raise UserDefaultWalletNotFoundError(uid=user_id)
project_wallet_id = parse_obj_as(WalletID, user_default_wallet_preference.value)
await projects_api.connect_wallet_to_project(
app,
product_name=product_name,
project_id=project_id,
user_id=user_id,
wallet_id=project_wallet_id,
)
else:
project_wallet_id = project_wallet.wallet_id

# Check whether user has access to the wallet
wallet = await wallets_api.get_wallet_with_available_credits_by_user_and_wallet(
app,
user_id=user_id,
wallet_id=project_wallet_id,
product_name=product_name,
)
if wallet.available_credits <= ZERO_CREDITS:
raise WalletNotEnoughCreditsError(
reason=f"Wallet {wallet.wallet_id} credit balance {wallet.available_credits}"
)
return WalletInfo(wallet_id=project_wallet_id, wallet_name=wallet.name)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from servicelib.logging_utils import log_decorator
from settings_library.utils_cli import create_json_encoder_wo_secrets

from ..products.api import get_product
from ._api_utils import get_wallet_info
from ._core_base import DataType, request_director_v2
from .exceptions import (
ClusterAccessForbidden,
Expand Down Expand Up @@ -113,6 +115,13 @@ async def create_or_update_pipeline(
"user_id": user_id,
"project_id": f"{project_id}",
"product_name": product_name,
"wallet_info": await get_wallet_info(
app,
product=get_product(app, product_name),
user_id=user_id,
project_id=project_id,
product_name=product_name,
),
}
# request to director-v2
try:
Expand All @@ -123,7 +132,7 @@ async def create_or_update_pipeline(
return computation_task_out

except DirectorServiceError as exc:
_logger.error(
_logger.error( # noqa: TRY400
"could not create pipeline from project %s: %s",
project_id,
exc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from models_library.clusters import ClusterID
from models_library.projects import ProjectID
from models_library.users import UserID
from models_library.wallets import ZERO_CREDITS, WalletID, WalletInfo
from pydantic import BaseModel, Field, ValidationError, parse_obj_as
from pydantic.types import NonNegativeInt
from servicelib.aiohttp.rest_responses import create_error_response, get_http_error
Expand All @@ -24,19 +23,16 @@

from .._constants import RQ_PRODUCT_KEY
from .._meta import API_VTAG as VTAG
from ..application_settings import get_settings
from ..db.plugin import get_database_engine
from ..login.decorators import login_required
from ..products import api as products_api
from ..projects import api as projects_api
from ..security.decorators import permission_required
from ..users import preferences_api as user_preferences_api
from ..users.exceptions import UserDefaultWalletNotFoundError
from ..utils_aiohttp import envelope_json_response
from ..version_control.models import CommitID
from ..wallets import api as wallets_api
from ..wallets.errors import WalletNotEnoughCreditsError
from ._abc import get_project_run_policy
from ._api_utils import get_wallet_info
from ._core_computations import ComputationsApi
from .exceptions import DirectorServiceError

Expand Down Expand Up @@ -98,54 +94,14 @@ async def start_computation(request: web.Request) -> web.Response:
)

# Get wallet information
wallet_info = None
product = products_api.get_current_product(request)
app_settings = get_settings(request.app)
if (
product.is_payment_enabled
and app_settings.WEBSERVER_CREDIT_COMPUTATION_ENABLED
):
project_wallet = await projects_api.get_project_wallet(
request.app, project_id=project_id
)
if project_wallet is None:
user_default_wallet_preference = await user_preferences_api.get_frontend_user_preference(
request.app,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
preference_class=user_preferences_api.PreferredWalletIdFrontendUserPreference,
)
if user_default_wallet_preference is None:
raise UserDefaultWalletNotFoundError(uid=req_ctx.user_id)
project_wallet_id = parse_obj_as(
WalletID, user_default_wallet_preference.value
)
await projects_api.connect_wallet_to_project(
request.app,
product_name=req_ctx.product_name,
project_id=project_id,
user_id=req_ctx.user_id,
wallet_id=project_wallet_id,
)
else:
project_wallet_id = project_wallet.wallet_id

# Check whether user has access to the wallet
wallet = (
await wallets_api.get_wallet_with_available_credits_by_user_and_wallet(
request.app,
user_id=req_ctx.user_id,
wallet_id=project_wallet_id,
product_name=req_ctx.product_name,
)
)
if wallet.available_credits <= ZERO_CREDITS:
raise WalletNotEnoughCreditsError(
reason=f"Wallet {wallet.wallet_id} credit balance {wallet.available_credits}"
)
wallet_info = WalletInfo(
wallet_id=project_wallet_id, wallet_name=wallet.name
)
wallet_info = await get_wallet_info(
request.app,
product=product,
user_id=req_ctx.user_id,
project_id=project_id,
product_name=req_ctx.product_name,
)

options = {
"start_pipeline": True,
Expand Down

0 comments on commit e04a211

Please sign in to comment.