Skip to content

Commit

Permalink
Improved MLJ interface also for PerceptronClassifier and PegasosClass…
Browse files Browse the repository at this point in the history
…ifier

This closes issue #31
  • Loading branch information
sylvaticus committed Jun 1, 2022
1 parent 4e2dc80 commit 767f697
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/Perceptron/Perceptron_MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ PegasosClassifier(;

function MMI.fit(model::PerceptronClassifier, verbosity, X, y)
x = MMI.matrix(X) # convert table to matrix
allClasses = levels(y)
initialθ = length(model.initialθ) == 0 ? zeros(size(x,2)) : model.initialθ
fitresult = perceptron(x, y; θ=initialθ, θ₀=model.initialθ₀, T=model.maxEpochs, nMsgs=0, shuffle=model.shuffle, forceOrigin=model.forceOrigin, returnMeanHyperplane=model.returnMeanHyperplane,rng=model.rng)
cache=nothing
report=nothing
return fitresult, cache, report
return (fitresult,allClasses), cache, report
end

function MMI.fit(model::KernelPerceptronClassifier, verbosity, X, y)
Expand All @@ -92,34 +93,36 @@ end

function MMI.fit(model::PegasosClassifier, verbosity, X, y)
x = MMI.matrix(X) # convert table to matrix
allClasses = levels(y)
initialθ = length(model.initialθ) == 0 ? zeros(size(x,2)) : model.initialθ
fitresult = pegasos(x, y; θ=initialθ,θ₀=model.initialθ₀, λ=model.λ,η=model.η, T=model.maxEpochs, nMsgs=0, shuffle=model.shuffle, forceOrigin=model.forceOrigin, returnMeanHyperplane=model.returnMeanHyperplane,rng=model.rng)
cache=nothing
report=nothing
return fitresult, cache, report
return (fitresult,allClasses), cache, report
end

# ------------------------------------------------------------------------------
# Predict functions....

function MMI.predict(model::Union{PerceptronClassifier,PegasosClassifier}, fitresult, Xnew)
fittedModel = fitresult
fittedModel = fitresult[1]
#classes = CategoricalVector(fittedModel.classes)
classes = fittedModel.classes
nLevels = length(classes)
allClasses = fitresult[2] # as classes do not includes classes unsees at training time
nLevels = length(allClasses)
nRecords = MMI.nrows(Xnew)
modelPredictions = Perceptron.predict(MMI.matrix(Xnew), fittedModel.θ, fittedModel.θ₀, fittedModel.classes)
predMatrix = zeros(Float64,(nRecords,nLevels))
# Transform the predictions from a vector of dictionaries to a matrix
# where the rows are the PMF of each record
for n in 1:nRecords
for (c,cl) in enumerate(classes)
for (c,cl) in enumerate(allClasses)
predMatrix[n,c] = get(modelPredictions[n],cl,0.0)
end
end
#predictions = [MMI.UnivariateFinite(classes, predMatrix[i,:])
# for i in 1:nRecords]
predictions = MMI.UnivariateFinite(fittedModel.classes,predMatrix,pool=missing)
predictions = MMI.UnivariateFinite(allClasses,predMatrix,pool=missing)
return predictions
end

Expand Down

2 comments on commit 767f697

@sylvaticus
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • bugfixes in MLJ interface, gmm clustering and other
  • API change for print(confusionMatrix) only

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/61465

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.0 -m "<description of version>" 767f69779913dd8a12412be0259e6718c9b61e6f
git push origin v0.6.0

Please sign in to comment.