Skip to content

Commit cafe208

Browse files
committed
fix trt
1 parent 2565a6a commit cafe208

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

paddlex/inference/components/paddle_predictor/predictor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _create(self):
125125
max_batch_size=self.option.batch_size,
126126
min_subgraph_size=self.option.min_subgraph_size,
127127
precision_mode=precision_map[self.option.run_mode],
128-
trt_use_static=self.option.trt_use_static,
128+
use_static=self.option.trt_use_static,
129129
use_calib_mode=self.option.trt_calib_mode,
130130
)
131131

paddlex/inference/models/base/base_predictor.py

-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from abc import abstractmethod
1919

2020
from ...components.base import BaseComponent
21-
from ...utils.pp_option import PaddlePredictorOption
2221
from ...utils.process_hook import generatorable_method
2322

2423

paddlex/inference/models/base/basic_predictor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ def _add_component(self, cmps):
7272
def set_predictor(self, batch_size=None, device=None, pp_option=None):
7373
if batch_size:
7474
self.components["ReadCmp"].batch_size = batch_size
75+
76+
self.pp_option.batch_size = batch_size
7577
if device and device != self.pp_option.device:
7678
self.pp_option.device = device
77-
self.components["PPEngineCmp"].reset()
7879
if pp_option and pp_option != self.pp_option:
7980
self.pp_option = pp_option
80-
self.components["PPEngineCmp"].reset()
8181

8282
def _has_setter(self, attr):
8383
prop = getattr(self.__class__, attr, None)

paddlex/inference/utils/pp_option.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def __init__(self, model_name=None, **kwargs):
3434
super().__init__()
3535
self.model_name = model_name
3636
self._cfg = {}
37-
self._init_option(**kwargs)
3837
self._observers = []
38+
self._init_option(**kwargs)
3939

4040
def _init_option(self, **kwargs):
4141
for k, v in kwargs.items():
@@ -62,6 +62,7 @@ def _get_default_config(self):
6262
"trt_use_static": False,
6363
"delete_pass": [],
6464
"enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
65+
"batch_size": 1, # only for trt
6566
}
6667

6768
def _update(self, k, v):
@@ -179,6 +180,14 @@ def enable_new_ir(self, enable_new_ir: bool):
179180
"""set run mode"""
180181
self._update("enable_new_ir", enable_new_ir)
181182

183+
@property
184+
def batch_size(self):
185+
return self._cfg["batch_size"]
186+
187+
@batch_size.setter
188+
def batch_size(self, batch_size):
189+
self._update("batch_size", batch_size)
190+
182191
def get_support_run_mode(self):
183192
"""get supported run mode"""
184193
return self.SUPPORT_RUN_MODE

0 commit comments

Comments
 (0)