-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Copy pathnum_microbatches_calculator.py
508 lines (438 loc) · 19.1 KB
/
num_microbatches_calculator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron Core number of microbatches calculators."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Union
logger = logging.getLogger(__name__)
# TODO: global_var merge into mcore?
_GLOBAL_NUM_MICROBATCHES_CALCULATOR: Union[
'ConstantNumMicroBatchesCalculator', 'RampupBatchsizeNumMicroBatchesCalculator'
] = None
def get_num_microbatches() -> int:
"""Get number of microbatches."""
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size() -> int:
"""Get current global batch size."""
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def get_micro_batch_size() -> int:
"""Get micro batch size."""
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_micro_batch_size()
def get_current_running_global_batch_size() -> int:
"""Get current running global batch size, taking into account number of DP replicas might be
incompatible with true global batch size if `decrease_batch_size_if_needed` is True."""
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_running_global_batch_size()
def update_num_microbatches(
consumed_samples: int, consistency_check: bool = True, verbose: bool = False
) -> None:
"""Update number of microbatches.
Args:
consumed_samples (int):
Number of samples consumed.
consistency_check (bool, optional):
Option to check current schedule's consistency. Defaults to True.
verbose (bool, optional):
Option to control logging. Defaults to False.
"""
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check, verbose)
def unset_num_microbatches_calculator():
"""Unset microbatches calculator.
Useful for multiple runs. See `tests/unit_tests/ckpt_converter/test_ckpt_converter.py`
for an example.
"""
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
def init_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool = False,
) -> None:
"""Initialize number of microbatches calculator. Supporting backward compatibility.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of [start_global_batch_size,
batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
"""
_configure_global_num_microbatches_calculator(
rank,
rampup_batch_size,
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
init=True,
)
def destroy_num_microbatches_calculator():
"""Destroy number of microbatches calculator."""
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
def reconfigure_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool = False,
) -> None:
"""Reconfigure number of microbatches calculator. Supporting backward compatibility.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
"""
_configure_global_num_microbatches_calculator(
rank,
rampup_batch_size,
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
init=False,
)
def _configure_global_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool = False,
init: bool = False,
) -> None:
"""Configure number of microbatches calculator. Can be used for initialization and
reconfiguration.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
init (bool, optional):
If true, initialize the calculator. Defaults to False.
"""
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
if init:
assert (
_GLOBAL_NUM_MICROBATCHES_CALCULATOR is None
), 'num microbatches calculator is already initialized.'
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = _build_num_microbatches_calculator(
rank,
rampup_batch_size,
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
)
def _build_num_microbatches_calculator(
rank: int,
rampup_batch_size: Optional[List[int]],
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool,
) -> Union['ConstantNumMicroBatchesCalculator', 'RampupBatchsizeNumMicroBatchesCalculator']:
"""Build number of microbatches calculator. Internal helper method.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
"""
# Constant batch size.
if rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatchesCalculator(
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
rank,
)
if rank == 0:
logger.info(
f'setting number of microbatches to constant {num_microbatches_calculator.get()}'
)
# Batch size ramp up.
else:
assert len(rampup_batch_size) == 3, (
'expected the following '
'format: --rampup-batch-size <start batch size> '
'<batch size incerement> <ramp-up samples>'
)
start_global_batch_size = int(rampup_batch_size[0])
batch_size_increment = int(rampup_batch_size[1])
ramup_samples = int(rampup_batch_size[2])
if rank == 0:
logger.info(
f'will use batch size rampup starting from global batch size '
f'{start_global_batch_size} to global batch size {global_batch_size} with batch'
f'size increments {batch_size_increment} over {ramup_samples} samples.'
)
num_microbatches_calculator = RampupBatchsizeNumMicroBatchesCalculator(
global_batch_size,
micro_batch_size,
data_parallel_size,
decrease_batch_size_if_needed,
rank,
start_global_batch_size,
batch_size_increment,
ramup_samples,
)
return num_microbatches_calculator
def _round(batch_size: int, divisor: int) -> int:
"""Round `batch_size` down to nearest batch size divisible by `divisor`."""
return (batch_size // divisor) * divisor
class NumMicroBatchesCalculator(ABC):
"""Base class for number of microbatches calculator."""
def __init__(self) -> None:
self.num_micro_batches = None
self.current_global_batch_size = None
self.micro_batch_size = None
self.current_running_global_batch_size = None
def get(self) -> int:
"""Get number of microbatches."""
return self.num_micro_batches
def get_current_global_batch_size(self) -> int:
"""Get current global batch size."""
return self.current_global_batch_size
def get_micro_batch_size(self) -> int:
"""Get current global batch size."""
return self.micro_batch_size
def get_current_running_global_batch_size(self) -> int:
"""Get current running global batch size. If decrease_batch_size_if_needed is False,
this just equals global batch size."""
return self.current_running_global_batch_size
@abstractmethod
def update(self, consumed_samples, consistency_check, verbose=False) -> None:
"""Update number of microbatches depending on batch size rampup."""
pass
class ConstantNumMicroBatchesCalculator(NumMicroBatchesCalculator):
"""Calculator of number of microbatches with constant global batch size.
Args:
global_batch_size (int):
Global batch size.
micro_batch_size (int):
Micro batch size.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, decrease batch size to ensure divisibility by DP size * microbatch size
(if needed).
rank (int):
Rank (to determine whether logging should be performed).
"""
def __init__(
self,
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool,
rank: int,
) -> None:
micro_batch_times_data_parallel_size = micro_batch_size * data_parallel_size
if decrease_batch_size_if_needed:
running_global_batch_size = _round(
global_batch_size, micro_batch_times_data_parallel_size
)
assert running_global_batch_size % micro_batch_times_data_parallel_size == 0
if rank == 0:
logger.info(
f'decreasing batch size from {global_batch_size} to {running_global_batch_size}'
f'to keep divisiblity by micro_batch_size={micro_batch_size} * '
f'data_parallel_size={data_parallel_size}'
)
self.num_micro_batches = (
running_global_batch_size // micro_batch_times_data_parallel_size
)
else:
assert global_batch_size % micro_batch_times_data_parallel_size == 0, (
'global batch size ({}) is not divisible by micro batch size ({})'
' times data parallel size ({})'.format(
global_batch_size, micro_batch_size, data_parallel_size
)
)
running_global_batch_size = global_batch_size
self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel_size
assert (
self.num_micro_batches >= 1
), 'number of microbatches should be at least 1, got {}.'.format(self.num_micro_batches)
self.current_global_batch_size = global_batch_size
self.current_running_global_batch_size = running_global_batch_size
self.micro_batch_size = micro_batch_size
def update(self, consumed_samples, consistency_check, verbose=False) -> None:
pass
class RampupBatchsizeNumMicroBatchesCalculator(NumMicroBatchesCalculator):
"""Calculator of number of microbatches with batch size rampup.
Over `steps = (global-batch-size - start-batch-size) / batch_size_increment` increment batch
size from start-batch-size to global-batch-size using rampup-samples / steps
samples.
Args:
global_batch_size (int):
Global batch size post rampup.
micro_batch_size (int):
Micro batch size.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, decrease batch size to ensure divisibility by DP size * microbatch size
(if needed).
rank (int):
Rank (to determine whether logging should be performed).
start_global_batch_size (int):
Global batch size to start with.
batch_size_increment (int):
Global batch size increments.
ramup_samples (int):
Number of samples to use ramp up global
batch size from `start_global_batch_size` to `global_batch_size`.
"""
def __init__(
self,
global_batch_size: int,
micro_batch_size: int,
data_parallel_size: int,
decrease_batch_size_if_needed: bool,
rank: int,
start_global_batch_size: int,
batch_size_increment: int,
ramup_samples: int,
) -> None:
assert global_batch_size > 0, 'global batch size should be positive, got {}.'.format(
global_batch_size
)
assert start_global_batch_size > 0, 'start batch size should be positive, got {}.'.format(
start_global_batch_size
)
assert batch_size_increment > 0, 'batch size increment should be positive, got {}.'.format(
batch_size_increment
)
assert ramup_samples >= 0, 'ramp-up samples should be non-negative, got {}.'.format(
ramup_samples
)
self.global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.decrease_batch_size_if_needed = decrease_batch_size_if_needed
self.rank = rank
self.start_global_batch_size = start_global_batch_size
self.batch_size_increment = batch_size_increment
self.ramup_samples = ramup_samples
self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size
assert self.micro_batch_times_data_parallel_size > 0
self.current_global_batch_size = None
diff_batch_size = self.global_batch_size - self.start_global_batch_size
assert diff_batch_size >= 0, (
'expected global batch size to be greater than or equal to start batch size, '
f'got {self.global_batch_size} and {self.start_global_batch_size}'
)
assert diff_batch_size % batch_size_increment == 0, (
'expected '
f'global batch size interval ({diff_batch_size}) to be divisible by global batch '
f'size increment ({batch_size_increment})'
)
num_increments = diff_batch_size // self.batch_size_increment
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0, consistency_check=False, verbose=True)
def update(self, consumed_samples: int, consistency_check: bool, verbose: bool = False) -> None:
"""Update number of microbatches.
Args:
consumed_samples (int): Number of samples consumed.
consistency_check (bool): Option to check current schedule's consistency.
verbose (bool, optional): Option to control logging. Defaults to False.
"""
# Update current global batch size.
global_batch_size_changed = False
old_current_global_batch_size = self.current_global_batch_size
if consumed_samples > self.ramup_samples:
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
self.current_global_batch_size = (
self.start_global_batch_size + steps * self.batch_size_increment
)
assert self.current_global_batch_size <= self.global_batch_size
if old_current_global_batch_size != self.current_global_batch_size:
global_batch_size_changed = True
if self.rank == 0 and global_batch_size_changed and verbose:
if old_current_global_batch_size is None:
logger.info(f'setting initial batch size to {self.current_global_batch_size}')
else:
logger.info(
f'ramping up batch size from {old_current_global_batch_size} to '
f'{self.current_global_batch_size}'
)
# Check consistency of the current global batch size.
if consistency_check and not self.decrease_batch_size_if_needed:
assert (
self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0
), (
'current global '
'batch size ({}) is not divisible by micro-batch-size ({}) times'
'data parallel size ({})'.format(
self.current_global_batch_size, self.micro_batch_size, self.data_parallel_size
)
)
if (
self.decrease_batch_size_if_needed
and self.current_global_batch_size % self.micro_batch_times_data_parallel_size != 0
):
self.current_running_global_batch_size = _round(
self.current_global_batch_size, self.micro_batch_times_data_parallel_size
)
if self.rank == 0 and global_batch_size_changed and verbose:
logger.info(
f'decreasing batch size from {self.current_global_batch_size} to '
f'{self.current_running_global_batch_size} to keep divisiblity by '
f'micro_batch_size={self.micro_batch_size} * '
f'data_parallel_size={self.data_parallel_size}'
)
assert (
self.current_running_global_batch_size % self.micro_batch_times_data_parallel_size
== 0
)
else:
self.current_running_global_batch_size = self.current_global_batch_size
self.num_micro_batches = (
self.current_running_global_batch_size // self.micro_batch_times_data_parallel_size
)