-
Notifications
You must be signed in to change notification settings - Fork 184
/
Copy pathpartition.py
242 lines (209 loc) · 12.4 KB
/
partition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""
Auxiliary data structures for AllReduceRunner
"""
import asyncio
from collections import deque
from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar
import numpy as np
import torch
from hivemind.averaging.accumulators import AccumulatorFactory
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor
T = TypeVar("T")
DEFAULT_PART_SIZE_BYTES = 2 ** 16
class TensorPartContainer:
"""
Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
The class is designed to avoid excessive memory allocation and run all heavy computation in background
:param tensors: local tensors to be split and aggregated
:param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
:param compression: optionally compress tensors with this compression algorithm before sending them to peers
:param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
:param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
:param prefetch: when compressing, pre-compute this many compressed tensors in background
"""
def __init__(
self,
tensors: Sequence[torch.Tensor],
peer_fractions: Sequence[float],
compression: CompressionBase = NoCompression(),
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
tensor_infos: Optional[Sequence[CompressionInfo]] = None,
prefetch: int = 5,
):
if tensor_infos is None:
tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
assert len(tensor_infos) == len(tensors), "compression types do not match the number of tensors"
self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
self.total_size = sum(tensor.numel() for tensor in tensors)
self.prefetch = prefetch
self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
self._outputs_consumed = False
self.finished = asyncio.Event()
self.num_parts_by_tensor = []
# split tensor parts in proportion to target_size_by_peer
current_length = 0
current_peer_index = 0
pivots = (np.cumsum(peer_fractions) / np.sum(peer_fractions) * self.total_size).astype(np.int64)
pivots[-1] = self.total_size
for tensor, info in zip(self.local_tensors, self.tensor_infos):
bytes_per_value = tensor.element_size() * compression.estimate_compression_ratio(info)
part_size_values = int(part_size_bytes / bytes_per_value)
tensor_parts = tensor.detach().view(-1).split(part_size_values)
self.num_parts_by_tensor.append(len(tensor_parts))
for part_index, part in enumerate(tensor_parts):
part_info = info.get_part(part_index, part_size_values)
if current_length + len(part) > pivots[current_peer_index]:
# switch to next peer; if a part lands between parts of two or
# more peers, assign that part to the peer with highest intersection
prev_peer_index = current_peer_index
peer_intersections = [pivots[current_peer_index] - current_length]
while current_length + len(part) > pivots[current_peer_index]:
current_peer_index += 1
current_peer_part_end = min(current_length + len(part), pivots[current_peer_index])
peer_intersections.append(current_peer_part_end - pivots[current_peer_index - 1])
assigned_peer_index = prev_peer_index + np.argmax(peer_intersections)
self._input_parts_by_peer[assigned_peer_index].append((part, part_info))
else:
self._input_parts_by_peer[current_peer_index].append((part, part_info))
current_length += len(part)
assert current_length == self.total_size
self.num_parts_by_peer = tuple(len(parts) for parts in self._input_parts_by_peer)
@torch.no_grad()
def get_raw_input_parts(self, peer_index: int) -> Tuple[torch.Tensor, ...]:
"""get non-serialized tensor parts for a peer at a given index"""
assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
self._inputs_consumed_by_peer[peer_index] = True
input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
self._input_parts_by_peer[peer_index].clear()
return input_parts
@torch.no_grad()
async def iterate_input_parts_for(self, peer_index: int) -> AsyncIterator[runtime_pb2.Tensor]:
"""iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
self._inputs_consumed_by_peer[peer_index] = True
async def _aiterate_parts():
for _ in range(self.num_parts_by_peer[peer_index]):
yield self._input_parts_by_peer[peer_index].popleft()
async for serialized_part in amap_in_executor(
lambda x_and_info: self.compression.compress(*x_and_info), _aiterate_parts(), max_prefetch=self.prefetch
):
yield serialized_part
def register_processed_part(self, peer_index: int, part_index: int, part: torch.Tensor):
"""
register next-in-line part of results received from a given peer for use in iterate_output_tensors
depending on the algorithm, processed part is an average, difference from average or another aggregation
"""
if part_index != self._outputs_registered_by_peer[peer_index]:
raise ValueError(
f"Could not register part #{part_index} from peer #{peer_index}, "
f" expected part index: {self._outputs_registered_by_peer[peer_index]}"
)
self._output_parts_by_peer[peer_index].append(part)
self._outputs_registered_by_peer[peer_index] += 1
self._output_part_available[peer_index].set()
async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
"""iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
self._outputs_consumed = True
peer_index = num_parts_processed = 0
for tensor_index in range(len(self.local_tensors)):
tensor_parts = []
while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
if num_parts_processed >= self.num_parts_by_peer[peer_index]:
num_parts_processed = 0
peer_index += 1
continue
if not self._output_parts_by_peer[peer_index]:
self._output_part_available[peer_index].clear()
await self._output_part_available[peer_index].wait()
if self.finished.is_set():
raise AllreduceException("All-reduce was terminated during iteration.")
tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
num_parts_processed += 1
tensor = torch.cat(tensor_parts)
del tensor_parts
yield tensor.reshape(self.local_tensors[tensor_index].shape)
def __del__(self):
self.finalize()
def finalize(self):
"""terminate all iterators, delete intermediate data"""
if not self.finished.is_set():
for peer_index in range(self.group_size):
self._inputs_consumed_by_peer[peer_index] = True
self._input_parts_by_peer[peer_index].clear()
self._output_parts_by_peer[peer_index].clear()
self._output_part_available[peer_index].set()
self._outputs_consumed = True
self.finished.set()
class TensorPartReducer:
"""
Auxiliary data structure responsible for running asynchronous all-reduce
:param part_shapes: a sequence of shapes of torch tensors that will be averaged by this reducer
:param num_senders: total number of peers in a given all-reduce group that will send gradients
:param weights: relative importance of each sender, used for weighted average (default = equal weights)
:note: even if local peer is not sending data, local parts will be used for shape information
"""
def __init__(
self,
part_shapes: Sequence[torch.Size],
num_senders: int,
*,
weights: Optional[Sequence[float]],
accumulator_factory: AccumulatorFactory,
):
self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
self.weights = tuple(weights or (1 for _ in range(num_senders)))
assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
assert all(isinstance(weight, (int, float)) for weight in self.weights)
self.current_part_index = -1 # index in local_parts of the part that should be loaded next
self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
self.current_part_future = asyncio.Future()
self.accumulator_factory = accumulator_factory
self.accumulator = None
self.finished = asyncio.Event()
self.reset_accumulators()
def reset_accumulators(self):
"""(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
if self.current_part_index >= self.num_parts - 1:
self.finalize()
return
self.current_part_index += 1
self.current_part_accumulated_from = 0
self.current_part_future = asyncio.Future()
self.accumulator = self.accumulator_factory(self.part_shapes[self.current_part_index], self.num_senders)
async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
"""Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
assert 0 <= sender_index < self.num_senders, "invalid sender index"
assert 0 <= part_index < self.num_parts, "invalid part index"
while part_index > self.current_part_index:
# wait for previous parts to finish processing ...
await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
if self.finished.is_set():
raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
assert part_index == self.current_part_index
current_part_future = self.current_part_future
self.accumulator.accumulate_part(tensor_part, self.weights[sender_index])
self.current_part_accumulated_from += 1
assert self.current_part_accumulated_from <= self.num_senders
if self.current_part_accumulated_from == self.num_senders:
current_part_future.set_result(self.accumulator.reduce())
self.reset_accumulators()
return await current_part_future
def finalize(self):
if not self.finished.is_set():
if hasattr(self, "current_part_future"):
self.current_part_future.cancel()
self.accumulator = None
self.finished.set()
def __del__(self):
self.finalize()
class AllreduceException(Exception):
"""A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error)"""