@@ -42,20 +42,24 @@ def __init__(self, engine_path: str, image_input_name: str|None=None):
4242
4343 self .context = self .engine .create_execution_context ()
4444
45- self .input_names = [self .engine .get_binding_name (i ) for i in range (self .engine .num_bindings ) if self .engine .binding_is_input (i )]
46- self .output_names = [self .engine .get_binding_name (i ) for i in range (self .engine .num_bindings ) if not self .engine .binding_is_input (i )]
47- self .output_shapes = [list (self .engine .get_binding_shape (name )) for name in self .output_names ]
45+ names = [self .engine .get_tensor_name (i ) for i in range (self .engine .num_io_tensors )]
46+
47+ self .input_names = [name for name in names if self .engine .get_tensor_mode (name ) == trt .TensorIOMode .INPUT ]
48+ self .output_names = [name for name in names if self .engine .get_tensor_mode (name ) == trt .TensorIOMode .OUTPUT ]
49+ self .output_shapes = [list (self .engine .get_tensor_shape (name )) for name in self .output_names ]
4850
4951 self .initialize_bindings ()
5052
5153 # Cache binding indices
5254 self .input_binding_idxs = {
53- name : self .engine .get_binding_index (name )
54- for name in self .input_names
55+ name : i
56+ for i , name in enumerate (names )
57+ if self .engine .get_tensor_mode (name ) == trt .TensorIOMode .INPUT
5558 }
5659 self .output_binding_idxs = {
57- name : self .engine .get_binding_index (name )
58- for name in self .output_names
60+ name : i
61+ for i , name in enumerate (names )
62+ if self .engine .get_tensor_mode (name ) == trt .TensorIOMode .OUTPUT
5963 }
6064
6165 if len (self .input_names ) != 1 and image_input_name is None :
@@ -64,7 +68,7 @@ def __init__(self, engine_path: str, image_input_name: str|None=None):
6468 assert image_input_name in self .input_names , f"Image input name { image_input_name } not found in model inputs"
6569
6670 self .image_input_name = image_input_name if image_input_name is not None else self .input_names [0 ]
67- self .image_input_shape = self .engine .get_binding_shape (self .image_input_name )
71+ self .image_input_shape = self .engine .get_tensor_shape (self .image_input_name )
6872
6973 self .profiler = CUDAProfiler ()
7074
@@ -78,7 +82,7 @@ def construct_bindings(self, input_image: torch.Tensor) -> list[int]:
7882
7983 input_image = input_image .contiguous ()
8084
81- bindings = [None ] * self .engine .num_bindings
85+ bindings = [None ] * self .engine .num_io_tensors
8286
8387 for name , binding_idx in self .output_binding_idxs .items ():
8488 bindings [binding_idx ] = self .binding_ptrs [name ]
@@ -90,11 +94,11 @@ def construct_bindings(self, input_image: torch.Tensor) -> list[int]:
9094 def initialize_bindings (self ):
9195 self .bindings = {}
9296 self .binding_ptrs = {}
93- for i in range (self .engine .num_bindings ):
94- name = self .engine .get_binding_name (i )
95- if not self .engine .binding_is_input ( i ) :
96- shape = self .engine .get_binding_shape (name )
97- dtype = trt .nptype (self .engine .get_binding_dtype (name ))
97+ for i in range (self .engine .num_io_tensors ):
98+ name = self .engine .get_tensor_name (i )
99+ if not self .engine .get_tensor_mode ( name ) == trt . TensorIOMode . INPUT :
100+ shape = self .engine .get_tensor_shape (name )
101+ dtype = trt .nptype (self .engine .get_tensor_dtype (name ))
98102 self .bindings [name ] = torch .from_numpy (np .empty (shape , dtype = dtype )).cuda ()
99103 self .binding_ptrs [name ] = self .bindings [name ].data_ptr ()
100104
0 commit comments