Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix binary power #63

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 83 additions & 29 deletions Sources/Towel/Mathematics/Matrix.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Towel.Measurements;
Expand Down Expand Up @@ -1134,68 +1135,121 @@ public Matrix<T> Divide(T b)

#region Power (Matrix ^ Scalar)

private class SquareMatrixFactory
{
private List<Matrix<T>> cache = new List<Matrix<T>>();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is probably better to replace List with your type to keep the "ecosystem"

private int pointer = 0;
public readonly int DiagonalLength;
public SquareMatrixFactory(int diagonalLength)
{
this.DiagonalLength = diagonalLength;
}

public Matrix<T> Get()
{
if (pointer == cache.Count)
cache.Add(new Matrix<T>(DiagonalLength, DiagonalLength));
var res = cache[pointer];
pointer++;
return res;
}

public void Return()
{
pointer--;
}
}

[ThreadStatic]
private static SquareMatrixFactory squareMatrixFactory;


// Approach to
// needed: a ^ (-13)
// b = a ^ -1
// needed: b ^ 13
// mp2 = b ^ 6 (goto beginning)
// needed: mp2 * mp2 * b
// that's it, works for O(log(power))
internal static void PowerPositiveSafe(Matrix<T> a, int power, ref Matrix<T> destination)
{
var dest = squareMatrixFactory.Get();
Power(a, power / 2, ref dest);
var mp2 = squareMatrixFactory.Get();
Multiply(dest, dest, ref mp2);
if (power % 2 == 1)
{
var tmp = squareMatrixFactory.Get();
Multiply(mp2, a, ref tmp);
mp2 = tmp;
squareMatrixFactory.Return();
}
destination = mp2;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof that was supposed to be .Clone I guess

squareMatrixFactory.Return();
squareMatrixFactory.Return();
}

/// <summary>Applies a power to a square matrix.</summary>
/// <param name="a">The matrix to be powered by.</param>
/// <param name="b">The power to apply to the matrix.</param>
/// <param name="c">The resulting matrix of the power operation.</param>
public static void Power(Matrix<T> a, int b, ref Matrix<T> c)
/// <param name="power">The power to apply to the matrix.</param>
/// <param name="destination">The resulting matrix of the power operation.</param>
public static void Power(Matrix<T> a, int power, ref Matrix<T> destination)
{
_ = a ?? throw new ArgumentNullException(nameof(a));
if (!a.IsSquare)
{
throw new MathematicsException("Invalid power (!" + nameof(a) + ".IsSquare)");
}
if (b < 0)
if (power < 0)
{
throw new ArgumentOutOfRangeException(nameof(b), b, "!(" + nameof(b) + " >= 0)");
Power(a.Inverse(), -power, ref destination);
return;
}
if (b == 0)
if (power == 0)
{
if (!(c is null) && c._matrix.Length == a._matrix.Length)
if (!(destination is null) && destination.IsSquare)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better use IsSquare to keep the code readable, the power operation is extremely pricy, 1-2 ns isn't worth saving

{
c._rows = a._rows;
c._columns = a._columns;
Format(c, (x, y) => x == y ? Constant<T>.One : Constant<T>.Zero);
destination._rows = a._rows;
destination._columns = a._columns;
Format(destination, (x, y) => x == y ? Constant<T>.One : Constant<T>.Zero);
}
else
{
c = Matrix<T>.FactoryIdentity(a._rows, a._columns);
destination = Matrix<T>.FactoryIdentity(a._rows, a._columns);
}
return;
}
if (!(c is null) && c._matrix.Length == a._matrix.Length)
if (!(destination is null) && destination._matrix.Length == a._matrix.Length)
{
c._rows = a._rows;
c._columns = a._columns;
destination._rows = a._rows;
destination._columns = a._columns;
T[] A = a._matrix;
T[] C = c._matrix;
T[] C = destination._matrix;
for (int i = 0; i < a._matrix.Length; i++)
{
C[i] = A[i];
}
}
else
{
c = a.Clone();
}
Matrix<T> d = new Matrix<T>(a._rows, a._columns, a._matrix.Length);
for (int i = 0; i < b; i++)
{
Multiply(c, a, ref d);
Matrix<T> temp = d;
d = c;
c = d;
destination = a.Clone();
}
}
if (power == 1)
return;
if (squareMatrixFactory is null || squareMatrixFactory.DiagonalLength != a._rows)
squareMatrixFactory = new SquareMatrixFactory(a._rows);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you already have a square matrix factory, this should be replaced with your type. Also my factory won't work if we change matrix' size after each operation, e. g.

Power(A, 3);
Power(A, 4);
Power(A, -13);

Is fast because we pull square matrices from the factory, while

Power(A, 3);
Power(B, 4);
Power(A, -13);

Because we re-create the factory in the second and third lines

PowerPositiveSafe(a, power, ref destination);
destination = destination.Clone();
}

/// <summary>Applies a power to a square matrix.</summary>
/// <param name="a">The matrix to be powered by.</param>
/// <param name="b">The power to apply to the matrix.</param>
/// <param name="power">The power to apply to the matrix.</param>
/// <returns>The resulting matrix of the power operation.</returns>
public static Matrix<T> Power(Matrix<T> a, int b)
public static Matrix<T> Power(Matrix<T> a, int power)
{
Matrix<T> c = null;
Power(a, b, ref c);
Power(a, power, ref c);
return c;
}

Expand Down Expand Up @@ -2393,7 +2447,7 @@ public bool Equal(Matrix<T> b)
return this == b;
}

#endregion
#endregion

#region Equal (+leniency)

Expand Down
6 changes: 3 additions & 3 deletions Sources/Towel/Towel.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16849,13 +16849,13 @@
<member name="M:Towel.Mathematics.Matrix`1.Power(Towel.Mathematics.Matrix{`0},System.Int32,Towel.Mathematics.Matrix{`0}@)">
<summary>Applies a power to a square matrix.</summary>
<param name="a">The matrix to be powered by.</param>
<param name="b">The power to apply to the matrix.</param>
<param name="c">The resulting matrix of the power operation.</param>
<param name="power">The power to apply to the matrix.</param>
<param name="destination">The resulting matrix of the power operation.</param>
</member>
<member name="M:Towel.Mathematics.Matrix`1.Power(Towel.Mathematics.Matrix{`0},System.Int32)">
<summary>Applies a power to a square matrix.</summary>
<param name="a">The matrix to be powered by.</param>
<param name="b">The power to apply to the matrix.</param>
<param name="power">The power to apply to the matrix.</param>
<returns>The resulting matrix of the power operation.</returns>
</member>
<member name="M:Towel.Mathematics.Matrix`1.op_ExclusiveOr(Towel.Mathematics.Matrix{`0},System.Int32)">
Expand Down
94 changes: 74 additions & 20 deletions Tools/Towel_Testing/Mathematics/Matrix.cs
Original file line number Diff line number Diff line change
Expand Up @@ -704,57 +704,111 @@ [TestMethod] public void Power()
{ 1, 2, },
{ 3, 4, },
};
Matrix<int> B = new int[,]
Matrix<int> A3 = new int[,]
{
{ 37, 54, },
{ 81, 118, },
};
Assert.IsTrue((A ^ 3) == B);
Matrix<int> A5 = new int[,]
{
{ 1069, 1558, },
{ 2337, 3406, },
};
Matrix<int> Am6 = new int[,]
{
{ 169, -70, },
{ -70, 29, },
};
var act3 = A ^ 3;
var act5 = A ^ 5;
var actm6 = A ^ (-6);
Assert.AreEqual(A3, act3);
Assert.AreEqual(A5, act5);
Assert.AreEqual(Am6, actm6);
}

// float
{
Matrix<float> A = new float[,]
{
{ 1f, 2f, },
{ 3f, 4f, },
{ 1, 2, },
{ 3, 4, },
};
Matrix<float> B = new float[,]
Matrix<float> A3 = new float[,]
{
{ 37, 54, },
{ 81, 118, },
};
Matrix<float> A5 = new float[,]
{
{ 37f, 54f, },
{ 81f, 118f, },
{ 1069, 1558, },
{ 2337, 3406, },
};
Assert.IsTrue((A ^ 3) == B);
var act3 = A ^ 3;
var act5 = A ^ 5;
var actm6 = A ^ (-6);
var act6 = A ^ 6;
Assert.AreEqual(A3, act3);
Assert.AreEqual(A5, act5);
var I = Matrix<float>.FactoryIdentity(2, 2);
var multiplied = actm6 * act6;
Assert.AreEqual(I, multiplied);
}

// double
{
Matrix<double> A = new double[,]
{
{ 1d, 2d, },
{ 3d, 4d, },
{ 1, 2, },
{ 3, 4, },
};
Matrix<double> B = new double[,]
Matrix<double> A3 = new double[,]
{
{ 37, 54, },
{ 81, 118, },
};
Matrix<double> A5 = new double[,]
{
{ 37d, 54d, },
{ 81d, 118d, },
{ 1069, 1558, },
{ 2337, 3406, },
};
Assert.IsTrue((A ^ 3) == B);
var act3 = A ^ 3;
var act5 = A ^ 5;
var actm6 = A ^ (-6);
var act6 = A ^ 6;
Assert.AreEqual(A3, act3);
Assert.AreEqual(A5, act5);
var I = Matrix<double>.FactoryIdentity(2, 2);
var multiplied = actm6 * act6;
Assert.AreEqual(I, multiplied);
}

// decimal
{
Matrix<decimal> A = new decimal[,]
{
{ 1m, 2m, },
{ 3m, 4m, },
{ 1, 2, },
{ 3, 4, },
};
Matrix<decimal> B = new decimal[,]
Matrix<decimal> A3 = new decimal[,]
{
{ 37, 54, },
{ 81, 118, },
};
Matrix<decimal> A5 = new decimal[,]
{
{ 37m, 54m, },
{ 81m, 118m, },
{ 1069, 1558, },
{ 2337, 3406, },
};
Assert.IsTrue((A ^ 3) == B);
var act3 = A ^ 3;
var act5 = A ^ 5;
var actm6 = A ^ (-6);
var act6 = A ^ 6;
Assert.AreEqual(A3, act3);
Assert.AreEqual(A5, act5);
var I = Matrix<decimal>.FactoryIdentity(2, 2);
var multiplied = actm6 * act6;
Assert.AreEqual(I, multiplied);
}

// Exceptions
Expand Down