diff --git a/tensorflow_serving/model_servers/main.cc b/tensorflow_serving/model_servers/main.cc index 32454b86b27..7e1261b8c61 100644 --- a/tensorflow_serving/model_servers/main.cc +++ b/tensorflow_serving/model_servers/main.cc @@ -250,6 +250,9 @@ int main(int argc, char** argv) { &options.num_request_iterations_for_warmup, "Number of times a request is iterated during warmup " "replay. This value is used only if > 0."), + tensorflow::Flag("num_warmup_threads", &options.num_warmup_threads, + "Number of threads for warmp up threads pool to use for model warmup." + "Default is 0, which means no thread pool is used."), tensorflow::Flag("version", &display_version, "Display version"), tensorflow::Flag( "monitoring_config_file", &options.monitoring_config_file, diff --git a/tensorflow_serving/model_servers/server.cc b/tensorflow_serving/model_servers/server.cc index 0655b3f33b5..9df9bc47d45 100644 --- a/tensorflow_serving/model_servers/server.cc +++ b/tensorflow_serving/model_servers/server.cc @@ -287,6 +287,9 @@ Status Server::BuildAndStart(const Options& server_options) { session_bundle_config.mutable_model_warmup_options() ->mutable_num_request_iterations() ->set_value(server_options.num_request_iterations_for_warmup); + session_bundle_config.mutable_model_warmup_options() + ->mutable_num_model_warmup_threads() + ->set_value(server_options.num_warmup_threads); } session_bundle_config.set_remove_unused_fields_from_bundle_metagraph( server_options.remove_unused_fields_from_bundle_metagraph); diff --git a/tensorflow_serving/model_servers/server.h b/tensorflow_serving/model_servers/server.h index ac7828fd4d8..9f035889d84 100644 --- a/tensorflow_serving/model_servers/server.h +++ b/tensorflow_serving/model_servers/server.h @@ -91,6 +91,7 @@ class Server { bool enable_model_warmup = true; // This value is used only if > 0. tensorflow::int32 num_request_iterations_for_warmup = 0; + tensorflow::int32 num_warmup_threads = 0; tensorflow::string monitoring_config_file; // Tensorflow session run options. bool enforce_session_run_timeout = true; diff --git a/tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc b/tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc index 4a889447dd3..720f2564633 100644 --- a/tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc +++ b/tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc @@ -99,12 +99,12 @@ absl::Status RunSavedModelWarmup( int num_model_warmup_threads = model_warmup_options.has_num_model_warmup_threads() - ? std::max(model_warmup_options.num_model_warmup_threads().value(), 1) - : 1; + ? std::max(model_warmup_options.num_model_warmup_threads().value(), 0) + : 0; std::unique_ptr tf_record_file_reader; absl::Status status; int num_warmup_records = 0; - if (num_model_warmup_threads <= 1) { + if (num_model_warmup_threads < 1) { tf_record_file_reader.reset( new tensorflow::io::SequentialRecordReader(tf_record_file.get())); tstring record; diff --git a/tensorflow_serving/servables/tensorflow/session_bundle_config.proto b/tensorflow_serving/servables/tensorflow/session_bundle_config.proto index f4e79c2630a..fe109ce6b53 100644 --- a/tensorflow_serving/servables/tensorflow/session_bundle_config.proto +++ b/tensorflow_serving/servables/tensorflow/session_bundle_config.proto @@ -10,7 +10,8 @@ import "tensorflow/core/protobuf/named_tensor.proto"; message ModelWarmupOptions { // Number of times a request is iterated during warmup replay. By default 1. google.protobuf.Int32Value num_request_iterations = 1; - // The number of threads to parallel execute warm up queries. By default 1. + // The number of threads to parallel execute warm up queries. By default 0. + // which means that no thread pool will be used. google.protobuf.Int32Value num_model_warmup_threads = 2; // Model name. string model_name = 3;