Skip to content

Commit

Permalink
Fixes multilabel classifiers and multilabel generator (#249)
Browse files Browse the repository at this point in the history
* Fixes multilabel classifiers and multilabel generator

* Formatting
  • Loading branch information
Alberto Cano authored Apr 14, 2022
1 parent ff3efe3 commit eab15f7
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class Range implements Serializable {

public Range(String range) {
this.rangeText = range;
//this.setRange(range); //needs upperLimit
this.setRange(range); //needs upperLimit
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
* @author Jesse Read
* @version $Revision: 1 $
*/
public class MEKAClassifier extends AbstractMultiLabelLearner implements MultiTargetRegressor, Serializable {
public class MEKAClassifier extends AbstractMultiLabelLearner implements MultiLabelLearner, MultiTargetRegressor, Serializable {

private static final long serialVersionUID = 1L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.yahoo.labs.samoa.instances.MultiLabelPrediction;
import com.yahoo.labs.samoa.instances.Prediction;
import moa.classifiers.AbstractMultiLabelLearner;
import moa.classifiers.MultiLabelLearner;
import moa.classifiers.MultiTargetRegressor;
import moa.core.StringUtils;

Expand All @@ -38,7 +39,7 @@
* @author Jesse Read (jesse@tsc.uc3m.es)
* @version $Revision: 1 $
*/
public class MajorityLabelset extends AbstractMultiLabelLearner implements MultiTargetRegressor {
public class MajorityLabelset extends AbstractMultiLabelLearner implements MultiLabelLearner {
//AbstractClassifier {

private static final long serialVersionUID = 1L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,13 @@ public MultilabelLearningNodeClassifier(double[] initialClassObservations, Class
super(initialClassObservations);

if (cl== null) {
this.classifier = ((Classifier) getPreparedClassOption(ht.learnerOption)).copy();
this.classifier.resetLearning();
MEKAClassifier learner = new MEKAClassifier();
learner.baseLearnerOption.setValueViaCLIString("meka.classifiers.multilabel.incremental.PSUpdateable");
learner.prepareForUse();
learner.setModelContext(ht.getModelContext());

InstancesHeader raw_header = ht.getModelContext();
this.classifier.setModelContext(raw_header);
this.classifier = learner;
this.classifier.resetLearning();
}
else{
this.classifier = cl.copy();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1481,4 +1481,4 @@ public void log(String s) {
// endregion --- Processing methods

// endregion ================ METHODS ================
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import moa.streams.InstanceStream;
import moa.streams.MultiTargetInstanceStream;
import moa.tasks.TaskMonitor;
import com.yahoo.labs.samoa.instances.Attribute;
import com.yahoo.labs.samoa.instances.DenseInstance;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import com.yahoo.labs.samoa.instances.Range;
import com.yahoo.labs.samoa.instances.SparseInstance;
import moa.core.FastVector;
import moa.core.Utils;
Expand All @@ -43,7 +46,7 @@
* @author Jesse Read ((jesse@tsc.uc3m.es))
* @version $Revision: 7 $
*/
public class MetaMultilabelGenerator extends AbstractOptionHandler implements InstanceStream {
public class MetaMultilabelGenerator extends AbstractOptionHandler implements MultiTargetInstanceStream {

private static final long serialVersionUID = 1L;

Expand Down Expand Up @@ -153,19 +156,25 @@ public void restart() {
* @param si single-label Instances
*/
protected MultilabelInstancesHeader generateMultilabelHeader(Instances si) {
Instances mi = new Instances(si, 0, 0);
mi.setClassIndex(-1);
mi.deleteAttributeAt(mi.numAttributes() - 1);
FastVector bfv = new FastVector();
bfv.addElement("0");
bfv.addElement("1");
for (int i = 0; i < this.m_L; i++) {
mi.insertAttributeAt(new Attribute("class" + i, bfv), i);
}
this.multilabelStreamTemplate = mi;
this.multilabelStreamTemplate.setRelationName("SYN_Z" + this.labelCardinalityOption.getValue() + "L" + this.m_L + "X" + m_A + "S" + metaRandomSeedOption.getValue() + ": -C " + this.m_L);
this.multilabelStreamTemplate.setClassIndex(this.m_L);
return new MultilabelInstancesHeader(multilabelStreamTemplate, m_L);
Instances mi = new Instances(si, 0, 0);
mi.deleteAttributeAt(mi.numAttributes() - 1);
FastVector bfv = new FastVector();
bfv.addElement("0");
bfv.addElement("1");
for (int i = 0; i < this.m_L; i++) {
mi.insertAttributeAt(new Attribute("class" + i, bfv), i);
}

Range range = new Range(Integer.toString((numLabelsOption.getValue())));

this.multilabelStreamTemplate = mi;
this.multilabelStreamTemplate.setRelationName("SYN_Z" + this.labelCardinalityOption.getValue() + "L" + this.m_L + "X" + m_A + "S" + metaRandomSeedOption.getValue() + ": -C " + this.m_L);
this.multilabelStreamTemplate.setClassIndex(Integer.MAX_VALUE);
this.multilabelStreamTemplate.setRangeOutputIndices(range);

MultilabelInstancesHeader header = new MultilabelInstancesHeader(multilabelStreamTemplate, m_L);
header.setRangeOutputIndices(range);
return header;
}

/**
Expand Down Expand Up @@ -267,7 +276,7 @@ private double joint(int k, int y[]) {
private Instance generateMLInstance(HashSet<Integer> Y) {

// create a multi-label instance:
Instance x_ml = new SparseInstance(this.multilabelStreamTemplate.numAttributes());
Instance x_ml = new DenseInstance(this.multilabelStreamTemplate.numAttributes());
x_ml.setDataset(this.multilabelStreamTemplate);

// set classes
Expand Down Expand Up @@ -472,7 +481,7 @@ public int compare(HashSet Y1, HashSet Y2) {
}

// shuffle
Collections.shuffle(Arrays.asList(map_set));
Collections.shuffle(Arrays.asList(map_set), m_MetaRandom);

// return
return map_set;
Expand Down Expand Up @@ -545,7 +554,7 @@ private ArrayList<Integer> getShuffledListToLWithoutK(int L, int k) {
list.add(j);
}
}
Collections.shuffle(list);
Collections.shuffle(list, m_MetaRandom);
return list;
}

Expand Down

0 comments on commit eab15f7

Please sign in to comment.