-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Copy pathutils.py
1755 lines (1435 loc) · 64.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Utility functions used throughout Megatron core"""
import array
import functools
import hashlib
import logging
import math
import operator
import queue
import socket
import sys
import threading
import time
import traceback
import warnings
from dataclasses import dataclass
from datetime import datetime
from functools import reduce, wraps
from importlib.metadata import version
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import torch
from packaging.version import Version as PkgVersion
from megatron.core import config
from megatron.core.package_info import __version__ as mcore_version
try:
from torch.distributed._tensor import DTensor
from torch.distributed.tensor.placement_types import Shard
HAVE_DTENSOR = True
except ImportError:
HAVE_DTENSOR = False
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedTensor
logger = logging.getLogger(__name__)
try:
_torch_version = PkgVersion(torch.__version__)
except Exception:
# This is a WAR for building docs, where torch is not actually imported
_torch_version = PkgVersion("0.0.0")
_te_version = None
_fa_version = None
class ExperimentalNotEnabledError(Exception):
"""Raised during calls to experimental code when ENABLE_EXPERIMENTAL not set."""
def experimental_fn(introduced_with_version: str):
"""A decorator that marks a function as experimental.
Experimental functions may change quickly and do not guarantee backwards
compatiblity.
Experimental functions have a limited lifetime and should
either be productionized or deprecated.
Args:
introduced_with_version (str): A version-like string of Mcore at time of
introduction.
Raises:
ExperimentalNotEnabledError: Error raised when experimental function
was called without enabling the experimental flag.
"""
def validator(func: Callable, max_lifetime: int = 3) -> Callable:
"""Validates the request to the experimental function.
Args:
func (Callable): Callee
max_lifetime (int, optional): Number of minor version that the experimental
function is allowed to exist. Defaults to 3.
Raises:
ExperimentalNotEnabledError: Error raised when experimental function
was called without enabling the experimental flag.
Returns:
Callable: The callee function.
"""
if (
PkgVersion(introduced_with_version).minor + max_lifetime
< PkgVersion(mcore_version).minor
):
logger.warning(
"%s has reached end of life. Please migrate to a non-experimental function.",
func.__name__,
)
@wraps(func)
def wrapped_func(*args, **kwargs):
if config.ENABLE_EXPERIMENTAL is not True:
raise ExperimentalNotEnabledError(f"Flag {config.ENABLE_EXPERIMENTAL} not enabled.")
logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.")
return func(*args, **kwargs)
return wrapped_func
return validator
def experimental_cls(introduced_with_version: str):
"""A decorator that marks a Class as experimental.
Experimental Classes may change quickly and do not guarantee backwards
compatiblity.
Experimental classes have a limited lifetime and should
either be productionized or deprecated.
Args:
introduced_with_version (str): A version-like string of Mcore at time of
introduction.
Raises:
ExperimentalNotEnabledError: Error raised when experimental class
was called without enabling the experimental flag.
"""
def validator(cls: Callable, max_lifetime: int = 3) -> Callable:
"""Validates the request to the experimental function.
Args:
func (Callable): Callee
max_lifetime (int, optional): Number of minor version that the experimental
function is allowed to exist. Defaults to 3.
Raises:
ExperimentalNotEnabledError: Error raised when experimental function
was called without enabling the experimental flag.
Returns:
Callable: The callee function.
"""
if (
PkgVersion(introduced_with_version).minor + max_lifetime
< PkgVersion(mcore_version).minor
):
logger.warning(
"%s has reached end of life. Please migrate to a non-experimental function.",
cls.__name__,
)
def wrapped_func(cls):
def guard(super: super, attr: str):
"""Pass-through to callee attribute if experimental flag is enabled.
Args:
super (super): Parent class of callee.
attr (str): Attribute of callee that is being called.
Raises:
ExperimentalNotEnabledError: Raised if flag is not set.
Returns:
Attribute of callee.
"""
if attr == "is_experimental":
return config.ENABLE_EXPERIMENTAL
if config.ENABLE_EXPERIMENTAL is not True:
raise ExperimentalNotEnabledError(
f"Flag {config.ENABLE_EXPERIMENTAL} not enabled."
)
logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.")
return super.__getattribute__(attr)
class ClassInterceptor(type):
"""Metaclass to intercept calls from the uninitialized class."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__class__ = type(cls.__qualname__, (ClassInterceptor,), {})
def __getattribute__(self, attr):
"""Intercepts calls like A.hello_world()"""
return guard(super(), attr)
class Proxy(cls, metaclass=ClassInterceptor):
"""Proxies calls from caller to the callee by relaying all
attribute calls through a guarding mechanism.
We use `__getattribute__` for relaying calls. Opposed to `__getattr__`,
this is called regardless of whether the attribute exists or not.
We need to distinguish two cases: callee is an instance vs. a class.
If callee is an instance, `__getattribute__` will look and find attributes
at the class level.
If callee is a class, `__getattribute__` will look for attributes at
_its_ class, which is `type`. Here, it won't find attributes.
We solve this a metaclass mixin which swaps `type` with a custom class
that supersets the callee's class. For mixins, any methods provided on
parent classes will be provided to the metaclass. We add a
`__getattribute__` to the metaclass as to allow it to fetch it from the
callees class.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__class__ = type(cls.__qualname__, (Proxy,), {})
def __getattribute__(self, attr):
"""Intercepts calls like a.hello_world()"""
return guard(super(), attr)
return Proxy
return wrapped_func(cls)
return validator
def get_torch_version():
"""Get pytorch version from __version__; if not available use pip's. Use caching."""
def get_torch_version_str():
import torch
if hasattr(torch, '__version__'):
return str(torch.__version__)
else:
return version("torch")
global _torch_version
if _torch_version is None:
_torch_version = PkgVersion(get_torch_version_str())
return _torch_version
def get_te_version():
"""Get TE version from __version__; if not available use pip's. Use caching."""
def get_te_version_str():
import transformer_engine as te
if hasattr(te, '__version__'):
return str(te.__version__)
else:
return version("transformer-engine")
global _te_version
if _te_version is None:
_te_version = PkgVersion(get_te_version_str())
return _te_version
def is_te_min_version(version, check_equality=True):
"""Check if minimum version of `transformer-engine` is installed."""
if check_equality:
return get_te_version() >= PkgVersion(version)
return get_te_version() > PkgVersion(version)
def get_torch_version():
"""Get torch version from __version__."""
global _torch_version
return _torch_version
def is_torch_min_version(version, check_equality=True):
"""Check if minimum version of `torch` is installed."""
if check_equality:
return get_torch_version() >= PkgVersion(version)
return get_torch_version() > PkgVersion(version)
def get_fa_version():
"""Get Flash attention version from __version__; if not available use pip's. Use caching."""
def get_fa_version_str():
import flash_attn as fa
if hasattr(fa, '__version__'):
return str(fa.__version__)
else:
return version("flash-attn")
global _fa_version
if _fa_version is None:
_fa_version = PkgVersion(get_fa_version_str())
return _fa_version
def is_fa_min_version(version, check_equality=True):
"""Check if minimum version of `flash-attn` is installed."""
if check_equality:
return get_fa_version() >= PkgVersion(version)
return get_fa_version() > PkgVersion(version)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def deprecate_inference_params(inference_context, inference_params):
"""Print warning for deprecated `inference_params`."""
if inference_context is None and inference_params is not None:
warnings.warn(
"`inference_params` renamed to `inference_context`, and will be "
"removed in `megatron-core` 0.13."
)
return inference_params
return inference_context
def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_initialized=True):
"""Issue a deprecation warning if tp_group is None and return the default tp group."""
# TODO(zijiey): remove this function later.
if tp_group is None:
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
warnings.warn(
"Warning: tp_group is None, using default tp group. "
"Passing tp_group will be mandatory soon",
DeprecationWarning,
stacklevel=2,
)
if is_expert:
tp_group = parallel_state.get_expert_tensor_parallel_group(
check_initialized=check_initialized
)
else:
tp_group = parallel_state.get_tensor_model_parallel_group(
check_initialized=check_initialized
)
return tp_group
def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False):
"""Get an attribute from a wrapped model.
If return_model_obj is true, return the object that has the 'attr' attribute;
otherwise, return the attribute directly."""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
if allow_none:
def condition(model, attr):
return not hasattr(model, attr)
else:
def condition(model, attr):
return getattr(model, attr, None) is None
while condition(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
if return_model_obj:
return model
return getattr(model, attr)
def get_model_type(model):
"""Returns model_type attribute"""
return get_attr_wrapped_model(model, 'model_type')
def get_model_xattn(model):
"""Returns whether the model has the xattn_needed attribute"""
try:
return get_attr_wrapped_model(model, 'xattn_needed')
except RuntimeError:
return False
def get_model_config(model):
"""Returns the config attribute, allowed to return None"""
return get_attr_wrapped_model(model, 'config', allow_none=False)
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
"""
Returns (potentially) a sub-tensor from the self.buffer for the given shape.
"""
required_len = reduce(operator.mul, tensor_shape, 1)
if (
self.buffer.get((name, dtype), None) is None
or self.buffer[(name, dtype)].numel() < required_len
):
self.buffer[(name, dtype)] = torch.empty(
required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
"""Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
"""
out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad)
out.data = inp.data
return out
class WrappedTensor:
"""
A wrapper for tensors that enables caller functions to pass an indirect reference
to callee functions. By wrapping the tensor, the caller's direct reference is removed,
allowing the tensor to be garbage collected once the callee unwraps and frees it.
"""
def __init__(self, tensor: torch.Tensor):
self._wrapper = [tensor]
def unwrap(self):
"""
Returns the wrapped tensor while deleting the internal reference.
Can only be called once.
"""
if len(self._wrapper) == 0:
raise RuntimeError(f"WrappedTensor has already been unwrapped")
return self._wrapper.pop(0)
class MakeViewlessTensor(torch.autograd.Function):
"""
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
"""
@staticmethod
def forward(ctx, inp, requires_grad):
"""Runs the fwd pass of _kernel_make_viewless_tensor"""
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
"""No-op"""
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
"""
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
"""
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg=None):
"""Assert that a tensor is not a view (i.e., its '._base' field is
not set)."""
if isinstance(tensor, list):
[assert_viewless_tensor(t) for t in tensor]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
f"likely accumulate over iterations). {extra_msg}"
)
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
"""Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
"""
assert_viewless_tensor(
tensor,
extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
% ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape),
)
tensor.data = new_data_tensor
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
return functools.partial(torch.nn.init.normal_, mean=0.0, std=sigma)
def scaled_init_method_normal(sigma, num_layers, multiplier=2.0):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(multiplier * num_layers)
return functools.partial(torch.nn.init.normal_, mean=0.0, std=std)
def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any):
"""If torch distributed is initialized, log only on rank
Args:
logger (logging.Logger): The logger to write the logs
args (Tuple[Any]): All logging.Logger.log positional arguments
rank (int, optional): The rank to write on. Defaults to 0.
kwargs (Dict[str, Any]): All logging.Logger.log keyword arguments
"""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == rank:
logger.log(*args, **kwargs)
else:
logger.log(*args, **kwargs)
def log_on_each_pipeline_stage(logger: logging.Logger, *args: Any, **kwargs: Any):
"""Log on first rank in each pipeline stage
Args:
logger (logging.Logger): The logger to write the logs
args (Tuple[Any]): All logging.Logger.log positional arguments
kwargs (Dict[str, Any]): All logging.Logger.log keyword arguments
"""
assert torch.distributed.is_initialized()
if (
parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0
and parallel_state.get_tensor_model_parallel_rank() == 0
):
logger.log(*args, **kwargs)
def check_param_hashes_across_dp_replicas(
model: List[torch.nn.Module], cross_check: bool = False
) -> bool:
"""Computes hashes of all parameters in model, all-gathers hashes across DP replicas,
and then checks for equality between the locally-computed hashes and those of other ranks.
NOTE: This function computes SHA-1 hashes on the CPU and thus needs to move all param
tensors from GPU to CPU first; as a result, this function is not intended to be called
very frequently in the main training loop.
Args:
model (List[torch.nn.Module]): List of model chunks whose parameter hashes need to
be checked.
cross_check (bool): If true, will check whether hashes match across all DP replicas.
Returns:
True if all param hashes match with corresponding hash on DP replica 0 or
across all replicas if cross_check is enabled, False otherwise.
"""
# Compute per-parameter hashes on this rank.
# Keep track of expert and non-expert parameters separately since they need to be
# all-gathered across different sets of ranks.
non_expert_params, expert_params = [], []
local_non_expert_param_hashes, local_expert_param_hashes = [], []
for model_chunk_id, model_chunk in enumerate(model):
for param_name, param in model_chunk.named_parameters():
param_hash = torch.frombuffer(
array.array(
'B', hashlib.sha1(param.data.to("cpu").float().numpy(force=True)).digest()
),
dtype=torch.uint8,
)
if getattr(param, 'allreduce', True):
non_expert_params.append((model_chunk_id, param_name, param))
local_non_expert_param_hashes.append(param_hash)
else:
expert_params.append((model_chunk_id, param_name, param))
local_expert_param_hashes.append(param_hash)
# Use data-modulo-expert parallel group to all-gather expert param hashes, regular
# data-parallel group for non-expert param hashes.
all_param_hashes_match = True
for params, local_param_hashes, all_gather_group in zip(
[non_expert_params, expert_params],
[local_non_expert_param_hashes, local_expert_param_hashes],
[parallel_state.get_data_parallel_group(), parallel_state.get_expert_data_parallel_group()],
):
# Collect per-parameter hashes across all ranks in group.
assert len(params) == len(local_param_hashes)
if len(params) == 0:
continue
local_param_hashes = torch.stack(local_param_hashes).cuda()
all_param_hashes = [
torch.zeros_like(local_param_hashes)
for _ in range(torch.distributed.get_world_size(all_gather_group))
]
torch.distributed.all_gather(all_param_hashes, local_param_hashes, group=all_gather_group)
# Make sure local per-parameter hash matches DP rank 0.
param_hashes_match = torch.equal(local_param_hashes, all_param_hashes[0])
if not param_hashes_match:
for i, (model_chunk_id, param_name, param) in enumerate(params):
if not torch.equal(local_param_hashes[i], all_param_hashes[0][i]):
rank = torch.distributed.get_rank()
logger.info(
f"[Rank {rank}] Hash not matching for {param_name} in model chunk"
f"{model_chunk_id}"
)
if cross_check:
# Make sure all ranks have the same hash.
all_param_hashes_match &= all(
map(lambda x: torch.equal(local_param_hashes, x), all_param_hashes)
)
else:
all_param_hashes_match &= param_hashes_match
return all_param_hashes_match
def make_tp_sharded_tensor_for_checkpoint(
tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
"""Helper for instantiating a ShardedTensor where the `tp_axis` dimension
is sharded across TP group.
Optionally, can provide offsets which prepend new dimensions to the tensor.
"""
prepend_axis_num = len(prepend_offsets)
new_offsets = []
tp_rank = parallel_state.get_tensor_model_parallel_rank()
dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True)
tp_size = parallel_state.get_tensor_model_parallel_world_size()
dp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True)
dp_replica_id = parallel_state.get_data_parallel_rank(with_context_parallel=True)
new_offsets.append((tp_axis + prepend_axis_num, tp_rank, tp_size))
if HAVE_DTENSOR and isinstance(tensor, DTensor):
# TP + FSDP2 sharding
dp_replica_id = 0
tensor = tensor._local_tensor
if tp_axis == 0:
# both FSDP2 and TP shards axis 0
# default MCore uses tp-cp-ep-dp-pp
# FSDP2 is compatibile with TP, CP
new_offsets[0] = (prepend_axis_num, tp_rank * dp_size + dp_rank, tp_size * dp_size)
else:
# FSDP2 shards axis 0 and TP shards some other axis
new_offsets.append((prepend_axis_num, dp_rank, dp_size))
if replica_id is None:
replica_id = (0, 0, dp_replica_id)
if hasattr(tensor, 'fully_shard_param_local_shard'):
assert len(replica_id) == 3, f'Expected replica_id format (PP, TP, DP), got: {replica_id}'
replica_id = (*replica_id[:2], 0)
sh_ten = ShardedTensor.from_rank_offsets_flat(
key,
tensor.fully_shard_param_local_shard,
tensor.shape,
*prepend_offsets,
(
tp_axis + prepend_axis_num,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_tensor_model_parallel_world_size(),
),
flattened_range=slice(*tensor.fully_shard_param_local_index),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
setattr(sh_ten, 'is_data_parallel_fully_shard', True)
return sh_ten
return ShardedTensor.from_rank_offsets(
key,
tensor,
*prepend_offsets,
*new_offsets,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_id=None, **kwargs):
"""Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group).
Optionally, can provide offsets which prepend new dimensions to the tensor.
"""
prepend_axis_num = len(prepend_offsets)
new_offsets = []
dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True)
dp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True)
dp_replica_id = parallel_state.get_data_parallel_rank(with_context_parallel=True)
if HAVE_DTENSOR and isinstance(tensor, DTensor):
# FSDP2 sharding
dp_replica_id = 0
tensor = get_full_tensor_if_necessary(tensor)
new_offsets.append((prepend_axis_num, dp_rank, dp_size))
if replica_id is None:
replica_id = (0, parallel_state.get_tensor_model_parallel_rank(), dp_replica_id)
if hasattr(tensor, 'fully_shard_param_local_shard'):
assert len(replica_id) == 3, f'Expected replica_id format (PP, TP, DP), got: {replica_id}'
replica_id = (*replica_id[:2], 0)
sh_ten = ShardedTensor.from_rank_offsets_flat(
key,
tensor.fully_shard_param_local_shard,
tensor.shape,
*prepend_offsets,
flattened_range=slice(*tensor.fully_shard_param_local_index),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
setattr(sh_ten, 'is_data_parallel_fully_shard', True)
return sh_ten
return ShardedTensor.from_rank_offsets(
key,
tensor,
*prepend_offsets,
*new_offsets,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
**kwargs,
)
def get_full_tensor_if_necessary(tensor):
"""For DTensor gets full tensor if some ranks will not have a local copy"""
need_full_tensor = False
for i in range(tensor.device_mesh.ndim):
if (
isinstance(tensor.placements[i], Shard)
and tensor.device_mesh.shape[i] > tensor.shape[tensor.placements[i].dim]
):
need_full_tensor = True
break
tensor = tensor.full_tensor() if need_full_tensor else tensor._local_tensor
return tensor
def to_local_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor:
"""Returns the local shard of the given tensor if it is a DTensor."""
with torch.no_grad():
return tensor.to_local() if HAVE_DTENSOR and isinstance(tensor, DTensor) else tensor
def get_data_parallel_group_if_dtensor(
tensor: Union[torch.Tensor, "DTensor"], data_parallel_group: "ProcessGroup" = None
) -> Optional["ProcessGroup"]:
"""Gets the data parallel group of the given tensor if it is a DTensor."""
if HAVE_DTENSOR and isinstance(tensor, DTensor):
current_group = tensor.device_mesh.get_group()
assert data_parallel_group is None or current_group == data_parallel_group
return current_group
return None
def prepare_input_tensors_for_wgrad_compute(grad_output, all_gathered_input):
"""Ensure grad_output is stored in a contiguous buffer."""
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
all_gathered_input = all_gathered_input.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if grad_output.dim() == 3:
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
all_gathered_input = all_gathered_input.view(
all_gathered_input.shape[0] * all_gathered_input.shape[1], all_gathered_input.shape[2]
)
return grad_output, all_gathered_input
if is_torch_min_version("1.13.0"):
dist_all_gather_func = torch.distributed.all_gather_into_tensor
else:
dist_all_gather_func = torch.distributed._all_gather_base
def drain_embedding_wgrad_compute(config, embedding_activation_buffer, grad_output_buffer, weight):
"""Helper for performing embedding wgrad GEMM's during the pipeline drain phase, pipelines the
AllGather and GEMM's.
Should only be used when pipeline model parallelism and gradient accumulation
fusion are enabled.
"""
assert len(embedding_activation_buffer) == len(
grad_output_buffer
), "Length of activation and gradient buffers need to be equal!"
import fused_weight_gradient_mlp_cuda
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
)
input = embedding_activation_buffer.pop(0)
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gathered_input = [None, None]
if config.sequence_parallel:
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu_0")
handle = dist_all_gather_func(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=False
)
all_gathered_input[0] = all_gather_buffer
all_gather_buffer = None
else:
all_gathered_input[0] = input
input = None
def wgrad_compute(all_gathered_input, grad_output, weight):
grad_output, all_gathered_input = prepare_input_tensors_for_wgrad_compute(
grad_output, all_gathered_input
)
if config.gradient_accumulation_fusion:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
all_gathered_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
all_gathered_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
# We have all_gathered_input list acting as a double buffer here,
# since we are pipelining the AllGather and GEMM,one buffer all gathers
# the input while the other buffer reads from it for the GEMM. We use i
# and (i+1) for indexing to enable this double buffering.
for i in range(len(embedding_activation_buffer)):
input = embedding_activation_buffer.pop(0)
if config.sequence_parallel:
name = "mpu_" + str((i + 1) % 2)
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, name)
handle = dist_all_gather_func(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
all_gathered_input[(i + 1) % 2] = all_gather_buffer
all_gather_buffer = None
else:
all_gathered_input[(i + 1) % 2] = input
grad_output = grad_output_buffer.pop(0)
wgrad_compute(all_gathered_input[i % 2], grad_output, weight)
drain_idx = (i + 1) % 2
input, all_gathered_input[i % 2], grad_output = None, None, None
if config.sequence_parallel:
handle.wait()
grad_output = grad_output_buffer.pop(0)
wgrad_compute(all_gathered_input[drain_idx], grad_output, weight)
input, all_gathered_input[drain_idx], grad_output = None, None, None
def local_multi_tensor_applier(op, noop_flag_buffer, tensor_lists, *args):
"""Multi tensor op applier"""
return op(2048 * 32, noop_flag_buffer, tensor_lists, *args)
# computes l2 norm for a list of contiguous tensors
# works as a drop-in replacement for amp_C.multi_tensor_l2norm
def local_multi_tensor_l2_norm(chunk_size, noop_flag, tensor_lists, per_tensor, *args):
"""
Computes l2 norm for a list of contiguous tensors
works as a drop-in replacement for amp_C.multi_tensor_l2norm
"""
l2 = [[(torch.norm(tensor)) for tensor in tensor_list] for tensor_list in tensor_lists]
l2_reduced = torch.norm(torch.tensor(l2))
l2_cuda = torch.tensor([float(l2_reduced)], dtype=torch.float, device='cuda')
return l2_cuda, None
# works as a drop-in replacement for amp_C.multi_tensor_scale
def local_multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale):
"""Works as a drop-in replacement for amp_C.multi_tensor_scale."""
for src, dst in zip(tensor_lists[0], tensor_lists[1]):
dst.copy_(src * scale)
class _ValueWithRank:
"""This is an internal class, not for use outside this module
Attributes:
_rank (int): rank for the value
_value (float) : the value it stores, eg elapsed time
_unit (str) : unit for the value
"""
def __init__(self, value: float, rank: int, unit: str = "") -> None:
"""Initializer
Args:
_value (float): the initial value with which it is inited
_rank (int): the rank number
_unit (str) : the unit of the value, eg ms or flops
"""
self._rank = rank
self._value = value
self._unit = unit
def __lt__(self, other) -> bool:
"""Check if value of self is smaller than other's value
Args:
other (_ValueWithRank): The other object to compare with
Returns:
bool: True if lhs._value of operand is less than rhs._value, else False
"""
return self._value < other._value
def __gt__(self, other) -> bool:
"""Check if value of self is larger than other's value
Args:
other (_ValueWithRank): The other object to compare with
Returns:
bool: True if lhs._value of operand is greater than rhs._value, else False
"""
return self._value > other._value
def __call__(self) -> Tuple[float, int, str]:
"""Returns the value, the rank, and unit as a Tuple
Returns:
Tuple[float, int, str]: value, rank, unit
"""
return self._value, self._rank, self._unit
def __str__(self) -> str: