Skip to content

Commit

Permalink
Add usolveAll and lsolveAll methods (#1916)
Browse files Browse the repository at this point in the history
* refactor solveValidation

* refactor usolve

* usolve algorithm implemented (for square mat.)

* added lsolve, consistent return type, fixed tests

* fixed lusolve and its tests, fixed linting issues

* added tests for usolve&lsolve, try-catch in lusolve

* put changes into separate files (u-/lsolveAll), revert changes to u-, l- and lusolve

* made *solveAll return [] for non-solvable, implemented sparse algorithms

* improved documentation for *solve(All)

Co-authored-by: Jos de Jong <[email protected]>
  • Loading branch information
cshaa and josdejong authored Aug 29, 2020
1 parent 8e00dc3 commit ba4ff2f
Show file tree
Hide file tree
Showing 14 changed files with 939 additions and 222 deletions.
4 changes: 4 additions & 0 deletions src/expression/embeddedDocs/embeddedDocs.js
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,13 @@ import { addDocs } from './function/arithmetic/add'
import { absDocs } from './function/arithmetic/abs'
import { qrDocs } from './function/algebra/qr'
import { usolveDocs } from './function/algebra/usolve'
import { usolveAllDocs } from './function/algebra/usolveAll'
import { sluDocs } from './function/algebra/slu'
import { rationalizeDocs } from './function/algebra/rationalize'
import { simplifyDocs } from './function/algebra/simplify'
import { lupDocs } from './function/algebra/lup'
import { lsolveDocs } from './function/algebra/lsolve'
import { lsolveAllDocs } from './function/algebra/lsolveAll'
import { derivativeDocs } from './function/algebra/derivative'
import { versionDocs } from './constants/version'
import { trueDocs } from './constants/true'
Expand Down Expand Up @@ -310,12 +312,14 @@ export const embeddedDocs = {
// functions - algebra
derivative: derivativeDocs,
lsolve: lsolveDocs,
lsolveAll: lsolveAllDocs,
lup: lupDocs,
lusolve: lusolveDocs,
simplify: simplifyDocs,
rationalize: rationalizeDocs,
slu: sluDocs,
usolve: usolveDocs,
usolveAll: usolveAllDocs,
qr: qrDocs,

// functions - arithmetic
Expand Down
4 changes: 2 additions & 2 deletions src/expression/embeddedDocs/function/algebra/lsolve.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ export const lsolveDocs = {
'x=lsolve(L, b)'
],
description:
'Solves the linear system L * x = b where L is an [n x n] lower triangular matrix and b is a [n] column vector.',
'Finds one solution of the linear system L * x = b where L is an [n x n] lower triangular matrix and b is a [n] column vector.',
examples: [
'a = [-2, 3; 2, 1]',
'b = [11, 9]',
'x = lsolve(a, b)'
],
seealso: [
'lup', 'lusolve', 'usolve', 'matrix', 'sparse'
'lsolveAll', 'lup', 'lusolve', 'usolve', 'matrix', 'sparse'
]
}
17 changes: 17 additions & 0 deletions src/expression/embeddedDocs/function/algebra/lsolveAll.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
export const lsolveAllDocs = {
name: 'lsolveAll',
category: 'Algebra',
syntax: [
'x=lsolveAll(L, b)'
],
description:
'Finds all solutions of the linear system L * x = b where L is an [n x n] lower triangular matrix and b is a [n] column vector.',
examples: [
'a = [-2, 3; 2, 1]',
'b = [11, 9]',
'x = lsolve(a, b)'
],
seealso: [
'lsolve', 'lup', 'lusolve', 'usolve', 'matrix', 'sparse'
]
}
4 changes: 2 additions & 2 deletions src/expression/embeddedDocs/function/algebra/usolve.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ export const usolveDocs = {
'x=usolve(U, b)'
],
description:
'Solves the linear system U * x = b where U is an [n x n] upper triangular matrix and b is a [n] column vector.',
'Finds one solution of the linear system U * x = b where U is an [n x n] upper triangular matrix and b is a [n] column vector.',
examples: [
'x=usolve(sparse([1, 1, 1, 1; 0, 1, 1, 1; 0, 0, 1, 1; 0, 0, 0, 1]), [1; 2; 3; 4])'
],
seealso: [
'lup', 'lusolve', 'lsolve', 'matrix', 'sparse'
'usolveAll', 'lup', 'lusolve', 'lsolve', 'matrix', 'sparse'
]
}
15 changes: 15 additions & 0 deletions src/expression/embeddedDocs/function/algebra/usolveAll.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export const usolveAllDocs = {
name: 'usolveAll',
category: 'Algebra',
syntax: [
'x=usolve(U, b)'
],
description:
'Finds all solutions of the linear system U * x = b where U is an [n x n] upper triangular matrix and b is a [n] column vector.',
examples: [
'x=usolve(sparse([1, 1, 1, 1; 0, 1, 1, 1; 0, 0, 1, 1; 0, 0, 0, 1]), [1; 2; 3; 4])'
],
seealso: [
'usolve', 'lup', 'lusolve', 'lsolve', 'matrix', 'sparse'
]
}
2 changes: 2 additions & 0 deletions src/factoriesAny.js
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ export { createDotPow } from './function/arithmetic/dotPow'
export { createDotDivide } from './function/arithmetic/dotDivide'
export { createLsolve } from './function/algebra/solver/lsolve'
export { createUsolve } from './function/algebra/solver/usolve'
export { createLsolveAll } from './function/algebra/solver/lsolveAll'
export { createUsolveAll } from './function/algebra/solver/usolveAll'
export { createLeftShift } from './function/bitwise/leftShift'
export { createRightArithShift } from './function/bitwise/rightArithShift'
export { createRightLogShift } from './function/bitwise/rightLogShift'
Expand Down
110 changes: 52 additions & 58 deletions src/function/algebra/solver/lsolve.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export const createLsolve = /* #__PURE__ */ factory(name, dependencies, ({ typed
const solveValidation = createSolveValidation({ DenseMatrix })

/**
* Solves the linear equation system by forwards substitution. Matrix must be a lower triangular matrix.
* Finds one solution of a linear equation system by forwards substitution. Matrix must be a lower triangular matrix. Throws an error if there's no solution.
*
* `L * x = b`
*
Expand All @@ -32,7 +32,7 @@ export const createLsolve = /* #__PURE__ */ factory(name, dependencies, ({ typed
*
* See also:
*
* lup, slu, usolve, lusolve
* lsolveAll, lup, slu, usolve, lusolve
*
* @param {Matrix, Array} L A N x N matrix or array (L)
* @param {Matrix, Array} b A column vector with the b values
Expand All @@ -42,67 +42,61 @@ export const createLsolve = /* #__PURE__ */ factory(name, dependencies, ({ typed
return typed(name, {

'SparseMatrix, Array | Matrix': function (m, b) {
// process matrix
return _sparseForwardSubstitution(m, b)
},

'DenseMatrix, Array | Matrix': function (m, b) {
// process matrix
return _denseForwardSubstitution(m, b)
},

'Array, Array | Matrix': function (a, b) {
// create dense matrix from array
const m = matrix(a)
// use matrix implementation
const r = _denseForwardSubstitution(m, b)
// result
return r.valueOf()
}
})

function _denseForwardSubstitution (m, b) {
// validate matrix and vector, return copy of column vector b
b = solveValidation(m, b, true)
// column vector data
const bdata = b._data
// rows & columns

const rows = m._size[0]
const columns = m._size[1]

// result
const x = []
// data
const data = m._data
// forward solve m * x = b, loop columns

const mdata = m._data

// loop columns
for (let j = 0; j < columns; j++) {
// b[j]
const bj = bdata[j][0] || 0
// x[j]
let xj
// forward substitution (outer product) avoids inner looping when bj === 0

if (!equalScalar(bj, 0)) {
// value @ [j, j]
const vjj = data[j][j]
// check vjj
// non-degenerate row, find solution

const vjj = mdata[j][j]

if (equalScalar(vjj, 0)) {
// system cannot be solved
throw new Error('Linear system cannot be solved since matrix is singular')
}
// calculate xj

xj = divideScalar(bj, vjj)

// loop rows
for (let i = j + 1; i < rows; i++) {
// update copy of b
bdata[i] = [subtract(bdata[i][0] || 0, multiplyScalar(xj, data[i][j]))]
bdata[i] = [subtract(bdata[i][0] || 0, multiplyScalar(xj, mdata[i][j]))]
}
} else {
// zero @ j
// degenerate row, we can choose any value
xj = 0
}
// update x

x[j] = [xj]
}
// return vector

return new DenseMatrix({
data: x,
size: [rows, 1]
Expand All @@ -112,68 +106,68 @@ export const createLsolve = /* #__PURE__ */ factory(name, dependencies, ({ typed
function _sparseForwardSubstitution (m, b) {
// validate matrix and vector, return copy of column vector b
b = solveValidation(m, b, true)
// column vector data

const bdata = b._data
// rows & columns

const rows = m._size[0]
const columns = m._size[1]
// matrix arrays

const values = m._values
const index = m._index
const ptr = m._ptr
// vars
let i, k

// result
const x = []
// forward solve m * x = b, loop columns

// loop columns
for (let j = 0; j < columns; j++) {
// b[j]
const bj = bdata[j][0] || 0
// forward substitution (outer product) avoids inner looping when bj === 0

if (!equalScalar(bj, 0)) {
// value @ [j, j]
// non-degenerate row, find solution

let vjj = 0
// lower triangular matrix values & index (column j)
const jvalues = []
const jindex = []
// last index in column
let l = ptr[j + 1]
// values in column, find value @ [j, j]
for (k = ptr[j]; k < l; k++) {
// row
i = index[k]
// matrix values & indices (column j)
const jValues = []
const jIndices = []

// first and last index in the column
const firstIndex = ptr[j]
const lastIndex = ptr[j + 1]

// values in column, find value at [j, j]
for (let k = firstIndex; k < lastIndex; k++) {
const i = index[k]

// check row (rows are not sorted!)
if (i === j) {
// update vjj
vjj = values[k]
} else if (i > j) {
// store lower triangular
jvalues.push(values[k])
jindex.push(i)
jValues.push(values[k])
jIndices.push(i)
}
}
// at this point we must have a value @ [j, j]

// at this point we must have a value in vjj
if (equalScalar(vjj, 0)) {
// system cannot be solved, there is no value @ [j, j]
throw new Error('Linear system cannot be solved since matrix is singular')
}
// calculate xj

const xj = divideScalar(bj, vjj)
// loop lower triangular
for (k = 0, l = jindex.length; k < l; k++) {
// row
i = jindex[k]
// update copy of b
bdata[i] = [subtract(bdata[i][0] || 0, multiplyScalar(xj, jvalues[k]))]

for (let k = 0, l = jIndices.length; k < l; k++) {
const i = jIndices[k]
bdata[i] = [subtract(bdata[i][0] || 0, multiplyScalar(xj, jValues[k]))]
}
// update x

x[j] = [xj]
} else {
// update x
// degenerate row, we can choose any value
x[j] = [0]
}
}
// return vector

return new DenseMatrix({
data: x,
size: [rows, 1]
Expand Down
Loading

0 comments on commit ba4ff2f

Please sign in to comment.