Skip to content

Commit

Permalink
Merge pull request #130 from b-mehta/dirichlet-characters
Browse files Browse the repository at this point in the history
Discrete log
  • Loading branch information
Bodigrim authored Sep 12, 2018
2 parents 2536f72 + 1393fdf commit dd59557
Show file tree
Hide file tree
Showing 13 changed files with 333 additions and 27 deletions.
9 changes: 4 additions & 5 deletions Math/NumberTheory/EisensteinIntegers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ import Data.Ord (comparing)
import GHC.Generics (Generic)

import qualified Math.NumberTheory.Euclidean as ED
import qualified Math.NumberTheory.Moduli as Moduli
import Math.NumberTheory.Moduli.Sqrt (FieldCharacteristic(..))
import Math.NumberTheory.Moduli.Sqrt (FieldCharacteristic(..), sqrtModMaybe)
import qualified Math.NumberTheory.Primes.Factorisation as Factorisation
import Math.NumberTheory.Primes.Types (PrimeNat(..))
import qualified Math.NumberTheory.Primes.Sieve as Sieve
Expand Down Expand Up @@ -176,12 +175,12 @@ divideByThree = go 0
-- @(z+3k)^2 ≡ 9k^2-1 (mod 6k+1)@
-- @z+3k = sqrtMod(9k^2-1)@
-- @z = sqrtMod(9k^2-1) - 3k@
--
--
-- * For example, let @p = 7@, then @k = 1@. Square root of @9*1^2-1 modulo 7@ is @1@.
-- * And @z = 1 - 3*1 = -2 ≡ 5 (mod 7)@.
-- * Truly, @norm (5 :+ 1) = 25 - 5 + 1 = 21 ≡ 0 (mod 7)@.
findPrime :: Integer -> EisensteinInteger
findPrime p = case Moduli.sqrtModMaybe (9*k*k - 1) (FieldCharacteristic (PrimeNat . integerToNatural $ p) 1) of
findPrime p = case sqrtModMaybe (9*k*k - 1) (FieldCharacteristic (PrimeNat . integerToNatural $ p) 1) of
Nothing -> error "findPrime: argument must be prime p = 6k + 1"
Just sqrtMod -> ED.gcd (p :+ 0) ((sqrtMod - 3 * k) :+ 1)
where
Expand Down Expand Up @@ -313,4 +312,4 @@ quotEvenI (x :+ y) n
| otherwise = Nothing
where
(xq, xr) = x `quotRem` n
(yq, yr) = y `quotRem` n
(yq, yr) = y `quotRem` n
6 changes: 3 additions & 3 deletions Math/NumberTheory/GaussianIntegers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ import Data.Maybe (fromMaybe)
import Data.Ord (comparing)
import GHC.Generics


import qualified Math.NumberTheory.Euclidean as ED
import qualified Math.NumberTheory.Moduli as Moduli
import Math.NumberTheory.Moduli.Sqrt (FieldCharacteristic(..))
import Math.NumberTheory.Moduli.Sqrt (FieldCharacteristic(..), sqrtModMaybe)
import Math.NumberTheory.Powers (integerSquareRoot)
import Math.NumberTheory.Primes.Types (PrimeNat(..))
import qualified Math.NumberTheory.Primes.Factorisation as Factorisation
Expand Down Expand Up @@ -140,7 +140,7 @@ gcdG' = ED.gcd
-- of form 4k + 1 using
-- <http://www.ams.org/journals/mcom/1972-26-120/S0025-5718-1972-0314745-6/S0025-5718-1972-0314745-6.pdf Hermite-Serret algorithm>.
findPrime :: Integer -> GaussianInteger
findPrime p = case Moduli.sqrtModMaybe (-1) (FieldCharacteristic (PrimeNat . integerToNatural $ p) 1) of
findPrime p = case sqrtModMaybe (-1) (FieldCharacteristic (PrimeNat . integerToNatural $ p) 1) of
Nothing -> error "findPrime: an argument must be prime p = 4k + 1"
Just z -> go p z -- Effectively we calculate gcdG' (p :+ 0) (z :+ 1)
where
Expand Down
4 changes: 4 additions & 0 deletions Math/NumberTheory/Moduli.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
module Math.NumberTheory.Moduli
( module Math.NumberTheory.Moduli.Class
, module Math.NumberTheory.Moduli.Chinese
, module Math.NumberTheory.Moduli.DiscreteLogarithm
, module Math.NumberTheory.Moduli.Jacobi
, module Math.NumberTheory.Moduli.PrimitiveRoot
, module Math.NumberTheory.Moduli.Sqrt
) where

import Math.NumberTheory.Moduli.Chinese
import Math.NumberTheory.Moduli.Class
import Math.NumberTheory.Moduli.DiscreteLogarithm
import Math.NumberTheory.Moduli.Jacobi
import Math.NumberTheory.Moduli.PrimitiveRoot
import Math.NumberTheory.Moduli.Sqrt
40 changes: 39 additions & 1 deletion Math/NumberTheory/Moduli/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
-- Safe modular arithmetic with modulo on type level.
--

{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
Expand All @@ -29,6 +28,11 @@ module Math.NumberTheory.Moduli.Class
, invertMod
, powMod
, (^%)
-- * Multiplicative group
, MultMod
, multElement
, isMultElement
, invertGroup
-- * Unknown modulo
, SomeMod(..)
, modulo
Expand All @@ -40,6 +44,7 @@ module Math.NumberTheory.Moduli.Class

import Data.Proxy
import Data.Ratio
import Data.Semigroup
import Data.Type.Equality
import GHC.Integer.GMP.Internals
import GHC.TypeNats.Compat
Expand Down Expand Up @@ -181,6 +186,39 @@ infixr 8 ^%
-- of type classes in Core.
-- {-# RULES "^%Mod" forall (x :: KnownNat m => Mod m) p. x ^ p = x ^% p #-}

-- | This type represents elements of the multiplicative group mod m, i.e.
-- those elements which are coprime to m. Use @toMultElement@ to construct.
newtype MultMod m = MultMod { multElement :: Mod m }
deriving (Eq, Ord, Show)

instance KnownNat m => Semigroup (MultMod m) where
MultMod a <> MultMod b = MultMod (a * b)
stimes k a@(MultMod a')
| k >= 0 = MultMod (powMod a' k)
| otherwise = invertGroup $ stimes (-k) a
-- ^ This Semigroup is in fact a group, so @stimes@ can be called with a negative first argument.

instance KnownNat m => Monoid (MultMod m) where
mempty = MultMod 1
mappend = (<>)

instance KnownNat m => Bounded (MultMod m) where
minBound = MultMod 1
maxBound = MultMod (-1)

-- | Attempt to construct a multiplicative group element.
isMultElement :: KnownNat m => Mod m -> Maybe (MultMod m)
isMultElement a = if getNatVal a `gcd` getNatMod a == 1
then Just $ MultMod a
else Nothing

-- | For elements of the multiplicative group, we can safely perform the inverse
-- without needing to worry about failure.
invertGroup :: KnownNat m => MultMod m -> MultMod m
invertGroup (MultMod a) = case invertMod a of
Just b -> MultMod b
Nothing -> error "Math.NumberTheory.Moduli.invertGroup: failed to invert element"

-- | This type represents residues with unknown modulo and rational numbers.
-- One can freely combine them in arithmetic expressions, but each operation
-- will spend time on modulo's recalculation:
Expand Down
119 changes: 119 additions & 0 deletions Math/NumberTheory/Moduli/DiscreteLogarithm.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
-- |
-- Module: Math.NumberTheory.Moduli.DiscreteLogarithm
-- Copyright: (c) 2018 Bhavik Mehta
-- License: MIT
-- Maintainer: Andrew Lelechenko <[email protected]>
-- Stability: Provisional
-- Portability: Non-portable
--

{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE BangPatterns #-}

module Math.NumberTheory.Moduli.DiscreteLogarithm
( discreteLogarithm
) where

import qualified Data.IntMap.Strict as M
import Data.Maybe (maybeToList)
import Numeric.Natural (Natural)
import GHC.Integer.GMP.Internals (recipModInteger, powModInteger)

import Math.NumberTheory.Moduli.Chinese (chineseRemainder2)
import Math.NumberTheory.Moduli.Class (KnownNat, MultMod(..), Mod, getVal)
import Math.NumberTheory.Moduli.Equations (solveLinear')
import Math.NumberTheory.Moduli.PrimitiveRoot (PrimitiveRoot(..), CyclicGroup(..))
import Math.NumberTheory.Powers.Squares (integerSquareRoot)
import Math.NumberTheory.UniqueFactorisation (unPrime)

-- | Computes the discrete logarithm. Currently uses a combination of the baby-step
-- giant-step method and Pollard's rho algorithm, with Bach reduction.
discreteLogarithm :: KnownNat m => PrimitiveRoot m -> MultMod m -> Natural
discreteLogarithm a b = discreteLogarithm' (getGroup a) (multElement $ unPrimitiveRoot a) (multElement b)

discreteLogarithm'
:: KnownNat m
=> CyclicGroup Natural -- ^ group structure (must be the multiplicative group mod m)
-> Mod m -- ^ a
-> Mod m -- ^ b
-> Natural -- ^ result
discreteLogarithm' cg a b =
case cg of
CG2 -> 0
-- the only valid input was a=1, b=1
CG4 -> if b == 1 then 0 else 1
-- the only possible input here is a=3 with b = 1 or 3
CGOddPrimePower (unPrime -> p) k -> discreteLogarithmPP p k (getVal a) (getVal b)
CGDoubleOddPrimePower (unPrime -> p) k -> discreteLogarithmPP p k (getVal a `rem` p^k) (getVal b `rem` p^k)
-- we have the isomorphism t -> t `rem` p^k from (Z/2p^kZ)* -> (Z/p^kZ)*

-- Implementation of Bach reduction (https://www2.eecs.berkeley.edu/Pubs/TechRpts/1984/CSD-84-186.pdf)
{-# INLINE discreteLogarithmPP #-}
discreteLogarithmPP :: Integer -> Word -> Integer -> Integer -> Natural
discreteLogarithmPP p 1 a b = discreteLogarithmPrime p a b
discreteLogarithmPP p k a b = fromInteger result
where
baseSol = toInteger $ discreteLogarithmPrime p (a `rem` p) (b `rem` p)
thetaA = theta p pkMinusOne a
thetaB = theta p pkMinusOne b
pkMinusOne = p^(k-1)
c = (recipModInteger thetaA pkMinusOne * thetaB) `rem` pkMinusOne
result = chineseRemainder2 (baseSol, p-1) (c, pkMinusOne)

-- compute the homomorphism theta given in https://math.stackexchange.com/a/1864495/418148
{-# INLINE theta #-}
theta :: Integer -> Integer -> Integer -> Integer
theta p pkMinusOne a = (numerator `quot` pk) `rem` pkMinusOne
where
pk = pkMinusOne * p
p2kMinusOne = pkMinusOne * pk
numerator = (powModInteger a (pk - pkMinusOne) p2kMinusOne - 1) `rem` p2kMinusOne

-- TODO: Use Pollig-Hellman to reduce the problem further into groups of prime order.
-- While Bach reduction simplifies the problem into groups of the form (Z/pZ)*, these
-- have non-prime order, and the Pollig-Hellman algorithm can reduce the problem into
-- smaller groups of prime order.
-- In addition, the gcd check before solveLinear is applied in Pollard below will be
-- made redundant, since n would be prime.
discreteLogarithmPrime :: Integer -> Integer -> Integer -> Natural
discreteLogarithmPrime p a b
| p < 100000000 = fromIntegral $ discreteLogarithmPrimeBSGS (fromInteger p) (fromInteger a) (fromInteger b)
| otherwise = discreteLogarithmPrimePollard p a b

discreteLogarithmPrimeBSGS :: Int -> Int -> Int -> Int
discreteLogarithmPrimeBSGS p a b = head [i*m + j | (v,i) <- zip giants [0..m-1], j <- maybeToList (M.lookup v table)]
where
m = integerSquareRoot (p - 2) + 1 -- simple way of ceiling (sqrt (p-1))
babies = iterate (.* a) 1
table = M.fromList (zip babies [0..m-1])
aInv = recipModInteger (toInteger a) (toInteger p)
bigGiant = fromInteger $ powModInteger aInv (toInteger m) (toInteger p)
giants = iterate (.* bigGiant) b
x .* y = x * y `rem` p

-- TODO: Use more advanced walks, in order to reduce divisions, cf
-- https://maths-people.anu.edu.au/~brent/pd/rpb231.pdf
-- This will slightly improve the expected time to collision, and can reduce the
-- number of divisions performed.
discreteLogarithmPrimePollard :: Integer -> Integer -> Integer -> Natural
discreteLogarithmPrimePollard p a b =
case concatMap runPollard [(x,y) | x <- [0..n], y <- [0..n]] of
(t:_) -> fromInteger t
[] -> error ("discreteLogarithm: pollard's rho failed, please report this as a bug. inputs " ++ show [p,a,b])
where
n = p-1 -- order of the cyclic group
halfN = n `quot` 2
mul2 m = if m < halfN then m * 2 else m * 2 - n
sqrtN = integerSquareRoot n
step (xi,!ai,!bi) = case xi `rem` 3 of
0 -> (xi*xi `rem` p, mul2 ai, mul2 bi)
1 -> ( a*xi `rem` p, ai+1, bi)
_ -> ( b*xi `rem` p, ai, bi+1)
initialise (x,y) = (powModInteger a x n * powModInteger b y n `rem` n, x, y)
begin t = go (step t) (step (step t))
check t = powModInteger a t p == b
go tort@(xi,ai,bi) hare@(x2i,a2i,b2i)
| xi == x2i, gcd (bi - b2i) n < sqrtN = solveLinear' n (bi - b2i) (ai - a2i)
| xi == x2i = []
| otherwise = go (step tort) (step (step hare))
runPollard = filter check . begin . initialise
1 change: 1 addition & 0 deletions Math/NumberTheory/Moduli/Equations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module Math.NumberTheory.Moduli.Equations
( solveLinear
, solveLinear'
, solveQuadratic
) where

Expand Down
44 changes: 32 additions & 12 deletions Math/NumberTheory/Moduli/PrimitiveRoot.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,36 @@
{-# LANGUAGE UndecidableInstances #-}

module Math.NumberTheory.Moduli.PrimitiveRoot
( isPrimitiveRoot
-- * Cyclic groups
, CyclicGroup(..)
( -- * Cyclic groups
CyclicGroup(..)
, cyclicGroupFromModulo
, cyclicGroupToModulo
, groupSize
-- * Primitive roots
, PrimitiveRoot
, unPrimitiveRoot
, getGroup
, isPrimitiveRoot
, isPrimitiveRoot'
) where

import Control.DeepSeq
#if __GLASGOW_HASKELL__ < 803
import Data.Semigroup
#endif

import Math.NumberTheory.ArithmeticFunctions (totient)
import Math.NumberTheory.GCD as Coprimes
import Math.NumberTheory.Moduli (Mod, getNatMod, getNatVal, KnownNat)
import Math.NumberTheory.Moduli.Class (getNatMod, getNatVal, KnownNat, Mod, MultMod, isMultElement)
import Math.NumberTheory.Powers.General (highestPower)
import Math.NumberTheory.Powers.Modular
import Math.NumberTheory.Prefactored
import Math.NumberTheory.UniqueFactorisation
import Math.NumberTheory.Utils.FromIntegral

import Control.DeepSeq
import Control.Monad (guard)
import GHC.Generics
import Numeric.Natural

-- | A multiplicative group of residues is called cyclic,
-- if there is a primitive root @g@,
Expand All @@ -58,8 +65,8 @@ data CyclicGroup a

instance NFData (Prime a) => NFData (CyclicGroup a)

deriving instance Eq (Prime a) => Eq (CyclicGroup a)
deriving instance Show (Prime a) => Show (CyclicGroup a)
deriving instance Eq (Prime a) => Eq (CyclicGroup a)
deriving instance Show (Prime a) => Show (CyclicGroup a)

-- | Check whether a multiplicative group of residues,
-- characterized by its modulo, is cyclic and, if yes, return its form.
Expand Down Expand Up @@ -112,6 +119,13 @@ cyclicGroupToModulo = fromFactors . \case
CGOddPrimePower p k -> Coprimes.singleton (unPrime p) k
CGDoubleOddPrimePower p k -> Coprimes.singleton 2 1 <> Coprimes.singleton (unPrime p) k

-- | 'PrimitiveRoot m' is a type which is only inhabited by primitive roots of n.
data PrimitiveRoot m = PrimitiveRoot
{ unPrimitiveRoot :: MultMod m -- ^ Extract primitive root value.
, getGroup :: CyclicGroup Natural -- ^ Get cyclic group structure.
}
deriving (Eq, Show)

-- | 'isPrimitiveRoot'' @cg@ @a@ checks whether @a@ is
-- a <https://en.wikipedia.org/wiki/Primitive_root_modulo_n primitive root>
-- of a given cyclic multiplicative group of residues @cg@.
Expand Down Expand Up @@ -153,15 +167,21 @@ isPrimitiveRoot' cg r =
--
-- Here is how to list all primitive roots:
--
-- >>> filter isPrimitiveRoot [minBound .. maxBound] :: [Mod 13]
-- >>> mapMaybe isPrimitiveRoot [minBound .. maxBound] :: [Mod 13]
-- [(2 `modulo` 13), (6 `modulo` 13), (7 `modulo` 13), (11 `modulo` 13)]
--
-- This function is a convenient wrapper around 'isPrimitiveRoot''. The latter
-- provides better control and performance, if you need them.
isPrimitiveRoot
:: KnownNat n
=> Mod n
-> Bool
isPrimitiveRoot r = case cyclicGroupFromModulo (getNatMod r) of
Nothing -> False
Just cg -> isPrimitiveRoot' cg (getNatVal r)
-> Maybe (PrimitiveRoot n)
isPrimitiveRoot r = do
r' <- isMultElement r
cg <- cyclicGroupFromModulo (getNatMod r)
guard $ isPrimitiveRoot' cg (getNatVal r)
return $ PrimitiveRoot r' cg

-- | Calculate the size of a given cyclic group.
groupSize :: (Integral a, UniqueFactorisation a) => CyclicGroup a -> Prefactored a
groupSize = totient . cyclicGroupToModulo
3 changes: 3 additions & 0 deletions arithmoi.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ library
Math.NumberTheory.Moduli
Math.NumberTheory.Moduli.Chinese
Math.NumberTheory.Moduli.Class
Math.NumberTheory.Moduli.DiscreteLogarithm
Math.NumberTheory.Moduli.Equations
Math.NumberTheory.Moduli.Jacobi
Math.NumberTheory.Moduli.PrimitiveRoot
Expand Down Expand Up @@ -145,6 +146,7 @@ test-suite spec
Math.NumberTheory.GaussianIntegersTests
Math.NumberTheory.GCDTests
Math.NumberTheory.Moduli.ChineseTests
Math.NumberTheory.Moduli.DiscreteLogarithmTests
Math.NumberTheory.Moduli.ClassTests
Math.NumberTheory.Moduli.EquationsTests
Math.NumberTheory.Moduli.JacobiTests
Expand Down Expand Up @@ -194,6 +196,7 @@ benchmark criterion
semigroups >=0.8
other-modules:
Math.NumberTheory.ArithmeticFunctionsBench
Math.NumberTheory.DiscreteLogarithmBench
Math.NumberTheory.EisensteinIntegersBench
Math.NumberTheory.GaussianIntegersBench
Math.NumberTheory.GCDBench
Expand Down
2 changes: 2 additions & 0 deletions benchmark/Bench.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Main where
import Gauge.Main

import Math.NumberTheory.ArithmeticFunctionsBench as ArithmeticFunctions
import Math.NumberTheory.DiscreteLogarithmBench as DiscreteLogarithm
import Math.NumberTheory.EisensteinIntegersBench as Eisenstein
import Math.NumberTheory.GaussianIntegersBench as Gaussian
import Math.NumberTheory.GCDBench as GCD
Expand All @@ -18,6 +19,7 @@ import Math.NumberTheory.SmoothNumbersBench as SmoothNumbers
main :: IO ()
main = defaultMain
[ ArithmeticFunctions.benchSuite
, DiscreteLogarithm.benchSuite
, Eisenstein.benchSuite
, Gaussian.benchSuite
, GCD.benchSuite
Expand Down
Loading

0 comments on commit dd59557

Please sign in to comment.