diff --git a/src/h2load.cc b/src/h2load.cc index 492d8d9f..8d72e3ba 100644 --- a/src/h2load.cc +++ b/src/h2load.cc @@ -151,33 +151,12 @@ std::mt19937 gen(rd()); } // namespace namespace { -void sampling_init(Sampling &smp, size_t total, size_t max_samples) { +void sampling_init(Sampling &smp, size_t max_samples) { smp.n = 0; - - if (total <= max_samples) { - smp.interval = 0.; - smp.point = 0.; - return; - } - - smp.interval = static_cast(total) / max_samples; - - std::uniform_real_distribution<> dis(0., smp.interval); - - smp.point = dis(gen); + smp.max_samples = max_samples; } } // namespace -namespace { -bool sampling_should_pick(Sampling &smp) { - return smp.interval == 0. || smp.n == ceil(smp.point); -} -} // namespace - -namespace { -void sampling_advance_point(Sampling &smp) { smp.point += smp.interval; } -} // namespace - namespace { void writecb(struct ev_loop *loop, ev_io *w, int revents) { auto client = static_cast(w->data); @@ -361,10 +340,7 @@ Client::~Client() { SSL_free(ssl); } - if (sampling_should_pick(worker->client_smp)) { - sampling_advance_point(worker->client_smp); - worker->sample_client_stat(&cstat); - } + worker->sample_client_stat(&cstat); ++worker->client_smp.n; } @@ -726,10 +702,7 @@ void Client::on_stream_close(int32_t stream_id, bool success, bool final) { ++worker->stats.req_failed; } - if (sampling_should_pick(worker->request_times_smp)) { - sampling_advance_point(worker->request_times_smp); - worker->sample_req_stat(req_stat); - } + worker->sample_req_stat(req_stat); // Count up in successful cases only ++worker->request_times_smp.n; @@ -1190,8 +1163,8 @@ Worker::Worker(uint32_t id, SSL_CTX *ssl_ctx, size_t req_todo, size_t nclients, stats.req_stats.reserve(std::min(req_todo, max_samples)); stats.client_stats.reserve(std::min(nclients, max_samples)); - sampling_init(request_times_smp, req_todo, max_samples); - sampling_init(client_smp, nclients, max_samples); + sampling_init(request_times_smp, max_samples); + sampling_init(client_smp, max_samples); } Worker::~Worker() { @@ -1224,14 +1197,28 @@ void Worker::run() { ev_run(loop, 0); } +namespace { +template +void sample(Sampling &smp, Stats &stats, Stat *s) { + ++smp.n; + if (stats.size() < smp.max_samples) { + stats.push_back(*s); + return; + } + auto d = std::uniform_int_distribution<>(0, smp.n - 1); + auto i = d(gen); + if (i < smp.max_samples) { + stats[i] = *s; + } +} +} // namespace + void Worker::sample_req_stat(RequestStat *req_stat) { - stats.req_stats.push_back(*req_stat); - assert(stats.req_stats.size() <= max_samples); + sample(request_times_smp, stats.req_stats, req_stat); } void Worker::sample_client_stat(ClientStat *cstat) { - stats.client_stats.push_back(*cstat); - assert(stats.client_stats.size() <= max_samples); + sample(client_smp, stats.client_stats, cstat); } void Worker::report_progress() { @@ -1313,14 +1300,10 @@ process_time_stats(const std::vector> &workers) { size_t nclient_times = 0; for (const auto &w : workers) { nrequest_times += w->stats.req_stats.size(); - if (w->request_times_smp.interval != 0.) { - request_times_sampling = true; - } + request_times_sampling = w->request_times_smp.n > w->stats.req_stats.size(); nclient_times += w->stats.client_stats.size(); - if (w->client_smp.interval != 0.) { - client_times_sampling = true; - } + client_times_sampling = w->client_smp.n > w->stats.client_stats.size(); } std::vector request_times; diff --git a/src/h2load.h b/src/h2load.h index db50472d..cd194bfd 100644 --- a/src/h2load.h +++ b/src/h2load.h @@ -217,13 +217,10 @@ enum ClientState { CLIENT_IDLE, CLIENT_CONNECTED }; struct Client; -// We use systematic sampling method +// We use reservoir sampling method struct Sampling { - // sampling interval - double interval; - // cumulative value of interval, and the next point is the integer - // rounded up from this value. - double point; + // maximum number of samples + size_t max_samples; // number of samples seen, including discarded samples. size_t n; };