25
25
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
#include < stdint.h>
28
-
29
28
#include < mutex>
30
29
#include < vector>
31
30
@@ -81,10 +80,10 @@ class ModelState : public BackendModel {
81
80
// onnx file, return in 'session' and 'allocator' the ORT session
82
81
// and allocator.
83
82
TRITONSERVER_Error* LoadModel (
84
- const std::string& artifact_name,
83
+ const std::string& artifact_name, const std::string& instance_name,
85
84
const TRITONSERVER_InstanceGroupKind instance_group_kind,
86
85
const int32_t instance_group_device_id, std::string* model_path,
87
- OrtSession** session, OrtAllocator** default_allocator,
86
+ std::shared_ptr< OrtSession>& session, OrtAllocator** default_allocator,
88
87
cudaStream_t stream);
89
88
90
89
const std::map<std::string, std::pair<int64_t , int64_t >>& ModelOutputs ()
@@ -101,6 +100,11 @@ class ModelState : public BackendModel {
101
100
TRITONSERVER_Error* AutoCompleteIO (
102
101
const char * key, const OnnxTensorInfoMap& io_infos);
103
102
103
+ TRITONSERVER_Error* GetSessionForGroup (
104
+ const std::string& group_name, std::shared_ptr<OrtSession>& session);
105
+ TRITONSERVER_Error* SetSessionForGroup (
106
+ const std::string& group_name, const std::shared_ptr<OrtSession>& session);
107
+
104
108
// Session options used when creating a ORT session.
105
109
std::unique_ptr<OrtSessionOptions, SessionOptionsDeleter> session_options_;
106
110
@@ -110,6 +114,17 @@ class ModelState : public BackendModel {
110
114
// is specified both in the output section and state section, it indicates
111
115
// that the backend must return the output state to the client too.
112
116
std::map<std::string, std::pair<int64_t , int64_t >> model_outputs_;
117
+
118
+ // Indicate if an onnxrt session should be shared or not. This is a model
119
+ // global and applies to all instances. So, storing it in the model state
120
+ bool share_session_;
121
+
122
+ // maintain a map of group id to onnx_rt session. This is only useful if
123
+ // share_session is set to true in parameters. share_session is a global model
124
+ // config and the user should be careful when setting this. There is no way to
125
+ // set this per instance group.
126
+ std::unordered_map<std::string, std::shared_ptr<OrtSession>>
127
+ groupInstanceSessionMap_;
113
128
};
114
129
115
130
TRITONSERVER_Error*
@@ -188,7 +203,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
188
203
}
189
204
190
205
ModelState::ModelState (TRITONBACKEND_Model* triton_model)
191
- : BackendModel(triton_model)
206
+ : BackendModel(triton_model), share_session_( false )
192
207
{
193
208
// Create session options that will be cloned and used for each
194
209
// instance when creating that instance's session.
@@ -338,20 +353,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
338
353
}
339
354
}
340
355
}
341
-
342
- // FIXME. Is it possible to share a single OrtSession across
343
- // multiple instances? If so then should move loading and validation
344
- // of the session to here instead of creating a session for each
345
- // instance in ModelStateInstance::Create().
356
+
357
+ // This setting will apply across multiple instance groups.
358
+ // If this value is set all instances within an instance group will share
359
+ // the ort session
360
+ {
361
+ bool share_session;
362
+ triton::common::TritonJson::Value params;
363
+ if (ModelConfig ().Find (" parameters" , ¶ms)) {
364
+ THROW_IF_BACKEND_MODEL_ERROR (TryParseModelStringParameter (
365
+ params, " share_session" , &share_session, false ));
366
+ }
367
+ share_session_ = share_session;
368
+ }
346
369
}
347
370
348
371
TRITONSERVER_Error*
349
372
ModelState::LoadModel (
350
- const std::string& artifact_name,
373
+ const std::string& artifact_name, const std::string& instance_name,
351
374
const TRITONSERVER_InstanceGroupKind instance_group_kind,
352
375
const int32_t instance_group_device_id, std::string* model_path,
353
- OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream)
376
+ std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
377
+ cudaStream_t stream)
354
378
{
379
+ // Get the group name for the instance
380
+ std::string instance_group_name (GetInstanceGroupName (Name (), instance_name));
355
381
// Find the ONNX file that describes the model itself. If the model
356
382
// configuration doesn't have an explicit model file specified then
357
383
// use the default name ("model.onnx").
@@ -363,6 +389,10 @@ ModelState::LoadModel(
363
389
*model_path = JoinPath (
364
390
{RepositoryPath (), std::to_string (Version ()), cc_model_filename});
365
391
392
+ // get default cpu allocator
393
+ RETURN_IF_ORT_ERROR (
394
+ ort_api->GetAllocatorWithDefaultOptions (default_allocator));
395
+
366
396
// If the model path is a directory then the actual model is
367
397
// <dir>/model.onnx.
368
398
{
@@ -373,6 +403,20 @@ ModelState::LoadModel(
373
403
}
374
404
}
375
405
406
+ // Check is we are sharing the session. If so get the session pointer and
407
+ // return
408
+ if (share_session_) {
409
+ if (GetSessionForGroup (instance_group_name, session) == nullptr ) {
410
+ LOG_MESSAGE (
411
+ TRITONSERVER_LOG_INFO,
412
+ (std::string (" Reusing session for group: " ) + instance_group_name)
413
+ .c_str ());
414
+ // Return the session
415
+ return nullptr ;
416
+ }
417
+ // In case of error carry on with the code
418
+ }
419
+
376
420
{
377
421
bool exists;
378
422
RETURN_IF_ERROR (FileExists (*model_path, &exists));
@@ -636,12 +680,22 @@ ModelState::LoadModel(
636
680
glock.lock ();
637
681
}
638
682
639
- RETURN_IF_ERROR (OnnxLoader::LoadSession (
640
- true /* is_path */ , *model_path, soptions, session));
683
+ {
684
+ // This will be allocated by OnnxRT here but will be freed when the last
685
+ // instance of shared_ptr is released
686
+ OrtSession* session_ptr;
687
+ RETURN_IF_ERROR (OnnxLoader::LoadSession (
688
+ true /* is_path */ , *model_path, soptions, &session_ptr));
641
689
642
- // get default cpu allocator
643
- RETURN_IF_ORT_ERROR (
644
- ort_api->GetAllocatorWithDefaultOptions (default_allocator));
690
+ session = std::shared_ptr<OrtSession>(session_ptr, SessionDeleter ());
691
+
692
+ if (share_session_) {
693
+ // The session was created fine this is not a critical error
694
+ LOG_IF_ERROR (
695
+ SetSessionForGroup (instance_group_name, session),
696
+ " Failed to map ort session to the group for sharing" );
697
+ }
698
+ }
645
699
646
700
return nullptr ; // success
647
701
}
@@ -685,7 +739,7 @@ ModelState::AutoCompleteConfig()
685
739
686
740
// Must cleanup 'session'. 'allocator' is default allocator which
687
741
// is managed by ONNX Runtime so don't need to free/release
688
- std::unique_ptr <OrtSession, SessionDeleter > session;
742
+ std::shared_ptr <OrtSession> session;
689
743
OrtAllocator* default_allocator;
690
744
std::string model_path;
691
745
{
@@ -714,12 +768,9 @@ ModelState::AutoCompleteConfig()
714
768
}
715
769
}
716
770
#endif // TRITON_ENABLE_GPU
717
-
718
- OrtSession* sptr = nullptr ;
719
771
RETURN_IF_ERROR (LoadModel (
720
- artifact_name, kind, 0 , &model_path, &sptr, &default_allocator,
721
- nullptr ));
722
- session.reset (sptr);
772
+ artifact_name, " " , kind, 0 , &model_path,
773
+ session, &default_allocator, nullptr ));
723
774
}
724
775
OnnxTensorInfoMap input_tensor_infos;
725
776
RETURN_IF_ERROR (
@@ -881,6 +932,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
881
932
return nullptr ; // success
882
933
}
883
934
935
+ TRITONSERVER_Error*
936
+ ModelState::GetSessionForGroup (
937
+ const std::string& group_name, std::shared_ptr<OrtSession>& session)
938
+ {
939
+ RETURN_ERROR_IF_TRUE (
940
+ group_name.empty (), TRITONSERVER_ERROR_INVALID_ARG,
941
+ std::string (" Invalid group name" ));
942
+ {
943
+ std::unordered_map<std::string, std::shared_ptr<OrtSession>>::iterator
944
+ sessionEntry;
945
+ sessionEntry = groupInstanceSessionMap_.find (group_name);
946
+ RETURN_ERROR_IF_TRUE (
947
+ (sessionEntry == groupInstanceSessionMap_.end ()),
948
+ TRITONSERVER_ERROR_NOT_FOUND, std::string (" No such group" ));
949
+
950
+ session = sessionEntry->second ;
951
+ }
952
+ return nullptr ;
953
+ }
954
+
955
+ TRITONSERVER_Error*
956
+ ModelState::SetSessionForGroup (
957
+ const std::string& group_name, const std::shared_ptr<OrtSession>& session)
958
+ {
959
+ RETURN_ERROR_IF_TRUE (
960
+ group_name.empty (), TRITONSERVER_ERROR_INVALID_ARG,
961
+ std::string (" Invalid group name" ));
962
+
963
+ groupInstanceSessionMap_[group_name] = session;
964
+ return nullptr ;
965
+ }
966
+
884
967
//
885
968
// ModelInstanceState
886
969
//
@@ -967,7 +1050,7 @@ class ModelInstanceState : public BackendModelInstance {
967
1050
968
1051
// Onnx Runtime variables that are used across runs on this
969
1052
// instance.
970
- OrtSession* session_;
1053
+ std::shared_ptr< OrtSession> session_;
971
1054
OrtAllocator* default_allocator_;
972
1055
OrtMemoryInfo* cuda_allocator_info_;
973
1056
const OrtMemoryInfo* cpu_allocator_info_;
@@ -1013,7 +1096,7 @@ ModelInstanceState::ModelInstanceState(
1013
1096
io_binding_(nullptr ), output_buffer_(nullptr )
1014
1097
{
1015
1098
THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
1016
- ArtifactFilename (), Kind (), DeviceId (), &model_path_, & session_,
1099
+ ArtifactFilename (), Name (), Kind (), DeviceId (), &model_path_, session_,
1017
1100
&default_allocator_, CudaStream ()));
1018
1101
1019
1102
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
@@ -1026,7 +1109,7 @@ ModelInstanceState::ModelInstanceState(
1026
1109
ort_api->AllocatorGetInfo (default_allocator_, &cpu_allocator_info_));
1027
1110
1028
1111
THROW_IF_BACKEND_INSTANCE_ORT_ERROR (
1029
- ort_api->CreateIoBinding (session_, &io_binding_));
1112
+ ort_api->CreateIoBinding (session_. get () , &io_binding_));
1030
1113
1031
1114
THROW_IF_BACKEND_INSTANCE_ORT_ERROR (ort_api->CreateRunOptions (&runOptions_));
1032
1115
@@ -1114,9 +1197,6 @@ ModelInstanceState::~ModelInstanceState()
1114
1197
ort_api->ReleaseRunOptions (runOptions_);
1115
1198
ort_api->ReleaseIoBinding (io_binding_);
1116
1199
ort_api->ReleaseMemoryInfo (cuda_allocator_info_);
1117
- if (session_ != nullptr ) {
1118
- OnnxLoader::UnloadSession (session_);
1119
- }
1120
1200
// 'default_allocator_' is default allocator which is managed by ONNX
1121
1201
// Runtime
1122
1202
}
@@ -1176,7 +1256,7 @@ ModelInstanceState::ValidateBooleanSequenceControl(
1176
1256
if (*have_control) {
1177
1257
OnnxTensorInfoMap input_tensor_infos;
1178
1258
RETURN_IF_ERROR (
1179
- InputInfos (session_, default_allocator_, input_tensor_infos));
1259
+ InputInfos (session_. get () , default_allocator_, input_tensor_infos));
1180
1260
const auto & iit = input_tensor_infos.find (tensor_name);
1181
1261
if (iit == input_tensor_infos.end ()) {
1182
1262
return TRITONSERVER_ErrorNew (
@@ -1233,7 +1313,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
1233
1313
if (*have_control) {
1234
1314
OnnxTensorInfoMap input_tensor_infos;
1235
1315
RETURN_IF_ERROR (
1236
- InputInfos (session_, default_allocator_, input_tensor_infos));
1316
+ InputInfos (session_. get () , default_allocator_, input_tensor_infos));
1237
1317
const auto & iit = input_tensor_infos.find (tensor_name);
1238
1318
if (iit == input_tensor_infos.end ()) {
1239
1319
return TRITONSERVER_ErrorNew (
@@ -1280,10 +1360,11 @@ TRITONSERVER_Error*
1280
1360
ModelInstanceState::ValidateInputs (const size_t expected_input_cnt)
1281
1361
{
1282
1362
std::set<std::string> input_tensor_names;
1283
- RETURN_IF_ERROR (InputNames (session_, input_tensor_names));
1363
+ RETURN_IF_ERROR (InputNames (session_. get () , input_tensor_names));
1284
1364
1285
1365
OnnxTensorInfoMap input_tensor_infos;
1286
- RETURN_IF_ERROR (InputInfos (session_, default_allocator_, input_tensor_infos));
1366
+ RETURN_IF_ERROR (
1367
+ InputInfos (session_.get (), default_allocator_, input_tensor_infos));
1287
1368
1288
1369
if (input_tensor_infos.size () != expected_input_cnt) {
1289
1370
return TRITONSERVER_ErrorNew (
@@ -1368,10 +1449,10 @@ TRITONSERVER_Error*
1368
1449
ModelInstanceState::ValidateOutputs ()
1369
1450
{
1370
1451
std::set<std::string> output_tensor_names;
1371
- RETURN_IF_ERROR (OutputNames (session_, output_tensor_names));
1452
+ RETURN_IF_ERROR (OutputNames (session_. get () , output_tensor_names));
1372
1453
1373
1454
RETURN_IF_ERROR (
1374
- OutputInfos (session_, default_allocator_, output_tensor_infos_));
1455
+ OutputInfos (session_. get () , default_allocator_, output_tensor_infos_));
1375
1456
1376
1457
triton::common::TritonJson::Value ios;
1377
1458
RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" output" , &ios));
@@ -1765,7 +1846,7 @@ ModelInstanceState::OrtRun(
1765
1846
const uint32_t response_count)
1766
1847
{
1767
1848
RETURN_IF_ORT_ERROR (
1768
- ort_api->RunWithBinding (session_, runOptions_, io_binding_));
1849
+ ort_api->RunWithBinding (session_. get () , runOptions_, io_binding_));
1769
1850
return nullptr ;
1770
1851
}
1771
1852
@@ -2267,7 +2348,6 @@ ModelInstanceState::ReadOutputTensors(
2267
2348
}
2268
2349
}
2269
2350
2270
-
2271
2351
} else {
2272
2352
char * output_buffer = nullptr ;
2273
2353
RETURN_IF_ORT_ERROR (
0 commit comments