-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEquivalenceChecker.scala
184 lines (170 loc) · 8.39 KB
/
EquivalenceChecker.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import scala.collection.mutable
import scala.math.Numeric.IntIsIntegral
/**
* An EquivalenceChecker is an object that allows to detect equivalence between formulas in the
* theory of Orthocomplemented Bisemilattices.
* This allows proof checkers and writers to avoid having to deal with a class of "easy" equivalence.
* For example, by considering "x ∨ y" as being the same formula as "y ∨ x", we can avoid frustrating errors.
* This relation is always a subrelation of the usual FOL implication.
*/
object EquivalenceChecker {
sealed abstract class SimpleFormula {
val size: Int
private[EquivalenceChecker] var normalForm: Option[NormalFormula] = None
}
case class SConstant(id: String) extends SimpleFormula {
val size = 1
}
case class SNeg(child: SimpleFormula) extends SimpleFormula {
val size: Int = 1 + child.size
}
case class SOr(children: List[SimpleFormula]) extends SimpleFormula {
val size: Int = (children map (_.size)).foldLeft(1) { case (a, b) => a + b }
}
case class SLiteral(b: Boolean) extends SimpleFormula {
val size = 1
normalForm = Some(NLiteral(b))
}
sealed abstract class NormalFormula {
val code:Int
}
case class NConstant(id: String, code:Int) extends NormalFormula
case class NNeg(child: NormalFormula, code:Int) extends NormalFormula
case class NOr(children: List[NormalFormula], code:Int) extends NormalFormula
case class NLiteral(b: Boolean) extends NormalFormula{
val code:Int = if (b) 1 else 0
}
class LocalEquivalenceChecker {
private val codesSig: mutable.HashMap[(String, Seq[Int]), Int] = mutable.HashMap()
codesSig.update(("zero", Nil), 0)
codesSig.update(("one", Nil), 1)
def hasNormaleRecComputed(sf:SimpleFormula): Boolean = sf.normalForm.nonEmpty && (sf match {
case SNeg(child) => hasNormaleRecComputed(child)
case SOr(children) => children.forall(c => hasNormaleRecComputed(c))
case _ => true
})
def checkForContradiction(children:List[(NormalFormula, Int)]): Boolean = {
val (negatives_temp, positives) = children.foldLeft[(List[NormalFormula], List[NormalFormula])]((Nil, Nil))(
(acc, ch) => acc match {
case (negatives, positives) => ch._1 match {
case NNeg(child, c) =>(child::negatives, positives)
case _ => (negatives, ch._1::positives)
}
}
)
val negatives = negatives_temp.sortBy(_.code)
var i, j = 0
while (i<positives.size && j<negatives.size){ //checks if there is a positive and negative nodes with same code.
val (c1, c2) = (positives(i).code, negatives(j).code)
if (c1<c2) i+=1
else if (c1 == c2) return true
else j+=1
}
val children_codes = children.map(c => c._2).toSet
var k = 0
while(k<negatives.size){
negatives(k) match {
case NOr(gdChildren, c) =>
if (gdChildren.forall(sf => children_codes.contains(sf.code))) return true
case _ => ()
}
k+=1
}
false
}
def updateCodesSig(sig: (String, Seq[Int])): Int = {
if (!codesSig.contains(sig)) codesSig.update(sig, codesSig.size)
codesSig(sig)
}
def OCBSLCode(phi: SimpleFormula): Int = {
if (phi.normalForm.nonEmpty) return phi.normalForm.get.code
val L = pDisj(phi, Nil)
val L2 = L zip (L map (_.code))
val L3 = L2.sortBy(_._2).distinctBy(_._2).filterNot(_._2 == 0) //not most efficient on sorted list but no big deal for now
if (L3.isEmpty) {
phi.normalForm = Some(NLiteral(false))
} else if (L3.length == 1) {
phi.normalForm = Some(L3.head._1)
} else if (L3.exists(_._2 == 1) || checkForContradiction(L3) ) {
phi.normalForm = Some(NLiteral(true))
} else {
phi.normalForm = Some(NOr(L3.map(_._1), updateCodesSig(("or", L3.map(_._2)))))
}
phi.normalForm.get.code
}
def pDisj(phi: SimpleFormula, acc: List[NormalFormula]): List[NormalFormula] = {
if (phi.normalForm.nonEmpty) return pDisjNormal(phi.normalForm.get, acc)
val r: List[NormalFormula] = phi match {
case SConstant(id) =>
val lab = "pred_" + id
phi.normalForm = Some(NConstant(id, updateCodesSig((lab, Nil))))
phi.normalForm.get :: acc
case SNeg(child) => pNeg(child, phi, acc)
case SOr(children) => children.foldLeft(acc)((p, a) => pDisj(a, p))
case SLiteral(true) =>
phi.normalForm = Some(NLiteral(true))
phi.normalForm.get :: acc
case SLiteral(false) =>
phi.normalForm = Some(NLiteral(false))
phi.normalForm.get :: acc
}
r
}
def pNeg(phi: SimpleFormula, parent: SimpleFormula, acc: List[NormalFormula]): List[NormalFormula] = {
if (phi.normalForm.nonEmpty) return pNegNormal(phi.normalForm.get, parent, acc)
val r:List[NormalFormula] = phi match {
case SConstant(id) =>
val lab = "pred_" + id
phi.normalForm = Some(NConstant(id, updateCodesSig((lab, Nil))))
parent.normalForm = Some(NNeg(phi.normalForm.get, updateCodesSig(("neg", List(phi.normalForm.get.code)))))
parent.normalForm.get :: acc
case SNeg(child) => pDisj(child, acc)
case SLiteral(true) =>
parent.normalForm = Some(NLiteral(false))
parent.normalForm.get :: acc
case SLiteral(false) =>
parent.normalForm = Some(NLiteral(true))
parent.normalForm.get :: acc
case SOr(children) =>
val T = children.sortBy(_.size)
val r1 = T.tail.foldLeft(List[NormalFormula]())((p, a) => pDisj(a, p))
val r2 = r1 zip (r1 map (_.code))
val r3 = r2.sortBy(_._2).distinctBy(_._2).filterNot(_._2 == 0)
if (r3.isEmpty) pNeg(T.head, parent, acc)
else {
val s1 = pDisj(T.head, r1)
val s2 = s1 zip (s1 map (_.code))
val s3 = s2.sortBy(_._2).distinctBy(_._2).filterNot(_._2 == 0)
if (s3.exists(_._2 == 1) || checkForContradiction(s3) ) {
phi.normalForm=Some(NLiteral(true))
parent.normalForm = Some(NLiteral(false))
parent.normalForm.get :: acc
} else if (s3.length == 1) {
pNegNormal(s3.head._1, parent, acc)
} else {
phi.normalForm = Some(NOr(s3.map(_._1), updateCodesSig(("or", s3.map(_._2)))))
parent.normalForm = Some(NNeg(phi.normalForm.get, updateCodesSig(("neg", List(phi.normalForm.get.code)))))
parent.normalForm.get :: acc
}
}
}
r
}
def pDisjNormal(f:NormalFormula, acc:List[NormalFormula]):List[NormalFormula] = f match {
case NOr(children, c) => children ++ acc
case _ => f :: acc
}
def pNegNormal(f:NormalFormula, parent: SimpleFormula, acc:List[NormalFormula]): List[NormalFormula] = f match {
case NNeg(child, c) =>
pDisjNormal(child, acc)
case _ =>
parent.normalForm = Some(NNeg(f, updateCodesSig(("neg", List(f.code)))))
parent.normalForm.get :: acc
}
def checkEquivalence(formula1: SimpleFormula, formula2: SimpleFormula): Boolean = {
getCode(formula1) == getCode(formula2)
}
def getCode(formula:SimpleFormula): Int = OCBSLCode(formula)
}
def isSame(formula1: SimpleFormula, formula2: SimpleFormula): Boolean = (new LocalEquivalenceChecker).checkEquivalence(formula1, formula2)
}