Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate client secret for confidential clients #15

Merged
merged 4 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aioauth_fastapi_demo/admin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class ClientCreate(BaseModel):
client_id: str
client_secret: str
client_secret: Optional[str]
grant_types: List[GrantType]
response_types: List[ResponseType]
redirect_uris: List[str]
Expand Down
2 changes: 1 addition & 1 deletion aioauth_fastapi_demo/oauth2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class Client(BaseTable, table=True): # type: ignore
client_id: str
client_secret: str
client_secret: Optional[str]
grant_types: List[str] = Field(sa_column=Column(ARRAY(String)))
response_types: List[str] = Field(sa_column=Column(ARRAY(String)))
redirect_uris: List[str] = Field(sa_column=Column(ARRAY(String)))
Expand Down
5 changes: 5 additions & 0 deletions aioauth_fastapi_demo/oauth2/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ async def get_client(
if not client_record:
return None

if client_secret is not None and client_record.client_secret is not None:
# validate the client_secret
if client_secret != client_record.client_secret:
return None

return Client(
client_id=client_record.client_id,
client_secret=client_record.client_secret,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Initial migrations

Revision ID: 07a7ace268a7
Revision ID: 21eb480f8485
Revises:
Create Date: 2021-10-02 22:50:10.418498
Create Date: 2024-08-03 21:35:56.307146

"""
import sqlalchemy as sa
import sqlmodel
from alembic import op

# revision identifiers, used by Alembic.
revision = "07a7ace268a7"
revision = "21eb480f8485"
down_revision = None
branch_labels = None
depends_on = None
Expand Down Expand Up @@ -125,7 +125,7 @@ def upgrade():
sa.Column("redirect_uris", sa.ARRAY(sa.String()), nullable=True),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("client_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("client_secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("client_secret", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("scope", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
Expand Down
32 changes: 25 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from uuid import uuid4
from typing import Optional
from uuid import UUID, uuid4

import pytest
from alembic.config import main
Expand Down Expand Up @@ -65,16 +66,19 @@ async def user(db: "SQLAlchemyStorage", user_password: str) -> User:
return user


@pytest.fixture
async def client(db: "SQLAlchemyStorage", user: "User") -> Client:
client_id = uuid4()
client_secret = uuid4()
async def _create_client(
db: "SQLAlchemyStorage",
user: "User",
client_id: UUID,
client_secret: Optional[UUID],
) -> Client:
grant_types = [
"authorization_code",
"client_credentials",
"password",
"refresh_token",
]
if client_secret is not None:
grant_types.append("client_credentials")
response_types = [
"code",
"id_token",
Expand All @@ -88,7 +92,7 @@ async def client(db: "SQLAlchemyStorage", user: "User") -> Client:

client = Client(
client_id=str(client_id),
client_secret=str(client_secret),
client_secret=str(client_secret) if client_secret is not None else None,
response_types=response_types,
grant_types=grant_types,
redirect_uris=redirect_uris,
Expand All @@ -99,3 +103,17 @@ async def client(db: "SQLAlchemyStorage", user: "User") -> Client:
await db.add(client)

return client


@pytest.fixture
async def client(db: "SQLAlchemyStorage", user: "User") -> Client:
client_id = uuid4()
client_secret = uuid4()
return await _create_client(db, user, client_id, client_secret)


@pytest.fixture
async def public_client(db: "SQLAlchemyStorage", user: "User") -> Client:
client_id = uuid4()
client_secret = None
return await _create_client(db, user, client_id, client_secret)
91 changes: 86 additions & 5 deletions tests/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
from aioauth_fastapi_demo.users.models import User


@pytest.mark.asyncio
async def test_authorization_code_flow(
http_client: TestClient, user: "User", client: "Client"
):
async def _get_authorization_code(http_client, user, client):
access_token, _ = get_jwt(user)

response = await http_client.get(
Expand All @@ -37,14 +34,23 @@ async def test_authorization_code_flow(
assert "scope" in parsed_qs.keys()
assert "id_token" in parsed_qs.keys()

return parsed_qs["code"][0]


@pytest.mark.asyncio
async def test_authorization_code_flow(
http_client: TestClient, user: "User", client: "Client"
):
authorization_code = await _get_authorization_code(http_client, user, client)

response = await http_client.post(
"/oauth2/token",
form={
"grant_type": "authorization_code",
"redirect_uri": client.redirect_uris[0],
"client_id": client.client_id,
"client_secret": client.client_secret,
"code": parsed_qs["code"][0],
"code": authorization_code,
},
)

Expand Down Expand Up @@ -79,6 +85,81 @@ async def test_authorization_code_flow(
), "re-trying to revoke an already revoked token should be rejected"


@pytest.mark.asyncio
async def test_authorization_code_no_secret(
http_client: TestClient, user: "User", client: "Client"
):
authorization_code = await _get_authorization_code(http_client, user, client)

response = await http_client.post(
"/oauth2/token",
form={
"grant_type": "authorization_code",
"redirect_uri": client.redirect_uris[0],
"client_id": client.client_id,
"code": authorization_code,
},
)

assert (
response.status_code == HTTPStatus.UNAUTHORIZED
), "no client secret for a confidential client should be rejected"


@pytest.mark.asyncio
async def test_authorization_code_wrong_secret(
http_client: TestClient, user: "User", client: "Client"
):
authorization_code = await _get_authorization_code(http_client, user, client)

response = await http_client.post(
"/oauth2/token",
form={
"grant_type": "authorization_code",
"redirect_uri": client.redirect_uris[0],
"client_id": client.client_id,
"client_secret": f"not {client.client_secret}",
"code": authorization_code,
},
)

assert (
response.status_code == HTTPStatus.UNAUTHORIZED
), "wrong client secret for a confidential client should be rejected"


@pytest.mark.asyncio
async def test_authorization_code_public_client(
http_client: TestClient, user: "User", public_client: "Client"
):
authorization_code = await _get_authorization_code(http_client, user, public_client)

response = await http_client.post(
"/oauth2/token",
form={
"grant_type": "authorization_code",
"redirect_uri": public_client.redirect_uris[0],
"client_id": public_client.client_id,
"code": authorization_code,
},
)

assert "access_token" in response.json()

refresh_token = response.json()["refresh_token"]

response = await http_client.post(
"/oauth2/token",
form={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": public_client.client_id,
},
)

assert response.status_code == HTTPStatus.OK


@pytest.mark.asyncio
async def test_implicit_flow(http_client: TestClient, user: "User", client: "Client"):
access_token, _ = get_jwt(user)
Expand Down
Loading