Skip to content

Commit

Permalink
UPDATE v1.0.3 - refactoring & falls model
Browse files Browse the repository at this point in the history
  • Loading branch information
detsutut committed Oct 12, 2020
1 parent f4bb564 commit 3ab6dde
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 71 deletions.
51 changes: 3 additions & 48 deletions scripts/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,51 +16,6 @@ bntools.fit = function(dag, data,method=c("bayes"),priorWeight = 1, verbose = FA
return(bn)
}

#' Query a target node given some evidence on other nodes, comparing the probability distribution of the target node before and after conditioning on the given evidence.
#'
#' @param bn the fully-specified bayesian network to query
#' @param target target node of the query
#' @param evidenceNodes nodes where the evidence is set
#' @param evidenceStates values of the evidence
#' @return table with the results of the query
#' @examples bntools.query(bn,target = "A", evidenceNodes = c("B","C"), evidenceStates = c("b1","c2"))
bntools.query = function(bn, target = NULL, evidenceNodes = c(), evidenceStates = c()){
junction_tree = compile(as.grain(bn))
if(is.null(target)) target = select.list(nodes(bn), preselect = NULL, multiple = FALSE, title = "Query target node:", graphics = TRUE)
if(length(evidenceNodes)==0){
selected = select.list(setdiff(nodes(bn), target),
preselect = NULL,
multiple = TRUE,
title = "Set evidence on:",
graphics = TRUE)
for(node in selected){
evidenceNodes = c(evidenceNodes,node)
levels = dimnames(bn[[node]]$prob)[[node]]
if(is.null(levels)) levels = dimnames(bn[[node]]$prob)[[1]]
state = select.list(levels,
preselect = NULL,
multiple = FALSE,
title = paste(toupper(node),"observed:"),
graphics = TRUE)
evidenceStates = c(evidenceStates,state)
}
}
junction_tree_evidence = setEvidence(junction_tree, nodes=evidenceNodes, states = evidenceStates)
queries = cbind(querygrain(junction_tree,nodes = target)[[target]],
querygrain(junction_tree_evidence,nodes = target)[[target]])
colnames(queries) = c(paste("P(",toupper(target),")",collapse = ""),paste("P(",toupper(target),"| Evidence* )",collapse = ""))
barplot(queries,
main=paste(toupper(target),"distributions"),
sub = paste("*Evidence :",paste(evidenceNodes,"=",evidenceStates,collapse = ", ")),
ylab="Probability",
legend = rownames(queries),
col = rainbow(n = length(rownames(queries)), s = 0.5),
args.legend = list(x = "bottomright", cex=0.8, title = toupper(target)),
beside=TRUE, horiz=FALSE)
tryCatch(shinyjs::hideElement(id = 'loading3'),error = function(e) print(e))
# return(queries)
}

#' Check if two nodes are d-separated given some evidence on other nodes. If no evidence is given, a greedy search will look for all the possible combination of nodes that,
#' when given, d-separate the source node and the target node. On complex network where the greedy search would be computationally expensive, the user may set the maximum subset
#' size to explore. If the maximum subset size is negative, the algorithm will stop when the minimum subset of d-separating features is detected.
Expand Down Expand Up @@ -102,7 +57,7 @@ bntools.dsep = function(bn, source=NULL, target=NULL, given = NULL, maxSize = NU
}
allCombos = allCombos[which(lapply(allCombos,nrow)<=maxSize)]
for (comboList in allCombos) {
results = pbapply(comboList, 2, function(z) {
results = apply(comboList, 2, function(z) {
dsep(dag,
x = source,
y = target,
Expand Down Expand Up @@ -265,7 +220,7 @@ dagtools.dsep = function(dag, source=NULL, target=NULL, given = NULL, maxSize =
}
allCombos = allCombos[which(lapply(allCombos,nrow)<=maxSize)]
for (comboList in allCombos) {
results = pbapply(comboList, 2, function(z) {
results = apply(comboList, 2, function(z) {
dsep(dag,
x = source,
y = target,
Expand Down Expand Up @@ -299,7 +254,7 @@ dagtools.findIc = function(dag, given = NULL) {
nodesToCheck = setdiff(nodes(dag), given)
combos = combn(x = nodesToCheck, 2)
positiveCombos = list()
results = pbapply(combos,2,function(x){
results = apply(combos,2,function(x){
dsep(dag,x=x[1],y=x[2],z=given)
})
if(length(which(results==TRUE))==0) return(NULL)
Expand Down
3 changes: 2 additions & 1 deletion server.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ function(input, output, session) {
bn = NULL
debug = FALSE
debugCounter = 0
queryRepeat = 19
evidenceMenuUiInjected = FALSE
shinyjs::runjs(
"if(getCookie('BN_tutorial') != 'true'){
Expand Down Expand Up @@ -396,7 +397,7 @@ function(input, output, session) {
queryData = eval(parse(text = paste("table(cpdist(bn, ", queryNodeString, ", ", #merge together and run the query
queryEvidenceString, "))", sep = "")))
#for loop to get more stable results
for (i in 1:4){
for (i in 1:queryRepeat){
queryData = rbind(queryData,eval(parse(text = paste("table(cpdist(bn, ", queryNodeString, ", ", #merge together and run the query
queryEvidenceString, "))", sep = ""))))
}
Expand Down
46 changes: 24 additions & 22 deletions ui.R
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
############## 0 ) INIT ##############
version = "1.0.3"

library("shiny")
library("shinyjs")
library("shinydashboard")
library("shiny") #Main package for web app development
library("shinyjs") #Adds JavaScript support
library("shinydashboard") #Adds dashboards support
library("shinyBS") #Adds Bootstrap support
library("ggplot2")
library("plotly")
library("dplyr")
library("shinyBS")
library("visNetwork")
library("bnlearn") #bayesian networks handler
library("gRain") #bayesian networks visualizer
library("pbapply") #adds progress bars to the apply family
library("DT")
library("visNetwork") #Tools for nteworks visualization
library("bnlearn") #Bayesian networks handler
library("DT") #Provides a visual interface for data tables

source("scripts/utilities.R") #load utilities
source("scripts/utilities.R") #Loads utilities

##### 0.1 ) Input list #####

Expand Down Expand Up @@ -64,30 +62,34 @@ fluidPage(
##### 1.2.1 ) File Loading #####
bsCollapse(id = "collapseLoad", open = "Learn The Network",
bsCollapsePanel("Learn The Network",
div(id="fileInput2",fileInput(inputId = "edgesFile", "Load edges",width = "95%", multiple = FALSE)),
div(id="fileInput3",fileInput(inputId = "dataFile", "Load data",width = "95%", multiple = FALSE)),
actionButton(inputId = "preTrained",
class = "debugElement",
label = "Load Car Insurance Example",
width = "86%"),
actionButton(inputId = "preTrainedFalls",
class = "debugElement",
label = "Load Pretrained Falls Network",
width = "86%")
div(id="fileInput2",fileInput(inputId = "edgesFile", "Load edges",width = "95%", multiple = FALSE),style="color:#b3bec2"),
div(id="fileInput3",fileInput(inputId = "dataFile", "Load data",width = "95%", multiple = FALSE),style="color:#b3bec2"),
div(class="row", style="margin-left: 0px; margin-right:35px",
h5("Preloaded Examples",style="text-align:center; color:#b3bec2"),
div(class="col-xs-6", style="padding-left: 0px; padding-right:0px",
actionButton(inputId = "preTrained",
class = "debugElement",
label = "Car Insurance",
width = "93%")),
div(class="col-xs-6", style="padding-left: 0px; padding-right:0px",
actionButton(inputId = "preTrainedFalls",
class = "debugElement",
label = "Falls Network",
width = "93%")))
)
),
hr(style="border-top-color:rgba(0,0,0,.1"),
##### 1.2.2 ) Query #####
bsCollapse(id = "collapseQuery", open = "Network Inference",
bsCollapsePanel("Network Inference",
div(id="querySection",
div(id="querySection", style="color:#b3bec2",
selectInput(inputId = "nodeToQuery",
label = "Selected node",
choices = c(""),
selected = NULL,
multiple = FALSE,
selectize = TRUE,
width ="95%",
width ="95%",
size = NULL),
actionButton(inputId = "query",
label = "Query",
Expand Down

0 comments on commit 3ab6dde

Please sign in to comment.