Skip to content

Commit

Permalink
-Fixed an issue where the number of trees in the forest could be set …
Browse files Browse the repository at this point in the history
…larger than the dataset size

-Tree depth is now an optional parameter. If it's <= 0 it will use the paper's calculation method
-The calculation for tree depth is now displayed
-The average sensitivity across all leaves of the forest is now displayed
-Documentation has been updated to reflect these changes
  • Loading branch information
michael committed Aug 15, 2020
1 parent e099300 commit 8333afb
Showing 1 changed file with 131 additions and 13 deletions.
144 changes: 131 additions & 13 deletions src/weka/classifiers/trees/SmoothPrivateForest.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@
* </pre>
*
* <pre>
* -T &lt;tree depth&gt;
* Tree depth option. If &lt;= 0 will be equal to the tree depth
* calculation from the paper
* (default -1)
* </pre>
*
* <pre>
* -E &lt;epsilon&gt;
* The privacy budget (epsilon) for the exponential mechanism.
* (default 1.0)
Expand Down Expand Up @@ -122,7 +129,7 @@ public class SmoothPrivateForest extends AbstractClassifier
* Privacy budget for differential privacy (episilon in paper)
*/
protected float m_privacyBudget = 1.0f;

/**
* Whether or not to display flipped majorities, sensitivity information and
* true distributions in leaves. Set to False for outputting and sharing
Expand Down Expand Up @@ -156,7 +163,13 @@ public class SmoothPrivateForest extends AbstractClassifier
private Instances m_ds;

/**
* Stores tree depth calculated as per original paper
* Tree depth. If <= 0, will use the paper's balls-in-bins probablity
* process to automatically select tree depth.
*/
protected int m_treeDepthOption = -1;

/**
* Calculated or option-specified tree depth.
*/
private int m_treeDepth = -1;

Expand Down Expand Up @@ -186,6 +199,10 @@ public class SmoothPrivateForest extends AbstractClassifier
private long m_SPFTime;

private boolean m_untrained = true;

private String errorMessage = "";

private int m_numberOfNumericalAttributes = -1;

/**
* Default classifier to setup SysFor.
Expand Down Expand Up @@ -317,7 +334,15 @@ public void buildClassifier(Instances ds) throws Exception {
ds.deleteWithMissingClass();

//if this is a dataset with only the class attribute
if (ds.numAttributes() == 1) {
if (ds.numAttributes() == 1 || this.m_forestSize > ds.numInstances()) {

if(ds.numAttributes() == 1) {
errorMessage += "This is a dataset with only the class value and cannot be used.\n";
}
if(this.m_forestSize > ds.numInstances()) {
errorMessage += "The forest size has been set higher than the number of instances in the dataset so the forest cannot be build.\n";
}

ZeroR zr = new ZeroR();
zr.buildClassifier(ds);
m_ensemble = new Classifier[1];
Expand Down Expand Up @@ -348,7 +373,13 @@ public void buildClassifier(Instances ds) throws Exception {
//start timing the SPF and start building it.
long startTimeSPF = System.currentTimeMillis();
m_ensemble = new Classifier[m_forestSize];
m_treeDepth = calculateTreeDepth(m_ds);

if(m_treeDepthOption <= 0) {
m_treeDepth = calculateTreeDepth(m_ds);
} //otherwise we will use the user selected value
else {
m_treeDepth = m_treeDepthOption;
}

//generate ur-domains
m_urDomains = new HashMap<>();
Expand Down Expand Up @@ -433,6 +464,7 @@ public int calculateTreeDepth(Instances ds) {
m++;
}
}
m_numberOfNumericalAttributes = m;

if (m == 0) {
retValue = (ds.numAttributes() - 1) / 2; //half the number of categorical atts
Expand Down Expand Up @@ -554,6 +586,10 @@ public String toString() {
if (m_ensemble == null) {
return "Forest not built!";
}

if(!"".equals(errorMessage)) {
return errorMessage;
}

StringBuilder sb = new StringBuilder();

Expand All @@ -564,6 +600,13 @@ public String toString() {
.append("\n\n");

}

//Calculate the average sensitivity across all the leaves of all the trees
double averageSensitivityAcrossAllTrees = 0;
for (int t = 0; t < m_forestSize; t++) {
averageSensitivityAcrossAllTrees += ((SmoothPrivateTree)m_ensemble[t]).averageSensitivity;
}
averageSensitivityAcrossAllTrees /= m_forestSize;

if (m_numDisplayTrees < m_forestSize) {
int tmp = m_forestSize - m_numDisplayTrees;
Expand All @@ -575,7 +618,26 @@ public String toString() {
.append(m_SPFTime).append(" ms.")
.append("\n").append("With privacy budget: ")
.append(m_privacyBudget)
.append("\n")
.append("And an average sensitivity across all leaves of: ")
.append(averageSensitivityAcrossAllTrees)
.append("\n");

if(m_treeDepthOption <= 0) {
//TODO
sb.append("Paper calculated tree depth via ")
.append("d = argmin E[X|d] (d:X<")
.append(m_numberOfNumericalAttributes)
.append("/2), E[X|d] = ")
.append(m_numberOfNumericalAttributes)
.append("* (")
.append(m_numberOfNumericalAttributes)
.append(" - 1 /").append(m_numberOfNumericalAttributes)
.append(" )^d")
.append(" gave the result ")
.append(m_treeDepth)
.append("\n");
}

sb.append("Accuracy information will appear below. For comparison, the selected ")
.append(m_comparisonClassifier.getClass().getName())
Expand All @@ -597,6 +659,11 @@ public String toString() {
*/
protected class SmoothPrivateTree extends AbstractClassifier implements Serializable {

/**
* The average sensitivity in the leaves.
*/
protected double averageSensitivity = 0;

/**
* The root node of the tree
*/
Expand Down Expand Up @@ -765,6 +832,7 @@ private void filterTrainingDataAndCount(Instances data, double epsilon) throws E

//set all noisy majorities
numFlippedMajorities = this.setAllNoisyMajorities(epsilon, root);
averageSensitivity /= numLeaves;

}

Expand Down Expand Up @@ -881,7 +949,9 @@ private int setAllNoisyMajorities(double epsilon, Node node) {

//do the noisy majority calculation if its a leaf, otherwise filter down to the leaves
if (node.children == null || node.splittingAttribute == -1 || node.children.isEmpty()) {
return node.setNoisyMajority(epsilon);
double[] returnedArray = node.setNoisyMajority(epsilon);
averageSensitivity += returnedArray[1];
return (int)returnedArray[0];
}

int sum = 0;
Expand Down Expand Up @@ -1039,10 +1109,12 @@ public void incrementClassCount(int classValue) {
/**
* Use exponential mechanism to estimate class
* @param epsilon - privacy budget
* @return estimated class
* @return flipped majorities and sensitivity
*/
public int setNoisyMajority(double epsilon) {
public double[] setNoisyMajority(double epsilon) {

double[] returnArray = new double[2];

//only run this once per leaf
if (this.noisyMajority == Integer.MIN_VALUE && this.children == null) {

Expand All @@ -1059,7 +1131,6 @@ public int setNoisyMajority(double epsilon) {
if (classCounts[i] > secondCC && classCounts[i] < maxCC) {
secondCC = classCounts[i];
}
//TODO Check this
}

//assign noisy majority
Expand All @@ -1069,27 +1140,30 @@ public int setNoisyMajority(double epsilon) {

this.empty = true;
this.noisyMajority = m_random.nextInt(classCounts.length);
return 0; //we don't want to count purely random flips
return returnArray;

} else {

int countDifference = maxCC - secondCC; //j in paper
this.sensitivity = Math.exp(-1.0 * countDifference * epsilon);
returnArray[1] = this.sensitivity;
this.sensOfSens = 1.0;
this.noisySensitivity = 1.0;

this.noisyMajority = this.expoMech(epsilon, this.sensitivity, this.classCounts);
if (this.noisyMajority != Utils.maxIndex(classCounts)) {
return 1; //we're summing flipped majorities
returnArray[0] = 1;
return returnArray;
} else {
return 0; //this means the exponential mechanism got it right
returnArray[0] = 0; //this means the exponential mechanism got it right
return returnArray;
}
}

}
} //end the "run once" if statement

return 0;
return returnArray;

}

Expand Down Expand Up @@ -1193,7 +1267,8 @@ public String toString(int indent, DecimalFormat df) {
}

sb.append("} Sensitivity: ")
.append(df.format(this.sensitivity));
//.append(df.format(this.sensitivity));
.append(String.format("%."+getNumDecimalPlaces()+"e", this.sensitivity));
}

}
Expand Down Expand Up @@ -1285,6 +1360,30 @@ public void setNumDisplayTrees(int m_numDisplayTrees) {
public String numDisplayTreesTipText() {
return "Amount of decision trees to display in the output.";
}

/**
* Get number of trees to display
* @return number of trees to display
*/
public int getTreeDepthOption() {
return m_treeDepthOption;
}

/**
* Set the tree depth (<= 0 uses paper's calculation)
* @param m_treeDepth - the tree depth (<= 0 uses paper's calculation)
*/
public void setTreeDepthOption(int m_treeDepthOption) {
this.m_treeDepthOption = m_treeDepthOption;
}

/**
* Weka tooltip
* @return Weka tooltip
*/
public String treeDepthOptionTipText() {
return "The tree depth (<= 0 uses paper's calculation).";
}

/**
* Get the privacy budget (epsilon) for differential privacy
Expand Down Expand Up @@ -1397,6 +1496,13 @@ protected String defaultComparisonClassifierString() {
* Number of trees to display in the output.
* (default 3)
* </pre>
*
* <pre>
* -T &lt;tree depth&gt;
* Tree depth option. If &lt;= 0 will be equal to the tree depth
* calculation from the paper
* (default -1)
* </pre>
*
* <pre>
* -E &lt;epsilon&gt;
Expand Down Expand Up @@ -1436,6 +1542,13 @@ public void setOptions(String[] options) throws Exception {
} else {
setForestSize(10);
}

String sTreeDepthOption = Utils.getOption('T', options);
if (sNumberTrees.length() != 0) {
setTreeDepthOption(Integer.parseInt(sTreeDepthOption));
} else {
setTreeDepthOption(-1);
}

String sPrivacyBudget = Utils.getOption('E', options);
if (sPrivacyBudget.length() != 0) {
Expand Down Expand Up @@ -1500,6 +1613,11 @@ public String[] getOptions() {

result.add("-D");
result.add("" + getNumDisplayTrees());

if(m_treeDepthOption <= 0) {
result.add("-T");
result.add("" + getTreeDepthOption());
}

result.add("-E");
result.add("" + getPrivacyBudget());
Expand Down

0 comments on commit 8333afb

Please sign in to comment.