18
18
19
19
from . import logging
20
20
from .errors import raise_unsupported_device_error
21
-
22
- SUPPORTED_DEVICE_TYPE = ["cpu" , "gpu" , "xpu" , "npu" , "mlu" ]
21
+ from .other_devices_model_list import OTHER_DEVICES_MODEL_LIST
23
22
24
23
25
24
def _constr_device (device_type , device_ids ):
@@ -38,6 +37,21 @@ def get_default_device():
38
37
return _constr_device ("gpu" , [avail_gpus [0 ]])
39
38
40
39
40
+ def check_device (model_name , device_type ):
41
+ supported_device_type = ["cpu" , "gpu" , "xpu" , "npu" , "mlu" , "dcu" ]
42
+ device_type = device_type .lower ()
43
+ if device_type not in supported_device_type :
44
+ support_run_mode_str = ", " .join (supported_device_type )
45
+ raise ValueError (
46
+ f"The device type must be one of { support_run_mode_str } , but received { repr (device_type )} ."
47
+ )
48
+ if device_type in OTHER_DEVICES_MODEL_LIST :
49
+ if model_name not in OTHER_DEVICES_MODEL_LIST [device_type ]:
50
+ raise ValueError (
51
+ f"The model '{ model_name } ' is not supported on { device_type } ."
52
+ )
53
+
54
+
41
55
def parse_device (device ):
42
56
"""parse_device"""
43
57
# According to https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/device/set_device_cn.html
@@ -55,14 +69,10 @@ def parse_device(device):
55
69
f"Device ID must be an integer. Invalid device ID: { device_id } "
56
70
)
57
71
device_ids = list (map (int , device_ids ))
58
- device_type = device_type .lower ()
59
- # raise_unsupported_device_error(device_type, SUPPORTED_DEVICE_TYPE)
60
- assert device_type .lower () in SUPPORTED_DEVICE_TYPE
61
72
return device_type , device_ids
62
73
63
74
64
- def update_device_num (device , num ):
65
- device_type , device_ids = parse_device (device )
75
+ def update_device_num (device_type , device_ids , num ):
66
76
if device_ids :
67
77
assert len (device_ids ) >= num
68
78
return _constr_device (device_type , device_ids [:num ])
0 commit comments