diff --git a/dashboard-controller.go b/dashboard-controller.go index c83d925..2f33c46 100644 --- a/dashboard-controller.go +++ b/dashboard-controller.go @@ -2,17 +2,18 @@ package main import ( "context" - "database/sql" + "github.com/abjrcode/swervo/favorites" "github.com/abjrcode/swervo/internal/logging" + "github.com/abjrcode/swervo/providers" "github.com/rs/zerolog" ) type DashboardController struct { - ctx context.Context - logger *zerolog.Logger - errorHandler logging.ErrorHandler - db *sql.DB + ctx context.Context + logger *zerolog.Logger + errorHandler logging.ErrorHandler + favoritesRepo favorites.FavoritesRepo } type Provider struct { @@ -26,21 +27,12 @@ type FavoriteInstance struct { InstanceId string `json:"instanceId"` } -var ( - SupportedProviders = map[string]Provider{ - "aws-iam-idc": { - Code: "aws-iam-idc", - Name: "AWS IAM IDC", - }, - } -) - var supportedProviders []Provider -func NewDashboardController(db *sql.DB) *DashboardController { +func NewDashboardController(favoritesRepo favorites.FavoritesRepo) *DashboardController { return &DashboardController{ - db: db, + favoritesRepo: favoritesRepo, } } @@ -50,35 +42,32 @@ func (c *DashboardController) Init(ctx context.Context, errorHandler logging.Err c.logger = &enrichedLogger c.errorHandler = errorHandler - supportedProviders = make([]Provider, 0, len(SupportedProviders)) - for _, provider := range SupportedProviders { - supportedProviders = append(supportedProviders, provider) + supportedProviders = make([]Provider, 0, len(providers.SupportedProviders)) + for _, provider := range providers.SupportedProviders { + supportedProviders = append(supportedProviders, Provider{ + Code: provider.Code, + Name: provider.Name, + }) } } func (c *DashboardController) ListFavorites() ([]FavoriteInstance, error) { - rows, err := c.db.QueryContext(c.ctx, `SELECT * FROM favorite_instances`) + favorites, err := c.favoritesRepo.ListAll(c.ctx) if err != nil { - if err == sql.ErrNoRows { - return []FavoriteInstance{}, nil - } - c.errorHandler.Catch(c.logger, err) } - favorites := make([]FavoriteInstance, 0, 10) + favoriteInstances := make([]FavoriteInstance, 0, len(favorites)) - for rows.Next() { - var favorite FavoriteInstance - err := rows.Scan(&favorite.ProviderCode, &favorite.InstanceId) - if err != nil { - c.errorHandler.Catch(c.logger, err) - } - favorites = append(favorites, favorite) + for _, favorite := range favorites { + favoriteInstances = append(favoriteInstances, FavoriteInstance{ + ProviderCode: favorite.ProviderCode, + InstanceId: favorite.InstanceId, + }) } - return favorites, nil + return favoriteInstances, nil } func (c *DashboardController) ListProviders() []Provider { diff --git a/dashboard-controller_test.go b/dashboard-controller_test.go index 56356e1..0137aa8 100644 --- a/dashboard-controller_test.go +++ b/dashboard-controller_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/abjrcode/swervo/favorites" "github.com/abjrcode/swervo/internal/migrations" "github.com/abjrcode/swervo/internal/testhelpers" "github.com/rs/zerolog" @@ -12,10 +13,13 @@ import ( func initDashboardController(t *testing.T) *DashboardController { db, err := migrations.NewInMemoryMigratedDatabase(t, "dashboard-controller-tests.db") - require.NoError(t, err) + + logger := zerolog.Nop() + favoritesRepo := favorites.NewFavorites(db, &logger) + controller := &DashboardController{ - db: db, + favoritesRepo: favoritesRepo, } ctx := zerolog.Nop().WithContext(context.Background()) errHandler := testhelpers.NewMockErrorHandler(t) diff --git a/favorites/repository.go b/favorites/repository.go new file mode 100644 index 0000000..50566dd --- /dev/null +++ b/favorites/repository.go @@ -0,0 +1,109 @@ +package favorites + +import ( + "context" + "database/sql" + "errors" + + "github.com/rs/zerolog" +) + +type Favorite struct { + ProviderCode string + InstanceId string +} + +type FavoritesRepo interface { + ListAll(ctx context.Context) ([]*Favorite, error) + IsFavorite(ctx context.Context, favorite *Favorite) (bool, error) + Add(ctx context.Context, favorite *Favorite) error + Remove(ctx context.Context, favorite *Favorite) error +} + +type favoritesImpl struct { + logger *zerolog.Logger + db *sql.DB +} + +func NewFavorites(db *sql.DB, logger *zerolog.Logger) FavoritesRepo { + enrichedLogger := logger.With().Str("component", "favorites_repo").Logger() + + return &favoritesImpl{ + db: db, + logger: &enrichedLogger, + } +} + +func (f *favoritesImpl) ListAll(ctx context.Context) ([]*Favorite, error) { + rows, err := f.db.QueryContext(ctx, `SELECT * FROM favorite_instances`) + + if err != nil { + if err == sql.ErrNoRows { + return []*Favorite{}, nil + } + + return nil, err + } + + favorites := make([]*Favorite, 0, 10) + + for rows.Next() { + var favorite Favorite + err := rows.Scan(&favorite.ProviderCode, &favorite.InstanceId) + + if err != nil { + return nil, err + } + + favorites = append(favorites, &favorite) + } + + return favorites, nil +} + +func (f *favoritesImpl) IsFavorite(ctx context.Context, favorite *Favorite) (bool, error) { + row := f.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM favorite_instances WHERE provider_code = ? AND instance_id = ? `, favorite.ProviderCode, favorite.InstanceId) + + var count int + err := row.Scan(&count) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + + return false, err + } + + return count > 0, nil +} + +func (f *favoritesImpl) Add(ctx context.Context, favorite *Favorite) error { + _, err := f.db.ExecContext(ctx, `INSERT INTO favorite_instances (provider_code, instance_id) VALUES (?, ?) `, favorite.ProviderCode, favorite.InstanceId) + + if err != nil { + return err + } + + return nil +} + +func (f *favoritesImpl) Remove(ctx context.Context, favorite *Favorite) error { + res, err := f.db.ExecContext(ctx, `DELETE FROM favorite_instances WHERE provider_code = ? AND instance_id = ? `, favorite.ProviderCode, favorite.InstanceId) + + if err != nil { + return err + } + + rowsAffected, err := res.RowsAffected() + + if err != nil { + return err + } + + if rowsAffected == 0 { + return sql.ErrNoRows + } + + return nil +} diff --git a/favorites/repository_test.go b/favorites/repository_test.go new file mode 100644 index 0000000..39dfd9d --- /dev/null +++ b/favorites/repository_test.go @@ -0,0 +1,95 @@ +package favorites + +import ( + "context" + "testing" + + "github.com/abjrcode/swervo/internal/migrations" + "github.com/abjrcode/swervo/providers" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" +) + +func TestAddFavorite(t *testing.T) { + db, err := migrations.NewInMemoryMigratedDatabase(t, "favorites-repo-tests.db") + require.NoError(t, err) + + logger := zerolog.Nop() + repo := NewFavorites(db, &logger) + + favorite := &Favorite{ + ProviderCode: providers.AwsIamIdc, + InstanceId: "some-nice-id", + } + + ctx := context.Background() + err = repo.Add(ctx, favorite) + require.NoError(t, err) + + favorites, err := repo.ListAll(ctx) + require.NoError(t, err) + + require.Len(t, favorites, 1) + require.Equal(t, favorite, favorites[0]) +} + +func TestRemoveFavorite(t *testing.T) { + db, err := migrations.NewInMemoryMigratedDatabase(t, "favorites-repo-tests.db") + require.NoError(t, err) + + logger := zerolog.Nop() + + repo := NewFavorites(db, &logger) + ctx := context.Background() + + favorite := &Favorite{ + ProviderCode: providers.AwsIamIdc, + InstanceId: "some-nice-id", + } + + err = repo.Add(ctx, favorite) + require.NoError(t, err) + + favorites, err := repo.ListAll(ctx) + require.NoError(t, err) + + require.Len(t, favorites, 1) + require.Equal(t, favorite, favorites[0]) + + err = repo.Remove(ctx, favorite) + require.NoError(t, err) + + favorites, err = repo.ListAll(ctx) + require.NoError(t, err) + + require.Len(t, favorites, 0) +} + +func TestIsFavorite(t *testing.T) { + db, err := migrations.NewInMemoryMigratedDatabase(t, "favorites-repo-tests.db") + require.NoError(t, err) + + logger := zerolog.Nop() + + repo := NewFavorites(db, &logger) + ctx := context.Background() + + favorite := &Favorite{ + ProviderCode: providers.AwsIamIdc, + InstanceId: "some-nice-id", + } + + err = repo.Add(ctx, favorite) + require.NoError(t, err) + + isFavorite, err := repo.IsFavorite(ctx, favorite) + require.NoError(t, err) + require.True(t, isFavorite) + + isFavorite, err = repo.IsFavorite(ctx, &Favorite{ + ProviderCode: providers.AwsIamIdc, + InstanceId: "some-nice-id-2", + }) + require.NoError(t, err) + require.False(t, isFavorite) +} diff --git a/frontend/src/components/toast.tsx b/frontend/src/components/toast.tsx index fd97f2c..d6831d0 100644 --- a/frontend/src/components/toast.tsx +++ b/frontend/src/components/toast.tsx @@ -46,8 +46,8 @@ export function Toast({ }` const toastAnimationClass = `${ - isEntering.current ? "animate-[slideInLeft_0.5s_ease-in_both]" : "" - } ${isExiting ? "animate-[slideOutLeft_0.3s_ease-out_both]" : ""}` + isEntering.current ? "animate-[slideInRight_0.5s_ease-in_both]" : "" + } ${isExiting ? "animate-[slideOutRight_0.3s_ease-out_both]" : ""}` const exitAnimationTimer = useRef(null) const toastRemovalTimer = useRef(null) diff --git a/frontend/src/layout.tsx b/frontend/src/layout.tsx index 846419b..0b4ea28 100644 --- a/frontend/src/layout.tsx +++ b/frontend/src/layout.tsx @@ -1,8 +1,13 @@ import { useEffect, useState } from "react" import { useToaster } from "./toast-provider/toast-context" +import { useAuth } from "./auth-provider/auth-context" +import { LockVault } from "../wailsjs/go/main/AuthController" +import { Link, useNavigate } from "react-router-dom" export function Layout({ children }: { children: React.ReactNode }) { const toaster = useToaster() + const authContext = useAuth() + const navigate = useNavigate() const [toasts, setToasts] = useState([]) @@ -14,15 +19,53 @@ export function Layout({ children }: { children: React.ReactNode }) { toaster.iRenderToasts(updateToasts) }, [toaster]) + async function attemptLock() { + await LockVault() + + navigate("/") + authContext.onVaultLocked() + } + return ( <>
+ className="fixed right-5 top-5 w-[512px] z-50"> {...toasts}
-
- {children} +
+ +
{children}
+
+ +
+ + + +
+
) diff --git a/frontend/src/main.tsx b/frontend/src/main.tsx index 070ca6b..d02b16b 100644 --- a/frontend/src/main.tsx +++ b/frontend/src/main.tsx @@ -19,6 +19,8 @@ import { WailsProvider } from "./wails-provider/wails-provider" import { AwsIamIdcDeviceAuth } from "./routes/aws-iam-idc/aws-iam-idc-device-auth" import { awsIamIdcDeviceAuthAction } from "./routes/aws-iam-idc/aws-iam-idc-device-auth-data" import { ToastProvider } from "./toast-provider/toast-provider" +import { AwsIamIdcInstances } from "./routes/aws-iam-idc/aws-iam-idc-instances" +import { awsIamIdcInstancesData } from "./routes/aws-iam-idc/aws-iam-idc-instances-data" const devMode = import.meta.env.DEV @@ -47,6 +49,11 @@ void (async function main() { { path: "aws-iam-idc", children: [ + { + index: true, + element: , + loader: awsIamIdcInstancesData, + }, { path: "setup", element: , diff --git a/frontend/src/routes/aws-iam-idc/aws-iam-idc-card.tsx b/frontend/src/routes/aws-iam-idc/aws-iam-idc-card.tsx index fc44cf1..e267634 100644 --- a/frontend/src/routes/aws-iam-idc/aws-iam-idc-card.tsx +++ b/frontend/src/routes/aws-iam-idc/aws-iam-idc-card.tsx @@ -1,6 +1,10 @@ import React from "react" -import { useFetcher, useNavigate } from "react-router-dom" -import { RefreshAccessToken } from "../../../wailsjs/go/awsiamidc/AwsIdentityCenterController" +import { useFetcher, useNavigate, useRevalidator } from "react-router-dom" +import { + MarkAsFavorite, + RefreshAccessToken, + UnmarkAsFavorite, +} from "../../../wailsjs/go/awsiamidc/AwsIdentityCenterController" import { AwsIamIdcCardDataError, @@ -10,6 +14,7 @@ import { export function AwsIamIdcCard({ instanceId }: { instanceId: string }) { const navigate = useNavigate() const fetcher = useFetcher() + const validator = useRevalidator() async function authorizeDevice(instanceId: string) { const deviceAuthFlowResult = await RefreshAccessToken(instanceId) @@ -37,20 +42,51 @@ export function AwsIamIdcCard({ instanceId }: { instanceId: string }) { } }, [instanceId, fetcher]) + async function markAsFavorite() { + await MarkAsFavorite(instanceId) + validator.revalidate() + } + + async function unmarkAsFavorite() { + await UnmarkAsFavorite(instanceId) + validator.revalidate() + } + const cardDataResult = fetcher.data as AwsIamIdcCardDataResult | undefined if (cardDataResult === undefined) { - return
Loading...
+ return ( +
+
+
+
+
+ ) } if (cardDataResult.success) { const cardData = cardDataResult.result return ( -
+
+ className="card-title"> +

{cardData.label}

diff --git a/frontend/src/routes/aws-iam-idc/aws-iam-idc-device-auth.tsx b/frontend/src/routes/aws-iam-idc/aws-iam-idc-device-auth.tsx index 66007f1..4b64595 100644 --- a/frontend/src/routes/aws-iam-idc/aws-iam-idc-device-auth.tsx +++ b/frontend/src/routes/aws-iam-idc/aws-iam-idc-device-auth.tsx @@ -79,7 +79,7 @@ export function AwsIamIdcDeviceAuth() { return (
+ className="h-screen flex flex-col items-center justify-center gap-4 border-2 p-6">

Please authorize the request by visiting{" "} +

AWS IAM Identity Center Instances

+ +
+
+ +
    + {loader.map((instance) => ( +
  • + +
  • + ))} + +
  • + + Add Instance + +
  • +
+
+ ) +} diff --git a/frontend/src/routes/aws-iam-idc/aws-iam-idc-setup.tsx b/frontend/src/routes/aws-iam-idc/aws-iam-idc-setup.tsx index 4451c0a..4c77015 100644 --- a/frontend/src/routes/aws-iam-idc/aws-iam-idc-setup.tsx +++ b/frontend/src/routes/aws-iam-idc/aws-iam-idc-setup.tsx @@ -57,50 +57,52 @@ export function AwsIamIdcSetup() { }, [navigate, setupResult, wails, toaster]) return ( - -

AWS IAM Identity Center

- - - - - - - - - +
+
+

AWS IAM Identity Center

+ + + + + + + + +
+
) } diff --git a/frontend/src/routes/dashboard/dashboard.tsx b/frontend/src/routes/dashboard/dashboard.tsx index fd17b07..87e0826 100644 --- a/frontend/src/routes/dashboard/dashboard.tsx +++ b/frontend/src/routes/dashboard/dashboard.tsx @@ -1,4 +1,4 @@ -import { Link, useLoaderData } from "react-router-dom" +import { useLoaderData } from "react-router-dom" import { AwsIamIdcCard } from "../aws-iam-idc/aws-iam-idc-card" import { main } from "../../../wailsjs/go/models" @@ -26,11 +26,6 @@ export function Dashboard() { /> ) })} - - New - ) } diff --git a/frontend/src/routes/providers/providers.tsx b/frontend/src/routes/providers/providers.tsx index 7362f49..e677e49 100644 --- a/frontend/src/routes/providers/providers.tsx +++ b/frontend/src/routes/providers/providers.tsx @@ -8,25 +8,20 @@ export function Providers() { return ( <> {outlet == null ? ( - <> +

Supported Providers

-
    +
      {providers.map((provider) => (
    • + to={`${provider.code}`}> {provider.name}
    • ))}
    - - ← dashboard - - +
) : null} diff --git a/frontend/src/routes/vault/vault.tsx b/frontend/src/routes/vault/vault.tsx index 8ce3122..b434d8c 100644 --- a/frontend/src/routes/vault/vault.tsx +++ b/frontend/src/routes/vault/vault.tsx @@ -1,9 +1,5 @@ -import { Outlet, useNavigate } from "react-router-dom" -import { - ConfigureVault, - UnlockVault, - LockVault, -} from "../../../wailsjs/go/main/AuthController" +import { Outlet } from "react-router-dom" +import { ConfigureVault, UnlockVault } from "../../../wailsjs/go/main/AuthController" import { Layout } from "../../layout" import { useAuth } from "../../auth-provider/auth-context" import { VaultBuilder } from "./vault-builder" @@ -11,7 +7,6 @@ import { VaultDoor } from "./vault-door" import { useState } from "react" export function Vault(props: { isVaultConfigured: boolean }) { - const navigate = useNavigate() const authContext = useAuth() const [isVaultConfigured, setIsVaultConfigured] = useState(props.isVaultConfigured) @@ -31,39 +26,25 @@ export function Vault(props: { isVaultConfigured: boolean }) { } } - async function attemptLock() { - await LockVault() - - navigate("/") - authContext.onVaultLocked() - } - if (authContext.isAuthenticated === false) { if (isVaultConfigured) { return ( - +
- +
) } return ( - +
- +
) } return ( - <> - - - - - + + + ) } diff --git a/frontend/tailwind.config.cjs b/frontend/tailwind.config.cjs index d249fd1..47cc7e2 100644 --- a/frontend/tailwind.config.cjs +++ b/frontend/tailwind.config.cjs @@ -76,6 +76,22 @@ module.exports = { transform: "translate3d(-110%, 0, 0)", }, }, + slideInRight: { + from: { + transform: "translate3d(100%, 0, 0)", + }, + to: { + transform: "translate3d(0, 0, 0)", + }, + }, + slideOutRight: { + from: { + transform: "translate3d(0, 0, 0)", + }, + to: { + transform: "translate3d(110%, 0, 0)", + }, + }, }, animation: { wiggle: "wiggle 1s ease-in-out infinite", diff --git a/frontend/wailsjs/go/awsiamidc/AwsIdentityCenterController.d.ts b/frontend/wailsjs/go/awsiamidc/AwsIdentityCenterController.d.ts index 692b45b..e59063e 100755 --- a/frontend/wailsjs/go/awsiamidc/AwsIdentityCenterController.d.ts +++ b/frontend/wailsjs/go/awsiamidc/AwsIdentityCenterController.d.ts @@ -12,6 +12,12 @@ export function GetInstanceData(arg1:string):Promise; +export function ListInstances():Promise>; + +export function MarkAsFavorite(arg1:string):Promise; + export function RefreshAccessToken(arg1:string):Promise; export function Setup(arg1:string,arg2:string,arg3:string):Promise; + +export function UnmarkAsFavorite(arg1:string):Promise; diff --git a/frontend/wailsjs/go/awsiamidc/AwsIdentityCenterController.js b/frontend/wailsjs/go/awsiamidc/AwsIdentityCenterController.js index 4943e78..a34b920 100755 --- a/frontend/wailsjs/go/awsiamidc/AwsIdentityCenterController.js +++ b/frontend/wailsjs/go/awsiamidc/AwsIdentityCenterController.js @@ -18,6 +18,14 @@ export function Init(arg1, arg2) { return window['go']['awsiamidc']['AwsIdentityCenterController']['Init'](arg1, arg2); } +export function ListInstances() { + return window['go']['awsiamidc']['AwsIdentityCenterController']['ListInstances'](); +} + +export function MarkAsFavorite(arg1) { + return window['go']['awsiamidc']['AwsIdentityCenterController']['MarkAsFavorite'](arg1); +} + export function RefreshAccessToken(arg1) { return window['go']['awsiamidc']['AwsIdentityCenterController']['RefreshAccessToken'](arg1); } @@ -25,3 +33,7 @@ export function RefreshAccessToken(arg1) { export function Setup(arg1, arg2, arg3) { return window['go']['awsiamidc']['AwsIdentityCenterController']['Setup'](arg1, arg2, arg3); } + +export function UnmarkAsFavorite(arg1) { + return window['go']['awsiamidc']['AwsIdentityCenterController']['UnmarkAsFavorite'](arg1); +} diff --git a/frontend/wailsjs/go/models.ts b/frontend/wailsjs/go/models.ts index f5421b1..589039e 100755 --- a/frontend/wailsjs/go/models.ts +++ b/frontend/wailsjs/go/models.ts @@ -46,6 +46,7 @@ export namespace awsiamidc { instanceId: string; enabled: boolean; label: string; + isFavorite: boolean; isAccessTokenExpired: boolean; accessTokenExpiresIn: string; accounts: AwsIdentityCenterAccount[]; @@ -59,6 +60,7 @@ export namespace awsiamidc { this.instanceId = source["instanceId"]; this.enabled = source["enabled"]; this.label = source["label"]; + this.isFavorite = source["isFavorite"]; this.isAccessTokenExpired = source["isAccessTokenExpired"]; this.accessTokenExpiresIn = source["accessTokenExpiresIn"]; this.accounts = this.convertValues(source["accounts"], AwsIdentityCenterAccount); diff --git a/main.go b/main.go index 4902409..082e48c 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "os" "github.com/abjrcode/swervo/clients/awssso" + "github.com/abjrcode/swervo/favorites" "github.com/abjrcode/swervo/internal/config" "github.com/abjrcode/swervo/internal/datastore" "github.com/abjrcode/swervo/internal/logging" @@ -97,9 +98,11 @@ func main() { defer vault.Seal() authController := NewAuthController(vault) - dashboardController := NewDashboardController(sqlDb) - awsIdcController := awsiamidc.NewAwsIdentityCenterController(sqlDb, vault, awssso.NewAwsSsoOidcClient(), timeProvider) + favoritesRepo := favorites.NewFavorites(sqlDb, &logger) + dashboardController := NewDashboardController(favoritesRepo) + + awsIdcController := awsiamidc.NewAwsIdentityCenterController(sqlDb, favoritesRepo, vault, awssso.NewAwsSsoOidcClient(), timeProvider) logger.Info().Msgf("PID [%d] - launching Swervo", os.Getpid()) if err := wails.Run(&options.App{ diff --git a/providers/aws-iam-idc/aws-identity-center.go b/providers/aws-iam-idc/aws-identity-center.go index ac59ee9..88fb33d 100644 --- a/providers/aws-iam-idc/aws-identity-center.go +++ b/providers/aws-iam-idc/aws-identity-center.go @@ -9,9 +9,11 @@ import ( "time" "github.com/abjrcode/swervo/clients/awssso" + "github.com/abjrcode/swervo/favorites" "github.com/abjrcode/swervo/internal/logging" "github.com/abjrcode/swervo/internal/security/encryption" "github.com/abjrcode/swervo/internal/utils" + "github.com/abjrcode/swervo/providers" "github.com/dustin/go-humanize" "github.com/rs/zerolog" "github.com/segmentio/ksuid" @@ -33,14 +35,16 @@ type AwsIdentityCenterController struct { logger *zerolog.Logger errHandler logging.ErrorHandler db *sql.DB + favoritesRepo favorites.FavoritesRepo encryptionService encryption.EncryptionService awsSsoClient awssso.AwsSsoOidcClient timeHelper utils.Clock } -func NewAwsIdentityCenterController(db *sql.DB, encryptionService encryption.EncryptionService, awsSsoClient awssso.AwsSsoOidcClient, datetime utils.Clock) *AwsIdentityCenterController { +func NewAwsIdentityCenterController(db *sql.DB, favoritesRepo favorites.FavoritesRepo, encryptionService encryption.EncryptionService, awsSsoClient awssso.AwsSsoOidcClient, datetime utils.Clock) *AwsIdentityCenterController { return &AwsIdentityCenterController{ db: db, + favoritesRepo: favoritesRepo, encryptionService: encryptionService, awsSsoClient: awsSsoClient, timeHelper: datetime, @@ -63,11 +67,38 @@ type AwsIdentityCenterCardData struct { InstanceId string `json:"instanceId"` Enabled bool `json:"enabled"` Label string `json:"label"` + IsFavorite bool `json:"isFavorite"` IsAccessTokenExpired bool `json:"isAccessTokenExpired"` AccessTokenExpiresIn string `json:"accessTokenExpiresIn"` Accounts []AwsIdentityCenterAccount `json:"accounts"` } +func (c *AwsIdentityCenterController) ListInstances() ([]string, error) { + rows, err := c.db.QueryContext(c.ctx, "SELECT instance_id FROM aws_iam_idc_instances ORDER BY instance_id DESC") + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return make([]string, 0), nil + } + + c.errHandler.Catch(c.logger, err) + } + + instances := make([]string, 0) + + for rows.Next() { + var instanceId string + + if err := rows.Scan(&instanceId); err != nil { + c.errHandler.Catch(c.logger, err) + } + + instances = append(instances, instanceId) + } + + return instances, nil +} + func (c *AwsIdentityCenterController) GetInstanceData(instanceId string) (*AwsIdentityCenterCardData, error) { row := c.db.QueryRowContext(c.ctx, "SELECT region, label, access_token_enc, access_token_created_at, access_token_expires_in, enc_key_id FROM aws_iam_idc_instances WHERE instance_id = ?", instanceId) @@ -86,6 +117,13 @@ func (c *AwsIdentityCenterController) GetInstanceData(instanceId string) (*AwsId c.errHandler.Catch(c.logger, err) } + isFavorite, err := c.favoritesRepo.IsFavorite(c.ctx, &favorites.Favorite{ + ProviderCode: providers.AwsIamIdc, + InstanceId: instanceId, + }) + + c.errHandler.CatchWithMsg(c.logger, err, "failed to check if instance is favorite") + now := c.timeHelper.NowUnix() if now > accessTokenCreatedAt+accessTokenExpiresIn { c.logger.Info().Msgf("token for instance [%s] has expired", instanceId) @@ -94,6 +132,7 @@ func (c *AwsIdentityCenterController) GetInstanceData(instanceId string) (*AwsId Enabled: true, InstanceId: instanceId, Label: label, + IsFavorite: isFavorite, IsAccessTokenExpired: true, AccessTokenExpiresIn: humanize.Time(time.Unix(accessTokenCreatedAt+accessTokenExpiresIn, 0)), Accounts: make([]AwsIdentityCenterAccount, 0), @@ -126,6 +165,7 @@ func (c *AwsIdentityCenterController) GetInstanceData(instanceId string) (*AwsId Enabled: true, InstanceId: instanceId, Label: label, + IsFavorite: isFavorite, IsAccessTokenExpired: false, AccessTokenExpiresIn: humanize.Time(time.Unix(accessTokenCreatedAt+accessTokenExpiresIn, 0)), Accounts: accounts, @@ -277,12 +317,6 @@ func (c *AwsIdentityCenterController) FinalizeSetup(clientId, startUrl, region, return "", ErrTransientAwsClientError } - tx, err := c.db.BeginTx(ctx, nil) - - c.errHandler.CatchWithMsg(c.logger, err, "failed to start transaction") - - defer tx.Rollback() - idTokenEnc, keyId, err := c.encryptionService.Encrypt(tokenRes.IdToken) c.errHandler.CatchWithMsg(c.logger, err, "failed to encrypt id token") @@ -294,7 +328,9 @@ func (c *AwsIdentityCenterController) FinalizeSetup(clientId, startUrl, region, refreshTokenEnc, _, err := c.encryptionService.Encrypt(tokenRes.RefreshToken) c.errHandler.CatchWithMsg(c.logger, err, "failed to encrypt refresh token") - uniqueId, err := ksuid.NewRandom() + nowUnix := c.timeHelper.NowUnix() + + uniqueId, err := ksuid.NewRandomWithTime(time.Unix(nowUnix, 0)) c.errHandler.CatchWithMsg(c.logger, err, "failed to generate instance ID") instanceId := uniqueId.String() @@ -302,7 +338,7 @@ func (c *AwsIdentityCenterController) FinalizeSetup(clientId, startUrl, region, (instance_id, start_url, region, label, enabled, id_token_enc, access_token_enc, token_type, access_token_created_at, access_token_expires_in, refresh_token_enc, enc_key_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` - _, err = tx.ExecContext(ctx, sql, + _, err = c.db.ExecContext(ctx, sql, instanceId, startUrl, region, @@ -311,22 +347,28 @@ func (c *AwsIdentityCenterController) FinalizeSetup(clientId, startUrl, region, idTokenEnc, accessTokenEnc, tokenRes.TokenType, - c.timeHelper.NowUnix(), + nowUnix, tokenRes.ExpiresIn, refreshTokenEnc, keyId) c.errHandler.CatchWithMsg(c.logger, err, "failed to save token to database") - _, err = tx.ExecContext(ctx, `INSERT INTO favorite_instances (provider_code, instance_id) VALUES (?, ?) `, "aws-iam-idc", instanceId) - - c.errHandler.CatchWithMsg(c.logger, err, "failed to add provider to list of configured providers") - - err = tx.Commit() + return instanceId, nil +} - c.errHandler.CatchWithMsg(c.logger, err, "failed to commit transaction") +func (c *AwsIdentityCenterController) MarkAsFavorite(instanceId string) error { + return c.favoritesRepo.Add(c.ctx, &favorites.Favorite{ + ProviderCode: providers.AwsIamIdc, + InstanceId: instanceId, + }) +} - return instanceId, nil +func (c *AwsIdentityCenterController) UnmarkAsFavorite(instanceId string) error { + return c.favoritesRepo.Remove(c.ctx, &favorites.Favorite{ + ProviderCode: providers.AwsIamIdc, + InstanceId: instanceId, + }) } func (c *AwsIdentityCenterController) RefreshAccessToken(instanceId string) (*AuthorizeDeviceFlowResult, error) { diff --git a/providers/aws-iam-idc/aws-identity-center_test.go b/providers/aws-iam-idc/aws-identity-center_test.go index 7e2cc57..138c9e6 100644 --- a/providers/aws-iam-idc/aws-identity-center_test.go +++ b/providers/aws-iam-idc/aws-identity-center_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/abjrcode/swervo/clients/awssso" + "github.com/abjrcode/swervo/favorites" "github.com/abjrcode/swervo/internal/migrations" "github.com/abjrcode/swervo/internal/security/vault" "github.com/abjrcode/swervo/internal/testhelpers" @@ -43,19 +44,19 @@ func (m *mockAwsSsoOidcClient) ListAccounts(ctx context.Context, accessToken str func initController(t *testing.T) (*AwsIdentityCenterController, *mockAwsSsoOidcClient, *testhelpers.MockClock) { db, err := migrations.NewInMemoryMigratedDatabase(t, "aws-iam-idc-controller-tests.db") - require.NoError(t, err) awsClient := new(mockAwsSsoOidcClient) mockDatetime := testhelpers.NewMockClock() logger := zerolog.Nop() + favoritesRepo := favorites.NewFavorites(db, &logger) errHandler := testhelpers.NewMockErrorHandler(t) ctx := logger.WithContext(context.Background()) vault := vault.NewVault(db, mockDatetime, &logger, errHandler) timeSetCall := mockDatetime.On("NowUnix").Return(1) err = vault.Configure(context.Background(), "abc") require.NoError(t, err) - controller := NewAwsIdentityCenterController(db, vault, awsClient, mockDatetime) + controller := NewAwsIdentityCenterController(db, favoritesRepo, vault, awsClient, mockDatetime) controller.Init(ctx, errHandler) timeSetCall.Unset() @@ -63,9 +64,7 @@ func initController(t *testing.T) (*AwsIdentityCenterController, *mockAwsSsoOidc return controller, awsClient, mockDatetime } -func simulateSuccessfulSetup(t *testing.T, startUrl, region string) (string, *AwsIdentityCenterController, *mockAwsSsoOidcClient, *testhelpers.MockClock) { - controller, mockAws, mockTimeProvider := initController(t) - +func simulateSuccessfulSetup(t *testing.T, controller *AwsIdentityCenterController, mockAws *mockAwsSsoOidcClient, mockTimeProvider *testhelpers.MockClock, startUrl, region, label string) string { mockRegRes := awssso.RegistrationResponse{ ClientId: "test-client-id", ClientSecret: "test-client-secret", @@ -82,7 +81,8 @@ func simulateSuccessfulSetup(t *testing.T, startUrl, region string) (string, *Aw } deviceAuthCall := mockAws.On("StartDeviceAuthorization", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockAuthRes, nil) - setupResult, err := controller.Setup(startUrl, region, "test-label") + mockTimeProvider.On("NowUnix").Once().Return(1) + setupResult, err := controller.Setup(startUrl, region, label) require.NoError(t, err) mockTokenRes := awssso.GetTokenResponse{ @@ -105,7 +105,7 @@ func simulateSuccessfulSetup(t *testing.T, startUrl, region string) (string, *Aw createTokenCall.Unset() timeSetCall.Unset() - return instanceId, controller, mockAws, mockTimeProvider + return instanceId } func TestNewAccountSetupErrorInvalidStartUrl(t *testing.T) { @@ -127,7 +127,7 @@ func TestNewAccountSetupErrorInvalidRegion(t *testing.T) { func TestNewAccountSetup_Error_InvalidLabel(t *testing.T) { controller, _, _ := initController(t) - _, err := controller.Setup("https://test-start-url.aws-apps.com/start", "region_mars", "") + _, err := controller.Setup("https://test-start-url.aws-apps.com/start", "region_mars", "i_am_a_very_long_label_that_is_longer_than_50_characters_and_therefore_invalid") require.Error(t, err, ErrInvalidLabel) } @@ -197,7 +197,7 @@ func TestNewAccount_FullSetup_Success(t *testing.T) { require.Equal(t, setupResult, &AuthorizeDeviceFlowResult{ StartUrl: "https://test-start-url.aws-apps.com/start", Region: "eu-west-1", - Label: "test_label", + Label: label, ClientId: "test-client-id", UserCode: "test-user-code", DeviceCode: "test-device-code", @@ -211,7 +211,9 @@ func TestNewAccountSetupErrorDoubleRegistration(t *testing.T) { region := "eu-west-1" label := "test_label" - _, controller, _, _ := simulateSuccessfulSetup(t, startUrl, region) + controller, mockAws, mockTimeProvider := initController(t) + + _ = simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) _, err := controller.Setup(startUrl, region, label) require.Error(t, err, ErrInstanceAlreadyRegistered) @@ -354,11 +356,61 @@ func TestFinalizeSetup_Error_DeviceAuthTimeout(t *testing.T) { require.Error(t, err, ErrDeviceAuthFlowTimedOut) } +func TestListInstances(t *testing.T) { + startUrl := "https://test-start-url.aws-apps.com/start" + region := "eu-west-1" + label := "test_label" + + controller, mockAws, mockTimeProvider := initController(t) + + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) + + startUrl2 := "https://test-start-url-2.aws-apps.com/start" + region2 := "eu-west-2" + label2 := "test_label_2" + + mockAuthRes := awssso.AuthorizationResponse{ + DeviceCode: "test-device-code", + UserCode: "test-user-code", + VerificationUriComplete: "https://test-verification-url", + ExpiresIn: 10, + } + mockAws.On("StartDeviceAuthorization", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().Return(&mockAuthRes, nil) + + mockTimeProvider.On("NowUnix").Once().Return(2) + setupResult, err := controller.Setup(startUrl2, region2, label2) + require.NoError(t, err) + + mockTokenRes := awssso.GetTokenResponse{ + IdToken: "test-id-token", + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "test-token-type", + ExpiresIn: 5, + } + mockAws.On("CreateToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().Return(&mockTokenRes, nil) + + tokenCreatedAt := 3 + mockTimeProvider.On("NowUnix").Once().Return(tokenCreatedAt) + + instanceId2, err := controller.FinalizeSetup(setupResult.ClientId, setupResult.StartUrl, setupResult.Region, setupResult.Label, setupResult.UserCode, setupResult.DeviceCode) + require.NoError(t, err) + + instances, err := controller.ListInstances() + + require.NoError(t, err) + + require.Equal(t, instances, []string{instanceId2, instanceId}) +} + func TestGetInstanceData(t *testing.T) { startUrl := "https://test-start-url.aws-apps.com/start" region := "eu-west-1" + label := "test_label" - instanceId, controller, mockAws, mockTimeProvider := simulateSuccessfulSetup(t, startUrl, region) + controller, mockAws, mockTimeProvider := initController(t) + + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) mockTimeProvider.On("NowUnix").Return(3) @@ -382,7 +434,8 @@ func TestGetInstanceData(t *testing.T) { require.NoError(t, err) require.Equal(t, instanceId, instanceData.InstanceId) - require.Equal(t, "test-label", instanceData.Label) + require.Equal(t, label, instanceData.Label) + require.Equal(t, false, instanceData.IsFavorite) require.Equal(t, false, instanceData.IsAccessTokenExpired) require.Equal(t, "test-account-id", instanceData.Accounts[0].AccountId) require.Equal(t, "test-account-name", instanceData.Accounts[0].AccountName) @@ -394,8 +447,11 @@ func TestGetInstanceData(t *testing.T) { func TestGetInstance_AccessTokenExpired(t *testing.T) { startUrl := "https://test-start-url.aws-apps.com/start" region := "eu-west-1" + label := "test_label" + + controller, mockAws, mockTimeProvider := initController(t) - instanceId, controller, _, mockTimeProvider := simulateSuccessfulSetup(t, startUrl, region) + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) mockTimeProvider.On("NowUnix").Return(10) @@ -406,7 +462,7 @@ func TestGetInstance_AccessTokenExpired(t *testing.T) { require.NoError(t, err) require.Equal(t, instanceId, data.InstanceId) - require.Equal(t, "test-label", data.Label) + require.Equal(t, label, data.Label) require.Equal(t, true, data.IsAccessTokenExpired) require.Empty(t, data.Accounts) } @@ -426,11 +482,93 @@ func TestGetNonExistentInstance(t *testing.T) { require.Error(t, err, ErrInstanceWasNotFound) } +func TestMarkInstanceAsFavorite(t *testing.T) { + controller, mockAws, mockTimeProvider := initController(t) + + startUrl := "https://test-start-url.aws-apps.com/start" + region := "eu-west-1" + label := "test_label" + + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) + + err := controller.MarkAsFavorite(instanceId) + require.NoError(t, err) + + mockTimeProvider.On("NowUnix").Once().Return(4) + mockListAccountsRes := awssso.ListAccountsResponse{ + Accounts: []awssso.AwsAccount{ + { + AccountId: "test-account-id", + AccountName: "test-account-name", + AccountEmail: "test-account-email", + }, + { + AccountId: "test-account-id-2", + AccountName: "test-account-name-2", + AccountEmail: "test-account-email-2", + }, + }, + } + mockAws.On("ListAccounts", mock.Anything, mock.AnythingOfType("string")).Return(&mockListAccountsRes, nil) + + instanceData, err := controller.GetInstanceData(instanceId) + require.NoError(t, err) + + require.Equal(t, true, instanceData.IsFavorite) +} + +func TestUnmarkInstanceAsFavorite(t *testing.T) { + controller, mockAws, mockTimeProvider := initController(t) + + startUrl := "https://test-start-url.aws-apps.com/start" + region := "eu-west-1" + label := "test_label" + + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) + + err := controller.MarkAsFavorite(instanceId) + require.NoError(t, err) + + mockTimeProvider.On("NowUnix").Once().Return(4) + mockListAccountsRes := awssso.ListAccountsResponse{ + Accounts: []awssso.AwsAccount{ + { + AccountId: "test-account-id", + AccountName: "test-account-name", + AccountEmail: "test-account-email", + }, + { + AccountId: "test-account-id-2", + AccountName: "test-account-name-2", + AccountEmail: "test-account-email-2", + }, + }, + } + mockAws.On("ListAccounts", mock.Anything, mock.AnythingOfType("string")).Return(&mockListAccountsRes, nil) + + instanceData, err := controller.GetInstanceData(instanceId) + require.NoError(t, err) + require.Equal(t, true, instanceData.IsFavorite) + + err = controller.UnmarkAsFavorite(instanceId) + require.NoError(t, err) + + mockTimeProvider.On("NowUnix").Once().Return(5) + + instanceData, err = controller.GetInstanceData(instanceId) + require.NoError(t, err) + + require.Equal(t, false, instanceData.IsFavorite) +} + func TestRefreshAccessToken(t *testing.T) { startUrl := "https://test-start-url.aws-apps.com/start" region := "eu-west-1" + label := "test_label" + + controller, mockAws, mockTimeProvider := initController(t) - instanceId, controller, mockAws, mockTimeProvider := simulateSuccessfulSetup(t, startUrl, region) + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) mockAuthRes := awssso.AuthorizationResponse{ DeviceCode: "test-device-code-2", @@ -449,7 +587,7 @@ func TestRefreshAccessToken(t *testing.T) { ClientId: "test-client-id", StartUrl: "https://test-start-url.aws-apps.com/start", Region: "eu-west-1", - Label: "test-label", + Label: label, VerificationUri: "https://test-verification-url-2", UserCode: "test-user-code-2", DeviceCode: "test-device-code-2", @@ -460,8 +598,11 @@ func TestRefreshAccessToken(t *testing.T) { func TestFinalizeRefreshAccessToken(t *testing.T) { startUrl := "https://test-start-url.aws-apps.com/start" region := "eu-west-1" + label := "test_label" + + controller, mockAws, mockTimeProvider := initController(t) - instanceId, controller, mockAws, mockTimeProvider := simulateSuccessfulSetup(t, startUrl, region) + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) mockAuthRes := awssso.AuthorizationResponse{ DeviceCode: "test-device-code-2", @@ -493,8 +634,11 @@ func TestFinalizeRefreshAccessToken(t *testing.T) { func TestRefresh_NonExistentInstance(t *testing.T) { startUrl := "https://test-start-url.aws-apps.com/start" region := "eu-west-1" + label := "test_label" + + controller, mockAws, mockTimeProvider := initController(t) - _, controller, _, _ := simulateSuccessfulSetup(t, startUrl, region) + _ = simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) _, err := controller.RefreshAccessToken("well-if-u-can-find-me-it-sucks") require.Error(t, err, ErrInstanceWasNotFound) @@ -503,8 +647,11 @@ func TestRefresh_NonExistentInstance(t *testing.T) { func TestFinalizeRefreshAccessToken_InstanceDoesNotExist(t *testing.T) { startUrl := "https://test-start-url.aws-apps.com/start" region := "eu-west-1" + label := "test_label" - instanceId, controller, mockAws, mockTimeProvider := simulateSuccessfulSetup(t, startUrl, region) + controller, mockAws, mockTimeProvider := initController(t) + + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) mockAuthRes := awssso.AuthorizationResponse{ DeviceCode: "test-device-code-2", @@ -528,8 +675,11 @@ func TestFinalizeRefreshAccessToken_InstanceDoesNotExist(t *testing.T) { func TestFinalizeRefreshAccessToken_DeviceNotAuthorizedByUser(t *testing.T) { startUrl := "https://test-start-url.aws-apps.com/start" region := "eu-west-1" + label := "test_label" + + controller, mockAws, mockTimeProvider := initController(t) - instanceId, controller, mockAws, mockTimeProvider := simulateSuccessfulSetup(t, startUrl, region) + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) mockAuthRes := awssso.AuthorizationResponse{ DeviceCode: "test-device-code-2", @@ -552,8 +702,11 @@ func TestFinalizeRefreshAccessToken_DeviceNotAuthorizedByUser(t *testing.T) { func TestFinalizeRefreshAccessToken_DeviceAuthTimeout(t *testing.T) { startUrl := "https://test-start-url.aws-apps.com/start" region := "eu-west-1" + label := "test_label" + + controller, mockAws, mockTimeProvider := initController(t) - instanceId, controller, mockAws, mockTimeProvider := simulateSuccessfulSetup(t, startUrl, region) + instanceId := simulateSuccessfulSetup(t, controller, mockAws, mockTimeProvider, startUrl, region, label) mockAuthRes := awssso.AuthorizationResponse{ DeviceCode: "test-device-code-2", diff --git a/providers/metadata.go b/providers/metadata.go new file mode 100644 index 0000000..fe16cbf --- /dev/null +++ b/providers/metadata.go @@ -0,0 +1,20 @@ +package providers + +type ProviderMeta struct { + Code string + Name string + IconSvgBase64 string +} + +var ( + AwsIamIdc = "aws-iam-idc" +) + +var ( + SupportedProviders = map[string]ProviderMeta{ + AwsIamIdc: { + Code: AwsIamIdc, + Name: "AWS IAM IDC", + }, + } +)