Skip to content

Commit

Permalink
grammar: add include directive for plp's
Browse files Browse the repository at this point in the history
  • Loading branch information
RenatoGeh committed Aug 26, 2024
1 parent 12f2b34 commit 1352f74
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
5 changes: 4 additions & 1 deletion pasp/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ SEMANTICS_OPT_PROB: "maxent" | "credal"
_semantics_exp: ((SEMANTICS_OPT_LOGIC ("," SEMANTICS_OPT_PROB)?) | (SEMANTICS_OPT_PROB ("," SEMANTICS_OPT_LOGIC)?))
semantics: "#semantics" (("(" _semantics_exp ")") | (_semantics_exp)) "."

// Include directive.
include: "#include" "\"" LOCAL_DATA "\"" ("," "\"" LOCAL_DATA "\"")* "."

// Inference directive.
exact_inf: "exact"
aseo_inf: "aseo" "," "nmodels" "=" ID
Expand All @@ -178,7 +181,7 @@ query: "#query" (("(" _interp_exp ")") | ( _interp_exp )) "."?
// Constant definition.
constdef: "#const" WORD "=" ID "."

plp: (constdef | _fact | _rule | _ad | _neural | data | python | constraint | query | learn | semantics | _aggr | inference)*
plp: (constdef | _fact | _rule | _ad | _neural | data | python | constraint | query | learn | semantics | _aggr | inference | include)*

COMMENT: "%" /[^\n]*/ NEWLINE

Expand Down
35 changes: 29 additions & 6 deletions pasp/grammar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import pathlib, enum, math, collections.abc
import pathlib, enum, math, collections.abc, os
import lark, lark.reconstruct, clingo, numpy
from numpy import ascontiguousarray as contiguous
from .program import ProbFact, Query, VarQuery, ProbRule, Program, CredalFact, unique_fact, \
Expand Down Expand Up @@ -49,6 +49,7 @@ class PreparsingTransformer(lark.Transformer):
def __init__(self):
super().__init__()
self.consts = {}
self.includes = set()
def __default__(self, _, __, ___): return lark.visitors.Discard
def SEMANTICS_OPT_LOGIC(self, O): return str(O)
def SEMANTICS_OPT_PROB(self, _): return lark.visitors.Discard
Expand All @@ -60,7 +61,10 @@ def constdef(self, C):
return lark.visitors.Discard
"Verify which logic semantic should be used and record constant definitions."
def plp(self, S):
return S[0] if len(S) > 0 else None, self.consts
return S[0] if len(S) > 0 else None, self.consts, self.includes
def include(self, P):
self.includes.update(map(lambda x: os.path.abspath(str(x)), P))
return lark.visitors.Discard

class StableTransformer(lark.Transformer):
class Pack(tuple):
Expand All @@ -71,10 +75,10 @@ def __new__(cls, tp: str, r: str = None, v = None, sc: dict = {}):
def __str__(self): return self[1]
def __repr__(self): return f"<{self[0]}: {self.__str__()}>"

def __init__(self, _, consts: dict = {}):
def __init__(self, _, consts: dict = {}, scope: dict = None):
super().__init__()
self.sem = Semantics.STABLE
self.torch_scope = {}
self.torch_scope = {} if scope is None else scope
self.n_prules = 0
self.consts = consts
self.varquery_id = 0
Expand Down Expand Up @@ -474,6 +478,9 @@ def learn(self, L):
data = self.torch_scope[L[0][1]] if L[0][0] == "PY_FUNC" else StableTransformer.path2obs(L[0][1])
return self.pack("directive", "", ("learn", data, A))

# Include directive.
def include(self, F): return lark.visitors.Discard

def exact_inf(self, I): return ("inference", "exact", tuple())
def aseo_inf(self, I): return ("inference", "aseo", (I[0][2],))
def inference(self, I): return self.pack("directive", "", I[0])
Expand Down Expand Up @@ -639,13 +646,29 @@ def transform(self, tree):
self.stable_p = StableTransformer(self.sem).transform(tree)
return super().transform(tree)

def _flatten_includes(*files: str, G: lark.Lark = None, from_str: bool = False) -> tuple:
transf = PreparsingTransformer()
T = read(*files, G=G, from_str=from_str)
sem, consts, includes = transf.transform(T)
_files = set() if from_str else set(map(os.path.abspath, set(files)))
to_parse = includes.difference(_files)
_files.update(includes)
while len(to_parse) > 0:
_T = read(*to_parse, G=G, from_str=False)
_sem, _consts, includes = transf.transform(_T)
if _sem is not None: sem = _sem
consts.update(_consts)
T.children.extend(u for u in _T.children if u not in T.children)
to_parse = includes.difference(_files)
_files.update(includes)
return sem, consts, T

def parse(*files: str, G: lark.Lark = None, from_str: bool = False, semantics: str = "stable") -> Program:
"""Either parses `streams` as blocks of text containing the PLP when `from_str = True`, or
interprets `streams` as filenames to be read and parsed into a `Program`."""
if semantics not in parse.trans_map:
raise ValueError("semantics not supported (must either be 'stable', 'partial' or 'lstable')!")
T = read(*files, G = G, from_str = from_str)
sem, consts = PreparsingTransformer().transform(T)
sem, consts, T = _flatten_includes(*files, G=G, from_str=from_str)
if sem is not None: semantics = sem
return parse.trans_map[semantics](semantics, consts).transform(T)
parse.trans_map = {}
Expand Down

0 comments on commit 1352f74

Please sign in to comment.