diff --git a/src/Perceptron/Perceptron_MLJ.jl b/src/Perceptron/Perceptron_MLJ.jl index 5d9ec5c9..1420ebb7 100644 --- a/src/Perceptron/Perceptron_MLJ.jl +++ b/src/Perceptron/Perceptron_MLJ.jl @@ -73,11 +73,12 @@ PegasosClassifier(; function MMI.fit(model::PerceptronClassifier, verbosity, X, y) x = MMI.matrix(X) # convert table to matrix + allClasses = levels(y) initialθ = length(model.initialθ) == 0 ? zeros(size(x,2)) : model.initialθ fitresult = perceptron(x, y; θ=initialθ, θ₀=model.initialθ₀, T=model.maxEpochs, nMsgs=0, shuffle=model.shuffle, forceOrigin=model.forceOrigin, returnMeanHyperplane=model.returnMeanHyperplane,rng=model.rng) cache=nothing report=nothing - return fitresult, cache, report + return (fitresult,allClasses), cache, report end function MMI.fit(model::KernelPerceptronClassifier, verbosity, X, y) @@ -92,34 +93,36 @@ end function MMI.fit(model::PegasosClassifier, verbosity, X, y) x = MMI.matrix(X) # convert table to matrix + allClasses = levels(y) initialθ = length(model.initialθ) == 0 ? zeros(size(x,2)) : model.initialθ fitresult = pegasos(x, y; θ=initialθ,θ₀=model.initialθ₀, λ=model.λ,η=model.η, T=model.maxEpochs, nMsgs=0, shuffle=model.shuffle, forceOrigin=model.forceOrigin, returnMeanHyperplane=model.returnMeanHyperplane,rng=model.rng) cache=nothing report=nothing - return fitresult, cache, report + return (fitresult,allClasses), cache, report end # ------------------------------------------------------------------------------ # Predict functions.... function MMI.predict(model::Union{PerceptronClassifier,PegasosClassifier}, fitresult, Xnew) - fittedModel = fitresult + fittedModel = fitresult[1] #classes = CategoricalVector(fittedModel.classes) classes = fittedModel.classes - nLevels = length(classes) + allClasses = fitresult[2] # as classes do not includes classes unsees at training time + nLevels = length(allClasses) nRecords = MMI.nrows(Xnew) modelPredictions = Perceptron.predict(MMI.matrix(Xnew), fittedModel.θ, fittedModel.θ₀, fittedModel.classes) predMatrix = zeros(Float64,(nRecords,nLevels)) # Transform the predictions from a vector of dictionaries to a matrix # where the rows are the PMF of each record for n in 1:nRecords - for (c,cl) in enumerate(classes) + for (c,cl) in enumerate(allClasses) predMatrix[n,c] = get(modelPredictions[n],cl,0.0) end end #predictions = [MMI.UnivariateFinite(classes, predMatrix[i,:]) # for i in 1:nRecords] - predictions = MMI.UnivariateFinite(fittedModel.classes,predMatrix,pool=missing) + predictions = MMI.UnivariateFinite(allClasses,predMatrix,pool=missing) return predictions end