Skip to content

Commit ca9d659

Browse files
committed
add other devices supported model list
1 parent cafe208 commit ca9d659

File tree

8 files changed

+293
-23
lines changed

8 files changed

+293
-23
lines changed

paddlex/inference/utils/new_ir_blacklist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
NEWIR_BLOCKLIST = [
15+
NEWIR_BLACKLIST = [
1616
"FasterRCNN-ResNet34-FPN",
1717
"FasterRCNN-ResNet50",
1818
"FasterRCNN-ResNet50-FPN",

paddlex/inference/utils/pp_option.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from ...utils.device import parse_device, set_env_for_device, get_default_device
15+
from ...utils.device import (
16+
parse_device,
17+
set_env_for_device,
18+
get_default_device,
19+
check_device,
20+
)
1621
from ...utils import logging
17-
from .new_ir_blacklist import NEWIR_BLOCKLIST
22+
from .new_ir_blacklist import NEWIR_BLACKLIST
1823

1924

2025
class PaddlePredictorOption(object):
@@ -28,7 +33,6 @@ class PaddlePredictorOption(object):
2833
"mkldnn",
2934
"mkldnn_bf16",
3035
)
31-
SUPPORT_DEVICE = ("gpu", "cpu", "npu", "xpu", "mlu", "dcu")
3236

3337
def __init__(self, model_name=None, **kwargs):
3438
super().__init__()
@@ -61,7 +65,7 @@ def _get_default_config(self):
6165
"cpu_threads": 1,
6266
"trt_use_static": False,
6367
"delete_pass": [],
64-
"enable_new_ir": True if self.model_name not in NEWIR_BLOCKLIST else False,
68+
"enable_new_ir": True if self.model_name not in NEWIR_BLACKLIST else False,
6569
"batch_size": 1, # only for trt
6670
}
6771

@@ -101,11 +105,7 @@ def device(self, device: str):
101105
if not device:
102106
return
103107
device_type, device_ids = parse_device(device)
104-
if device_type not in self.SUPPORT_DEVICE:
105-
support_run_mode_str = ", ".join(self.SUPPORT_DEVICE)
106-
raise ValueError(
107-
f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
108-
)
108+
check_device(device_type)
109109
self._update("device", device_type)
110110
device_id = device_ids[0] if device_ids is not None else 0
111111
self._update("device_id", device_id)

paddlex/modules/base/evaluator.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from abc import ABC, abstractmethod
1818

1919
from .build_model import build_model
20-
from ...utils.device import update_device_num, set_env_for_device
20+
from ...utils.device import (
21+
update_device_num,
22+
set_env_for_device,
23+
parse_device,
24+
check_device,
25+
)
2126
from ...utils.misc import AutoRegisterABCMetaClass
2227
from ...utils.config import AttrDict
2328
from ...utils.logging import *
@@ -138,8 +143,10 @@ def get_device(self, using_device_number: int = None) -> str:
138143
Returns:
139144
str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
140145
"""
146+
device_type, device_ids = parse_device(self.global_config.device)
147+
check_device(self.global_config.model, device_type)
141148
if using_device_number:
142-
return update_device_num(self.global_config.device, using_device_number)
149+
return update_device_num(device_type, device_ids, using_device_number)
143150
set_env_for_device(self.global_config.device)
144151
return self.global_config.device
145152

paddlex/modules/base/exportor.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from abc import ABC, abstractmethod
1818

1919
from .build_model import build_model
20-
from ...utils.device import update_device_num, set_env_for_device
20+
from ...utils.device import (
21+
update_device_num,
22+
set_env_for_device,
23+
parse_device,
24+
check_device,
25+
)
2126
from ...utils.misc import AutoRegisterABCMetaClass
2227
from ...utils.config import AttrDict
2328
from ...utils import logging
@@ -103,8 +108,10 @@ def get_device(self, using_device_number: int = None) -> str:
103108
Returns:
104109
str: device setting, such as: `gpu:0,1`, `npu:0,1`, `cpu`.
105110
"""
111+
device_type, device_ids = parse_device(self.global_config.device)
112+
check_device(self.global_config.model, device_type)
106113
if using_device_number:
107-
return update_device_num(self.global_config.device, using_device_number)
114+
return update_device_num(device_type, device_ids, using_device_number)
108115
set_env_for_device(self.global_config.device)
109116
return self.global_config.device
110117

paddlex/modules/base/trainer.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
from abc import ABC, abstractmethod
1717
from pathlib import Path
1818
from .build_model import build_model
19-
from ...utils.device import update_device_num, set_env_for_device
19+
from ...utils.device import (
20+
update_device_num,
21+
set_env_for_device,
22+
parse_device,
23+
check_device,
24+
)
2025
from ...utils.misc import AutoRegisterABCMetaClass
2126
from ...utils.config import AttrDict
2227

@@ -95,8 +100,10 @@ def get_device(self, using_device_number: int = None) -> str:
95100
Returns:
96101
str: device setting, such as: `gpu:0,1`, `npu:0,1` `cpu`.
97102
"""
103+
device_type, device_ids = parse_device(self.global_config.device)
104+
check_device(self.global_config.model, device_type)
98105
if using_device_number:
99-
return update_device_num(self.global_config.device, using_device_number)
106+
return update_device_num(device_type, device_ids, using_device_number)
100107
set_env_for_device(self.global_config.device)
101108
return self.global_config.device
102109

paddlex/paddlex_cli.py

+15
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
import argparse
1717
import subprocess
1818
import sys
19+
import shutil
1920
import tempfile
21+
from pathlib import Path
2022

2123
from . import create_pipeline
2224
from .inference.pipelines import create_pipeline_from_config, load_pipeline_config
2325
from .repo_manager import setup, get_all_supported_repo_names
26+
from .utils.cache import CACHE_DIR
2427
from .utils import logging
2528
from .utils.interactive_get_pipeline import interactive_get_pipeline
2629

@@ -65,6 +68,7 @@ def parse_str(s):
6568

6669
################# install pdx #################
6770
parser.add_argument("--install", action="store_true", default=False, help="")
71+
parser.add_argument("--clear_cache", action="store_true", default=False, help="")
6872
parser.add_argument("plugins", nargs="*", default=[])
6973
parser.add_argument("--no_deps", action="store_true")
7074
parser.add_argument("--platform", type=str, default="github.com")
@@ -159,6 +163,15 @@ def serve(pipeline, *, device, use_hpip, serial_number, update_license, host, po
159163
run_server(app, host=host, port=port, debug=False)
160164

161165

166+
def clear_cache():
167+
cache_dir = Path(CACHE_DIR) / "official_models"
168+
if cache_dir.exists() and cache_dir.is_dir():
169+
shutil.rmtree(cache_dir)
170+
logging.info(f"Successfully cleared the cache models at {cache_dir}")
171+
else:
172+
logging.info(f"No cache models found at {cache_dir}")
173+
174+
162175
# for CLI
163176
def main():
164177
"""API for commad line"""
@@ -180,6 +193,8 @@ def main():
180193
host=args.host,
181194
port=args.port,
182195
)
196+
elif args.clear_cache:
197+
clear_cache()
183198
else:
184199
if args.get_pipeline_config is not None:
185200
interactive_get_pipeline(args.get_pipeline_config, args.save_path)

paddlex/utils/device.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818

1919
from . import logging
2020
from .errors import raise_unsupported_device_error
21-
22-
SUPPORTED_DEVICE_TYPE = ["cpu", "gpu", "xpu", "npu", "mlu"]
21+
from .other_devices_model_list import OTHER_DEVICES_MODEL_LIST
2322

2423

2524
def _constr_device(device_type, device_ids):
@@ -38,6 +37,21 @@ def get_default_device():
3837
return _constr_device("gpu", [avail_gpus[0]])
3938

4039

40+
def check_device(model_name, device_type):
41+
supported_device_type = ["cpu", "gpu", "xpu", "npu", "mlu", "dcu"]
42+
device_type = device_type.lower()
43+
if device_type not in supported_device_type:
44+
support_run_mode_str = ", ".join(supported_device_type)
45+
raise ValueError(
46+
f"The device type must be one of {support_run_mode_str}, but received {repr(device_type)}."
47+
)
48+
if device_type in OTHER_DEVICES_MODEL_LIST:
49+
if model_name not in OTHER_DEVICES_MODEL_LIST[device_type]:
50+
raise ValueError(
51+
f"The model '{model_name}' is not supported on {device_type}."
52+
)
53+
54+
4155
def parse_device(device):
4256
"""parse_device"""
4357
# According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
@@ -55,14 +69,10 @@ def parse_device(device):
5569
f"Device ID must be an integer. Invalid device ID: {device_id}"
5670
)
5771
device_ids = list(map(int, device_ids))
58-
device_type = device_type.lower()
59-
# raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
60-
assert device_type.lower() in SUPPORTED_DEVICE_TYPE
6172
return device_type, device_ids
6273

6374

64-
def update_device_num(device, num):
65-
device_type, device_ids = parse_device(device)
75+
def update_device_num(device_type, device_ids, num):
6676
if device_ids:
6777
assert len(device_ids) >= num
6878
return _constr_device(device_type, device_ids[:num])

0 commit comments

Comments
 (0)