bif2asp.jl
4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
using PythonCall
function is_truthy(x)
falsy_re = r"false|0|-0|nul|undef|no|nan"i
return Bool(match(falsy_re, "$x") === nothing)
end
BIFReader = pyimport("pgmpy.readwrite").BIFReader
function cpd_indexes(x)
if length(x) > 0
h = x[1]
subs = cpd_indexes(x[2:end])
result = Dict()
if length(subs) > 0
for subk in keys(subs)
pos_key = [:pos, subk...]
neg_key = [:neg, subk...]
result[pos_key] = [h[2], subs[subk]...]
result[neg_key] = [h[3], subs[subk]...]
end
else
result[ [:pos] ] = [ h[2] ]
result[ [:neg] ] = [ h[3] ]
end
return result
else
return Dict()
end
end
function compile_nodecpd(states, parents, cpd)
pos_idx = findfirst(states .|> is_truthy)
par_idxs = cpd_indexes(parents)
idxs = length(par_idxs) > 0 ?
Dict( [:pos, cx...] => [pos_idx, ix...] for (cx, ix) in par_idxs ) :
Dict([:pos] => [ pos_idx ])
return Dict(
conf => getindex(cpd, idx...) for (conf, idx) in idxs)
end
function get_bif(filename)
dag = Dict{String, Any}("source" => basename(filename))
model = BIFReader(filename).get_model()
nodes = model.nodes()
bn_info = []
for node in nodes
# states
states = pyconvert(PyList, model.states[node])
# parents
parents_names = pyconvert(PyList, model.get_parents(node))
parents = []
for parent in parents_names
parent_states = pyconvert(PyList, model.states[parent])
pos_idx = findfirst(parent_states .|> is_truthy)
neg_idx = findfirst(parent_states .|> (x -> !x) ∘ is_truthy)
push!(parents, (parent, pos_idx, neg_idx))
end
# CPDS
prob_tab = model.get_cpds(node)
prob_cpd = pyconvert(PyArray, prob_tab.values)
push!(bn_info, Dict(
"node" => pyconvert(String, node),
"parents" => parents_names,
"cpd" => compile_nodecpd(states, parents, prob_cpd)
))
end
dag["nodes"] = bn_info
return dag
end
function pretty_print(d::Dict, pre=1)
for (k,v) in d
if typeof(v) <: Dict
s = "$(repr(k)) => "
println(join(fill(" ", pre)) * s)
pretty_print(v, pre+1+length(s))
else
println(join(fill(" ", pre)) * "$(repr(k)) => $(repr(v))")
end
end
nothing
end
function make_rule(head, parents, case, prob; with_annotations=false)
body = []
for (ci,si) in enumerate(case[2:end])
p = getindex(parents, ci)
push!(body, si == :pos ? p : "-$p")
end
body_str = join(body, ", ")
if with_annotations
rule = "$(rpad(prob, 5)) :: $head :- $body_str."
else
rule = "% prob($head, [$body_str], \"$prob\").\n$head ; -$head :- $body_str."
end
return rule
end
make_fact(head, prob; with_annotations=false) = with_annotations ?
"$(rpad(prob, 5)) :: $head." :
"% prob($head, \"$prob\").\n$head ; -$head."
function bif2ast(dag; with_annotations=false, logging=false)
rules = []
for node in dag["nodes"]
parents = node["parents"]
arity = length(parents)
node_head = node["node"]
cpd = node["cpd"]
if logging
println("Node $node_head | $(join(parents, " "))\n\t$(cpd)")
end
if arity == 0
push!(rules, make_fact(node_head, cpd[[:pos]]; with_annotations))
else
for (case, prob) in cpd
push!(rules, make_rule(node_head, parents, case, prob; with_annotations))
end
end
if logging
println("\t Rules: $(length(rules))")
end
end
return join(rules, "\n")
end
using ArgParse
arg_parser = ArgParseSettings()
@add_arg_table arg_parser begin
"filename"
help = "A BIF file to convert."
required = true
"--outfile", "-o"
help = "Target ASP file."
"--annotated", "-a"
help = "Generate ASP extended with probabilistic annotations."
action = :store_true
end
args = parse_args(arg_parser)
filename = args["filename"]
annotations = args["annotated"]
has_astfile = args["outfile"] !== nothing
if !isfile(filename) error("Missing file [$filename].") end
dag = get_bif(filename)
ast = bif2ast(dag; with_annotations=annotations)
if has_astfile
open(args["outfile"], "w") do ast_file
write(ast_file, ast)
end
else
println(ast)
end