Skip to content

Commit

Permalink
Merge pull request #420 from jmalkin/ebpps_speedup
Browse files Browse the repository at this point in the history
Ebpps speedup
  • Loading branch information
jmalkin authored Jan 30, 2024
2 parents 349b6e7 + 37ff643 commit f905bdd
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 35 deletions.
8 changes: 4 additions & 4 deletions sampling/include/ebpps_sample.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ class ebpps_sample {
public:
explicit ebpps_sample(uint32_t k, const A& allocator = A());

// constructor used to create a sample to merge one item
template<typename TT>
ebpps_sample(TT&& item, double theta, const A& allocator = A());

// for deserialization
class items_deleter;
ebpps_sample(std::vector<T, A>&& data, optional<T>&& partial_item, double c, const A& allocator = A());

// used instead of having a single-item constructor for update/merge calls
template<typename TT>
void replace_content(TT&& item, double theta);

void reset();
void downsample(double theta);

Expand Down
29 changes: 13 additions & 16 deletions sampling/include/ebpps_sample_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,6 @@ ebpps_sample<T,A>::ebpps_sample(uint32_t reserved_size, const A& allocator) :
data_.reserve(reserved_size);
}

template<typename T, typename A>
template<typename TT>
ebpps_sample<T,A>::ebpps_sample(TT&& item, double theta, const A& allocator) :
allocator_(allocator),
c_(theta),
partial_item_(),
data_(allocator)
{
if (theta == 1.0) {
data_.reserve(1);
data_.emplace_back(std::forward<TT>(item));
} else {
partial_item_.emplace(std::forward<TT>(item));
}
}

template<typename T, typename A>
ebpps_sample<T,A>::ebpps_sample(std::vector<T, A>&& data, optional<T>&& partial_item, double c, const A& allocator) :
allocator_(allocator),
Expand All @@ -65,6 +49,19 @@ ebpps_sample<T,A>::ebpps_sample(std::vector<T, A>&& data, optional<T>&& partial_
data_(data, allocator)
{}

template<typename T, typename A>
template<typename TT>
void ebpps_sample<T,A>::replace_content(TT&& item, double theta) {
c_ = theta;
data_.clear();
partial_item_.reset();
if (theta == 1.0) {
data_.emplace_back(std::forward<TT>(item));
} else {
partial_item_.emplace(std::forward<TT>(item));
}
}

template<typename T, typename A>
auto ebpps_sample<T,A>::get_sample() const -> result_type {
double unused;
Expand Down
2 changes: 2 additions & 0 deletions sampling/include/ebpps_sketch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ class ebpps_sketch {

ebpps_sample<T,A> sample_; // Object holding the current state of the sample

ebpps_sample<T,A> tmp_; // Temporary sample of size 1 used in updates

// handles merge after ensuring other.cumulative_wt_ <= this->cumulative_wt_
// so we can send items in individually
template<typename O>
Expand Down
21 changes: 10 additions & 11 deletions sampling/include/ebpps_sketch_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ ebpps_sketch<T, A>::ebpps_sketch(uint32_t k, const A& allocator) :
cumulative_wt_(0.0),
wt_max_(0.0),
rho_(1.0),
sample_(check_k(k), allocator)
sample_(check_k(k), allocator),
tmp_(1, allocator)
{}

template<typename T, typename A>
Expand All @@ -53,7 +54,8 @@ ebpps_sketch<T,A>::ebpps_sketch(uint32_t k, uint64_t n, double cumulative_wt,
cumulative_wt_(cumulative_wt),
wt_max_(wt_max),
rho_(rho),
sample_(sample)
sample_(sample),
tmp_(1, allocator)
{}

template<typename T, typename A>
Expand Down Expand Up @@ -148,9 +150,8 @@ void ebpps_sketch<T, A>::internal_update(FwdItem&& item, double weight) {
if (cumulative_wt_ > 0.0)
sample_.downsample(new_rho / rho_);

ebpps_sample<T,A> tmp(conditional_forward<FwdItem>(item), new_rho * weight, allocator_);

sample_.merge(tmp);
tmp_.replace_content(conditional_forward<FwdItem>(item), new_rho * weight);
sample_.merge(tmp_);

cumulative_wt_ = new_cum_wt;
wt_max_ = new_wt_max;
Expand Down Expand Up @@ -240,9 +241,8 @@ void ebpps_sketch<T, A>::internal_merge(O&& sk) {
if (cumulative_wt_ > 0.0)
sample_.downsample(new_rho / rho_);

ebpps_sample<T,A> tmp(conditional_forward<O>(items[i]), new_rho * avg_wt, allocator_);

sample_.merge(tmp);
tmp_.replace_content(conditional_forward<O>(items[i]), new_rho * avg_wt);
sample_.merge(tmp_);

cumulative_wt_ = new_cum_wt;
rho_ = new_rho;
Expand All @@ -259,9 +259,8 @@ void ebpps_sketch<T, A>::internal_merge(O&& sk) {
if (cumulative_wt_ > 0.0)
sample_.downsample(new_rho / rho_);

ebpps_sample<T,A> tmp(conditional_forward<O>(other_sample.get_partial_item()), new_rho * other_c_frac * avg_wt, allocator_);

sample_.merge(tmp);
tmp_.replace_content(conditional_forward<O>(other_sample.get_partial_item()), new_rho * other_c_frac * avg_wt);
sample_.merge(tmp_);

cumulative_wt_ = new_cum_wt;
rho_ = new_rho;
Expand Down
11 changes: 7 additions & 4 deletions sampling/test/ebpps_sample_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,24 @@ TEST_CASE("ebpps sample: basic initialization", "[ebpps_sketch]") {

TEST_CASE("ebpps sample: pre-initialized", "[ebpps_sketch]") {
double theta = 1.0;
ebpps_sample<int> sample = ebpps_sample<int>(-1, theta);
ebpps_sample<int> sample(1);
sample.replace_content(-1, theta);
REQUIRE(sample.get_c() == theta);
REQUIRE(sample.get_num_retained_items() == 1);
REQUIRE(sample.get_sample().size() == 1);
REQUIRE(sample.has_partial_item() == false);

theta = 1e-300;
sample = ebpps_sample<int>(-1, theta);
sample.replace_content(-1, theta);
REQUIRE(sample.get_c() == theta);
REQUIRE(sample.get_num_retained_items() == 1);
REQUIRE(sample.get_sample().size() == 0); // assuming the random number is > 1e-300
REQUIRE(sample.has_partial_item());
}

TEST_CASE("ebpps sample: downsampling", "[ebpps_sketch]") {
ebpps_sample<char> sample = ebpps_sample<char>('a', 1.0);
ebpps_sample<char> sample(1);
sample.replace_content('a', 1.0);

sample.downsample(2.0); // no-op
REQUIRE(sample.get_c() == 1.0);
Expand Down Expand Up @@ -121,8 +123,9 @@ TEST_CASE("ebpps sample: merge unit samples", "[ebpps_sketch]") {
uint32_t k = 8;
ebpps_sample<int> sample = ebpps_sample<int>(k);

ebpps_sample<int> s(1);
for (uint32_t i = 1; i <= k; ++i) {
ebpps_sample<int> s = ebpps_sample<int>(i, 1.0);
s.replace_content(i, 1.0);
sample.merge(s);
REQUIRE(sample.get_c() == static_cast<double>(i));
REQUIRE(sample.get_num_retained_items() == i);
Expand Down

0 comments on commit f905bdd

Please sign in to comment.