-
Notifications
You must be signed in to change notification settings - Fork 0
/
Unification.py
141 lines (121 loc) · 4.46 KB
/
Unification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import itertools, re
from utils import *
class Expr:
def __init__(self, op, *args):
assert isinstance(op, str) or (isnumber(op) and not args)
self.op = num_or_str(op)
self.args = map(expr, args)
def __call__(self, *args):
assert is_symbol(self.op) and not self.args
return Expr(self.op, *args)
def __repr__(self):
if not self.args:
return str(self.op)
elif is_symbol(self.op):
return '%s(%s)' % (self.op, ', '.join(map(repr, self.args)))
elif len(self.args) == 1:
return self.op + repr(self.args[0])
else:
return '(%s)' % (' '+self.op+' ').join(map(repr, self.args))
def __eq__(self, other):
return (other is self) or (isinstance(other, Expr)
and self.op == other.op and self.args == other.args)
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.op) ^ hash(tuple(self.args))
def __lt__(self, other): return Expr('<', self, other)
def __le__(self, other): return Expr('<=', self, other)
def __ge__(self, other): return Expr('>=', self, other)
def __gt__(self, other): return Expr('>', self, other)
def __add__(self, other): return Expr('+', self, other)
def __sub__(self, other): return Expr('-', self, other)
def __and__(self, other): return Expr('&', self, other)
def __div__(self, other): return Expr('/', self, other)
def __truediv__(self, other):return Expr('/', self, other)
def __invert__(self): return Expr('~', self)
def __lshift__(self, other): return Expr('<<', self, other)
def __rshift__(self, other): return Expr('>>', self, other)
def __mul__(self, other): return Expr('*', self, other)
def __neg__(self): return Expr('-', self)
def __or__(self, other): return Expr('|', self, other)
def __pow__(self, other): return Expr('**', self, other)
def __xor__(self, other): return Expr('^', self, other)
def __mod__(self, other): return Expr('<=>', self, other)
def expr(s):
if isinstance(s, Expr): return s
if isnumber(s): return Expr(s)
s = s.replace('==>', '>>').replace('<==', '<<')
s = s.replace('<=>', '%').replace('=/=', '^')
s = re.sub(r'([a-zA-Z0-9_.]+)', r'Expr("\1")', s)
return eval(s, {'Expr':Expr})
def is_symbol(s):
return isinstance(s, str) and s[:1].isalpha()
def is_var_symbol(s):
return is_symbol(s) and s[0].islower()
def is_prop_symbol(s):
return is_symbol(s) and s[0].isupper() and s != 'TRUE' and s != 'FALSE'
def variables(s):
result = set([])
def walk(s):
if is_variable(s):
result.add(s)
else:
for arg in s.args:
walk(arg)
walk(s)
return result
def unify(x, y, s):
if s is None:
return None
elif x == y:
return s
elif is_variable(x):
return unify_var(x, y, s)
elif is_variable(y):
return unify_var(y, x, s)
elif isinstance(x, Expr) and isinstance(y, Expr):
return unify(x.args, y.args, unify(x.op, y.op, s))
elif isinstance(x, str) or isinstance(y, str):
return None
elif issequence(x) and issequence(y) and len(x) == len(y):
if not x: return s
return unify(x[1:], y[1:], unify(x[0], y[0], s))
else:
return None
def is_variable(x):
return isinstance(x, Expr) and not x.args and is_var_symbol(x.op)
def unify_var(var, x, s):
if var in s:
return unify(s[var], x, s)
elif occur_check(var, x, s):
return None
else:
return extend(s, var, x)
def occur_check(var, x, s):
if var == x:
return True
elif is_variable(x) and x in s:
return occur_check(var, s[x], s)
elif isinstance(x, Expr):
return (occur_check(var, x.op, s) or
occur_check(var, x.args, s))
elif isinstance(x, (list, tuple)):
return some(lambda element: occur_check(var, element, s), x)
else:
return False
def extend(s, var, val):
s2 = s.copy()
s2[var] = val
return s2
def subst(s, x):
if isinstance(x, list):
return [subst(s, xi) for xi in x]
elif isinstance(x, tuple):
return tuple([subst(s, xi) for xi in x])
elif not isinstance(x, Expr):
return x
elif is_var_symbol(x.op):
return s.get(x, x)
else:
return Expr(x.op, *[subst(s, arg) for arg in x.args])