Composerdistributes work across devices via data-parallelism-only.We made this design choice in order to provide the most flexibility to algorithms,which can modify the training loop in complex ways. Data parallelismgreatly simplifies model building and memory management. Every GPU isperforming the same work, so inspecting the rank zero is sufficient toreason about memory, performance, and other properties.
Within Composer, we have three options for data-parallelism-onlyexecution: Pytorch DDP (default), Pytorch FSDP, and DeepSpeed Zero.Although Pytorch DDP is the default, Pytorch FSDP increases memory and computationalefficiency when configured correctly while producing the same results and is the recommended option.
The port on the master hosting the C10d TCP store. If you are running multiple trainers on a single node, this generally needs to be unique for each one. Overrides env var MASTER_PORT. Defaults to a random free port in a single-node environment.
Developers may need to access the current rank or world size in adistributed setting. For example, a callback may only want to logsomething for rank zero. Use our composer.utils.dist module toretrieve this information. The methods are similiar totorch.distributed, but also return defaults in a non-distributedsetting.
When providing torch.utils.data.Dataset which is not torch.utils.data.IterableDatasetwith torch.utils.data.DataLoader to Composer, a torch.utils.data.distributed.DistributedSampleris necessary to ensure different devices receive different batches. Composer willraise an error if a DistributedSampler is not provided. composer.utils.distprovides a helper function to create a DistributedSampler with the correctparameters in composer.utils.dist.get_sampler().
composer.datasets.StreamingDataset is an IterableDataset so aDistributedSampler is not supported as IterableDatasets need to handle multi-workertraining internally. See IterableDataset [docs]( )for more information
Composer comes with DeepSpeed support, allowing you to leverage theirfull set of features that makes it easier to train large models across(1) any type of GPU and (2) multiple nodes. For more details on DeepSpeed,see their website.
At a high level, when you use the Composer Trainer, you must pass it a ComposerModel likeComposerGPTthat defines certain functions like forward, eval_forward, loss, etc. thatare called during the training loop.
Inside that ComposerModel you may have one or many submodules, such as a .model or.language_model or .classifier that is the actual torch.nn.Module that youwill be deploying at inference time. In our case, this is theGPTmodule that we build and attach ComposerGPT.model.
When you provide an parallelism_config='fsdp': ... dictionary to the Composer Trainer,then on __init__, the Trainer will attempt to wrap each of the submodules of yourComposerModel with an FSDP auto wrap policy. This wrapping is recursive, so not only isGPT wrapped, but all submodules of GPT may/may not be wrapped too. See theFSDP documentation for more details on how autowrap policies work.
To save and load sharded checkpoints with FSDP, you can make use of the field, state_dict_type in fsdp_config.Depending on the value you set for state_dict_type, you can get different checkpointing behavior:
1. state_dict_type='full'The default. Saves one big checkpoint file for the whole model.It does this by gathering the model state to the global rank 0 device, unflattening it, and then saving it out.If load_monolith_rank0_only=True, then when loading checkpoints the global rank 0 device will loadin the checkpoint file and scatter the model and optimizer state to the other ranks, which will willdramatically reduce the memory usage on system. Otherwise, all ranks will separately load in the checkpoint file.
Composer with PyTorch version 2.0.0 and higher does support elastic checkpointing (more ranks than checkpoint files or more files than ranks), so you can resume on a different number of ranks than you saved on.
Sentence Transformers implements two forms of distributed training: Data Parallel (DP) and Distributed Data Parallel (DDP). Read the Data Parallelism documentation on Hugging Face for more details on these strategies. Some of the key differences include:
In short, DDP is generally recommended. You can use DDP by running your normal training scripts with torchrun or accelerate. For example, if you have a script called train_script.py, you can run it with DDP using the following command:
Fully Sharded Data Parallelism (FSDP) is another distributed training strategy that is not fully supported by Sentence Transformers. It is a more advanced version of DDP that is particularly useful for very large models. Note that in the previous comparison, FSDP reaches 5782 samples per second (2.122x speedup), i.e. worse than DDP. FSDP only makes sense with very large models. If you want to use FSDP with Sentence Transformers, you have to be aware of the following limitations:
You have to use fsdp=["full_shard", "auto_wrap"] and fsdp_config="transformer_layer_cls_to_wrap": "BertLayer" in your SentenceTransformerTrainingArguments, where BertLayer is the repeated layer in the encoder that houses the multi-head attention and feed-forward layers, so e.g. BertLayer or MPNetLayer.
In this blog post, we will look at how to fine-tune Llama 2 70B using PyTorch FSDP and related best practices. We will be leveraging Model Database Transformers, Accelerate and TRL. We will also learn how to use Accelerate with SLURM.
Fully Sharded Data Parallelism (FSDP) is a paradigm in which the optimizer states, gradients and parameters are sharded across devices. During the forward pass, each FSDP unit performs an all-gather operation to get the complete weights, computation is performed followed by discarding the shards from other devices. After the forward pass, the loss is computed followed by the backward pass. In the backward pass, each FSDP unit performs an all-gather operation to get the complete weights, with computation performed to get the local gradients. These local gradients are averaged and sharded across the devices via a reduce-scatter operation so that each device can update the parameters of its shard. For more information on what PyTorch FSDP is, please refer to this blog post: Accelerate Large Model Training using PyTorch Fully Sharded Data Parallel.
Saving entire intermediate checkpoints using FULL_STATE_DICT with CPU offloading on rank 0 takes a lot of time and often results in NCCL Timeout errors due to indefinite hanging during broadcasting. However, at the end of training, we want the whole model state dict instead of the sharded state dict which is only compatible with FSDP.
Below is the output snippet on a 7B model on 2 GPUs measuring the memory consumed and model parameters at various stages. We can observe that during loading the pre-trained model rank 0 & rank 1 have CPU total peak memory of 32744 MB and 1506 MB , respectively. Therefore, only rank 0 is loading the pre-trained model leading to efficient usage of CPU RAM. The whole logs at be found here
It is addressed via choosing SHARDED_STATE_DICT state dict type when creating FSDP config. SHARDED_STATE_DICT saves shard per GPU separately which makes it quick to save or resume training from intermediate checkpoint. When FULL_STATE_DICT is used, first process (rank 0) gathers the whole model on CPU and then saving it in a standard format.
The resulting config is available here: fsdp_config.yaml. Here, the sharding strategy is FULL_SHARD. We are using TRANSFORMER_BASED_WRAP for auto wrap policy and it uses _no_split_module to find the Transformer block name for nested FSDP auto wrap. We use SHARDED_STATE_DICT to save the intermediate checkpoints and optimizer states in this format recommended by the PyTorch team. Make sure to enable broadcasting module parameters from rank 0 at the start as mentioned in the above paragraph on addressing Challenge 1. We are enabling bf16 mixed precision training.
Flash Attention and enabling gradient checkpointing are required for faster training and reducing VRAM usage to enable fine-tuning and save compute costs. The codebase currently uses monkey patching and the implementation is at chat_assistant/training/llama_flash_attn_monkey_patch.py.
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness introduces a way to compute exact attention while being faster and memory-efficient by leveraging the knowledge of the memory hierarchy of the underlying hardware/GPUs - The higher the bandwidth/speed of the memory, the smaller its capacity as it becomes more expensive.
If we follow the blog Making Deep Learning Go Brrrr From First Principles, we can figure out that Attention module on current hardware is memory-bound/bandwidth-bound. The reason being that Attention mostly consists of elementwise operations as shown below on the left hand side. We can observe that masking, softmax and dropout operations take up the bulk of the time instead of matrix multiplications which consists of the bulk of FLOPs.
This is precisely the problem that Flash Attention addresses. The idea is to remove redundant HBM reads/writes. It does so by keeping everything in SRAM, perform all the intermediate steps and only then write the final result back to HBM, also known as Kernel Fusion. Below is an illustration of how this overcomes the memory-bound bottleneck.
Tiling is used during forward and backward passes to chunk the NxN softmax/scores computation into blocks to overcome the limitation of SRAM memory size. To enable tiling, online softmax algorithm is used. Recomputation is used during backward pass in order to avoid storing the entire NxN softmax/score matrix during forward pass. This greatly reduces the memory consumption.
For a simplified and in depth understanding of Flash Attention, please refer the blog posts ELI5: FlashAttention and Making Deep Learning Go Brrrr From First Principles along with the original paper FlashAttention: Fast and Memory-Efficient Exact Attentionwith IO-Awareness.
3a8082e126