Skip to content

Commit f0f00ce

Browse files
lancertsmsaroufim
andauthored
In tutorials/quantize_vit, extract common methods to util.py (#238)
* Extract common methods to util.py * Update tutorials/quantize_vit/run_vit_b.py Co-authored-by: Mark Saroufim <[email protected]> * Update tutorials/quantize_vit/run_vit_b_quant.py Co-authored-by: Mark Saroufim <[email protected]> * amend * amend * Include the torchao utils --------- Co-authored-by: Mark Saroufim <[email protected]> Co-authored-by: Mark Saroufim <[email protected]>
1 parent adfe570 commit f0f00ce

File tree

3 files changed

+30
-48
lines changed

3 files changed

+30
-48
lines changed

torchao/utils.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
3+
4+
def benchmark_model(model, num_runs, input_tensor):
5+
torch.cuda.synchronize()
6+
start_event = torch.cuda.Event(enable_timing=True)
7+
end_event = torch.cuda.Event(enable_timing=True)
8+
start_event.record()
9+
10+
# benchmark
11+
for _ in range(num_runs):
12+
with torch.autograd.profiler.record_function("timed region"):
13+
model(input_tensor)
14+
15+
end_event.record()
16+
torch.cuda.synchronize()
17+
return start_event.elapsed_time(end_event) / num_runs
18+
19+
def profiler_runner(path, fn, *args, **kwargs):
20+
with torch.profiler.profile(
21+
activities=[torch.profiler.ProfilerActivity.CPU,
22+
torch.profiler.ProfilerActivity.CUDA],
23+
record_shapes=True) as prof:
24+
result = fn(*args, **kwargs)
25+
prof.export_chrome_trace(path)
26+
return result

tutorials/quantize_vit/run_vit_b.py

+2-24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import torchvision.models.vision_transformer as models
33

4+
from torchao.utils import benchmark_model, profiler_runner
5+
torch.set_float32_matmul_precision("high")
46
# Load Vision Transformer model
57
model = models.vit_b_16(pretrained=True)
68

@@ -12,30 +14,6 @@
1214

1315
model = torch.compile(model, mode='max-autotune')
1416

15-
def benchmark_model(model, num_runs, input_tensor):
16-
torch.cuda.synchronize()
17-
start_event = torch.cuda.Event(enable_timing=True)
18-
end_event = torch.cuda.Event(enable_timing=True)
19-
start_event.record()
20-
21-
# benchmark
22-
for _ in range(num_runs):
23-
with torch.autograd.profiler.record_function("timed region"):
24-
model(input_tensor)
25-
26-
end_event.record()
27-
torch.cuda.synchronize()
28-
return start_event.elapsed_time(end_event) / num_runs
29-
30-
def profiler_runner(path, fn, *args, **kwargs):
31-
with torch.profiler.profile(
32-
activities=[torch.profiler.ProfilerActivity.CPU,
33-
torch.profiler.ProfilerActivity.CUDA],
34-
record_shapes=True) as prof:
35-
result = fn(*args, **kwargs)
36-
prof.export_chrome_trace(path)
37-
return result
38-
3917
# Must run with no_grad when optimizing for inference
4018
with torch.no_grad():
4119
# warmup

tutorials/quantize_vit/run_vit_b_quant.py

+2-24
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torchao
33
import torchvision.models.vision_transformer as models
44

5+
from torchao.utils import benchmark_model, profiler_runner
6+
torch.set_float32_matmul_precision("high")
57
# Load Vision Transformer model
68
model = models.vit_b_16(pretrained=True)
79

@@ -19,30 +21,6 @@
1921

2022
model = torch.compile(model, mode='max-autotune')
2123

22-
def benchmark_model(model, num_runs, input_tensor):
23-
torch.cuda.synchronize()
24-
start_event = torch.cuda.Event(enable_timing=True)
25-
end_event = torch.cuda.Event(enable_timing=True)
26-
start_event.record()
27-
28-
# benchmark
29-
for _ in range(num_runs):
30-
with torch.autograd.profiler.record_function("timed region"):
31-
model(input_tensor)
32-
33-
end_event.record()
34-
torch.cuda.synchronize()
35-
return start_event.elapsed_time(end_event) / num_runs
36-
37-
def profiler_runner(path, fn, *args, **kwargs):
38-
with torch.profiler.profile(
39-
activities=[torch.profiler.ProfilerActivity.CPU,
40-
torch.profiler.ProfilerActivity.CUDA],
41-
record_shapes=True) as prof:
42-
result = fn(*args, **kwargs)
43-
prof.export_chrome_trace(path)
44-
return result
45-
4624
# Must run with no_grad when optimizing for inference
4725
with torch.no_grad():
4826
# warmup

0 commit comments

Comments
 (0)