@@ -14,53 +14,33 @@ See the License for the specific language governing permissions and
14
14
limitations under the License.
15
15
*/
16
16
17
+ use std:: collections:: HashMap ;
17
18
use std:: io:: { IsTerminal , Write } ;
18
19
19
20
use hyperlight_common:: flatbuffer_wrappers:: function_types:: { ParameterValue , ReturnValue } ;
20
21
use termcolor:: { Color , ColorChoice , ColorSpec , StandardStream , WriteColor } ;
21
22
use tracing:: { instrument, Span } ;
22
23
23
- use super :: { ExtraAllowedSyscall , FunctionsMap } ;
24
- use crate :: func:: host_functions:: HostFunctionDefinition ;
24
+ use super :: ExtraAllowedSyscall ;
25
25
use crate :: func:: HyperlightFunction ;
26
26
use crate :: HyperlightError :: HostFunctionNotFound ;
27
27
use crate :: { new_error, Result } ;
28
28
29
- type HostFunctionDetails = Option < Vec < HostFunctionDefinition > > ;
30
-
31
29
#[ derive( Default , Clone ) ]
32
30
/// A Wrapper around details of functions exposed by the Host
33
31
pub struct HostFuncsWrapper {
34
- functions_map : FunctionsMap ,
35
- function_details : HostFunctionDetails ,
32
+ functions_map : HashMap < String , ( HyperlightFunction , Option < Vec < ExtraAllowedSyscall > > ) > ,
36
33
}
37
34
38
35
impl HostFuncsWrapper {
39
- #[ instrument( skip_all, parent = Span :: current( ) , level = "Trace" ) ]
40
- fn get_host_funcs ( & self ) -> & FunctionsMap {
41
- & self . functions_map
42
- }
43
- #[ instrument( skip_all, parent = Span :: current( ) , level = "Trace" ) ]
44
- fn get_host_funcs_mut ( & mut self ) -> & mut FunctionsMap {
45
- & mut self . functions_map
46
- }
47
- #[ instrument( skip_all, parent = Span :: current( ) , level = "Trace" ) ]
48
- fn get_host_func_details ( & self ) -> & HostFunctionDetails {
49
- & self . function_details
50
- }
51
- #[ instrument( skip_all, parent = Span :: current( ) , level = "Trace" ) ]
52
- fn get_host_func_details_mut ( & mut self ) -> & mut HostFunctionDetails {
53
- & mut self . function_details
54
- }
55
-
56
36
/// Register a host function with the sandbox.
57
37
#[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
58
38
pub ( crate ) fn register_host_function (
59
39
& mut self ,
60
- hfd : & HostFunctionDefinition ,
40
+ name : String ,
61
41
func : HyperlightFunction ,
62
42
) -> Result < ( ) > {
63
- register_host_function_helper ( self , hfd , func, None )
43
+ register_host_function_helper ( self , name , func, None )
64
44
}
65
45
66
46
/// Register a host function with the sandbox, with a list of extra syscalls
@@ -69,11 +49,11 @@ impl HostFuncsWrapper {
69
49
#[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ]
70
50
pub ( crate ) fn register_host_function_with_syscalls (
71
51
& mut self ,
72
- hfd : & HostFunctionDefinition ,
52
+ name : String ,
73
53
func : HyperlightFunction ,
74
54
extra_allowed_syscalls : Vec < ExtraAllowedSyscall > ,
75
55
) -> Result < ( ) > {
76
- register_host_function_helper ( self , hfd , func, Some ( extra_allowed_syscalls) )
56
+ register_host_function_helper ( self , name , func, Some ( extra_allowed_syscalls) )
77
57
}
78
58
79
59
/// Assuming a host function called `"HostPrint"` exists, and takes a
@@ -84,7 +64,7 @@ impl HostFuncsWrapper {
84
64
#[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
85
65
pub ( super ) fn host_print ( & mut self , msg : String ) -> Result < i32 > {
86
66
let res = call_host_func_impl (
87
- self . get_host_funcs ( ) ,
67
+ & self . functions_map ,
88
68
"HostPrint" ,
89
69
vec ! [ ParameterValue :: String ( msg) ] ,
90
70
) ?;
@@ -104,56 +84,40 @@ impl HostFuncsWrapper {
104
84
name : & str ,
105
85
args : Vec < ParameterValue > ,
106
86
) -> Result < ReturnValue > {
107
- call_host_func_impl ( self . get_host_funcs ( ) , name, args)
108
- }
109
-
110
- /// Insert a host function into the list of registered host functions.
111
- pub ( super ) fn insert_host_function ( & mut self , host_function : HostFunctionDefinition ) {
112
- match & mut self . function_details {
113
- Some ( host_functions) => host_functions. push ( host_function) ,
114
- None => {
115
- let host_functions = Vec :: from ( & [ host_function] ) ;
116
- self . function_details = Some ( host_functions) ;
117
- }
118
- }
87
+ call_host_func_impl ( & self . functions_map , name, args)
119
88
}
120
89
}
121
90
122
91
fn register_host_function_helper (
123
92
self_ : & mut HostFuncsWrapper ,
124
- hfd : & HostFunctionDefinition ,
93
+ name : String ,
125
94
func : HyperlightFunction ,
126
95
extra_allowed_syscalls : Option < Vec < ExtraAllowedSyscall > > ,
127
96
) -> Result < ( ) > {
128
97
if let Some ( _syscalls) = extra_allowed_syscalls {
129
98
#[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ]
130
- self_
131
- . get_host_funcs_mut ( )
132
- . insert ( hfd. function_name . to_string ( ) , func, Some ( _syscalls) ) ;
99
+ self_. functions_map . insert ( name, ( func, Some ( _syscalls) ) ) ;
133
100
134
101
#[ cfg( not( all( feature = "seccomp" , target_os = "linux" ) ) ) ]
135
102
return Err ( new_error ! (
136
103
"Extra syscalls are only supported on Linux with seccomp"
137
104
) ) ;
138
105
} else {
139
- self_
140
- . get_host_funcs_mut ( )
141
- . insert ( hfd. function_name . to_string ( ) , func, None ) ;
106
+ self_. functions_map . insert ( name, ( func, None ) ) ;
142
107
}
143
- self_. insert_host_function ( hfd. clone ( ) ) ;
144
108
145
109
Ok ( ( ) )
146
110
}
147
111
148
112
#[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
149
113
fn call_host_func_impl (
150
- host_funcs : & FunctionsMap ,
114
+ host_funcs : & HashMap < String , ( HyperlightFunction , Option < Vec < ExtraAllowedSyscall > > ) > ,
151
115
name : & str ,
152
116
args : Vec < ParameterValue > ,
153
117
) -> Result < ReturnValue > {
154
118
// Inner function containing the common logic
155
119
fn call_func (
156
- host_funcs : & FunctionsMap ,
120
+ host_funcs : & HashMap < String , ( HyperlightFunction , Option < Vec < ExtraAllowedSyscall > > ) > ,
157
121
name : & str ,
158
122
args : Vec < ParameterValue > ,
159
123
) -> Result < ReturnValue > {
0 commit comments