-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
425 lines (337 loc) · 15.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import shutil
import time
from enum import Enum
import torch
from torch import nn
from torch import optim
from torch.cuda import amp
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import config
from dataset import CUDAPrefetcher
from dataset import TrainValidImageDataset, TestImageDataset
from image_quality_assessment import PSNR, SSIM
from model import DBPN
def main():
# Initialize the number of training epochs
start_epoch = 0
# Initialize training to generate network evaluation indicators
best_psnr = 0.0
best_ssim = 0.0
train_prefetcher, valid_prefetcher, test_prefetcher = load_dataset()
print("Load train dataset and valid dataset successfully.")
model = build_model()
print("Build DBPN model successfully.")
pixel_criterion = define_loss()
print("Define all loss functions successfully.")
optimizer = define_optimizer(model)
print("Define all optimizer functions successfully.")
scheduler = define_scheduler(optimizer)
print("Define all optimizer scheduler successfully.")
print("Check whether the pretrained model is restored...")
if config.resume:
# Load checkpoint model
checkpoint = torch.load(config.resume, map_location=lambda storage, loc: storage)
# Restore the parameters in the training node to this point
start_epoch = checkpoint["epoch"]
best_psnr = checkpoint["best_psnr"]
best_ssim = checkpoint["best_ssim"]
# Load checkpoint state dict. Extract the fitted model weights
model_state_dict = model.state_dict()
new_state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict}
# Overwrite the pretrained model weights to the current model
model_state_dict.update(new_state_dict)
model.load_state_dict(model_state_dict)
# Load the optimizer model
optimizer.load_state_dict(checkpoint["optimizer"])
# Load the scheduler model
scheduler.load_state_dict(checkpoint["scheduler"])
print("Loaded pretrained model weights.")
else:
config.start_epoch = 0
# Create a folder of super-resolution experiment results
samples_dir = os.path.join("samples", config.exp_name)
results_dir = os.path.join("results", config.exp_name)
if not os.path.exists(samples_dir):
os.makedirs(samples_dir)
if not os.path.exists(results_dir):
os.makedirs(results_dir)
# Create training process log file
writer = SummaryWriter(os.path.join("samples", "logs", config.exp_name))
# Initialize the gradient scaler
scaler = amp.GradScaler()
# Create an IQA evaluation model
psnr_model = PSNR(config.upscale_factor, config.only_test_y_channel)
ssim_model = SSIM(config.upscale_factor, config.only_test_y_channel)
# Transfer the IQA model to the specified device
psnr_model = psnr_model.to(device=config.device, memory_format=torch.channels_last, non_blocking=True)
ssim_model = ssim_model.to(device=config.device, memory_format=torch.channels_last, non_blocking=True)
for epoch in range(start_epoch, config.epochs):
train(model, train_prefetcher, pixel_criterion, optimizer, epoch, scaler, writer)
_, _ = validate(model, valid_prefetcher, epoch, writer, psnr_model, ssim_model, "Valid")
psnr, ssim = validate(model, test_prefetcher, epoch, writer, psnr_model, ssim_model, "Test")
print("\n")
# Update lr
scheduler.step()
# Automatically save the model with the highest index
is_best = psnr > best_psnr and ssim > best_ssim
best_psnr = max(psnr, best_psnr)
best_ssim = max(ssim, best_ssim)
torch.save({"epoch": epoch + 1,
"best_psnr": best_psnr,
"best_ssim": best_ssim,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict()},
os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"))
if is_best:
shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"),
os.path.join(results_dir, "best.pth.tar"))
if (epoch + 1) == config.epochs:
shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"),
os.path.join(results_dir, "last.pth.tar"))
def load_dataset() -> [CUDAPrefetcher, CUDAPrefetcher, CUDAPrefetcher]:
# Load train, test and valid datasets
train_datasets = TrainValidImageDataset(config.train_image_dir, config.image_size, config.upscale_factor, "Train")
valid_datasets = TrainValidImageDataset(config.valid_image_dir, config.image_size, config.upscale_factor, "Valid")
test_datasets = TestImageDataset(config.test_lr_image_dir, config.test_hr_image_dir, config.upscale_factor)
# Generator all dataloader
train_dataloader = DataLoader(train_datasets,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=True,
drop_last=True,
persistent_workers=True)
valid_dataloader = DataLoader(valid_datasets,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=True,
drop_last=False,
persistent_workers=True)
test_dataloader = DataLoader(test_datasets,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=True,
drop_last=False,
persistent_workers=True)
# Place all data on the preprocessing data loader
train_prefetcher = CUDAPrefetcher(train_dataloader, config.device)
valid_prefetcher = CUDAPrefetcher(valid_dataloader, config.device)
test_prefetcher = CUDAPrefetcher(test_dataloader, config.device)
return train_prefetcher, valid_prefetcher, test_prefetcher
def build_model() -> nn.Module:
model = DBPN(upscale_factor=config.upscale_factor)
model = model.to(device=config.device, memory_format=torch.channels_last)
return model
def define_loss() -> [nn.L1Loss]:
pixel_criterion = nn.L1Loss()
pixel_criterion = pixel_criterion.to(device=config.device, memory_format=torch.channels_last)
return pixel_criterion
def define_optimizer(model) -> optim.Adam:
optimizer = optim.Adam(model.parameters(), lr=config.model_lr, betas=config.model_betas)
return optimizer
def define_scheduler(optimizer) -> lr_scheduler.StepLR:
scheduler = lr_scheduler.StepLR(optimizer, step_size=config.lr_scheduler_step_size, gamma=config.lr_scheduler_gamma)
return scheduler
def train(model: nn.Module,
train_prefetcher: CUDAPrefetcher,
pixel_criterion: nn.L1Loss,
optimizer: optim.Adam,
epoch: int,
scaler: amp.GradScaler,
writer: SummaryWriter) -> None:
"""Training main program
Args:
model (nn.Module): the generator model in the generative network
train_prefetcher (CUDAPrefetcher): training dataset iterator
pixel_criterion (nn.L1Loss): Calculate the pixel difference between real and fake samples
optimizer (optim.Adam): optimizer for optimizing generator models in generative networks
epoch (int): number of training epochs during training the generative network
scaler (amp.GradScaler): Mixed precision training function
writer (SummaryWrite): log file management function
"""
batches = len(train_prefetcher)
# Print information of progress bar during training
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":6.6f")
progress = ProgressMeter(batches, [batch_time, data_time, losses], prefix=f"Epoch: [{epoch + 1}]")
# Put the generative network model in training mode
model.train()
# Initialize the number of data batches to print logs on the terminal
batch_index = 0
# Initialize the data loader and load the first batch of data
train_prefetcher.reset()
batch_data = train_prefetcher.next()
# Get the initialization training time
end = time.time()
while batch_data is not None:
# Calculate the time it takes to load a batch of data
data_time.update(time.time() - end)
# Transfer in-memory data to CUDA devices to speed up training
lr = batch_data["lr"].to(config.device, non_blocking=True)
bic = batch_data["bic"].to(config.device, non_blocking=True)
hr = batch_data["hr"].to(config.device, non_blocking=True)
# Initialize generator gradients
model.zero_grad()
# Mixed precision training
with amp.autocast():
sr = model(lr) + bic
loss = pixel_criterion(sr, hr)
# Backpropagation
scaler.scale(loss).backward()
# update generator weights
scaler.step(optimizer)
scaler.update()
# Statistical loss value for terminal data output
losses.update(loss.item(), lr.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# Record training log information
if batch_index % config.print_frequency == 0:
# Writer Loss to file
writer.add_scalar("Train/Loss", loss.item(), batch_index + epoch * batches + 1)
progress.display(batch_index)
# Preload the next batch of data
batch_data = train_prefetcher.next()
# After a batch of data is calculated, add 1 to the number of batches
batch_index += 1
def validate(model: nn.Module,
data_prefetcher: CUDAPrefetcher,
epoch: int,
writer: SummaryWriter,
psnr_model: nn.Module,
ssim_model: nn.Module,
mode: str) -> [float, float]:
"""Test main program
Args:
model (nn.Module): generator model in adversarial networks
data_prefetcher (CUDAPrefetcher): test dataset iterator
epoch (int): number of test epochs during training of the adversarial network
writer (SummaryWriter): log file management function
psnr_model (nn.Module): The model used to calculate the PSNR function
ssim_model (nn.Module): The model used to compute the SSIM function
mode (str): test validation dataset accuracy or test dataset accuracy
"""
# Calculate how many batches of data are in each Epoch
batches = len(data_prefetcher)
batch_time = AverageMeter("Time", ":6.3f")
psnres = AverageMeter("PSNR", ":4.2f")
ssimes = AverageMeter("SSIM", ":4.4f")
progress = ProgressMeter(len(data_prefetcher), [batch_time, psnres, ssimes], prefix=f"{mode}: ")
# Put the adversarial network model in validation mode
model.eval()
# Initialize the number of data batches to print logs on the terminal
batch_index = 0
# Initialize the data loader and load the first batch of data
data_prefetcher.reset()
batch_data = data_prefetcher.next()
# Get the initialization test time
end = time.time()
with torch.no_grad():
while batch_data is not None:
# Transfer the in-memory data to the CUDA device to speed up the test
lr = batch_data["lr"].to(config.device, non_blocking=True)
bic = batch_data["bic"].to(config.device, non_blocking=True)
hr = batch_data["hr"].to(config.device, non_blocking=True)
# Mixed precision
with amp.autocast():
sr = model(lr) + bic
# Statistical loss value for terminal data output
psnr = psnr_model(sr, hr)
ssim = ssim_model(sr, hr)
psnres.update(psnr.item(), lr.size(0))
ssimes.update(ssim.item(), lr.size(0))
# Calculate the time it takes to fully test a batch of data
batch_time.update(time.time() - end)
end = time.time()
# Record training log information
if batch_index % (batches // 5) == 0:
progress.display(batch_index)
# Preload the next batch of data
batch_data = data_prefetcher.next()
# After training a batch of data, add 1 to the number of data batches to ensure that the
# terminal print data normally
batch_index += 1
# print metrics
progress.display_summary()
if mode == "Valid" or mode == "Test":
writer.add_scalar(f"{mode}/PSNR", psnres.avg, epoch + 1)
writer.add_scalar(f"{mode}/SSIM", ssimes.avg, epoch + 1)
else:
raise ValueError("Unsupported mode, please use `Valid` or `Test`.")
return psnres.avg, ssimes.avg
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def summary(self):
if self.summary_type is Summary.NONE:
fmtstr = ""
elif self.summary_type is Summary.AVERAGE:
fmtstr = "{name} {avg:.2f}"
elif self.summary_type is Summary.SUM:
fmtstr = "{name} {sum:.2f}"
elif self.summary_type is Summary.COUNT:
fmtstr = "{name} {count:.2f}"
else:
raise ValueError(f"Invalid summary type {self.summary_type}")
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print("\t".join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(" ".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
if __name__ == "__main__":
main()