h2load: Reservoir sampling

This commit is contained in:
Tatsuhiro Tsujikawa 2017-08-14 20:25:02 +09:00
parent 4c53da6961
commit 83039ae2d4
2 changed files with 29 additions and 49 deletions

View File

@ -151,33 +151,12 @@ std::mt19937 gen(rd());
} // namespace
namespace {
void sampling_init(Sampling &smp, size_t total, size_t max_samples) {
void sampling_init(Sampling &smp, size_t max_samples) {
smp.n = 0;
if (total <= max_samples) {
smp.interval = 0.;
smp.point = 0.;
return;
}
smp.interval = static_cast<double>(total) / max_samples;
std::uniform_real_distribution<> dis(0., smp.interval);
smp.point = dis(gen);
smp.max_samples = max_samples;
}
} // namespace
namespace {
bool sampling_should_pick(Sampling &smp) {
return smp.interval == 0. || smp.n == ceil(smp.point);
}
} // namespace
namespace {
void sampling_advance_point(Sampling &smp) { smp.point += smp.interval; }
} // namespace
namespace {
void writecb(struct ev_loop *loop, ev_io *w, int revents) {
auto client = static_cast<Client *>(w->data);
@ -361,10 +340,7 @@ Client::~Client() {
SSL_free(ssl);
}
if (sampling_should_pick(worker->client_smp)) {
sampling_advance_point(worker->client_smp);
worker->sample_client_stat(&cstat);
}
worker->sample_client_stat(&cstat);
++worker->client_smp.n;
}
@ -726,10 +702,7 @@ void Client::on_stream_close(int32_t stream_id, bool success, bool final) {
++worker->stats.req_failed;
}
if (sampling_should_pick(worker->request_times_smp)) {
sampling_advance_point(worker->request_times_smp);
worker->sample_req_stat(req_stat);
}
worker->sample_req_stat(req_stat);
// Count up in successful cases only
++worker->request_times_smp.n;
@ -1190,8 +1163,8 @@ Worker::Worker(uint32_t id, SSL_CTX *ssl_ctx, size_t req_todo, size_t nclients,
stats.req_stats.reserve(std::min(req_todo, max_samples));
stats.client_stats.reserve(std::min(nclients, max_samples));
sampling_init(request_times_smp, req_todo, max_samples);
sampling_init(client_smp, nclients, max_samples);
sampling_init(request_times_smp, max_samples);
sampling_init(client_smp, max_samples);
}
Worker::~Worker() {
@ -1224,14 +1197,28 @@ void Worker::run() {
ev_run(loop, 0);
}
namespace {
template <typename Stats, typename Stat>
void sample(Sampling &smp, Stats &stats, Stat *s) {
++smp.n;
if (stats.size() < smp.max_samples) {
stats.push_back(*s);
return;
}
auto d = std::uniform_int_distribution<>(0, smp.n - 1);
auto i = d(gen);
if (i < smp.max_samples) {
stats[i] = *s;
}
}
} // namespace
void Worker::sample_req_stat(RequestStat *req_stat) {
stats.req_stats.push_back(*req_stat);
assert(stats.req_stats.size() <= max_samples);
sample(request_times_smp, stats.req_stats, req_stat);
}
void Worker::sample_client_stat(ClientStat *cstat) {
stats.client_stats.push_back(*cstat);
assert(stats.client_stats.size() <= max_samples);
sample(client_smp, stats.client_stats, cstat);
}
void Worker::report_progress() {
@ -1313,14 +1300,10 @@ process_time_stats(const std::vector<std::unique_ptr<Worker>> &workers) {
size_t nclient_times = 0;
for (const auto &w : workers) {
nrequest_times += w->stats.req_stats.size();
if (w->request_times_smp.interval != 0.) {
request_times_sampling = true;
}
request_times_sampling = w->request_times_smp.n > w->stats.req_stats.size();
nclient_times += w->stats.client_stats.size();
if (w->client_smp.interval != 0.) {
client_times_sampling = true;
}
client_times_sampling = w->client_smp.n > w->stats.client_stats.size();
}
std::vector<double> request_times;

View File

@ -217,13 +217,10 @@ enum ClientState { CLIENT_IDLE, CLIENT_CONNECTED };
struct Client;
// We use systematic sampling method
// We use reservoir sampling method
struct Sampling {
// sampling interval
double interval;
// cumulative value of interval, and the next point is the integer
// rounded up from this value.
double point;
// maximum number of samples
size_t max_samples;
// number of samples seen, including discarded samples.
size_t n;
};