Skip to content

Commit

Permalink
feat: allow extending toEqual (fix #2875) (#4880)
Browse files Browse the repository at this point in the history
Co-authored-by: Vladimir <[email protected]>
  • Loading branch information
tigranmk and sheremet-va authored Jan 12, 2024
1 parent 7f59a1b commit 463bee3
Show file tree
Hide file tree
Showing 11 changed files with 380 additions and 43 deletions.
52 changes: 52 additions & 0 deletions docs/api/expect.md
Original file line number Diff line number Diff line change
Expand Up @@ -1405,3 +1405,55 @@ Don't forget to include the ambient declaration file in your `tsconfig.json`.
:::tip
If you want to know more, checkout [guide on extending matchers](/guide/extending-matchers).
:::

## expect.addEqualityTesters <Badge type="info">1.2.0+</Badge>

- **Type:** `(tester: Array<Tester>) => void`

You can use this method to define custom testers, which are methods used by matchers, to test if two objects are equal. It is compatible with Jest's `expect.addEqualityTesters`.

```ts
import { expect, test } from 'vitest'

class AnagramComparator {
public word: string

constructor(word: string) {
this.word = word
}

equals(other: AnagramComparator): boolean {
const cleanStr1 = this.word.replace(/ /g, '').toLowerCase()
const cleanStr2 = other.word.replace(/ /g, '').toLowerCase()

const sortedStr1 = cleanStr1.split('').sort().join('')
const sortedStr2 = cleanStr2.split('').sort().join('')

return sortedStr1 === sortedStr2
}
}

function isAnagramComparator(a: unknown): a is AnagramComparator {
return a instanceof AnagramComparator
}

function areAnagramsEqual(a: unknown, b: unknown): boolean | undefined {
const isAAnagramComparator = isAnagramComparator(a)
const isBAnagramComparator = isAnagramComparator(b)

if (isAAnagramComparator && isBAnagramComparator)
return a.equals(b)

else if (isAAnagramComparator === isBAnagramComparator)
return undefined

else
return false
}

expect.addEqualityTesters([areAnagramsEqual])

test('custom equality tester', () => {
expect(new AnagramComparator('listen')).toEqual(new AnagramComparator('silent'))
})
```
1 change: 1 addition & 0 deletions packages/expect/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ export * from './constants'
export * from './types'
export { getState, setState } from './state'
export { JestChaiExpect } from './jest-expect'
export { addCustomEqualityTesters } from './jest-matcher-utils'
export { JestExtend } from './jest-extend'
export { setupColors } from '@vitest/utils'
10 changes: 6 additions & 4 deletions packages/expect/src/jest-asymmetric-matchers.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { ChaiPlugin, MatcherState } from './types'
import { GLOBAL_EXPECT } from './constants'
import { getState } from './state'
import { diff, getMatcherUtils, stringify } from './jest-matcher-utils'
import { diff, getCustomEqualityTesters, getMatcherUtils, stringify } from './jest-matcher-utils'

import { equals, isA, iterableEquality, pluralize, subsetEquality } from './jest-utils'

Expand All @@ -26,7 +26,7 @@ export abstract class AsymmetricMatcher<
...getState(expect || (globalThis as any)[GLOBAL_EXPECT]),
equals,
isNot: this.inverse,
customTesters: [],
customTesters: getCustomEqualityTesters(),
utils: {
...getMatcherUtils(),
diff,
Expand Down Expand Up @@ -116,8 +116,9 @@ export class ObjectContaining extends AsymmetricMatcher<Record<string, unknown>>

let result = true

const matcherContext = this.getMatcherContext()
for (const property in this.sample) {
if (!this.hasProperty(other, property) || !equals(this.sample[property], other[property])) {
if (!this.hasProperty(other, property) || !equals(this.sample[property], other[property], matcherContext.customTesters)) {
result = false
break
}
Expand Down Expand Up @@ -149,11 +150,12 @@ export class ArrayContaining<T = unknown> extends AsymmetricMatcher<Array<T>> {
)
}

const matcherContext = this.getMatcherContext()
const result
= this.sample.length === 0
|| (Array.isArray(other)
&& this.sample.every(item =>
other.some(another => equals(item, another)),
other.some(another => equals(item, another, matcherContext.customTesters)),
))

return this.inverse ? !result : result
Expand Down
21 changes: 12 additions & 9 deletions packages/expect/src/jest-expect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type { Test } from '@vitest/runner'
import type { Assertion, ChaiPlugin } from './types'
import { arrayBufferEquality, generateToBeMessage, iterableEquality, equals as jestEquals, sparseArrayEquality, subsetEquality, typeEquality } from './jest-utils'
import type { AsymmetricMatcher } from './jest-asymmetric-matchers'
import { diff, stringify } from './jest-matcher-utils'
import { diff, getCustomEqualityTesters, stringify } from './jest-matcher-utils'
import { JEST_MATCHERS_OBJECT } from './constants'
import { recordAsyncExpect, wrapSoft } from './utils'

Expand All @@ -23,6 +23,7 @@ declare class DOMTokenList {
export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
const { AssertionError } = chai
const c = () => getColors()
const customTesters = getCustomEqualityTesters()

function def(name: keyof Assertion | (keyof Assertion)[], fn: ((this: Chai.AssertionStatic & Assertion, ...args: any[]) => any)) {
const addMethod = (n: keyof Assertion) => {
Expand Down Expand Up @@ -80,7 +81,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
const equal = jestEquals(
actual,
expected,
[iterableEquality],
[...customTesters, iterableEquality],
)

return this.assert(
Expand All @@ -98,6 +99,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
obj,
expected,
[
...customTesters,
iterableEquality,
typeEquality,
sparseArrayEquality,
Expand Down Expand Up @@ -125,6 +127,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
actual,
expected,
[
...customTesters,
iterableEquality,
typeEquality,
sparseArrayEquality,
Expand All @@ -140,7 +143,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
const toEqualPass = jestEquals(
actual,
expected,
[iterableEquality],
[...customTesters, iterableEquality],
)

if (toEqualPass)
Expand All @@ -159,7 +162,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
def('toMatchObject', function (expected) {
const actual = this._obj
return this.assert(
jestEquals(actual, expected, [iterableEquality, subsetEquality]),
jestEquals(actual, expected, [...customTesters, iterableEquality, subsetEquality]),
'expected #{this} to match object #{exp}',
'expected #{this} to not match object #{exp}',
expected,
Expand Down Expand Up @@ -208,7 +211,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
def('toContainEqual', function (expected) {
const obj = utils.flag(this, 'object')
const index = Array.from(obj).findIndex((item) => {
return jestEquals(item, expected)
return jestEquals(item, expected, customTesters)
})

this.assert(
Expand Down Expand Up @@ -339,7 +342,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
return utils.getPathInfo(actual, propertyName)
}
const { value, exists } = getValue()
const pass = exists && (args.length === 1 || jestEquals(expected, value))
const pass = exists && (args.length === 1 || jestEquals(expected, value, customTesters))

const valueString = args.length === 1 ? '' : ` with value ${utils.objDisplay(expected)}`

Expand Down Expand Up @@ -482,7 +485,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
def(['toHaveBeenCalledWith', 'toBeCalledWith'], function (...args) {
const spy = getSpy(this)
const spyName = spy.getMockName()
const pass = spy.mock.calls.some(callArg => jestEquals(callArg, args, [iterableEquality]))
const pass = spy.mock.calls.some(callArg => jestEquals(callArg, args, [...customTesters, iterableEquality]))
const isNot = utils.flag(this, 'negate') as boolean

const msg = utils.getMessage(
Expand All @@ -504,7 +507,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
const nthCall = spy.mock.calls[times - 1]

this.assert(
jestEquals(nthCall, args, [iterableEquality]),
jestEquals(nthCall, args, [...customTesters, iterableEquality]),
`expected ${ordinalOf(times)} "${spyName}" call to have been called with #{exp}`,
`expected ${ordinalOf(times)} "${spyName}" call to not have been called with #{exp}`,
args,
Expand All @@ -517,7 +520,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
const lastCall = spy.mock.calls[spy.mock.calls.length - 1]

this.assert(
jestEquals(lastCall, args, [iterableEquality]),
jestEquals(lastCall, args, [...customTesters, iterableEquality]),
`expected last "${spyName}" call to have been called with #{exp}`,
`expected last "${spyName}" call to not have been called with #{exp}`,
args,
Expand Down
5 changes: 2 additions & 3 deletions packages/expect/src/jest-extend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { ASYMMETRIC_MATCHERS_OBJECT, JEST_MATCHERS_OBJECT } from './constants'
import { AsymmetricMatcher } from './jest-asymmetric-matchers'
import { getState } from './state'

import { diff, getMatcherUtils, stringify } from './jest-matcher-utils'
import { diff, getCustomEqualityTesters, getMatcherUtils, stringify } from './jest-matcher-utils'

import {
equals,
Expand All @@ -33,8 +33,7 @@ function getMatcherState(assertion: Chai.AssertionStatic & Chai.Assertion, expec

const matcherState: MatcherState = {
...getState(expect),
// TODO: implement via expect.addEqualityTesters
customTesters: [],
customTesters: getCustomEqualityTesters(),
isNot,
utils: jestUtils,
promise,
Expand Down
23 changes: 21 additions & 2 deletions packages/expect/src/jest-matcher-utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { getColors, stringify } from '@vitest/utils'
import type { MatcherHintOptions } from './types'
import { getColors, getType, stringify } from '@vitest/utils'
import type { MatcherHintOptions, Tester } from './types'
import { JEST_MATCHERS_OBJECT } from './constants'

export { diff } from '@vitest/utils/diff'
export { stringify }
Expand Down Expand Up @@ -101,3 +102,21 @@ export function getMatcherUtils() {
printExpected,
}
}

export function addCustomEqualityTesters(newTesters: Array<Tester>): void {
if (!Array.isArray(newTesters)) {
throw new TypeError(
`expect.customEqualityTesters: Must be set to an array of Testers. Was given "${getType(
newTesters,
)}"`,
)
}

(globalThis as any)[JEST_MATCHERS_OBJECT].customEqualityTesters.push(
...newTesters,
)
}

export function getCustomEqualityTesters(): Array<Tester> {
return (globalThis as any)[JEST_MATCHERS_OBJECT].customEqualityTesters
}
Loading

0 comments on commit 463bee3

Please sign in to comment.