Skip to content

Commit

Permalink
fix #3094: function derivative mutates the input expression when it…
Browse files Browse the repository at this point in the history
… fails
  • Loading branch information
josdejong committed Nov 15, 2023
1 parent 3d84b5b commit a1f3b7c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 50 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Fix #3087: extend function `mod` with support for negative divisors in when
using `BigNumber` or `Fraction`.
- Fix #3092: a typo in an error message when converting a string into a number.
- Fix #3094: function `derivative` mutates the input expression when it fails.


# 2023-10-26, 12.0.0
Expand Down
45 changes: 11 additions & 34 deletions src/function/algebra/derivative.js
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,6 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
},

'FunctionNode, Object': function (node, constNodes) {
if (node.args.length !== 1) {
funcArgsCheck(node)
}

if (constNodes[node] !== undefined) {
return createConstantNode(0)
}
Expand Down Expand Up @@ -303,9 +299,12 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
}
break
case 'pow':
constNodes[arg1] = constNodes[node.args[1]]
// Pass to pow operator node parser
return _derivative(new OperatorNode('^', 'pow', [arg0, node.args[1]]), constNodes)
if (node.args.length === 2) {
constNodes[arg1] = constNodes[node.args[1]]
// Pass to pow operator node parser
return _derivative(new OperatorNode('^', 'pow', [arg0, node.args[1]]), constNodes)
}
break
case 'exp':
// d/dx(e^x) = e^x
funcDerivative = new FunctionNode('exp', [arg0.clone()])
Expand Down Expand Up @@ -563,7 +562,9 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
])
break
case 'gamma': // Needs digamma function, d/dx(gamma(x)) = gamma(x)digamma(x)
default: throw new Error('Function "' + node.name + '" is not supported by derivative, or a wrong number of arguments is passed')
default:
throw new Error('Cannot process function "' + node.name + '" in derivative: ' +
'the function is not supported, undefined, or the number of arguments passed to it are not supported')
}

let op, func
Expand Down Expand Up @@ -740,35 +741,11 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
])
}

throw new Error('Operator "' + node.op + '" is not supported by derivative, or a wrong number of arguments is passed')
throw new Error('Cannot process operator "' + node.op + '" in derivative: ' +
'the operator is not supported, undefined, or the number of arguments passed to it are not supported')
}
})

/**
* Ensures the number of arguments for a function are correct,
* and will throw an error otherwise.
*
* @param {FunctionNode} node
*/
function funcArgsCheck (node) {
// TODO add min, max etc
if ((node.name === 'log' || node.name === 'nthRoot' || node.name === 'pow') && node.args.length === 2) {
return
}

// There should be an incorrect number of arguments if we reach here

// Change all args to constants to avoid unidentified
// symbol error when compiling function
for (let i = 0; i < node.args.length; ++i) {
node.args[i] = createConstantNode(0)
}

node.compile().evaluate()

throw new Error('Function "' + node.name + '" is not supported by derivative, or a wrong number of arguments is passed')
}

/**
* Helper function to create a constant node with a specific type
* (number, BigNumber, Fraction)
Expand Down
40 changes: 24 additions & 16 deletions test/unit-tests/function/algebra/derivative.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -229,30 +229,32 @@ describe('derivative', function () {
assert.throws(function () {
const node = new OperatorNode('/', 'myDivide', [c12, c4, x])
derivative(node, 'x')
}, /Error: Operator "\/" is not supported by derivative, or a wrong number of arguments is passed/)
}, /Error: Cannot process operator "\/" in derivative: the operator is not supported, undefined, or the number of arguments passed to it are not supported/)

assert.throws(function () {
const node = new OperatorNode('^', 'myPow', [c12, c4, x])
derivative(node, 'x')
}, /Error: Operator "\^" is not supported by derivative, or a wrong number of arguments is passed/)
}, /Error: Cannot process operator "\^" in derivative: the operator is not supported, undefined, or the number of arguments passed to it are not supported/)
})

it('should throw error if expressions contain unsupported operators or functions', function () {
assert.throws(function () { derivative('x << 2', 'x') }, /Error: Operator "<<" is not supported by derivative, or a wrong number of arguments is passed/)
assert.throws(function () { derivative('subset(x)', 'x') }, /Error: Function "subset" is not supported by derivative, or a wrong number of arguments is passed/)
assert.throws(function () { derivative('max(x)', 'x') }, /Error: Function "max" is not supported by derivative, or a wrong number of arguments is passed/)
assert.throws(function () { derivative('max(x, y)', 'x') }, /Error: Function "max" is not supported by derivative, or a wrong number of arguments is passed/)
assert.throws(function () { derivative('max(x, 1)', 'x') }, /Error: Function "max" is not supported by derivative, or a wrong number of arguments is passed/)
assert.throws(function () { derivative('add(2,3,x)', 'x') }, /Error: Function "add" is not supported by derivative, or a wrong number of arguments is passed/)
it('should not mutate the input expression', function () {
const expr = math.parse('min(x, y)')
assert.strictEqual(expr.toString(), 'min(x, y)')

assert.throws(() => { math.derivative(expr, 'x') },
/Error: Cannot process function "min" in derivative: the function is not supported, undefined, or the number of arguments passed to it are not supported/
)

assert.strictEqual(expr.toString(), 'min(x, y)')
})

it('should have controlled behavior on arguments errors', function () {
assert.throws(function () {
derivative('sqrt()', 'x')
}, /TypeError: Too few arguments in function sqrt \(expected: number or Complex or BigNumber or Unit or Fraction or string or boolean, index: 0\)/)
assert.throws(function () {
derivative('sqrt(12, 2x)', 'x')
}, /TypeError: Too many arguments in function sqrt \(expected: 1, actual: 2\)/)
it('should throw error if expressions contain unsupported operators or functions', function () {
assert.throws(function () { derivative('x << 2', 'x') }, /Error: Cannot process operator "<<" in derivative: the operator is not supported, undefined, or the number of arguments passed to it are not supported/)
assert.throws(function () { derivative('subset(x)', 'x') }, /Error: Cannot process function "subset" in derivative: the function is not supported, undefined, or the number of arguments passed to it are not supported/)
assert.throws(function () { derivative('max(x)', 'x') }, /Error: Cannot process function "max" in derivative: the function is not supported, undefined, or the number of arguments passed to it are not supported/)
assert.throws(function () { derivative('max(x, y)', 'x') }, /Error: Cannot process function "max" in derivative: the function is not supported, undefined, or the number of arguments passed to it are not supported/)
assert.throws(function () { derivative('max(x, 1)', 'x') }, /Error: Cannot process function "max" in derivative: the function is not supported, undefined, or the number of arguments passed to it are not supported/)
assert.throws(function () { derivative('add(2,3,x)', 'x') }, /Error: Cannot process function "add" in derivative: the function is not supported, undefined, or the number of arguments passed to it are not supported/)
})

it('should throw error for incorrect argument types', function () {
Expand All @@ -279,6 +281,12 @@ describe('derivative', function () {
}, /TypeError: Too many arguments in function derivative \(expected: 3, actual: 5\)/)
})

it('should throw error in case of an unknown function', function () {
assert.throws(function () {
derivative('foo(x)', 'x')
}, /Error: Cannot process function "foo" in derivative: the function is not supported, undefined, or the number of arguments passed to it are not supported/)
})

it('should LaTeX expressions involving derivative', function () {
compareString(math.parse('derivative(x*y,x)').toTex(), '{d\\over dx}\\left[ x\\cdot y\\right]')
compareString(math.parse('derivative("x*y",x)').toTex(), '{d\\over dx}\\left[x * y\\right]')
Expand Down

0 comments on commit a1f3b7c

Please sign in to comment.