4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ # pyre-unsafe
8
+
7
9
import argparse
8
10
import inspect
9
11
import os
19
21
from executorch .exir .backend .test .backend_with_compiler_demo import (
20
22
BackendWithCompilerDemo ,
21
23
)
24
+ from executorch .exir .program import ExecutorchProgramManager
22
25
from torch import nn
23
26
from torch .export import export
24
27
@@ -111,10 +114,10 @@ def export_module_to_program(
111
114
* ,
112
115
backend_id : str ,
113
116
extract_delegate_segments : bool ,
114
- constant_tensor_alignemnt : Optional [int ] = None ,
117
+ constant_tensor_alignment : Optional [int ] = None ,
115
118
delegate_alignment : Optional [int ] = None ,
116
119
method : str = "forward" ,
117
- ) -> bytes :
120
+ ) -> ExecutorchProgramManager :
118
121
eager_module = module_class ().eval ()
119
122
inputs = ()
120
123
if hasattr (eager_module , "get_random_inputs" ):
@@ -135,7 +138,7 @@ def forward(self, *args, **kwargs):
135
138
edge_config = EdgeCompileConfig (_check_ir_validity = False )
136
139
et_config = exir .ExecutorchBackendConfig (
137
140
extract_delegate_segments = extract_delegate_segments ,
138
- constant_tensor_alignment = constant_tensor_alignemnt ,
141
+ constant_tensor_alignment = constant_tensor_alignment ,
139
142
delegate_alignment = delegate_alignment ,
140
143
)
141
144
@@ -170,7 +173,7 @@ def forward(self, *args, **kwargs):
170
173
export (composite_module , args = inputs , strict = True )
171
174
).to_executorch (config = et_config )
172
175
173
- return executorch_program . buffer
176
+ return executorch_program
174
177
175
178
176
179
def main () -> None :
@@ -199,6 +202,14 @@ def main() -> None:
199
202
help = "ID of the backend to use for delegation; "
200
203
+ f"one of { known_backend_ids } " ,
201
204
)
205
+ parser .add_argument (
206
+ "--inline_delegate_segments" ,
207
+ action = "store_true" ,
208
+ help = "Store delegate data inside the flatbuffer." ,
209
+ )
210
+ parser .add_argument (
211
+ "--delegate_alignment" , type = int , default = None , help = "Delegate alignment."
212
+ )
202
213
parser .add_argument (
203
214
"--outdir" ,
204
215
type = str ,
@@ -219,25 +230,22 @@ def main() -> None:
219
230
220
231
# Export and write to the output files.
221
232
os .makedirs (args .outdir , exist_ok = True )
233
+ suffix = ""
222
234
for module_name , module_class in module_names_to_classes .items ():
223
- for extract_delegate_segments in (True , False ):
224
- suffix = "" if extract_delegate_segments else "-nosegments"
225
- # Create files with the default alignment, and a large alignment.
226
- # This alignment should be so large that it's extremely unlikely for
227
- # the data to accidentally be aligned to it in the default case.
228
- for delegate_alignment in (None , 1024 ):
229
- suffix += f"-da{ delegate_alignment } " if delegate_alignment else ""
230
- outfile = os .path .join (args .outdir , f"{ module_name } { suffix } .pte" )
231
- with open (outfile , "wb" ) as fp :
232
- fp .write (
233
- export_module_to_program (
234
- module_class ,
235
- backend_id = args .backend_id ,
236
- extract_delegate_segments = extract_delegate_segments ,
237
- delegate_alignment = delegate_alignment ,
238
- )
239
- )
240
- print (f"Exported { module_name } and wrote program data to { outfile } " )
235
+ if args .inline_delegate_segments :
236
+ suffix += "-nosegments"
237
+ if args .delegate_alignment is not None :
238
+ suffix += f"-da{ args .delegate_alignment } "
239
+ outfile = os .path .join (args .outdir , f"{ module_name } { suffix } .pte" )
240
+ executorch_program = export_module_to_program (
241
+ module_class ,
242
+ backend_id = args .backend_id ,
243
+ extract_delegate_segments = not args .inline_delegate_segments ,
244
+ delegate_alignment = args .delegate_alignment ,
245
+ )
246
+ with open (outfile , "wb" ) as fp :
247
+ fp .write (executorch_program .buffer )
248
+ print (f"Exported { module_name } and wrote program data to { outfile } " )
241
249
242
250
243
251
if __name__ == "__main__" :
0 commit comments