dpda.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. ###############################################################
  4. # Learning dPDA from examples by constraint solving via SMT #
  5. ###############################################################
  6. # python3 dpda.py
  7. # python3 -m unittest dpda.py
  8. import z3
  9. import base64
  10. import random
  11. import itertools
  12. ###############
  13. # Utilities #
  14. ###############
  15. def unescape(bs):
  16. try:
  17. result = bs.decode('unicode_escape').encode('latin-1')
  18. return result
  19. except UnicodeDecodeError as e:
  20. print ("error", e)
  21. return None
  22. # C'mon, python, seriously?
  23. class List(list):
  24. def __hash__(self):
  25. return hash(tuple(self))
  26. # Missing from Z3py:
  27. def Sequence(name, ctx=None):
  28. """Return a sequence constant named `name`. If `ctx=None`, then the global context is used.
  29. >>> x = Sequence('x')
  30. """
  31. ctx = z3.get_ctx(ctx)
  32. int_sort = z3.IntSort(ctx)
  33. return z3.SeqRef(
  34. z3.Z3_mk_const(ctx.ref(),
  35. z3.to_symbol(name, ctx),
  36. z3.SeqSortRef(z3.Z3_mk_seq_sort(int_sort.ctx_ref(), int_sort.ast)).ast),
  37. ctx)
  38. def shortlex(alphabet, prefix=b''):
  39. """
  40. Enumerate all strings over alphabet in shortlex order
  41. """
  42. assert(len(alphabet))
  43. iters = []
  44. for a in alphabet:
  45. iters += [shortlex(alphabet, prefix=prefix+bytes([a]))]
  46. yield prefix
  47. while iters != []:
  48. for x in iters:
  49. try:
  50. yield next(x)
  51. except StopIteration:
  52. iters.remove(x)
  53. class InfoTrie:
  54. """
  55. A word trie with information optionally attached to nodes
  56. (self.info == None indicates that the node is an inner node)
  57. @note only add is supported, not remove
  58. """
  59. def __init__(self):
  60. self.dict = {}
  61. self.info = None
  62. def add(self, s, info=True, accum=b''):
  63. """
  64. Adds a word, with info.
  65. Modifies the structure.
  66. @return the added suffix.
  67. """
  68. if len(s):
  69. i = InfoTrie()
  70. t = self.dict.setdefault(s[0], i)
  71. if t == i:
  72. accum += bytes([s[0]])
  73. return t.add(s[1:], info, accum)
  74. else:
  75. self.info = info
  76. return accum
  77. def get(self, s):
  78. """
  79. @return None if no info or no path
  80. """
  81. if len(s):
  82. t = self.dict.get(s[0])
  83. if t != None:
  84. return t.get(s[1:])
  85. return None
  86. return self.info
  87. def iter(self, prefix=b''):
  88. """
  89. @return iterate over the trie, depth-first, outputting all nodes.
  90. @param prefix optional
  91. """
  92. yield (prefix, self, self.info, self.has_children())
  93. for (k,v) in self.dict.items():
  94. yield from v.iter(b"%s%s" % (prefix, bytes([k])))
  95. def has_children(self):
  96. """
  97. Assuming the invariant that only inner nodes have info == None,
  98. this means the node has descendants with info != None
  99. """
  100. return len(self.dict) > 0
  101. def first_word_not_in(self, alphabet):
  102. '''
  103. (by shortlex)
  104. '''
  105. return next(itertools.filterfalse(lambda w: self.get(w) != None, shortlex(alphabet)))
  106. class Automaton:
  107. """
  108. Class for a specific type of real-time deterministic pushdown automata (dPDA)
  109. Acceptance criterion: stack empty or contaning just a "final state" symbol
  110. """
  111. def __init__(self):
  112. self.QF = set()
  113. self.D = dict()
  114. self.productive = None
  115. def construct_from_z3_model(self, m, d, Qf, alphabet):
  116. to_check = [0]
  117. checked = set(to_check)
  118. print ("Extracting tables")
  119. self.D = dict()
  120. self.QF = set()
  121. self.productive = None
  122. print ("m[d] = %s" % m[d])
  123. print ("m[qf] = %s" % m[Qf])
  124. symbols = set([0])
  125. while len(to_check):
  126. current = to_check.pop()
  127. conf = z3.Unit(z3.IntVal(current))
  128. for a in alphabet: # range(0, 256):
  129. y = m.evaluate(
  130. d(
  131. # z3.SubSeq(conf, z3.Length(conf)-1, 1),
  132. conf,
  133. z3.StringVal(bytes([a]))),
  134. model_completion = True)
  135. def extract_seq_as_list(y):
  136. result = List()
  137. for c in y.children():
  138. if isinstance(c, z3.SeqRef):
  139. result += extract_seq_as_list(c)
  140. else:
  141. result += List([c.as_long()])
  142. return result
  143. rhs = extract_seq_as_list(y)
  144. for symbol in rhs:
  145. symbols.add(symbol)
  146. Dq = self.D.setdefault(current, dict())
  147. Dq[a] = rhs
  148. for i in rhs:
  149. if not i in checked:
  150. checked.add(i)
  151. to_check += [i]
  152. if m.evaluate(Qf(z3.Empty(z3.SeqSort(z3.IntSort())))):
  153. self.QF.add(List([]))
  154. print("(stack/q) symbols encountered: %s" % symbols)
  155. for symbol in symbols:
  156. conf = z3.Unit(z3.IntVal(symbol))
  157. f = m.evaluate(Qf(conf))
  158. if f:
  159. self.QF.add(List([symbol]))
  160. self.symbols = symbols
  161. def productivity(self):
  162. if self.productive == None:
  163. self.compute_productivity()
  164. return self.productive
  165. def compute_productivity(self):
  166. """
  167. Determine for each stack/state symbol whether it leads to QF, or to decreasing length
  168. """
  169. productive = {'down':set(), 'toqf':set(), 'tononqf':set()}
  170. S = set()
  171. for k in self.D.keys():
  172. S.add(k)
  173. down_immediate = set()
  174. for s in S:
  175. if any(map(lambda x: len(x)<1, self.D.get(s).values())):
  176. down_immediate.add(s)
  177. toqf_immediate = set()
  178. for s in S:
  179. if List([s]) in self.QF:
  180. toqf_immediate.add(s)
  181. tononqf_immediate = set()
  182. for s in S:
  183. if List([s]) not in self.QF:
  184. tononqf_immediate.add(s)
  185. import copy
  186. eqs = copy.deepcopy(self.D)
  187. changed = True # iterate until no more changes
  188. down = down_immediate.copy()
  189. toqf = toqf_immediate.copy()
  190. tononqf = tononqf_immediate.copy()
  191. while changed:
  192. changed = False
  193. for s in eqs.keys():
  194. for (k,v) in eqs[s].items():
  195. new = List(filter(lambda x: x not in down, eqs[s][k]))
  196. if eqs[s][k] != new:
  197. changed = True
  198. eqs[s][k] = new
  199. if new == []:
  200. down.add(s)
  201. if len(v) == 1 and v[0] in toqf and s not in toqf:
  202. toqf.add(s)
  203. changed = True
  204. if len(v) == 1 and v[0] in tononqf and s not in tononqf:
  205. tononqf.add(s)
  206. changed = True
  207. if len(v) > 1 and s not in tononqf:
  208. tononqf.add(s)
  209. changed = True
  210. productive['down'] = down
  211. productive['toqf'] = toqf
  212. productive['tononqf'] = tononqf
  213. self.productive = productive # cache
  214. return productive
  215. def enumerate_words(self, alphabet, configurations_prefixes={List([0]) : {b''}}, mode = 'words'):
  216. """
  217. Short-Lex enumeration of L (or \Sigma^\ast - L),
  218. @param trie exclude these words
  219. """
  220. successors = {}
  221. productive = self.productivity()
  222. for (configuration, prefixes) in configurations_prefixes.items():
  223. # final and empty
  224. if len(configuration) == 0:
  225. yield from prefixes
  226. else:
  227. # final without being empty - the show can go on
  228. if len(configuration) == 1 and configuration in self.QF:
  229. yield from prefixes
  230. # otherwise, it goes on anyway but the word is not yielded.
  231. # detect empty languages
  232. if not configuration[-1] in productive["down"]:
  233. if not (len(configuration)==1 and configuration[-1] in productive["toqf"]):
  234. continue
  235. # compute next layer
  236. dmap = self.D.get(configuration[-1])
  237. if dmap != None:
  238. for (a, stack_suffix) in dmap.items():
  239. newstack = List(configuration[:-1] + stack_suffix)
  240. for prefix in prefixes:
  241. successors.setdefault(newstack, set()).add(b"%s%s" % (prefix, bytes([a])))
  242. if len(successors):
  243. yield from self.enumerate_words(alphabet, successors, mode)
  244. class Automaker:
  245. """
  246. attempt to guess "realtime dPDA", (no epsilon moves)
  247. no explicit set Q of control states, just a stack.
  248. accept when: A) stack empty or B) stack contains just one symbol from QF
  249. """
  250. def __init__(self, t, limitS, limitL):
  251. self.t = t
  252. self.s = z3.Solver()
  253. self.i = 0 # serial number generator
  254. self.alphabet = range(0, 256)
  255. # arbitrarily limit number of state/stack symbols
  256. self.limitS = limitS
  257. # arbitrarily limit length of added suffix
  258. self.limitL = limitL
  259. def set_alphabet(self, alphabet):
  260. self.alphabet = alphabet
  261. def addPath(self, prefix, rest, final):
  262. """
  263. """
  264. for l in range(0, len(rest)+1):
  265. w = prefix + rest[:l]
  266. sv = b"stack" + base64.b16encode(w)
  267. self.stackvars[w] = Sequence(sv)
  268. w = prefix + rest
  269. self.s_add_finalstate(w, final)
  270. if len(w):
  271. self.s_add_transition_to(w)
  272. def setupProblem(self):
  273. self.stackvars = {}
  274. self.Qf = z3.Function("final", z3.SeqSort(z3.IntSort()), z3.BoolSort())
  275. self.d = z3.Function("delta", z3.SeqSort(z3.IntSort()), z3.StringSort(), z3.SeqSort(z3.IntSort()))
  276. for (w, st, final, has_children) in self.t.iter():
  277. sv = b"stack" + base64.b16encode(w)
  278. self.stackvars[w] = Sequence(sv)
  279. self.s_add_finalstate(w, final)
  280. if final and has_children:
  281. self.s_add_nonemptystate(w)
  282. if len(w):
  283. self.s_add_transition_to(w)
  284. self.s.add(self.stackvars[b''] == z3.Unit(z3.IntVal(0)))
  285. # most useful convention:
  286. # accept by drained stack, but don't read any more and fail then
  287. self.s.add(self.Qf(z3.Empty(z3.SeqSort(z3.IntSort()))) == True)
  288. def s_isFinalWord(self, w):
  289. z3var = self.stackvars[w]
  290. return z3.And(z3.Length(z3var)<=1, self.Qf(z3var))
  291. def s_isFinalConfiguration(self, c):
  292. z3val = z3.StringVal(c)
  293. return z3.And(z3.Length(z3val)<=1, self.Qf(z3val))
  294. def s_add_finalstate(self, w, final=True):
  295. """
  296. final=True: the state reached by this word is final ($w \in L$)
  297. final=True: the state reached by this word is not final ($w \neg\in L$)
  298. """
  299. isFinal = self.s_isFinalWord(w)
  300. if final:
  301. self.s.add(isFinal)
  302. elif final == False:
  303. self.s.add(z3.Not(isFinal))
  304. def s_add_nonemptystate(self, w):
  305. z3var = self.stackvars[w]
  306. self.s.add(z3.Not(z3.Length(z3var)==0))
  307. def s_add_transition_to(self, w):
  308. """
  309. @pre len(w)>0
  310. """
  311. i = self.gennum()
  312. pre = Sequence("pre%d" % i)
  313. a = Sequence("a%d" % i)
  314. self.s.add(z3.Length(a) == 1)
  315. self.s.add(z3.Concat(pre, a) == self.stackvars[w[:-1]])
  316. x = self.d(a, z3.StringVal(w[-1:]))
  317. self.s.add(z3.Length(x) <= self.limitL)
  318. for i in range(self.limitL):
  319. self.s.add(
  320. z3.Implies(
  321. z3.Length(x) > i,
  322. z3.And(
  323. x[i] < z3.IntVal(self.limitS),
  324. x[i] >= 0)
  325. ))
  326. self.s.add(z3.Concat(pre, x) == self.stackvars[w])
  327. def gennum(self):
  328. i = self.i
  329. self.i += 1
  330. return i
  331. def askZ3(self):
  332. r = self.s.check()
  333. print(r)
  334. if z3.sat == r:
  335. self.m = self.s.model()
  336. self.extract_tables()
  337. print(self.m)
  338. return True
  339. else:
  340. return False
  341. def extract_tables(self):
  342. """
  343. Extract the automaton's information from the Z3 model
  344. """
  345. self.automaton = Automaton()
  346. self.automaton.construct_from_z3_model(self.m, self.d, self.Qf, self.alphabet)
  347. return self.automaton
  348. def enumerate_words_t(self, alphabet, prefix=b'', configuration = List([0]), mode = 'words'):
  349. yield from self.automaton.enumerate_words(alphabet, {configuration : {prefix}}, mode)
  350. ################
  351. # Unit tests #
  352. ################
  353. import unittest
  354. class TrieTest(unittest.TestCase):
  355. def test_trie(self):
  356. t = InfoTrie()
  357. self.assertEqual(t.info, None)
  358. t.add(b'')
  359. self.assertEqual(t.info, True)
  360. t.add(b'add')
  361. self.assertEqual(t.info, True)
  362. self.assertEqual(t.get(b''), True)
  363. self.assertEqual(t.get(b'a'), None)
  364. self.assertEqual(t.get(b'add'), True)
  365. t.add(b'a')
  366. self.assertEqual(t.get(b'a'), True)
  367. t.add(b'ah')
  368. self.assertEqual(t.get(b'a'), True)
  369. self.assertEqual(t.get(b'ah'), True)
  370. self.assertEqual(t.get(b'ad'), None)
  371. self.assertEqual(len([_ for (_,_,info,_) in t.iter() if info]), 4)
  372. self.assertEqual(len([_ for (_,_,_,_) in t.iter()]), 5)
  373. self.assertEqual(len([_ for (_,_,info,_) in t.iter(b'a') if info]), 4)
  374. self.assertEqual(len([_ for (_,_,_,_) in t.iter(b'a')]), 5)
  375. ##################
  376. # Main program #
  377. ##################
  378. import pprint
  379. import argparse
  380. import fileinput
  381. if __name__=='__main__':
  382. parser = argparse.ArgumentParser( description='realtime-deterministic-PDA constructor', )
  383. parser.add_argument('-m', action="store", dest="mode", type=str, default="simple")
  384. parser.add_argument('-i', action="append", dest="files", type=str, default=[])
  385. parser.add_argument('--version', action='version', version='%(prog)s 0.0.1')
  386. args = parser.parse_args()
  387. files = args.files
  388. mode = args.mode
  389. try:
  390. t = InfoTrie()
  391. # arbitrarily limit number of state/stack symbols
  392. limitS = 5
  393. # arbitrarily limit length of suffix added in any stack operation
  394. limitL = 2
  395. # limit search to the observed input alphabet
  396. inputalph = set()
  397. if mode == 'simple':
  398. for l in fileinput.input(files, mode='rb'):
  399. t.add(l[:-1], True)
  400. inputalph = inputalph.union(l[:-1])
  401. print(b"; ".join([w for (w, st, info, has_children) in t.iter()]))
  402. print("; ".join(['(%s,%s)' % (w, pprint.saferepr(i)) for (w, st, i, has_children) in t.iter()]))
  403. a = Automaker(t, limitS, limitL)
  404. a.setupProblem()
  405. a.set_alphabet(inputalph)
  406. a.askZ3()
  407. elif mode == 'advanced':
  408. a = Automaker(t, limitS, limitL)
  409. a.setupProblem()
  410. question = None
  411. prompt = "dPDA wizard :> "
  412. if len(files) < 1:
  413. print (prompt, end='', flush=True)
  414. for l in fileinput.input(files, mode='rb'):
  415. inp = None
  416. # assert a positive example:
  417. if l[0] == b'p'[0] or l[0] == b'y'[0]:
  418. inp = unescape(l[2:-1])
  419. if question != None and not inp:
  420. inp = question
  421. question = None
  422. pol = True
  423. # assert a negative example:
  424. elif l[0] == b'n'[0] or l[0] == b'n'[0]:
  425. inp = unescape(l[2:-1])
  426. if question != None and not inp:
  427. inp = question
  428. question = None
  429. pol = False
  430. # query
  431. elif l[0] == b'?'[0]:
  432. print ("asking Z3")
  433. sat = a.askZ3()
  434. # ...
  435. # print(a.automaton.D)
  436. # print(a.automaton.QF)
  437. if sat:
  438. checkalph = List(inputalph)
  439. print ("here's some words (alphabet %s)" % checkalph)
  440. try:
  441. en = a.enumerate_words_t(checkalph, b"")
  442. for w in itertools.islice(en, 0, 14):
  443. print (w)
  444. except StopIteration:
  445. print ("oops, language is not infinite")
  446. except RecursionError:
  447. print ("RecursionError (this is a bug)")
  448. question = t.first_word_not_in(checkalph)
  449. # quit
  450. elif l[0] == b'q'[0]:
  451. exit(0)
  452. elif l[0] == b's'[0]:
  453. limitS = int(l[2:])
  454. print ("number of internal symbols set to %d" % limitS)
  455. # process input in case of assertion
  456. if inp != None:
  457. suffix = t.add(inp, pol)
  458. inputalph = inputalph.union(inp)
  459. print ('asserting %s %s' % (pol, inp))
  460. a = Automaker(t, limitS, limitL)
  461. a.setupProblem()
  462. a.set_alphabet(inputalph)
  463. if len(files) < 1:
  464. if question == None:
  465. prompt = "dPDA wizard :> "
  466. else:
  467. prompt = "%s? :>" % repr(question)
  468. print (prompt, end='', flush=True)
  469. except KeyboardInterrupt as e:
  470. pass
  471. raise e