-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathobdd_compiler.py
184 lines (156 loc) · 6.24 KB
/
obdd_compiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import copy
import itertools
class BDD:
def __init__(self, var, low, high):
self.var = var
self.low = low
self.high = high
self.explore_id = 0
def is_sink(self):
return self.low is None and self.high is None
def _print_info(self, current_id, rank, output_file=None):
if self.explore_id > 0:
return current_id, rank
if self.is_sink():
if output_file is not None:
out = open(output_file, 'a')
if self.var == True:
out.write(' {} [label="True", color=green, shape=square];\n'.format(current_id+1))
elif self.var == False:
out.write(' {} [label="False", color=red, shape=square];\n'.format(current_id+1))
out.close()
else:
print('{}-SINK : {}'.format(current_id+1,self.var))
else:
left_current_id, rank = self.low._print_info(current_id, rank, output_file)
current_id, rank = self.high._print_info(left_current_id, rank, output_file)
if output_file is not None:
out = open(output_file, 'a')
out.write(' {} [label="{}"];\n'.format(current_id+1, self.var))
out.write(' {} -> {} [style=dotted];\n'.format(current_id+1, self.low.explore_id))
out.write(' {} -> {};\n'.format(current_id+1, self.high.explore_id))
out.close()
else:
print('{}-Var: {}'.format(current_id+1,self.var))
self.explore_id = current_id+1
if self.is_sink():
rank[0].append(self.explore_id)
else:
rank[self.var].append(self.explore_id)
return current_id+1, rank
def print_info(self, nvars, output_file=None):
rank = [[] for i in range(nvars+1)]
_, rank = copy.deepcopy(self)._print_info(0, rank, output_file)
# for i in range(len(rank)):
# print(i, ': ', rank[i])
return rank
class BDD_Compiler:
def __init__(self, n_vars, clausal_form):
self.clausal_form = clausal_form
self.n_vars = n_vars
self.unique = {}
self.cache = {}
for i in range(n_vars+1):
self.cache[i] = {}
self.cutset_cache = self._generate_cutset_cache()
self.separator_cache = self._generate_separator_cache()
self.F_SINK = BDD(False, None, None)
self.T_SINK = BDD(True, None, None)
def bcp(self, formula, literal):
modified = []
for clause in formula:
if literal in clause:
modified.append([])
elif -literal in clause:
c = [x for x in clause if x != -literal]
if len(c) == 0:
return -1
modified.append(c)
else:
modified.append(clause)
return modified
'''
Functions used for computing cutset key and cache
'''
def _compute_cutset(self, clausal_form, var):
cutset = []
for i, clause in enumerate(clausal_form):
if len(clause) == 0:
continue
atoms = [abs(l) for l in clause]
if min(atoms) <= var and var < max(atoms):
cutset.append(i)
return cutset
def _generate_cutset_cache(self):
cutset_cache = []
print('CUTSET CACHE:')
for i in range(self.n_vars):
cutset_i = self._compute_cutset(self.clausal_form, i+1)
cutset_cache.append(cutset_i)
print('-cutset {} : {}'.format(i+1, cutset_i))
return cutset_cache
def compute_cutset_key(self, clausal_form, var):
cutset_var = self.cutset_cache[var-1]
cutset_key = 0
for i, c in enumerate(cutset_var):
if len(clausal_form[c]) == 0:
cutset_key += 2**i
return cutset_key
'''
Functions used for compute separator key and cache
'''
def _compute_separator(self, clausal_form, var):
sep = []
for ci in self.cutset_cache[var-1]:
sep += self.clausal_form[ci]
sep = [abs(l) for l in sep if abs(l) <= var]
sep = list(set(sep))
return sep
def _generate_separator_cache(self):
sep_cache = []
print('SEPARATOR CACHE:')
for i in range(self.n_vars):
sep_i = self._compute_separator(self.clausal_form, i+1)
sep_cache.append(sep_i)
print('-sep {} : {}'.format(i+1, sep_i))
return sep_cache
def compute_separator_key(self, clausal_form, var):
sep_var = self.separator_cache[var-1]
sep_key = 0
for v in sep_var:
sep_key += 2**v
return sep_key
'''
Core functions
'''
def get_nodes(self, var, low, high):
if low == high:
return low
if (var, low, high) in self.unique: # and low == self.unique[i].low and high == self.unique[i].high:
# print('Unique node {} found!'.format(var))
return self.unique[(var, low, high)]
result = BDD(var, low, high)
self.unique[(var, low, high)] = result
return result
def cnf2obdd(self, clausal_form, i, key_type='cutset'):
assert key_type == 'cutset' or key_type == 'separator'
if clausal_form == -1:
return self.F_SINK
elif len(list(itertools.chain(*clausal_form))) == 0:
return self.T_SINK
assert i <= self.n_vars+1
if key_type == 'cutset':
key = self.compute_cutset_key(clausal_form, i-1)
elif key_type == 'separator':
key = self.compute_separator_key(clausal_form, i-1)
if key in self.cache[i-1]:
print('This node is already in cache {} with key {}'.format(i-1,key))
return self.cache[i-1][key]
low = self.cnf2obdd(self.bcp(clausal_form, -i), i+1)
high = self.cnf2obdd(self.bcp(clausal_form, i), i+1)
result = self.get_nodes(i, low, high)
self.cache[i-1][key] = result
# print('This node is stored in cache {} with key {}'.format(i-1, key))
return result
def compile(self, key_type='cutset'):
return self.cnf2obdd(self.clausal_form, 1, key_type)