Skip to content

Commit e000a48

Browse files
samgoel01quic-suppugun
authored andcommitted
Add support for sharing an ORT session
For every instance in a model instance group a new ORT session is created. This code adds support to share a session per instance group. This support can be enabled by defining 'share_session' to true in triton model config "parameters". Example: parameters [ ..... { key: "share_session" value: {string_value: "true"} } ] This is a global parameter and cannot be defined per instance group. The user should determine if the parameter makes sense for their setup.
1 parent abc3ee7 commit e000a48

File tree

3 files changed

+137
-36
lines changed

3 files changed

+137
-36
lines changed

src/onnxruntime.cc

+116-36
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
#include <stdint.h>
28-
2928
#include <mutex>
3029
#include <vector>
3130

@@ -81,10 +80,10 @@ class ModelState : public BackendModel {
8180
// onnx file, return in 'session' and 'allocator' the ORT session
8281
// and allocator.
8382
TRITONSERVER_Error* LoadModel(
84-
const std::string& artifact_name,
83+
const std::string& artifact_name, const std::string& instance_name,
8584
const TRITONSERVER_InstanceGroupKind instance_group_kind,
8685
const int32_t instance_group_device_id, std::string* model_path,
87-
OrtSession** session, OrtAllocator** default_allocator,
86+
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
8887
cudaStream_t stream);
8988

9089
const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
@@ -101,6 +100,11 @@ class ModelState : public BackendModel {
101100
TRITONSERVER_Error* AutoCompleteIO(
102101
const char* key, const OnnxTensorInfoMap& io_infos);
103102

103+
TRITONSERVER_Error* GetSessionForGroup(
104+
const std::string& group_name, std::shared_ptr<OrtSession>& session);
105+
TRITONSERVER_Error* SetSessionForGroup(
106+
const std::string& group_name, const std::shared_ptr<OrtSession>& session);
107+
104108
// Session options used when creating a ORT session.
105109
std::unique_ptr<OrtSessionOptions, SessionOptionsDeleter> session_options_;
106110

@@ -110,6 +114,17 @@ class ModelState : public BackendModel {
110114
// is specified both in the output section and state section, it indicates
111115
// that the backend must return the output state to the client too.
112116
std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;
117+
118+
// Indicate if an onnxrt session should be shared or not. This is a model
119+
// global and applies to all instances. So, storing it in the model state
120+
bool share_session_;
121+
122+
// maintain a map of group id to onnx_rt session. This is only useful if
123+
// share_session is set to true in parameters. share_session is a global model
124+
// config and the user should be careful when setting this. There is no way to
125+
// set this per instance group.
126+
std::unordered_map<std::string, std::shared_ptr<OrtSession>>
127+
groupInstanceSessionMap_;
113128
};
114129

115130
TRITONSERVER_Error*
@@ -188,7 +203,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
188203
}
189204

190205
ModelState::ModelState(TRITONBACKEND_Model* triton_model)
191-
: BackendModel(triton_model)
206+
: BackendModel(triton_model), share_session_(false)
192207
{
193208
// Create session options that will be cloned and used for each
194209
// instance when creating that instance's session.
@@ -338,20 +353,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
338353
}
339354
}
340355
}
341-
342-
// FIXME. Is it possible to share a single OrtSession across
343-
// multiple instances? If so then should move loading and validation
344-
// of the session to here instead of creating a session for each
345-
// instance in ModelStateInstance::Create().
356+
357+
// This setting will apply across multiple instance groups.
358+
// If this value is set all instances within an instance group will share
359+
// the ort session
360+
{
361+
bool share_session;
362+
triton::common::TritonJson::Value params;
363+
if (ModelConfig().Find("parameters", &params)) {
364+
THROW_IF_BACKEND_MODEL_ERROR(TryParseModelStringParameter(
365+
params, "share_session", &share_session, false));
366+
}
367+
share_session_ = share_session;
368+
}
346369
}
347370

348371
TRITONSERVER_Error*
349372
ModelState::LoadModel(
350-
const std::string& artifact_name,
373+
const std::string& artifact_name, const std::string& instance_name,
351374
const TRITONSERVER_InstanceGroupKind instance_group_kind,
352375
const int32_t instance_group_device_id, std::string* model_path,
353-
OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream)
376+
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
377+
cudaStream_t stream)
354378
{
379+
// Get the group name for the instance
380+
std::string instance_group_name(GetInstanceGroupName(Name(), instance_name));
355381
// Find the ONNX file that describes the model itself. If the model
356382
// configuration doesn't have an explicit model file specified then
357383
// use the default name ("model.onnx").
@@ -363,6 +389,10 @@ ModelState::LoadModel(
363389
*model_path = JoinPath(
364390
{RepositoryPath(), std::to_string(Version()), cc_model_filename});
365391

392+
// get default cpu allocator
393+
RETURN_IF_ORT_ERROR(
394+
ort_api->GetAllocatorWithDefaultOptions(default_allocator));
395+
366396
// If the model path is a directory then the actual model is
367397
// <dir>/model.onnx.
368398
{
@@ -373,6 +403,20 @@ ModelState::LoadModel(
373403
}
374404
}
375405

406+
// Check is we are sharing the session. If so get the session pointer and
407+
// return
408+
if (share_session_) {
409+
if (GetSessionForGroup(instance_group_name, session) == nullptr) {
410+
LOG_MESSAGE(
411+
TRITONSERVER_LOG_INFO,
412+
(std::string("Reusing session for group: ") + instance_group_name)
413+
.c_str());
414+
// Return the session
415+
return nullptr;
416+
}
417+
// In case of error carry on with the code
418+
}
419+
376420
{
377421
bool exists;
378422
RETURN_IF_ERROR(FileExists(*model_path, &exists));
@@ -636,12 +680,22 @@ ModelState::LoadModel(
636680
glock.lock();
637681
}
638682

639-
RETURN_IF_ERROR(OnnxLoader::LoadSession(
640-
true /* is_path */, *model_path, soptions, session));
683+
{
684+
// This will be allocated by OnnxRT here but will be freed when the last
685+
// instance of shared_ptr is released
686+
OrtSession* session_ptr;
687+
RETURN_IF_ERROR(OnnxLoader::LoadSession(
688+
true /* is_path */, *model_path, soptions, &session_ptr));
641689

642-
// get default cpu allocator
643-
RETURN_IF_ORT_ERROR(
644-
ort_api->GetAllocatorWithDefaultOptions(default_allocator));
690+
session = std::shared_ptr<OrtSession>(session_ptr, SessionDeleter());
691+
692+
if (share_session_) {
693+
// The session was created fine this is not a critical error
694+
LOG_IF_ERROR(
695+
SetSessionForGroup(instance_group_name, session),
696+
"Failed to map ort session to the group for sharing");
697+
}
698+
}
645699

646700
return nullptr; // success
647701
}
@@ -685,7 +739,7 @@ ModelState::AutoCompleteConfig()
685739

686740
// Must cleanup 'session'. 'allocator' is default allocator which
687741
// is managed by ONNX Runtime so don't need to free/release
688-
std::unique_ptr<OrtSession, SessionDeleter> session;
742+
std::shared_ptr<OrtSession> session;
689743
OrtAllocator* default_allocator;
690744
std::string model_path;
691745
{
@@ -714,12 +768,9 @@ ModelState::AutoCompleteConfig()
714768
}
715769
}
716770
#endif // TRITON_ENABLE_GPU
717-
718-
OrtSession* sptr = nullptr;
719771
RETURN_IF_ERROR(LoadModel(
720-
artifact_name, kind, 0, &model_path, &sptr, &default_allocator,
721-
nullptr));
722-
session.reset(sptr);
772+
artifact_name, "", kind, 0, &model_path,
773+
session, &default_allocator, nullptr));
723774
}
724775
OnnxTensorInfoMap input_tensor_infos;
725776
RETURN_IF_ERROR(
@@ -881,6 +932,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
881932
return nullptr; // success
882933
}
883934

935+
TRITONSERVER_Error*
936+
ModelState::GetSessionForGroup(
937+
const std::string& group_name, std::shared_ptr<OrtSession>& session)
938+
{
939+
RETURN_ERROR_IF_TRUE(
940+
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
941+
std::string("Invalid group name"));
942+
{
943+
std::unordered_map<std::string, std::shared_ptr<OrtSession>>::iterator
944+
sessionEntry;
945+
sessionEntry = groupInstanceSessionMap_.find(group_name);
946+
RETURN_ERROR_IF_TRUE(
947+
(sessionEntry == groupInstanceSessionMap_.end()),
948+
TRITONSERVER_ERROR_NOT_FOUND, std::string("No such group"));
949+
950+
session = sessionEntry->second;
951+
}
952+
return nullptr;
953+
}
954+
955+
TRITONSERVER_Error*
956+
ModelState::SetSessionForGroup(
957+
const std::string& group_name, const std::shared_ptr<OrtSession>& session)
958+
{
959+
RETURN_ERROR_IF_TRUE(
960+
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
961+
std::string("Invalid group name"));
962+
963+
groupInstanceSessionMap_[group_name] = session;
964+
return nullptr;
965+
}
966+
884967
//
885968
// ModelInstanceState
886969
//
@@ -967,7 +1050,7 @@ class ModelInstanceState : public BackendModelInstance {
9671050

9681051
// Onnx Runtime variables that are used across runs on this
9691052
// instance.
970-
OrtSession* session_;
1053+
std::shared_ptr<OrtSession> session_;
9711054
OrtAllocator* default_allocator_;
9721055
OrtMemoryInfo* cuda_allocator_info_;
9731056
const OrtMemoryInfo* cpu_allocator_info_;
@@ -1013,7 +1096,7 @@ ModelInstanceState::ModelInstanceState(
10131096
io_binding_(nullptr), output_buffer_(nullptr)
10141097
{
10151098
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
1016-
ArtifactFilename(), Kind(), DeviceId(), &model_path_, &session_,
1099+
ArtifactFilename(), Name(), Kind(), DeviceId(), &model_path_, session_,
10171100
&default_allocator_, CudaStream()));
10181101

10191102
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
@@ -1026,7 +1109,7 @@ ModelInstanceState::ModelInstanceState(
10261109
ort_api->AllocatorGetInfo(default_allocator_, &cpu_allocator_info_));
10271110

10281111
THROW_IF_BACKEND_INSTANCE_ORT_ERROR(
1029-
ort_api->CreateIoBinding(session_, &io_binding_));
1112+
ort_api->CreateIoBinding(session_.get(), &io_binding_));
10301113

10311114
THROW_IF_BACKEND_INSTANCE_ORT_ERROR(ort_api->CreateRunOptions(&runOptions_));
10321115

@@ -1114,9 +1197,6 @@ ModelInstanceState::~ModelInstanceState()
11141197
ort_api->ReleaseRunOptions(runOptions_);
11151198
ort_api->ReleaseIoBinding(io_binding_);
11161199
ort_api->ReleaseMemoryInfo(cuda_allocator_info_);
1117-
if (session_ != nullptr) {
1118-
OnnxLoader::UnloadSession(session_);
1119-
}
11201200
// 'default_allocator_' is default allocator which is managed by ONNX
11211201
// Runtime
11221202
}
@@ -1176,7 +1256,7 @@ ModelInstanceState::ValidateBooleanSequenceControl(
11761256
if (*have_control) {
11771257
OnnxTensorInfoMap input_tensor_infos;
11781258
RETURN_IF_ERROR(
1179-
InputInfos(session_, default_allocator_, input_tensor_infos));
1259+
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
11801260
const auto& iit = input_tensor_infos.find(tensor_name);
11811261
if (iit == input_tensor_infos.end()) {
11821262
return TRITONSERVER_ErrorNew(
@@ -1233,7 +1313,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
12331313
if (*have_control) {
12341314
OnnxTensorInfoMap input_tensor_infos;
12351315
RETURN_IF_ERROR(
1236-
InputInfos(session_, default_allocator_, input_tensor_infos));
1316+
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
12371317
const auto& iit = input_tensor_infos.find(tensor_name);
12381318
if (iit == input_tensor_infos.end()) {
12391319
return TRITONSERVER_ErrorNew(
@@ -1280,10 +1360,11 @@ TRITONSERVER_Error*
12801360
ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
12811361
{
12821362
std::set<std::string> input_tensor_names;
1283-
RETURN_IF_ERROR(InputNames(session_, input_tensor_names));
1363+
RETURN_IF_ERROR(InputNames(session_.get(), input_tensor_names));
12841364

12851365
OnnxTensorInfoMap input_tensor_infos;
1286-
RETURN_IF_ERROR(InputInfos(session_, default_allocator_, input_tensor_infos));
1366+
RETURN_IF_ERROR(
1367+
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
12871368

12881369
if (input_tensor_infos.size() != expected_input_cnt) {
12891370
return TRITONSERVER_ErrorNew(
@@ -1368,10 +1449,10 @@ TRITONSERVER_Error*
13681449
ModelInstanceState::ValidateOutputs()
13691450
{
13701451
std::set<std::string> output_tensor_names;
1371-
RETURN_IF_ERROR(OutputNames(session_, output_tensor_names));
1452+
RETURN_IF_ERROR(OutputNames(session_.get(), output_tensor_names));
13721453

13731454
RETURN_IF_ERROR(
1374-
OutputInfos(session_, default_allocator_, output_tensor_infos_));
1455+
OutputInfos(session_.get(), default_allocator_, output_tensor_infos_));
13751456

13761457
triton::common::TritonJson::Value ios;
13771458
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios));
@@ -1765,7 +1846,7 @@ ModelInstanceState::OrtRun(
17651846
const uint32_t response_count)
17661847
{
17671848
RETURN_IF_ORT_ERROR(
1768-
ort_api->RunWithBinding(session_, runOptions_, io_binding_));
1849+
ort_api->RunWithBinding(session_.get(), runOptions_, io_binding_));
17691850
return nullptr;
17701851
}
17711852

@@ -2267,7 +2348,6 @@ ModelInstanceState::ReadOutputTensors(
22672348
}
22682349
}
22692350

2270-
22712351
} else {
22722352
char* output_buffer = nullptr;
22732353
RETURN_IF_ORT_ERROR(

src/onnxruntime_utils.cc

+17
Original file line numberDiff line numberDiff line change
@@ -493,5 +493,22 @@ CompareDimsSupported(
493493
return nullptr; // success
494494
}
495495

496+
std::string
497+
GetInstanceGroupName(
498+
const std::string& model_name, const std::string& instance_name)
499+
{
500+
std::regex groupNameRegex('(' + model_name + '_' + "[0-9]" + ')');
501+
std::smatch groupName;
502+
503+
if (model_name.empty() || instance_name.empty()) {
504+
return "";
505+
}
506+
507+
if (std::regex_search(instance_name, groupName, groupNameRegex)) {
508+
return groupName.str(1);
509+
}
510+
511+
return "";
512+
}
496513

497514
}}} // namespace triton::backend::onnxruntime

src/onnxruntime_utils.h

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#pragma once
2828

2929
#include <onnxruntime_c_api.h>
30+
#include <regex>
3031
#include <set>
3132
#include <string>
3233
#include <unordered_map>
@@ -149,4 +150,7 @@ TRITONSERVER_Error* CompareDimsSupported(
149150
const std::vector<int64_t>& model_shape, const std::vector<int64_t>& dims,
150151
const int max_batch_size, const bool compare_exact);
151152

153+
std::string GetInstanceGroupName(
154+
const std::string& model_name, const std::string& instance_name);
155+
152156
}}} // namespace triton::backend::onnxruntime

0 commit comments

Comments
 (0)