1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- from z3 import *
- def all_smt(s, initial_terms):
- def block_term(s, m, t):
- s.add(t != m.eval(t, model_completion=True))
- def fix_term(s, m, t):
- s.add(t == m.eval(t, model_completion=True))
- def all_smt_rec(terms):
- if sat == s.check():
- m = s.model()
- yield m
- for i in range(len(terms)):
- s.push()
- block_term(s, m, terms[i])
- for j in range(i):
- fix_term(s, m, terms[j])
- yield from all_smt_rec(terms[i:])
- s.pop()
- yield from all_smt_rec(list(initial_terms))
- def main():
- s = Solver()
- for n in range(2, 5):
- s.push()
- left = [Int(str(i)) for i in range(n)]
- right = [Int(str(i + n)) for i in range(n)]
- # constrain size of number to have the right number of digits
- sum_l = 0
- sum_r = 0
- for i in range(n):
- s.add(left[i] >= 0, left[i] < 10)
- s.add(right[i] >= 0, right[i] < 10)
- sum_l += left[i]
- sum_r += right[i]
- s.add(sum_l == sum_r)
- models = list(all_smt(s, left + right))
- print(f'for n={n} there are {len(models)} models')
- #print(f'here are 10 models: {models[:10]}')
- s.pop()
- if __name__ == '__main__':
- main()
|