-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pr-curve support #8
base: master
Are you sure you want to change the base?
Conversation
Just noticed, histogram support broke because of change in default_gernerate, I will try to fix it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really nice! I don't have enough time to look into the details currently, and only commented on some coding styles. Most remaining format related issues can be fixed by clang-format
. I'll come back to this as soon as possible (maybe this weekend).
@@ -140,9 +142,24 @@ class TensorBoardLogger { | |||
const std::vector<std::string> &metadata = std::vector<std::string>(), | |||
const std::string &metadata_filename = "", | |||
int step = 1 /* no effect */); | |||
|
|||
int prcurve(const std::string tag, | |||
const std::vector<double>labels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use reference
@@ -243,6 +248,121 @@ int TensorBoardLogger::add_embedding( | |||
tensor_shape, step); | |||
} | |||
|
|||
std::vector<std::vector<double>> TensorBoardLogger::compute_curve( | |||
const std::vector<double>labels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reference
{ | ||
weights.push_back(1.0); | ||
} | ||
generate_default_buckets({0, (double)num_thresholds - 1}, num_thresholds, true, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this num_thresholds
be fixed to 127, so buckets only need to be generated once? I don't see the necessity to change it. (I will look deeper later)
return data; | ||
} | ||
int TensorBoardLogger::prcurve( | ||
const std::string tag, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
references
auto *tensor = new TensorProto(); | ||
tensor->set_dtype(tensorflow::DataType::DT_DOUBLE); | ||
tensor->set_allocated_tensor_shape(tensorshape); | ||
for(int i=0;i<data.size();i++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Range-based for loop is better suited here.
private: | ||
int generate_default_buckets(); | ||
std::vector<std::vector<double>> compute_curve( | ||
const std::vector<double>labels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
references
I have no idea, but I think maybe it's not necessary, since reasonable PR curve won't be such a flat line. |
I think bucket generation still have some issues, histogram looks kind of strange now. I also tested pr curve using data generated by pr curve demo code in TensorBoard repo, and the result looks different: vs. I'm not sure whether I used it correctly, I'll try to find out. |
Ok, I will take a look again |
Hi can you quickly re run the script or test |
Sorry for the delay. The pr curve looks correct now, however the histogram still has some issue. I will try to find a fix. |
@RustingSword There you go
Not sure about the math, it probably looks correct but have a look at the math again :)
Changed
generate_bucket()
a little bit to resemble ->
np.histogram()
Note : https://github.com/reminisce/tensorboard-mxnet-logger/blob/ad7d6522e4010deaa3a2a95e93629d8b01751078/tensorboardX/summary.py#L219