Transformers多机多卡的炼丹实践

前言

随着预训练语言模型的快速发展,很多问题可以通过堆数据和堆模型参数简单粗暴的有效解决。所以亲自训练一个大模型一定是每个NLPer都想尝试的事,这时候就需要进行多机多卡的分布式训练了。本文是一篇踩坑后的总结,介绍如何基于huggingface的transformers库来快速实现。

注意本文仅涉及数据并行,而不涉及模型并行。所以参考本文可以自己从零训练一个bert,bert-large等,但想训练万亿参数的超大模型(一张显卡都不能存储模型的参数)就需要更复杂的实现了。

基本概念

  • node_rank: 节点的编号
  • rank: 全局进程的编号
  • local_rank: 单个节点上进程的编号
  • word_size: 全局总进程的数量
  • master ip: master进程的ip地址
  • master port: master进程的端口

一般rank编号为0的进程会作为master进程

具体举个例子:当前有2个节点,每个节点有8块GPU卡,然后启动多机多卡的分布式训练用满这16块卡,这时候:

  • node_rank: [0,1]
  • rank: [0,1,2,3,4,…,15]
  • loacal_rank: 节点1上[0,1,2,..,7], 节点2上[0,1,2,..,7]
  • word_size: 16

如果通过python -m torch.distributed.launch的方式启动,部分参数都会自动注入到环境变量中,可以在脚本中进行获取。例如:

1
2
3
rank = int(os.environ.get('RANK'))
local_rank = int(os.environ.get('LOCAL_RANK'))
word_size = int(os.environ.get('WORLD_SIZE'))

IterableDataset

训练大模型一定是基于大数据,可能非常大(例如上百GB),所以不能采用map-style的dataset作为训练集的dataset,因为无法直接load到内存中,所以需要采用IterableDataset。同时为了训练的数据较快,需要采用多进程的数据加载,即num_worker>0

这时候假设在一个2个节点,每个节点8张卡的分布式环境中,同时采用10个子进程进行数据加载。
那么此时在数据加载阶段启用的进程总数为: 2 * 8 * 10 = 160

在IterableDataset如果直接按照最简单的写法,如下所示:

1
2
3
4
5
6
7
8
9
10
11
class CustomIterableDataset(IterableDataset):

def __init__(self, data_file):
self.data_file = data_file

def __iter__(self):
while True:
with open(self.data_file, 'rt', encoding='utf-8') as f:
for line in f:
print(f"系统进程号{os.getpid()}, 加载的数据{line.strip()}")
yield line.strip()

通过日志打印可以发现,同一份数据将被160个进程重复加载,这显然就不是数据并行了。

所以迭代阶段就需要进行精细的处理,避免一份数据被多个进程重复加载。参考Pytorch官方的文档,可以发现实际上是预计算出每个子进程需要迭代的区间,然后结合子进程的信息找到对应的区间进行迭代。

https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)

# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))

# Mult-process loading with two worker processes
# Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))

# With even more workers
print(list(torch.utils.data.DataLoader(ds, num_workers=20)))

但是这里只处理了单块GPU卡多子进程加载数据的写法,我们这里是分布式的多机多卡,所以还需要对以上代码进行改造,进一步引入rank的信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class CustomIterableDataset(IterableDataset):

def __init__(self, data_file, num_lines):
self.data_file = data_file
self.start = 0
self.end = num_lines
self.word_size = int(os.environ.get('WORLD_SIZE'))
self.rank = int(os.environ.get('RANK'))

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
total_data_loader_count = worker_info.num_workers * self.word_size
per_worker_lines = int(math.ceil((self.end - self.start) / total_data_loader_count))
iter_start = self.start + (self.rank * worker_info.num_workers + worker_info.id) * per_worker_lines
while True:
with open(self.data_file, 'rt', encoding='utf-8') as f:
# 跳过区间之前的数据
for _ in range(iter_start):
f.readline()
print(f"系统进程号:{os.getpid()} rank编号:{} dataloader子进程编号:{worker_info.id}, 开始加载数据")
# 开始加载数据,长度为per_worker_lines
for _ in range(per_worker_lines):
line = f.readline().strip()
print(f"系统进程号{os.getpid()}, 加载的数据{line.strip()}")
yield line

train_file_lines = sum(1 for line in open('big_file.txt')
train_dataset = CustomIterableDataset('big_file.txt', train_file_lines)

通过以上处理后,可以观察到160个进程每个进程加载的数据都是不一样的。

Trainer is ALL YOU NEED

如果采用原生的torch.distributed.launch进行多机多卡的训练是需要写很多范式代码的,例如init_process_group。而采用transformersTrainer自动帮你去适配你当前的环境,也就是无论是单机单卡,还是单机多卡,还是多机多卡都是一份代码,并且对于fp16这种配置也就是一个参数项。最近发现Trainer实现分布式训练的底层逻辑其实已经进一步抽象成了一个新的库huggingface/accelerate

贴一段这个库的描述:🚀 A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision

再贴一段这个库与transformerstrainer的关系:https://github.com/huggingface/accelerate/issues/144

所以直接使用Trainer就好了!! 下面是一段示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# 定义模型
model = CustomModel()
# 定义数据集
train_file_lines = sum(1 for line in open(args.train_data_file)
train_dataset = CustomIterableDataset(args.train_data_file, train_file_lines)
dev_dataset = CustomDataset(args.eval_data_file)
# 计算整体迭代步数max_steps
total_training_samples = train_file_lines * args.num_train_epochs
total_batch_size = args.train_batch_size * int(os.environ.get('WORLD_SIZE'))
max_steps = math.ceil(total_training_samples / total_batch_size)
# 构造TrainingArguments
train_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
do_train=True,
do_eval=True,
evaluation_strategy="steps",
logging_steps=args.logging_steps,
eval_steps=args.eval_steps,
save_steps=args.eval_steps,
overwrite_output_dir=True,
save_total_limit=1,
local_rank=int(os.environ.get('LOCAL_RANK', -1)),
learning_rate=args.learning_rate,
metric_for_best_model='eval_loss',
fp16=True,
max_steps=max_steps,
dataloader_num_workers=10,
)
# 构造Trainer
trainer = Trainer(
model=model,
args=train_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
# 开始炼丹!!
trainer.train()

train.sh

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# 单机单卡
CUDA VISIBLE DEVICE=0 python train.py \
--train_data_file $DATA_DRT/train.txt \
--eval_data_file $DATA_DRT/dev.txt \
--learning_rate 5e-5 \
--train_batch_size 128 \
--eval_batch_size 128 \
--eval_steps 1000 \
--num_train_epochs 10

# 单机多卡
python -m torch.distributed.launch \
--nproc_per_node=8 \
train.py \
--train_data_file $DATA_DRT/train.txt \
--eval_data_file $DATA_DRT/dev.txt \
--learning_rate 5e-5 \
--train_batch_size 128 \
--eval_batch_size 128 \
--eval_steps 1000 \
--num_train_epochs 10

# 多机多卡
python -m torch.distributed.launch \
--nproc_per_node=8 \
--use_env \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
--train_data_file $DATA_DRT/train.txt \
--eval_data_file $DATA_DRT/dev.txt \
--learning_rate 5e-5 \
--train_batch_size 128 \
--eval_batch_size 128 \
--eval_steps 1000 \
--num_train_epochs 10

其他tips


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!