forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheckpoint.py
646 lines (539 loc) · 24.6 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import enum
import functools
import os
import queue
import re
import shutil
import threading
import time
from typing import Any
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
set_model_state_dict,
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchtitan.components.ft import FTManager
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.utils import GarbageCollection
MODEL = "model"
OPTIMIZER = "optimizer"
LR_SCHEDULER = "lr_scheduler"
DATALOADER = "dataloader"
TRAIN_STATE = "train_state"
class AsyncMode(str, enum.Enum):
DISABLED = "disabled"
ASYNC = "async"
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"
class ModelWrapper(Stateful):
def __init__(self, model: nn.Module | list[nn.Module]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
}
def state_dict(self) -> dict[str, Any]:
return self.cache_state_dict
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
func = functools.partial(
set_model_state_dict,
model_state_dict=state_dict,
options=StateDictOptions(strict=False),
)
list(map(func, self.model))
# `set_model_state_dict()` does change the keys of the input state_dict,
# we will need to reinitialize the cache_state_dict.
self.cache_state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
}
class Terminate:
pass
class SaveDone:
pass
@torch.no_grad()
def save_with_gc(state, checkpoint_id):
dcp.save(state, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection invoked by checkpointer.")
def checkpoint_mp(recv: mp.Queue, send: mp.Queue):
"""Process to save the checkpoint in the background.
This is only used when async_checkpoint_with_pinned_memory is enabled.
Args:
recv (mp.Queue): The queue to receive the state_dict and Terminate signal.
send (mp.Queue): The queue to send the SaveDone signal.
"""
init_logger()
os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2)
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
dist.init_process_group()
try:
while True:
logger.debug("Checkpoint background process is done.")
send.put(SaveDone())
logger.debug("Wait for the new state_dict.")
obj = recv.get()
logger.debug("Received the new state_dict.")
if isinstance(obj, Terminate):
logger.info("Terminating the checkpoint background process.")
return
assert isinstance(obj, tuple)
begin = time.monotonic()
state, checkpoint_id = obj
save_with_gc(state, checkpoint_id=checkpoint_id)
logger.info(
"Finish saving the checkpoint in the background process in %.2f seconds.",
time.monotonic() - begin,
)
finally:
logger.info("Destroying the process group.")
dist.destroy_process_group()
def purge_thread(purge_queue: queue.Queue):
"""Thread to purge the old checkpoints.
This is only used when keep_latest_k > 0.
Args:
purge_queue (queue.Queue): The queue to receive the path to purge and Terminate signal.
"""
try:
while True:
path = purge_queue.get()
if isinstance(path, Terminate):
return
assert isinstance(path, str)
logger.info("Checkpointer is deleting %s.", path)
begin = time.monotonic()
shutil.rmtree(path, ignore_errors=True)
logger.info(
"Checkpointer deleted %s in %.2f seconds.",
path,
time.monotonic() - begin,
)
finally:
logger.info("Destroying the purge thread.")
class CheckpointManager:
"""This class manages the checkpointing logic for the TorchTitan trainer.
Note: Pipeline Parallelism and Virtual Stages
1. even for simple PP schedules, there is a separate optimizer each PP rank.
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original
model. rank1's would _also_ have a param_group[0], since it's index based, but
referring to layers.1. When saving, these collide and one of them is lost. Then when
reloading, only one stage can restore its optimizer states, others will error.
The solution to this problem is optimizer flattening: it landed in #127071 and is
enabled in TorchTitan by passing the 'flatten_optimizer_state_dict' kwarg to DCP
functions called in the OptimizerContainer.
See PR #127071 (https://github.com/pytorch/pytorch/pull/127071) for the example of
a flattening state_dict.
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds
challenge (1) by also requiring us to reason about multiple 'optim' objects locally.
We solve this in the Model and Optimizer wrapper classes by flattening the state dicts
from each object into one state dict before saving/loading. We rely on the individual
state_dicts to not collide, which is gauranteed for the model by correct pipeline
splitting and for the optimizer by the flattening support described in (1).
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
with the assumption that all lr_schedulers have the same state_dict.
Note: TorchFT checkpointing flow
There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
checkpoint, 2) the per-replica checkpoint.
The full perisistent checkpoint is saved by the replica with
``ft_manager.participating_rank() == 0``. It contains everything including the model,
optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
checkpoint is loaded by all replicas. However, we can optimize it to only load if
there are no other alive replicas.
The per-replica checkpoint contains only the dataloader and is saved/loaded by all
replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
Args:
dataloader (DataLoader): The dataloader used to load the data.
model_parts (List[nn.Module]): List of model parts to be optimized.
optimizers (OptimizersContainer): The optimizers used to optimize the model.
lr_schedulers (LRSchedulersContainer): The lr schedulers used to optimize the model.
states (Dict[str, Any]): The states that need to be saved, other than the
previous 4 components.
job_config (JobConfig): The job config used to configure the checkpointing.
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
"""
def __init__(
self,
dataloader: DataLoader,
model_parts: list[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: LRSchedulersContainer,
states: dict[str, Any],
job_config: JobConfig,
ft_manager: FTManager,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.ft_manager = ft_manager.manager if ft_manager.enabled else None
if self.ft_manager:
optimizers.init_cache_state_dict()
def state_dict():
ret = {}
for k, v in self.states.items():
if k in {
MODEL,
OPTIMIZER,
LR_SCHEDULER,
TRAIN_STATE,
}:
ret[k] = v.state_dict()
return ret
def load_state_dict(state_dict):
assert state_dict is not None
for k, v in state_dict.items():
self.states[k].load_state_dict(v)
self.ft_manager.set_state_dict_fns(load_state_dict, state_dict)
self.ft_replica_id = job_config.fault_tolerance.replica_id
async_mode = ckpt_config.async_mode.lower()
self.enable_staging = (
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
) or self.ft_manager
if not self.enable_checkpoint and self.ft_manager is None:
return
self.states = states
self.states.update(
{
MODEL: ModelWrapper(model_parts),
OPTIMIZER: optimizers,
DATALOADER: dataloader,
LR_SCHEDULER: lr_schedulers,
}
)
self.ft_states = {DATALOADER: dataloader}
self.staging = False
self.sending_to_checkpoint_mp = False
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval = ckpt_config.interval
async_mode = ckpt_config.async_mode.lower()
if async_mode == AsyncMode.ASYNC or self.ft_manager:
self.pg = dist.new_group(backend="gloo")
self.keep_latest_k = ckpt_config.keep_latest_k
if self.keep_latest_k > 0:
if self.keep_latest_k == 1:
raise ValueError(
"We need to maintain at least 2 checkpoint replicas, "
"as the last one may be in the process of being saved."
)
self.purge_queue = queue.Queue()
self.purge_thread = threading.Thread(
target=purge_thread, args=(self.purge_queue,), daemon=True
)
self.purge_thread.start()
else:
self.purge_thread = None
self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
self.exclude_from_loading = ckpt_config.exclude_from_loading
self.mp = None
if async_mode == AsyncMode.DISABLED:
self.async_mode = AsyncMode.DISABLED
elif async_mode == AsyncMode.ASYNC:
self.async_mode = AsyncMode.ASYNC
self.async_future = None
elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM
ctx = mp.get_context("spawn")
self.mp_queue_send = ctx.Queue()
self.mp_queue_recv = ctx.Queue()
self.mp = ctx.Process(
target=checkpoint_mp,
args=(
self.mp_queue_send,
self.mp_queue_recv,
),
daemon=True,
)
self.mp.start()
else:
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")
logger.info(
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"
)
def __del__(self):
self.close()
def close(self):
if self.enable_checkpoint:
if self.mp and self.mp.is_alive():
self.mp_queue_send.put(Terminate())
self.mp.join()
if self.purge_thread and self.purge_thread.is_alive():
self.purge_queue.put(Terminate())
self.purge_thread.join()
@torch.no_grad()
def save(self, curr_step: int, force: bool = False) -> None:
"""Save the checkpoint for the current step.
This function will save the checkpoint for the current step. If ``force`` is
true, it will save the checkpoint even if the interval has not been reached.
This only happens when train_state.step == job_config.training.steps, or
for initial seed checkpoint.
Args:
curr_step (int): The current step.
force (bool, optional): Whether to force save the checkpoint. Defaults to False.
Returns:
None
"""
if self.ft_manager:
self._ft_save(curr_step)
if not self._should_save(curr_step, force):
return
begin = time.monotonic()
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
# This GC is called for async checkpoint as it is useless to do
# GC right after async_save -- the CPU memory is not able to be
# freed until _async_wait()
if force:
self._save_last_step(curr_step)
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self._async_with_pinned_memory(checkpoint_id)
elif self.async_mode == AsyncMode.ASYNC:
GarbageCollection.collect("GC collection invoked by checkpointer.")
self.async_future = dcp.async_save(
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
save_with_gc(self.states, checkpoint_id=checkpoint_id)
self._purge_stale_checkpoints()
logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
elif self.ft_manager:
logger.info(
"Replica %d doesn't save checkpoint.",
self.ft_manager.participating_rank(),
)
@torch.no_grad()
def load(self, step: int = -1) -> bool:
"""Load the checkpoint for the given step.
This function will load the checkpoint for the given step. If ``step`` is -1, it
will load the latest checkpoint. If the checkpoint does not exist, it will return
False and load nothing.
Args:
step (int, optional): The step to load the checkpoint for. Defaults to -1.
Returns:
bool: Whether the checkpoint was loaded successfully.
"""
if self.ft_manager:
self._ft_load()
if not self.enable_checkpoint or not os.path.isdir(self.folder):
return False
if step == -1:
step = self._find_load_step()
if step == -1:
return False
checkpoint_id = self._create_checkpoint_id(step)
if not os.path.isdir(checkpoint_id):
return False
logger.info(f"Loading the checkpoint at step {step}.")
begin = time.monotonic()
states = self._states_to_load(step)
dcp.load(states, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
)
return True
def maybe_wait_for_staging(self) -> None:
"""Wait for the staging to finish if it is enabled.
This function will wait for staging to finish. The staging is only enabled
with ``async_checkpoint_with_pinned_memory``.
"""
if self.enable_staging and self.staging:
if not self.staging_stream.query():
begin = time.monotonic()
self.staging_stream.synchronize()
logger.info(
"Checkpointer waited staging %.2f seconds.",
time.monotonic() - begin,
)
self.staging = False
if self.sending_to_checkpoint_mp:
# Copy the sync staging result to another process.
def sync_func():
self.mp_queue_send.put_nowait(
(self.cpu_offload_state_dict, self.staging_id)
)
# This may be a faster way to do zero-overhead checkpointing staging
# checkpointing but we need more thorough investigation before
# swithing to this method.
# self.my_thread = threading.Thread(target=func).start()
begin = time.monotonic()
sync_func()
logger.info(
"Checkpointer sent staged state_dict to another process %.2f seconds",
time.monotonic() - begin,
)
self.sending_to_checkpoint_mp = False
def _find_load_step(self, folder: str = "") -> int:
"""Find the step to load the checkpoint for.
Args:
folder (str, optional): The folder to find the checkpoint for. If ``folder``
is "", then ``self.folder`` will be used.
Returns:
int: The step to load the checkpoint for.
"""
folder = folder if folder else self.folder
pattern = r"step-(\d+)"
step_counts = []
if not os.path.isdir(folder):
return -1
for filename in os.listdir(folder):
match = re.search(pattern, filename)
metadata_probe = os.path.join(folder, filename, ".metadata")
if match and os.path.isfile(metadata_probe):
step_counts.append(int(match.group(1)))
if not step_counts:
return -1
return max(step_counts)
def _ft_folder(self) -> str:
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
folder = folder if folder else self.folder
return os.path.join(folder, f"step-{step}")
def _ft_save(self, step: int) -> None:
begin = time.monotonic()
self._async_wait()
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
self.async_future = dcp.async_save(
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
)
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
def _ft_load(self) -> None:
step = self._find_load_step(folder=self._ft_folder())
if step == -1:
return
begin = time.monotonic()
logger.info(f"Loading the FT checkpoint at step {step}.")
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
)
def _states_to_load(self, step: int) -> dict[str, Any]:
"""Determines which states to load for the given step.
When checkpointer determines which step of the checkpoint to load, this API is
used to determine which states to load based on the step.
Args:
step (int): The step to load the checkpoint for.
Returns:
Dict[str, Any]: The states to load for the given step.
"""
# For the first step, we will only load the model weights.
states = {MODEL: self.states[MODEL]} if step == 0 else self.states
states_to_load = {
k: v for k, v in states.items() if k not in self.exclude_from_loading
}
for exclude_key in self.exclude_from_loading:
if exclude_key not in states:
raise ValueError(f"{exclude_key} not found in state_dict.")
if self.ft_manager:
states_to_load.pop(DATALOADER)
return states_to_load
def _save_last_step(self, curr_step: int) -> None:
# We only consider saving weights only at the end of the training. So
# this won't affect preemption and training resume. We also only allow
# dtype conversion when we are checkpoint model weights only and the
# current dtype is not the same as the export dtype at the end of the training.
if self.model_weights_only:
# We update self.states to keep the model only.
# After this update, self.states = {
# 'tok_embeddings.weight':...,
# 'layers.0.attention.wq.weight': ...
# }.
self.states = self.states[MODEL].state_dict()
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L348
self.states.pop("freqs_cis")
if self.export_dtype != torch.float32:
self.states = {
k: v.to(self.export_dtype) for k, v in self.states.items()
}
logger.info(
f"Saving a model weights only checkpoint in {self.export_dtype} "
f"at last step, step {curr_step}."
)
else:
logger.info(f"Saving a full checkpoint at last step, step {curr_step}.")
save_with_gc(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
def _should_save(self, curr_step: int, force: bool = False) -> bool:
if not self.enable_checkpoint:
return False
# Force saving a checkpoint at step 1 to fail fast if checkpointer is not
# compatible with the cluster.
if curr_step == 1:
return True
if force:
return True
if curr_step % self.interval == 0:
return True
return False
def _async_wait(self) -> None:
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
logger.debug(
f"Waiting for the background process to finish, {time.monotonic()=}.:.2f"
)
if not self.mp.is_alive():
raise RuntimeError("The checkpoint background process is dead.")
_ = self.mp_queue_recv.get()
elif self.async_mode == AsyncMode.ASYNC:
if self.async_future is not None:
self.async_future.result()
def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
self._cpu_staging(checkpoint_id)
self.sending_to_checkpoint_mp = True
def _cpu_staging(self, checkpoint_id: str | None) -> None:
"""Offload state_dict to CPU memory"""
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
if self.cpu_offload_state_dict is None:
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
self.cpu_offload_state_dict = _create_cpu_state_dict(
state_dict, pin_memory=True, share_memory=True
)
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
with torch.cuda.stream(self.staging_stream):
self.cpu_offload_state_dict = _copy_state_dict(
state_dict,
self.cpu_offload_state_dict,
non_blocking=True,
)
self.staging = True
self.staging_id = checkpoint_id
def _purge_stale_checkpoints(self):
if (
self.keep_latest_k > 0
and dist.get_rank() == 0
and os.path.isdir(self.folder)
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
):
discovered_checkpoints = []
for filename in os.listdir(self.folder):
match = re.search(r"step-(\d+)", filename)
path = os.path.join(self.folder, filename)
discovered_checkpoints.append((int(match.group(1)), path))
discovered_checkpoints.sort()
to_delete = discovered_checkpoints[: -1 * self.keep_latest_k]
for _, path in to_delete:
assert self.purge_thread is not None
self.purge_queue.put(path)