Skip to content

Commit

Permalink
Merge branch '3078-rate-limit-users' into develop
Browse files Browse the repository at this point in the history
Issue #3078
PR #3220
  • Loading branch information
cakekoa committed Apr 11, 2023
2 parents c5616b0 + 91d177a commit 52636ef
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 7 deletions.
62 changes: 62 additions & 0 deletions envs/monkey_zoo/blackbox/test_blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from http import HTTPStatus
from threading import Thread
from time import sleep
from typing import List
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -205,6 +206,67 @@ def make_request(monkey_island_requests, request_callback):
assert response_codes.count(HTTPStatus.TOO_MANY_REQUESTS) == 1


RATE_LIMIT_AGENT1_ID = uuid4()
RATE_LIMIT_AGENT2_ID = uuid4()


@pytest.mark.parametrize(
"request_callback, successful_request_status, max_requests_per_second",
[
(lambda mir: mir.get(GET_AGENT_OTP_ENDPOINT), HTTPStatus.OK, MAX_OTP_REQUESTS_PER_SECOND),
],
)
def test_rate_limit__agent_user(
island,
monkey_island_requests,
request_callback,
successful_request_status,
max_requests_per_second,
):
monkey_island_requests.login()
response = monkey_island_requests.get(GET_AGENT_OTP_ENDPOINT)
otp1 = response.json()["otp"]
response = monkey_island_requests.get(GET_AGENT_OTP_ENDPOINT)
otp2 = response.json()["otp"]

agent1_requests = AgentRequests(island, RATE_LIMIT_AGENT1_ID, OTP(otp1))
agent1_requests.login()
agent2_requests = AgentRequests(island, RATE_LIMIT_AGENT2_ID, OTP(otp2))
agent2_requests.login()

threads = []
response_codes1: List[int] = []
response_codes2: List[int] = []

def make_request(agent_requests, request_callback, response_codes):
response = request_callback(agent_requests)
response_codes.append(response.status_code)

for _ in range(0, max_requests_per_second + 1):
t1 = Thread(
target=make_request,
args=(agent1_requests, request_callback, response_codes1),
daemon=True,
)
t1.start()
t2 = Thread(
target=make_request,
args=(agent2_requests, request_callback, response_codes2),
daemon=True,
)
t2.start()
threads.append(t1)
threads.append(t2)

for t in threads:
t.join()

assert response_codes1.count(successful_request_status) == max_requests_per_second
assert response_codes1.count(HTTPStatus.TOO_MANY_REQUESTS) == 1
assert response_codes2.count(successful_request_status) == max_requests_per_second
assert response_codes2.count(HTTPStatus.TOO_MANY_REQUESTS) == 1


def test_refresh_access_token(monkey_island_requests):
monkey_island_requests.login()
original_token = monkey_island_requests.token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from flask import make_response
from flask_limiter import Limiter, RateLimitExceeded
from flask_limiter.util import get_remote_address
from flask_login import current_user
from flask_security import auth_token_required, roles_accepted

from monkey_island.cc.flask_utils import AbstractResource
Expand Down Expand Up @@ -31,15 +31,14 @@ def __init__(self, otp_generator: IOTPGenerator, limiter: Limiter):
# we need to ensure that a single instance of the limiter is used. Hence
# the class variable.
#
# TODO: The limit is currently applied per IP address. We will want to change
# it to per-user, per-IP once we require authentication for this endpoint.
# Note that we do not want to limit to just per-user, otherwise this endpoint could be used
# to enumerate users/tokens.
# to enumerate users/tokens. This should already be captured by the role-based access
# control.
with AgentOTP.lock:
if AgentOTP.limiter is None:
AgentOTP.limiter = limiter.limit(
f"{MAX_OTP_REQUESTS_PER_SECOND}/second",
key_func=get_remote_address,
key_func=lambda: current_user.username,
per_method=True,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from flask import make_response
from flask_limiter import Limiter, RateLimitExceeded
from flask_limiter.util import get_remote_address
from flask_login import current_user
from flask_security import auth_token_required

Expand Down Expand Up @@ -38,7 +37,7 @@ def __init__(self, authentication_facade: AuthenticationFacade, limiter: Limiter
if RefreshAuthenticationToken.limiter is None:
RefreshAuthenticationToken.limiter = limiter.limit(
f"{MAX_REFRESH_AUTHENTICATION_TOKEN_REQUESTS_PER_SECOND}/second",
key_func=get_remote_address,
key_func=lambda: current_user.username,
per_method=True,
)

Expand Down

0 comments on commit 52636ef

Please sign in to comment.