Skip to content

Commit

Permalink
feat: add onConnect handler to <ConnectWallet /> (#1529)
Browse files Browse the repository at this point in the history
Co-authored-by: dschlabach <[email protected]>
  • Loading branch information
dschlabach and dschlabach authored Nov 1, 2024
1 parent 2c4e0a1 commit d4724e9
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 11 deletions.
5 changes: 5 additions & 0 deletions .changeset/good-beans-invent.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@coinbase/onchainkit': minor
---

feat: add `onConnect` handler to `<ConnectWallet />`. By @dschlabach #1529
105 changes: 96 additions & 9 deletions src/wallet/components/ConnectWallet.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ describe('ConnectWallet', () => {
expect(connectedText).toBeInTheDocument();
});

it('should calls connect function when connect button is clicked', () => {
it('should call connect function when connect button is clicked', () => {
const connectMock = vi.fn();
vi.mocked(useConnect).mockReturnValue({
connectors: [{ id: 'mockConnector' }],
Expand All @@ -98,9 +98,14 @@ describe('ConnectWallet', () => {
render(<ConnectWallet text="Connect Wallet" />);
const button = screen.getByTestId('ockConnectButton');
fireEvent.click(button);
expect(connectMock).toHaveBeenCalledWith({
connector: { id: 'mockConnector' },
});
expect(connectMock).toHaveBeenCalledWith(
{
connector: { id: 'mockConnector' },
},
{
onSuccess: expect.any(Function),
},
);
});

it('should toggle wallet modal on button click when connected', () => {
Expand Down Expand Up @@ -162,6 +167,56 @@ describe('ConnectWallet', () => {
expect(screen.queryByText('Not Render')).not.toBeInTheDocument();
});

it('should call onConnect callback when connect button is clicked', async () => {
const mockUseAccount = vi.mocked(useAccount);
const connectMock = vi.fn();
const onConnectMock = vi.fn();

// Initial state: disconnected
mockUseAccount.mockReturnValue({
address: undefined,
status: 'disconnected',
});

vi.mocked(useConnect).mockReturnValue({
connectors: [{ id: 'mockConnector' }],
connect: connectMock,
status: 'idle',
});

render(<ConnectWallet text="Connect Wallet" onConnect={onConnectMock} />);

const button = screen.getByTestId('ockConnectButton');
fireEvent.click(button);

// Simulate successful connection
connectMock.mock.calls[0][1].onSuccess();

// Update account status to connected
mockUseAccount.mockReturnValue({
address: '0x123',
status: 'connected',
});

// Force a re-render to trigger the useEffect
render(<ConnectWallet text="Connect Wallet" onConnect={onConnectMock} />);

expect(onConnectMock).toHaveBeenCalledTimes(1);
});

it('should not call onConnect callback when component is first mounted', () => {
const mockUseAccount = vi.mocked(useAccount);
mockUseAccount.mockReturnValue({
address: '0x123',
status: 'connected',
});

const onConnectMock = vi.fn();
render(<ConnectWallet text="Connect Wallet" onConnect={onConnectMock} />);

expect(onConnectMock).toHaveBeenCalledTimes(0);
});

describe('withWalletAggregator', () => {
beforeEach(() => {
vi.mocked(useAccount).mockReturnValue({
Expand All @@ -175,7 +230,7 @@ describe('ConnectWallet', () => {
});
});

it('should render ConnectButtonRainboKit when withWalletAggregator is true', () => {
it('should render ConnectButtonRainbowKit when withWalletAggregator is true', () => {
render(
<ConnectWallet text="Connect Wallet" withWalletAggregator={true} />,
);
Expand All @@ -198,12 +253,17 @@ describe('ConnectWallet', () => {
);
const connectButton = screen.getByTestId('ockConnectButton');
fireEvent.click(connectButton);
expect(connectMock).toHaveBeenCalledWith({
connector: { id: 'mockConnector' },
});
expect(connectMock).toHaveBeenCalledWith(
{
connector: { id: 'mockConnector' },
},
{
onSuccess: expect.any(Function),
},
);
});

it('should calls openConnectModal function when connect button is clicked', () => {
it('should call openConnectModal function when connect button is clicked', () => {
vi.mocked(useWalletContext).mockReturnValue({
isOpen: false,
setIsOpen: vi.fn(),
Expand All @@ -215,5 +275,32 @@ describe('ConnectWallet', () => {
fireEvent.click(button);
expect(openConnectModalMock).toHaveBeenCalled();
});

it('should call onConnect callback when connect button is clicked', () => {
const mockUseAccount = vi.mocked(useAccount);
mockUseAccount.mockReturnValue({
address: undefined,
status: 'disconnected',
});

const onConnectMock = vi.fn();
render(
<ConnectWallet
text="Connect Wallet"
onConnect={onConnectMock}
withWalletAggregator={true}
/>,
);
const button = screen.getByTestId('ockConnectButton');

mockUseAccount.mockReturnValue({
address: '0x123',
status: 'connected',
});

fireEvent.click(button);

expect(onConnectMock).toHaveBeenCalledTimes(1);
});
});
});
29 changes: 27 additions & 2 deletions src/wallet/components/ConnectWallet.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { ConnectButton as ConnectButtonRainbowKit } from '@rainbow-me/rainbowkit';
import { Children, isValidElement, useCallback, useMemo } from 'react';
import type { ReactNode } from 'react';
import { useEffect, useState } from 'react';
import { useAccount, useConnect } from 'wagmi';
import { IdentityProvider } from '../../identity/components/IdentityProvider';
import { Spinner } from '../../internal/components/Spinner';
Expand All @@ -24,12 +25,16 @@ export function ConnectWallet({
// but for now we will keep it for backward compatibility.
text = 'Connect Wallet',
withWalletAggregator = false,
onConnect,
}: ConnectWalletReact) {
// Core Hooks
const { isOpen, setIsOpen } = useWalletContext();
const { address: accountAddress, status } = useAccount();
const { connectors, connect, status: connectStatus } = useConnect();

// State
const [hasClickedConnect, setHasClickedConnect] = useState(false);

// Get connectWalletText from children when present,
// this is used to customize the connect wallet button text
const { connectWalletText } = useMemo(() => {
Expand Down Expand Up @@ -58,6 +63,14 @@ export function ConnectWallet({
setIsOpen(!isOpen);
}, [isOpen, setIsOpen]);

// Effects
useEffect(() => {
if (hasClickedConnect && status === 'connected' && onConnect) {
onConnect();
setHasClickedConnect(false);
}
}, [status, hasClickedConnect, onConnect]);

if (status === 'disconnected') {
if (withWalletAggregator) {
return (
Expand All @@ -67,7 +80,10 @@ export function ConnectWallet({
<ConnectButton
className={className}
connectWalletText={connectWalletText}
onClick={() => openConnectModal()}
onClick={() => {
openConnectModal();
setHasClickedConnect(true);
}}
text={text}
/>
</div>
Expand All @@ -80,7 +96,16 @@ export function ConnectWallet({
<ConnectButton
className={className}
connectWalletText={connectWalletText}
onClick={() => connect({ connector })}
onClick={() => {
connect(
{ connector },
{
onSuccess: () => {
onConnect?.();
},
},
);
}}
text={text}
/>
</div>
Expand Down
1 change: 1 addition & 0 deletions src/wallet/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export type ConnectWalletReact = {
/** @deprecated Prefer `ConnectWalletText component` */
text?: string; // Optional text override for button
withWalletAggregator?: boolean; // Optional flag to enable the wallet aggregator like RainbowKit
onConnect?: () => void; // Optional callback function to execute when the wallet is connected.
};

/**
Expand Down

0 comments on commit d4724e9

Please sign in to comment.