diff --git a/src/shrpx_downstream.cc b/src/shrpx_downstream.cc index 2f55e5d6..d9325699 100644 --- a/src/shrpx_downstream.cc +++ b/src/shrpx_downstream.cc @@ -127,7 +127,7 @@ Downstream::Downstream(Upstream *upstream, MemchunkPool *mcpool, request_http2_expect_body_(false), chunked_response_(false), response_connection_close_(false), response_header_key_prev_(false), response_trailer_key_prev_(false), expect_final_response_(false), - request_pending_(false), request_headers_dirty_(false) { + request_pending_(false) { ev_timer_init(&upstream_rtimer_, &upstream_rtimeoutcb, 0., get_config()->stream_read_timeout); @@ -1218,12 +1218,4 @@ bool Downstream::can_detach_downstream_connection() const { !response_connection_close_; } -void Downstream::set_request_headers_dirty(bool f) { - request_headers_dirty_ = f; -} - -bool Downstream::get_request_headers_dirty() const { - return request_headers_dirty_; -} - } // namespace shrpx diff --git a/src/shrpx_downstream.h b/src/shrpx_downstream.h index a12359fa..9139eb5d 100644 --- a/src/shrpx_downstream.h +++ b/src/shrpx_downstream.h @@ -127,8 +127,6 @@ public: void append_last_request_header_value(const char *data, size_t len); // Empties request headers. void clear_request_headers(); - void set_request_headers_dirty(bool f); - bool get_request_headers_dirty() const; size_t get_request_headers_sum() const; @@ -457,8 +455,6 @@ private: // has not been established or should be checked before use; // currently used only with HTTP/2 connection. bool request_pending_; - // true if we need to execute index_request_headers() - bool request_headers_dirty_; }; } // namespace shrpx diff --git a/src/shrpx_mruby.cc b/src/shrpx_mruby.cc index aa3135dd..8e32387d 100644 --- a/src/shrpx_mruby.cc +++ b/src/shrpx_mruby.cc @@ -43,42 +43,49 @@ MRubyContext::MRubyContext(mrb_state *mrb, RProc *on_request_proc, MRubyContext::~MRubyContext() { mrb_close(mrb_); } -int MRubyContext::run_on_request_proc(Downstream *downstream) { - if (!on_request_proc_) { +namespace { +int run_request_proc(mrb_state *mrb, Downstream *downstream, RProc *proc) { + if (!proc) { return 0; } - mrb_->ud = downstream; + MRubyAssocData data{downstream}; + + mrb->ud = &data; int rv = 0; - auto ai = mrb_gc_arena_save(mrb_); + auto ai = mrb_gc_arena_save(mrb); - auto res = mrb_run(mrb_, on_request_proc_, mrb_top_self(mrb_)); + auto res = mrb_run(mrb, proc, mrb_top_self(mrb)); (void)res; - if (mrb_->exc) { + if (mrb->exc) { rv = -1; auto error = - mrb_str_ptr(mrb_funcall(mrb_, mrb_obj_value(mrb_->exc), "inspect", 0)); + mrb_str_ptr(mrb_funcall(mrb, mrb_obj_value(mrb->exc), "inspect", 0)); LOG(ERROR) << "Exception caught while executing mruby code: " << error->as.heap.ptr; - mrb_->exc = 0; + mrb->exc = 0; } - mrb_->ud = nullptr; + mrb->ud = nullptr; - mrb_gc_arena_restore(mrb_, ai); + mrb_gc_arena_restore(mrb, ai); - if (downstream->get_request_headers_dirty()) { - downstream->set_request_headers_dirty(false); + if (data.request_headers_dirty) { downstream->index_request_headers(); } return rv; } +} // namespace -int run_on_response_proc(Downstream *downstream) { +int MRubyContext::run_on_request_proc(Downstream *downstream) { + return run_request_proc(mrb_, downstream, on_request_proc_); +} + +int MRubyContext::run_on_response_proc(Downstream *downstream) { // TODO not implemented yet return 0; } diff --git a/src/shrpx_mruby.h b/src/shrpx_mruby.h index 6e0bc57e..6494d8f0 100644 --- a/src/shrpx_mruby.h +++ b/src/shrpx_mruby.h @@ -52,6 +52,11 @@ private: RProc *on_response_proc_; }; +struct MRubyAssocData { + Downstream *downstream; + bool request_headers_dirty; +}; + RProc *compile(mrb_state *mrb, const char *filename); std::unique_ptr create_mruby_context(); diff --git a/src/shrpx_mruby_module.cc b/src/shrpx_mruby_module.cc index fabb1c5d..9e0be077 100644 --- a/src/shrpx_mruby_module.cc +++ b/src/shrpx_mruby_module.cc @@ -26,21 +26,46 @@ #include #include +#include +#include #include "shrpx_downstream.h" +#include "shrpx_mruby.h" #include "util.h" namespace shrpx { namespace mruby { +namespace { +mrb_value create_headers_hash(mrb_state *mrb, const Headers &headers) { + auto hash = mrb_hash_new(mrb); + + for (auto &hd : headers) { + if (hd.name.empty() || hd.name[0] == ':') { + continue; + } + auto key = mrb_str_new(mrb, hd.name.c_str(), hd.name.size()); + auto ary = mrb_hash_get(mrb, hash, key); + if (mrb_nil_p(ary)) { + ary = mrb_ary_new(mrb); + mrb_hash_set(mrb, hash, key, ary); + } + mrb_ary_push(mrb, ary, mrb_str_new(mrb, hd.value.c_str(), hd.value.size())); + } + + return hash; +} +} // namespace + namespace { mrb_value request_init(mrb_state *mrb, mrb_value self) { return self; } } // namespace namespace { mrb_value request_get_path(mrb_state *mrb, mrb_value self) { - auto downstream = static_cast(mrb->ud); + auto data = static_cast(mrb->ud); + auto downstream = data->downstream; auto &path = downstream->get_request_path(); return mrb_str_new(mrb, path.c_str(), path.size()); @@ -49,7 +74,8 @@ mrb_value request_get_path(mrb_state *mrb, mrb_value self) { namespace { mrb_value request_set_path(mrb_state *mrb, mrb_value self) { - auto downstream = static_cast(mrb->ud); + auto data = static_cast(mrb->ud); + auto downstream = data->downstream; const char *path; mrb_int pathlen; @@ -66,100 +92,51 @@ mrb_value request_set_path(mrb_state *mrb, mrb_value self) { namespace { mrb_value request_get_headers(mrb_state *mrb, mrb_value self) { - auto headers = mrb_iv_get(mrb, self, mrb_intern_lit(mrb, "RequestHeaders")); - if (mrb_nil_p(headers)) { - auto module = mrb_module_get(mrb, "Nghttpx"); - auto headers_class = mrb_class_get_under(mrb, module, "RequestHeaders"); - headers = mrb_obj_new(mrb, headers_class, 0, nullptr); - mrb_iv_set(mrb, self, mrb_intern_lit(mrb, "RequestHeaders"), headers); - } - return headers; + auto data = static_cast(mrb->ud); + auto downstream = data->downstream; + return create_headers_hash(mrb, downstream->get_request_headers()); } } // namespace namespace { -mrb_value headers_init(mrb_state *mrb, mrb_value self) { return self; } -} // namespace +mrb_value request_set_header(mrb_state *mrb, mrb_value self) { + auto data = static_cast(mrb->ud); + auto downstream = data->downstream; -namespace { -mrb_value request_headers_get(mrb_state *mrb, mrb_value self) { - auto downstream = static_cast(mrb->ud); - - mrb_value key; - mrb_get_args(mrb, "o", &key); - - key = mrb_funcall(mrb, key, "downcase", 0); - - if (RSTRING_LEN(key) == 0) { - return key; - } - - auto hd = downstream->get_request_header( - std::string(RSTRING_PTR(key), RSTRING_LEN(key))); - - if (hd == nullptr) { - return mrb_nil_value(); - } - - return mrb_str_new(mrb, hd->value.c_str(), hd->value.size()); -} -} // namespace - -namespace { -mrb_value request_headers_set(mrb_state *mrb, mrb_value self, bool repl) { - auto downstream = static_cast(mrb->ud); - - mrb_value key, value; - mrb_get_args(mrb, "oo", &key, &value); - - key = mrb_funcall(mrb, key, "downcase", 0); + mrb_value key, values; + mrb_get_args(mrb, "oo", &key, &values); if (RSTRING_LEN(key) == 0) { mrb_raise(mrb, E_RUNTIME_ERROR, "empty key is not allowed"); } - if (repl) { - for (auto &hd : downstream->get_request_headers()) { - if (util::streq(std::begin(hd.name), hd.name.size(), RSTRING_PTR(key), - RSTRING_LEN(key))) { - hd.name = ""; - } + key = mrb_funcall(mrb, key, "downcase", 0); + + // making name empty will effectively delete header fields + for (auto &hd : downstream->get_request_headers()) { + if (util::streq(std::begin(hd.name), hd.name.size(), RSTRING_PTR(key), + RSTRING_LEN(key))) { + hd.name = ""; } } - downstream->add_request_header( - std::string(RSTRING_PTR(key), RSTRING_LEN(key)), - std::string(RSTRING_PTR(value), RSTRING_LEN(value))); + if (mrb_obj_is_instance_of(mrb, values, mrb->array_class)) { + auto n = mrb_ary_len(mrb, values); + for (int i = 0; i < n; ++i) { + auto value = mrb_ary_entry(values, i); + downstream->add_request_header( + std::string(RSTRING_PTR(key), RSTRING_LEN(key)), + std::string(RSTRING_PTR(value), RSTRING_LEN(value))); + } + } else { + downstream->add_request_header( + std::string(RSTRING_PTR(key), RSTRING_LEN(key)), + std::string(RSTRING_PTR(values), RSTRING_LEN(values))); + } - downstream->set_request_headers_dirty(true); + data->request_headers_dirty = true; - return key; -} -} // namespace - -namespace { -mrb_value request_headers_set(mrb_state *mrb, mrb_value self) { - return request_headers_set(mrb, self, true); -} -} // namespace - -namespace { -mrb_value request_headers_add(mrb_state *mrb, mrb_value self) { - return request_headers_set(mrb, self, false); -} -} // namespace - -namespace { -void init_headers_class(mrb_state *mrb, RClass *module, const char *name, - mrb_func_t get, mrb_func_t set, mrb_func_t add) { - auto headers_class = - mrb_define_class_under(mrb, module, name, mrb->object_class); - - mrb_define_method(mrb, headers_class, "initialize", headers_init, - MRB_ARGS_NONE()); - mrb_define_method(mrb, headers_class, "get", get, MRB_ARGS_REQ(1)); - mrb_define_method(mrb, headers_class, "set", set, MRB_ARGS_REQ(2)); - mrb_define_method(mrb, headers_class, "add", add, MRB_ARGS_REQ(2)); + return mrb_nil_value(); } } // namespace @@ -176,15 +153,30 @@ void init_request_class(mrb_state *mrb, RClass *module) { MRB_ARGS_REQ(1)); mrb_define_method(mrb, request_class, "headers", request_get_headers, MRB_ARGS_NONE()); + mrb_define_method(mrb, request_class, "set_header", request_set_header, + MRB_ARGS_REQ(2)); +} +} // namespace - init_headers_class(mrb, module, "RequestHeaders", request_headers_get, - request_headers_set, request_headers_add); +namespace { +mrb_value run(mrb_state *mrb, mrb_value self) { + mrb_value b; + mrb_get_args(mrb, "&", &b); + + auto module = mrb_module_get(mrb, "Nghttpx"); + auto request_class = mrb_class_get_under(mrb, module, "Request"); + auto request = mrb_obj_new(mrb, request_class, 0, nullptr); + + std::array args{{request}}; + return mrb_yield_argv(mrb, b, args.size(), args.data()); } } // namespace void init_module(mrb_state *mrb) { auto module = mrb_define_module(mrb, "Nghttpx"); + mrb_define_class_method(mrb, module, "run", run, MRB_ARGS_BLOCK()); + init_request_class(mrb, module); }