Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atomic modify+get to MonadState #120

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .jvmopts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,4 @@
-Xmx6G
-XX:ReservedCodeCacheSize=250M
-XX:+TieredCompilation
-XX:-UseGCOverheadLimit
# effectively adds GC to Perm space
-XX:+CMSClassUnloadingEnabled
# must be enabled for CMSClassUnloadingEnabled to work
-XX:+UseConcMarkSweepGC
-XX:+UseParallelGC
1 change: 1 addition & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
version = "2.3.2"
maxColumn = 100
32 changes: 20 additions & 12 deletions core/src/main/scala/cats/mtl/MonadState.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ package mtl
*
* MonadState has four external laws:
* {{{
*
* def getThenSetDoesNothing = {
* get >>= set <-> pure(())
* }
* def setThenGetReturnsSetted(s: S) = {
* def setThenGetReturnsSet(s: S) = {
* set(s) *> get <-> set(s) *> pure(s)
* }
* def setThenSetSetsLast(s1: S, s2: S) = {
Expand All @@ -21,21 +22,27 @@ package mtl
* }
* }}}
*
* `MonadState` has two internal law:
* `MonadState` has three internal laws:
* {{{
* def modifyIsGetThenSet(f: S => S) = {
* modify(f) <-> (inspect(f) flatMap set)
* def setIsStateUnit(s: S) = {
* set(s) <-> state(_ => (s, ()))
* }

* def inpectIsState[A](f: S => A) = {
* inspect(f) <-> state(s => (s, f(s)))
* }
*
* def inspectLaw[A](f: S => A) = {
* inspect(f) <-> (get map f)
* def modifyIsState(f: S => S) = {
* modify(f) <-> state(s => (f(s), ()))
* }
* }}}
*
*/
trait MonadState[F[_], S] extends Serializable {
val monad: Monad[F]

def state[A](f: S => (S, A)): F[A]

def get: F[S]

def set(s: S): F[Unit]
Expand All @@ -45,7 +52,6 @@ trait MonadState[F[_], S] extends Serializable {
def modify(f: S => S): F[Unit]
}


object MonadState {
def get[F[_], S](implicit ev: MonadState[F, S]): F[S] =
ev.get
Expand All @@ -63,16 +69,18 @@ object MonadState {
def modify[F[_], S](f: S => S)(implicit state: MonadState[F, S]): F[Unit] =
state.modify(f)

def state[F[_], S, A](f: S => (S, A))(implicit state: MonadState[F, S]): F[A] =
state.state(f)

def inspect[F[_], S, A](f: S => A)(implicit state: MonadState[F, S]): F[A] =
state.inspect(f)

def apply[F[_], S](implicit monadState: MonadState[F, S]): MonadState[F, S] = monadState
}


trait DefaultMonadState[F[_], S] extends MonadState[F, S] {

def inspect[A](f: S => A): F[A] = monad.map(get)(f)

def modify(f: S => S): F[Unit] = monad.flatMap(inspect(f))(set)
def get: F[S] = state(s => (s, s))
def set(s: S): F[Unit] = state(_ => (s, ()))
def inspect[A](f: S => A): F[A] = state(s => (s, f(s)))
def modify(f: S => S): F[Unit] = state(s => (f(s), ()))
}
27 changes: 20 additions & 7 deletions core/src/main/scala/cats/mtl/instances/state.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ trait StateInstances extends StateInstancesLowPriority1 {
// this dependency on LayerFunctor is required because non-`LayerFunctor`s may not be lawful
// to lift MonadState into
implicit final def stateInd[M[_], Inner[_], E](implicit ml: MonadLayerFunctor[M, Inner],
under: MonadState[Inner, E]
): MonadState[M, E] = {
under: MonadState[Inner, E]): MonadState[M, E] = {
new MonadState[M, E] {
val monad: Monad[M] = ml.outerInstance

Expand All @@ -20,7 +19,9 @@ trait StateInstances extends StateInstancesLowPriority1 {

def modify(f: E => E): M[Unit] = ml.layer(under.modify(f))

def inspect[A](f: (E) => A): M[A] = ml.layer(under.inspect(f))
def inspect[A](f: E => A): M[A] = ml.layer(under.inspect(f))

def state[A](f: E => (E, A)): M[A] = ml.layer(under.state(f))
}
}

Expand All @@ -40,14 +41,20 @@ private[instances] trait StateInstancesLowPriority1 {

def modify(f: S => S): StateT[M, S, Unit] = StateT.modify(f)

def inspect[A](f: (S) => A): StateT[M, S, A] = StateT.inspect(f)
def inspect[A](f: S => A): StateT[M, S, A] = StateT.inspect(f)

def state[A](f: S => (S, A)): StateT[M, S, A] = StateT { s =>
M.pure(f(s))
}
}
}

implicit final def readerWriterStateState[M[_], R, L, S]
(implicit M: Monad[M], L: Monoid[L]): MonadState[ReaderWriterStateT[M, R, L, S, ?], S] =
implicit final def readerWriterStateState[M[_], R, L, S](
implicit M: Monad[M],
L: Monoid[L]): MonadState[ReaderWriterStateT[M, R, L, S, ?], S] =
new MonadState[ReaderWriterStateT[M, R, L, S, ?], S] {
val monad: Monad[ReaderWriterStateT[M, R, L, S, ?]] = IndexedReaderWriterStateT.catsDataMonadForRWST
val monad: Monad[ReaderWriterStateT[M, R, L, S, ?]] =
IndexedReaderWriterStateT.catsDataMonadForRWST

def get: ReaderWriterStateT[M, R, L, S, S] =
ReaderWriterStateT.get[M, R, L, S]
Expand All @@ -60,6 +67,12 @@ private[instances] trait StateInstancesLowPriority1 {

def inspect[A](f: S => A): ReaderWriterStateT[M, R, L, S, A] =
ReaderWriterStateT.inspect[M, R, L, S, A](f)

def state[A](f: S => (S, A)): ReaderWriterStateT[M, R, L, S, A] = ReaderWriterStateT.apply {
(e, s) =>
val (s2, a) = f(s)
M.pure((L.empty, s2, a))
}
}
}

Expand Down
17 changes: 12 additions & 5 deletions laws/src/main/scala/cats/mtl/laws/MonadStateLaws.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package laws

import cats.laws.IsEq
import cats.laws.IsEqArrow
import cats.syntax.functor._
import cats.syntax.flatMap._
import cats.syntax.apply._

Expand All @@ -20,7 +19,7 @@ trait MonadStateLaws[F[_], S] {
(get >>= set) <-> pure(())
}

def setThenGetReturnsSetted(s: S): IsEq[F[S]] = {
def setThenGetReturnsSet(s: S): IsEq[F[S]] = {
(set(s) *> get) <-> (set(s) *> pure(s))
}

Expand All @@ -32,9 +31,17 @@ trait MonadStateLaws[F[_], S] {
get *> get <-> get
}

// internal law:
def modifyIsGetThenSet(f: S => S): IsEq[F[Unit]] = {
modify(f) <-> ((get map f) flatMap set)
// internal laws:
def setIsStateUnit(s: S): IsEq[F[Unit]] = {
set(s) <-> state(_ => (s, ()))
}

def inpectIsState[A](f: S => A): IsEq[F[A]] = {
inspect(f) <-> state(s => (s, f(s)))
}

def modifyIsState(f: S => S): IsEq[F[Unit]] = {
modify(f) <-> state(s => (f(s), ()))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@ trait MonadStateTests[F[_], S] extends Laws {
ArbS: Arbitrary[S],
CogenS: Cogen[S],
EqFU: Eq[F[Unit]],
EqFS: Eq[F[S]]
EqFS: Eq[F[S]],
EqFA: Eq[F[A]]
): RuleSet = {
new DefaultRuleSet(
name = "monadState",
parent = None,
"get then set has does nothing" -> laws.getThenSetDoesNothing,
"set then get returns the setted value" -> ∀(laws.setThenGetReturnsSetted _),
"set then get returns the set value" -> ∀(laws.setThenGetReturnsSet _),
"set then set sets the last value" -> ∀(laws.setThenSetSetsLast _),
"get then get gets once" -> laws.getThenGetGetsOnce,
"modify is get then set" -> ∀(laws.modifyIsGetThenSet _)
"set is state(unit)" -> ∀(laws.setIsStateUnit _),
"inspect is state" -> ∀(laws.inpectIsState[A] _),
"modify is state" -> ∀(laws.modifyIsState _)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,8 @@ class ApplicativeCensorDefaultTests extends StateTTestsBase {

class MonadStateDefaultTests extends StateTTestsBase {
val defaultListMonadState: MonadState[StateC[String]#l, String] = new DefaultMonadState[StateC[String]#l, String] {

val monad: Monad[StateC[String]#l] = implicitly

def get: StateC[String]#l[String] = State.get

def set(s: String): State[String, Unit] = State.set(s)
def state[A](f: String => (String, A)): State[String, A] = State(f)
}

checkAll("State[String, ?]",
Expand Down