From 56ff88eeec52baffa6f93c2150e9c3e8d2bf4065 Mon Sep 17 00:00:00 2001 From: Amit Kumar Jaiswal Date: Wed, 22 Feb 2017 10:06:33 +0530 Subject: [PATCH] Create HMC.scala Dislocation of implicits --- .../mandar2812/dynaml/probability/HMC.scala | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/probability/HMC.scala diff --git a/dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/probability/HMC.scala b/dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/probability/HMC.scala new file mode 100644 index 000000000..8c8805dd0 --- /dev/null +++ b/dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/probability/HMC.scala @@ -0,0 +1,51 @@ +package io.github.mandar2812.dynaml.probability + +import org.apache.commons.math3.linear.{Array2DRowRealMatrix, CholeskyDecomposition, LUDecomposition} +import shapeless.Witness +import spire.algebra._ +import spire.random.{Dist, Gaussian, Generator, Uniform} +import spire.std.seq._ +import spire.syntax.innerProductSpace._ +import spire.syntax.order._ + +abstract class HMC[R : Uniform : Gaussian, N, G, D <: Int with Singleton : Witness.Aux](val posterior: Tree[R, N] => (R, G), val M: Matrix[D, R], val alpha: R, val eps: R, val L: Int, val RToDouble: R => Double)(implicit val rng: Generator, implicit val f: Field[R], implicit val trig: Trig[R], implicit val n: NRoot[R], implicit val s: Signed[R], implicit val o: Order[R]) extends (Z[R, N, G] => Z[R, N, G]) { + + type ZZ = Z[R, N, G] + + val (invM, choleskyL): (Matrix[D, R], Matrix[D, R]) = { + val apacheM = new Array2DRowRealMatrix(M.size, M.size) + M.indices.foreach(Function.tupled((i, j) => apacheM.setEntry(i, j, RToDouble(M(i, j))))) + val apacheInvM = new LUDecomposition(apacheM).getSolver.getInverse + val apacheCholeskyL = new CholeskyDecomposition(apacheM).getL + (Matrix[D, R]((i: Int, j: Int) => Field[R].fromDouble(apacheInvM.getEntry(i, j))), Matrix[D, R]((i: Int, j: Int) => Field[R].fromDouble(apacheCholeskyL.getEntry(i, j)))) + } + + val uniform = Dist.uniform(Field[R].zero, Field[R].one) + val gaussian = Dist.gaussian(Field[R].zero, Field[R].one) + val sqrtalpha = NRoot[R].sqrt(alpha) + val sqrt1malpha = NRoot[R].sqrt(1 - alpha) + + def U(q: Tree[R, N]): (R, G) + + def K(p: IndexedSeq[R]): (R, G) + + def flipMomentum(z: ZZ): ZZ = { + val pp = -z.p + z.copy(p = pp)(_K = K(pp)) + } + + def corruptMomentum(z: ZZ): ZZ = { + val r = IndexedSeq.fill(z.p.size)(rng.next(gaussian)) + val pp = sqrt1malpha *: z.p + sqrtalpha *: (choleskyL * r) + z.copy(p = pp)(_K = K(pp)) + } + + def simulateDynamics(z: ZZ): ZZ + + override def apply(z: ZZ): ZZ = { + val zp = flipMomentum(simulateDynamics(z)) + val a = Trig[R].exp(z.H - zp.H) min 1 + corruptMomentum(flipMomentum(if (rng.next(uniform) < a) zp else z)) + } + +}