-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_batch.R
82 lines (68 loc) · 2.69 KB
/
predict_batch.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
args=commandArgs(trailingOnly=TRUE)
# print(length(args))
if (length(args) != 3 )
stop("Invalid number of arguments to Rscript.")
script_path <- args[1]
input_file <- args[2]
output_path <- args[3]
# print(script_path)
# print(input_file)
# print(output_path)
convertRI <- function(diff){
if(diff < 0.1)
return(1)
return(ceiling(round(diff*10, digits = 0)/2))
}
library(reshape2, quietly=TRUE)
library(caret, quietly=TRUE)
Input <- read.table(input_file, sep = "\t", comment.char = "", header = TRUE, stringsAsFactors=FALSE)
Ind_batch <- Input[,c("POS", "REF", "ALT", "QUAL", "SAMPLE", "AO.1", "DP.1")]
Ind_batch$AO.1[Ind_batch$AO.1=="."] <- 0
Ind_batch$DP.1[Ind_batch$DP.1=="."] <- 0
Ind_batch <- transform(Ind_batch, DP.1 = as.numeric(DP.1))
Ind_batch <- transform(Ind_batch, AO.1 = as.numeric(AO.1))
Ind_batch$MUT <- paste(Ind_batch$REF, Ind_batch$POS, Ind_batch$ALT, sep = "")
Ind_batch$VAL <- Ind_batch$AO.1/Ind_batch$DP.1
Ind_batch$VAL[Ind_batch$VAL == "NaN"] <- 0.0000
Ind_batch$VAL <- round(Ind_batch$VAL, digits = 4)
Ind_batch_l <- dcast(Ind_batch[,c("SAMPLE", "MUT", "VAL")], SAMPLE~MUT, value.var = "VAL")
A_col <- unlist(read.table(paste0(script_path, "A_col.txt")))
svm_L <- readRDS(paste0(script_path, "svm_L_model.rds"))
data_2_bat <- data.frame(matrix(ncol = length(A_col), nrow = nrow(Ind_batch_l)))
colnames(data_2_bat) <- A_col
for(j in 1:nrow(data_2_bat)){
for(i in colnames(data_2_bat)){
if(i %in% colnames(Ind_batch_l)){
data_2_bat[j,i] <- Ind_batch_l[j,i]
}else{
data_2_bat[j,i] <- 0.00
}
}
}
pred_prob_bat <- predict(svm_L, data_2_bat, type="prob")
pred_bat <- predict(svm_L, data_2_bat)
# pred_bat <- unlist(pred_prob_svm_R_ind_bat)
output <- data.frame(matrix(ncol = 6, nrow = nrow(pred_prob_bat)))
colnames(output) <- c("SAMPLE", "MDR", "Susceptible", "XDR", "CLASS", "RI")
for(i in 1:nrow(pred_prob_bat)){
n <- which.max(pred_prob_bat[i,])
output[i,"SAMPLE"] <- Ind_batch_l$SAMPLE[i]
output[i,2:4] <- pred_prob_bat[i,1:3]
output[i,2:4] <- round(output[i,2:4], digits = 4)
# output[i,"CLASS"] <- colnames(pred_prob_bat[i,])[n]
if(colnames(pred_prob_bat[i,])[n] == "M"){
output[i,"CLASS"] <- "MDR"
}else if(colnames(pred_prob_bat[i,])[n] == "S"){
output[i,"CLASS"] <- "Susceptible"
}else{
output[i,"CLASS"] <- "XDR"
}
# output[i,3] <- max(pred_prob_bat[i,])- max(pred_prob_bat[i,-n])
# output[i,"RI"] <- round((max(pred_prob_bat[i,])- max(pred_prob_bat[i,-n]))*10, digits = 0)
output[i,"RI"] <- convertRI((max(pred_prob_bat[i,])- max(pred_prob_bat[i,-n])))
}
# colnames(output) <- c("SAMPLE", "PREDICTION", "DIFF", "RI")
cat("\n")
print(output)
cat("\n")
write.table(output, file = paste0(output_path, "prediction.tsv"), row.names = FALSE)