Skip to content

Commit

Permalink
Merge pull request #806 from perazz/linalg_solve
Browse files Browse the repository at this point in the history
linalg: solve
  • Loading branch information
jvdp1 authored May 11, 2024
2 parents 60ce18f + 5832df5 commit 3bdcc82
Show file tree
Hide file tree
Showing 11 changed files with 653 additions and 3 deletions.
101 changes: 101 additions & 0 deletions doc/specs/stdlib_linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,107 @@ Specifically, upper Hessenberg matrices satisfy `a_ij = 0` when `j < i-1`, and l
{!example/linalg/example_is_hessenberg.f90!}
```

## `solve` - Solves a linear matrix equation or a linear system of equations.

### Status

Experimental

### Description

This function computes the solution to a linear matrix equation \( A \cdot x = b \), where \( A \) is a square, full-rank, `real` or `complex` matrix.

Result vector or array `x` returns the exact solution to within numerical precision, provided that the matrix is not ill-conditioned.
An error is returned if the matrix is rank-deficient or singular to working precision.
The solver is based on LAPACK's `*GESV` backends.

### Syntax

`Pure` interface:

`x = ` [[stdlib_linalg(module):solve(interface)]] `(a, b)`

Expert interface:

`x = ` [[stdlib_linalg(module):solve(interface)]] `(a, b [, overwrite_a], err)`

### Arguments

`a`: Shall be a rank-2 `real` or `complex` square array containing the coefficient matrix. It is normally an `intent(in)` argument. If `overwrite_a=.true.`, it is an `intent(inout)` argument and is destroyed by the call.

`b`: Shall be a rank-1 or rank-2 array of the same kind as `a`, containing the right-hand-side vector(s). It is an `intent(in)` argument.

`overwrite_a` (optional): Shall be an input logical flag. if `.true.`, input matrix `a` will be used as temporary storage and overwritten. This avoids internal data allocation. This is an `intent(in)` argument.

`err` (optional): Shall be a `type(linalg_state_type)` value. This is an `intent(out)` argument. The function is not `pure` if this argument is provided.

### Return value

For a full-rank matrix, returns an array value that represents the solution to the linear system of equations.

Raises `LINALG_ERROR` if the matrix is singular to working precision.
Raises `LINALG_VALUE_ERROR` if the matrix and rhs vectors have invalid/incompatible sizes.
If `err` is not present, exceptions trigger an `error stop`.

### Example

```fortran
{!example/linalg/example_solve1.f90!}
{!example/linalg/example_solve2.f90!}
```

## `solve_lu` - Solves a linear matrix equation or a linear system of equations (subroutine interface).

### Status

Experimental

### Description

This subroutine computes the solution to a linear matrix equation \( A \cdot x = b \), where \( A \) is a square, full-rank, `real` or `complex` matrix.

Result vector or array `x` returns the exact solution to within numerical precision, provided that the matrix is not ill-conditioned.
An error is returned if the matrix is rank-deficient or singular to working precision.
If all optional arrays are provided by the user, no internal allocations take place.
The solver is based on LAPACK's `*GESV` backends.

### Syntax

Simple (`Pure`) interface:

`call ` [[stdlib_linalg(module):solve_lu(interface)]] `(a, b, x)`

Expert (`Pure`) interface:

`call ` [[stdlib_linalg(module):solve_lu(interface)]] `(a, b, x [, pivot, overwrite_a, err])`

### Arguments

`a`: Shall be a rank-2 `real` or `complex` square array containing the coefficient matrix. It is normally an `intent(in)` argument. If `overwrite_a=.true.`, it is an `intent(inout)` argument and is destroyed by the call.

`b`: Shall be a rank-1 or rank-2 array of the same kind as `a`, containing the right-hand-side vector(s). It is an `intent(in)` argument.

`x`: Shall be a rank-1 or rank-2 array of the same kind and size as `b`, that returns the solution(s) to the system. It is an `intent(inout)` argument, and must have the `contiguous` property.

`pivot` (optional): Shall be a rank-1 array of the same kind and matrix dimension as `a`, providing storage for the diagonal pivot indices. It is an `intent(inout)` arguments, and returns the diagonal pivot indices.

`overwrite_a` (optional): Shall be an input logical flag. if `.true.`, input matrix `a` will be used as temporary storage and overwritten. This avoids internal data allocation. This is an `intent(in)` argument.

### Return value

For a full-rank matrix, returns an array value that represents the solution to the linear system of equations.

Raises `LINALG_ERROR` if the matrix is singular to working precision.
Raises `LINALG_VALUE_ERROR` if the matrix and rhs vectors have invalid/incompatible sizes.
If `err` is not present, exceptions trigger an `error stop`.

### Example

```fortran
{!example/linalg/example_solve3.f90!}
```

## `lstsq` - Computes the least squares solution to a linear matrix equation.

### Status
Expand Down
3 changes: 3 additions & 0 deletions example/linalg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,8 @@ ADD_EXAMPLE(blas_gemv)
ADD_EXAMPLE(lapack_getrf)
ADD_EXAMPLE(lstsq1)
ADD_EXAMPLE(lstsq2)
ADD_EXAMPLE(solve1)
ADD_EXAMPLE(solve2)
ADD_EXAMPLE(solve3)
ADD_EXAMPLE(determinant)
ADD_EXAMPLE(determinant2)
26 changes: 26 additions & 0 deletions example/linalg/example_solve1.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
program example_solve1
use stdlib_linalg_constants, only: sp
use stdlib_linalg, only: solve, linalg_state_type
implicit none

real(sp), allocatable :: A(:,:),b(:),x(:)

! Solve a system of 3 linear equations:
! 4x + 3y + 2z = 25
! -2x + 2y + 3z = -10
! 3x - 5y + 2z = -4

! Note: Fortran is column-major! -> transpose
A = transpose(reshape([ 4, 3, 2, &
-2, 2, 3, &
3,-5, 2], [3,3]))
b = [25,-10,-4]

! Get coefficients of y = coef(1) + x*coef(2) + x^2*coef(3)
x = solve(A,b)

print *, 'solution: ',x
! 5.0, 3.0, -2.0

end program example_solve1

26 changes: 26 additions & 0 deletions example/linalg/example_solve2.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
program example_solve2
use stdlib_linalg_constants, only: sp
use stdlib_linalg, only: solve, linalg_state_type
implicit none

complex(sp), allocatable :: A(:,:),b(:),x(:)

! Solve a system of 3 complex linear equations:
! 2x + iy + 2z = (5-i)
! -ix + (4-3i)y + 6z = i
! 4x + 3y + z = 1

! Note: Fortran is column-major! -> transpose
A = transpose(reshape([(2.0, 0.0),(0.0, 1.0),(2.0,0.0), &
(0.0,-1.0),(4.0,-3.0),(6.0,0.0), &
(4.0, 0.0),(3.0, 0.0),(1.0,0.0)] , [3,3]))
b = [(5.0,-1.0),(0.0,1.0),(1.0,0.0)]

! Get coefficients of y = coef(1) + x*coef(2) + x^2*coef(3)
x = solve(A,b)

print *, 'solution: ',x
! (1.0947,0.3674) (-1.519,-0.4539) (1.1784,-0.1078)

end program example_solve2

32 changes: 32 additions & 0 deletions example/linalg/example_solve3.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
program example_solve3
use stdlib_linalg_constants, only: sp,ilp
use stdlib_linalg, only: solve_lu, linalg_state_type
implicit none

integer(ilp) :: test
integer(ilp), allocatable :: pivot(:)
complex(sp), allocatable :: A(:,:),b(:),x(:)

! Solve a system of 3 complex linear equations:
! 2x + iy + 2z = (5-i)
! -ix + (4-3i)y + 6z = i
! 4x + 3y + z = 1

! Note: Fortran is column-major! -> transpose
A = transpose(reshape([(2.0, 0.0),(0.0, 1.0),(2.0,0.0), &
(0.0,-1.0),(4.0,-3.0),(6.0,0.0), &
(4.0, 0.0),(3.0, 0.0),(1.0,0.0)] , [3,3]))

! Pre-allocate x
allocate(b(size(A,2)),pivot(size(A,2)))
allocate(x,mold=b)

! Call system many times avoiding reallocation
do test=1,100
b = test*[(5.0,-1.0),(0.0,1.0),(1.0,0.0)]
call solve_lu(A,b,x,pivot)
print "(i3,'-th solution: ',*(1x,f12.6))", test,x
end do

end program example_solve3

1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ set(fppFiles
stdlib_linalg_outer_product.fypp
stdlib_linalg_kronecker.fypp
stdlib_linalg_cross_product.fypp
stdlib_linalg_solve.fypp
stdlib_linalg_determinant.fypp
stdlib_linalg_state.fypp
stdlib_optval.fypp
Expand Down
96 changes: 96 additions & 0 deletions src/stdlib_linalg.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ module stdlib_linalg
public :: eye
public :: lstsq
public :: lstsq_space
public :: solve
public :: solve_lu
public :: solve_lstsq
public :: trace
public :: outer_product
Expand Down Expand Up @@ -228,6 +230,100 @@ module stdlib_linalg
#:endfor
end interface is_hessenberg

! Solve linear system system Ax=b.
interface solve
!! version: experimental
!!
!! Solves the linear system \( A \cdot x = b \) for the unknown vector \( x \) from a square matrix \( A \).
!! ([Specification](../page/specs/stdlib_linalg.html#solve-solves-a-linear-matrix-equation-or-a-linear-system-of-equations))
!!
!!### Summary
!! Interface for solving a linear system arising from a general matrix.
!!
!!### Description
!!
!! This interface provides methods for computing the solution of a linear matrix system.
!! Supported data types include `real` and `complex`. No assumption is made on the matrix
!! structure.
!! The function can solve simultaneously either one (from a 1-d right-hand-side vector `b(:)`)
!! or several (from a 2-d right-hand-side vector `b(:,:)`) systems.
!!
!!@note The solution is based on LAPACK's generic LU decomposition based solvers `*GESV`.
!!@note BLAS/LAPACK backends do not currently support extended precision (``xdp``).
!!
#:for nd,ndsuf,nde in ALL_RHS
#:for rk,rt,ri in RC_KINDS_TYPES
#:if rk!="xdp"
module function stdlib_linalg_${ri}$_solve_${ndsuf}$(a,b,overwrite_a,err) result(x)
!> Input matrix a[n,n]
${rt}$, intent(inout), target :: a(:,:)
!> Right hand side vector or array, b[n] or b[n,nrhs]
${rt}$, intent(in) :: b${nd}$
!> [optional] Can A data be overwritten and destroyed?
logical(lk), optional, intent(in) :: overwrite_a
!> [optional] state return flag. On error if not requested, the code will stop
type(linalg_state_type), intent(out) :: err
!> Result array/matrix x[n] or x[n,nrhs]
${rt}$, allocatable, target :: x${nd}$
end function stdlib_linalg_${ri}$_solve_${ndsuf}$
pure module function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$(a,b) result(x)
!> Input matrix a[n,n]
${rt}$, intent(in) :: a(:,:)
!> Right hand side vector or array, b[n] or b[n,nrhs]
${rt}$, intent(in) :: b${nd}$
!> Result array/matrix x[n] or x[n,nrhs]
${rt}$, allocatable, target :: x${nd}$
end function stdlib_linalg_${ri}$_pure_solve_${ndsuf}$
#:endif
#:endfor
#:endfor
end interface solve
! Solve linear system Ax = b using LU decomposition (subroutine interface).
interface solve_lu
!! version: experimental
!!
!! Solves the linear system \( A \cdot x = b \) for the unknown vector \( x \) from a square matrix \( A \).
!! ([Specification](../page/specs/stdlib_linalg.html#solve-lu-solves-a-linear-matrix-equation-or-a-linear-system-of-equations-subroutine-interface))
!!
!!### Summary
!! Subroutine interface for solving a linear system using LU decomposition.
!!
!!### Description
!!
!! This interface provides methods for computing the solution of a linear matrix system using
!! a subroutine. Supported data types include `real` and `complex`. No assumption is made on the matrix
!! structure. Preallocated space for the solution vector `x` is user-provided, and it may be provided
!! for the array of pivot indices, `pivot`. If all pre-allocated work spaces are provided, no internal
!! memory allocations take place when using this interface.
!! The function can solve simultaneously either one (from a 1-d right-hand-side vector `b(:)`)
!! or several (from a 2-d right-hand-side vector `b(:,:)`) systems.
!!
!!@note The solution is based on LAPACK's generic LU decomposition based solvers `*GESV`.
!!@note BLAS/LAPACK backends do not currently support extended precision (``xdp``).
!!
#:for nd,ndsuf,nde in ALL_RHS
#:for rk,rt,ri in RC_KINDS_TYPES
#:if rk!="xdp"
pure module subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$(a,b,x,pivot,overwrite_a,err)
!> Input matrix a[n,n]
${rt}$, intent(inout), target :: a(:,:)
!> Right hand side vector or array, b[n] or b[n,nrhs]
${rt}$, intent(in) :: b${nd}$
!> Result array/matrix x[n] or x[n,nrhs]
${rt}$, intent(inout), contiguous, target :: x${nd}$
!> [optional] Storage array for the diagonal pivot indices
integer(ilp), optional, intent(inout), target :: pivot(:)
!> [optional] Can A data be overwritten and destroyed?
logical(lk), optional, intent(in) :: overwrite_a
!> [optional] state return flag. On error if not requested, the code will stop
type(linalg_state_type), optional, intent(out) :: err
end subroutine stdlib_linalg_${ri}$_solve_lu_${ndsuf}$
#:endif
#:endfor
#:endfor
end interface solve_lu

! Least squares solution to system Ax=b, i.e. such that the 2-norm abs(b-Ax) is minimized.
interface lstsq
!! version: experimental
Expand Down
Loading

0 comments on commit 3bdcc82

Please sign in to comment.