日本熟妇hd丰满老熟妇,中文字幕一区二区三区在线不卡 ,亚洲成片在线观看,免费女同在线一区二区

自定義模型接入TorchAcc

阿里云PAI為您提供了部分典型場景下的示例模型,便于您便捷地接入TorchAcc進行訓練加速,同時也支持將自行開發(fā)的模型接入TorchAcc進行加速,本文為您介紹如何在自定義模型中接入TorchAcc以提高分布式訓練速度和效率。

背景信息

TorchAcc的優(yōu)化方式分為以下兩類,您可以根據(jù)實際需求選擇合適的優(yōu)化方式,以提高模型訓練速度和效率。

  • 編譯優(yōu)化

    TorchAcc支持將PyTorch動態(tài)圖轉換為靜態(tài)圖,并進行計算圖優(yōu)化和編譯,以提高模型訓練速度和效率。TorchAcc會將計算圖轉換為高效的計算圖,并使用JIT編譯器將其編譯為更高效的代碼。這樣可以避免PyTorch動態(tài)圖計算過程中的一些性能損失,并提高模型訓練速度和效率。

  • 定制優(yōu)化

    當模型包含Dynamic Shape、Custom算子、Dynamic ControlFlow等特性時,暫時無法應用全局編譯優(yōu)化進行分布式訓練加速。針對此類場景,TorchAcc提供了定制優(yōu)化:

    • IO優(yōu)化

    • 計算(Kernel)優(yōu)化

    • 顯存優(yōu)化

TorchAcc編譯優(yōu)化

接入分布式訓練

接入TorchAcc的Compiler進行分布式訓練,具體操作步驟如下:

  1. 固定隨機種子。

    通過固定隨機種子保證每個Worker權重的初始化保持一致,用于代替權重broadcast的效果。

    torch.manual_seed(SEED_NUMBER)
    替換為:
    xm.set_rng_state(SEED_NUMBER)
  2. 在獲取xla_device后,調用set_replication、封裝dataloader并設置model device placement。

    device = xm.xla_device()
    xm.set_replication(device, [device])
    
    # Wrapper dataloader
    data_loader_train = pl.MpDeviceLoader(data_loader_train, device)
    data_loader_val = pl.MpDeviceLoader(data_loader_val, device)
    
    # Dispatch device to model
    model.to(device)
  3. 分布式初始化。

    dist.init_process_group的backend參數(shù)配置為'xla'

    dist.init_process_group(backend='xla', init_method='env://')
  4. 梯度allreduce通信。

    在loss backward后對梯度進行allreduce操作:

    gradients=xm._fetch_gradients(optimizer)
    xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())
    重要

    如果使用混合精度AMP訓練,且手動調用了scaler.unscale_,一定要在scaler.unscale_之前調用xm.all_reduce,以確保基于all_reduce之后的梯度進行溢出檢測。

  5. 使用xlarun拉起任務。

    xlarun --nproc_per_node=8 YOUR_MODEL.py
    說明

    多機情況使用方法與torchrun相同。

接入混合精度

通過混合精度訓練可以加速模型訓練速度,在單卡訓練或分布式訓練的基礎上按照以下步驟完成AMP邏輯的實現(xiàn)。在上一章節(jié)基礎上接入混合精度進行TorchAcc編譯優(yōu)化的具體操作步驟如下。

  1. 按照pytorch原生功能實現(xiàn)AMP。

    TorchAcc混合精度與Pytorch原生混合精度使用方法基本一致,請先參照以下文檔實現(xiàn)Pytorch原生的AMP功能。

  2. 替換GradScaler。

    torch.cuda.amp.GradScaler替換為torchacc.torch_xla.amp.GradScaler

    from torchacc.torch_xla.amp import GradScaler
  3. 替換optimizer。

    使用原生PyTorch optimizer性能會稍差,可將torch.optim的optimizer替換為syncfree optimizer來進一步提升訓練速度。

    from torchacc.torch_xla.amp import syncfree
    
    adam_optimizer = syncfree.Adam()
    adamw_optimizer = syncfree.AdamW()
    sgd_optimizer = syncfree.SGD()

    目前syncfree optimizer只提供了以上三類optimizer的實現(xiàn),其它類型optimizer可繼續(xù)使用PyTorch原生optimizer即可。

接入案例

以Bert-base模型為例,代碼示例如下:

import argparse
import os
import time
import torch
import torch.distributed as dist

from datasets import load_from_disk
from datetime import datetime as dt
from time import gmtime, strftime
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding

# Pytorch1.12 default set False.
torch.backends.cuda.matmul.allow_tf32=True

parser = argparse.ArgumentParser()
parser.add_argument("--amp-level", choices=["O1"], default="O1", help="amp level.")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile_folder", type=str, default="./profile_folder")
parser.add_argument("--dataset_path", type=str, default="./sst_data/train")
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--max_seq_length", type=int, default=512)
parser.add_argument("--break_step_for_profiling", type=int, default=20)
parser.add_argument("--model_name", type=str, default="bert-base-cased")
parser.add_argument("--local_rank", type=int, default="-1")
parser.add_argument("--log-interval", type=int, default="10")
parser.add_argument('--max-steps', type=int, default=200, help='total training epochs.')
args = parser.parse_args()
print("Job running args: ", args)
args.local_rank = os.getenv("LOCAL_RANK", 0)


+def enable_torchacc_compiler():
+    return os.getenv('TORCHACC_COMPILER_OPT') is not None


def print_rank_0(message):
    """If distributed is initialized, print only on rank 0."""
    if torch.distributed.get_rank() == 0:
        print(message, flush=True)


def print_test_update(epoch, step, batch_size, loss, time_elapsed, samples_per_step, peak_mem):
  # Getting the current date and time
  dt = strftime("%a, %d %b %Y %H:%M:%S", gmtime())
  print_rank_0(train_format_string.format(dt, epoch, step, batch_size, loss, time_elapsed, samples_per_step, peak_mem))


def log_metrics(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem):
    batch_time = f"{batch_time:.3f}"
    samples_per_step = f"{samples_per_step:.3f}"
    peak_mem = f"{peak_mem:.3f}"
+    if enable_torchacc_compiler():
+        import torchacc.torch_xla.core.xla_model as xm
+        xm.add_step_closure(
+            print_test_update, args=(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem), run_async=True)
+    else:
        print_test_update(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem)


+if enable_torchacc_compiler():
+  import torchacc.torch_xla.core.xla_model as xm
+  import torchacc.torch_xla.distributed.parallel_loader as pl
+  import torchacc.torch_xla.distributed.xla_backend
+  from torchacc.torch_xla.amp import autocast, GradScaler, syncfree
+  xm.set_rng_state(101)
+  dist.init_process_group(backend="xla", init_method="env://")
+else:
  from torch.cuda.amp import autocast, GradScaler
  dist.init_process_group(backend="nccl", init_method="env://")

dist.barrier()
args.world_size = dist.get_world_size()
args.rank = dist.get_rank()
print("world size:", args.world_size, " rank:", args.rank, " local rank:", args.local_rank)


def get_autocast_and_scaler():
+  if enable_torchacc_compiler():
+    return autocast, GradScaler()

  return autocast, GradScaler()


def loop_with_amp(model, inputs, optimizer, autocast, scaler):
  with autocast():
    outputs = model(**inputs)
    loss = outputs["loss"]

  scaler.scale(loss).backward()
+  if enable_torchacc_compiler():
+    gradients = xm._fetch_gradients(optimizer)
+    xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())
  scaler.step(optimizer)
  scaler.update()

  return loss, optimizer


def loop_without_amp(model, inputs, optimizer):
  outputs = model(**inputs)
  loss = outputs["loss"]
  loss.backward()
+  if enable_torchacc_compiler():
+    xm.optimizer_step(optimizer)
+  else:
    optimizer.step()
  return loss, optimizer


def full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler, profiler=None):
  model.train()

  iteration_time = time.time()
  num_steps = int(len(train_device_loader.dataset) / args.batch_size)
  for step, inputs in enumerate(train_device_loader):
    if step > args.max_steps:
      break
+    if not enable_torchacc_compiler():
      inputs.to(device)

    optimizer.zero_grad()

    if args.amp_level == "O1":
      loss, optimizer = loop_with_amp(model, inputs, optimizer, autocast, scaler)
    else:
      loss, optimizer = loop_without_amp(model, inputs, optimizer)

    if args.profile and profiler:
      profiler.step()

    if step % args.log_interval == 0:
      time_elapsed = (time.time() - iteration_time) / args.log_interval
      iteration_time = time.time()
      samples_per_step = float(args.batch_size / time_elapsed) * args.world_size
      peak_mem = torch.cuda.memory_allocated()/1024.0/1024.0/1024.0
      log_metrics(epoch, step, args.batch_size, loss, time_elapsed, samples_per_step, peak_mem)


def train_bert():
  model = AutoModelForSequenceClassification.from_pretrained(args.model_name, cache_dir="./model")
  tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  tokenizer.model_max_length = args.max_seq_length

  training_dataset = load_from_disk(args.dataset_path)
  collator = DataCollatorWithPadding(tokenizer)
  training_dataset = training_dataset.remove_columns(['text'])
  train_device_loader = torch.utils.data.DataLoader(
      training_dataset, batch_size=args.batch_size, collate_fn=collator, shuffle=True, num_workers=4)

+  if enable_torchacc_compiler():
+    device = xm.xla_device()
+    xm.set_replication(device, [device])
+    train_device_loader = pl.MpDeviceLoader(train_device_loader, device)
+    model = model.to(device)
+  else:
    device = torch.device(f"cuda:{args.local_rank}")
    torch.cuda.set_device(device)
    model = model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model)


+  if enable_torchacc_compiler() and args.amp_level == "O1":
+    optimizer = syncfree.Adam(model.parameters(), lr=1e-3)
+  else:
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

  autocast, scaler = None, None
  if args.amp_level == "O1":
    autocast, scaler = get_autocast_and_scaler()

  if args.profile:
    with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=2, warmup=2, active=20),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(args.profile_folder)) as prof:
      for epoch in range(args.num_epochs):
        full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler, profiler=prof)
  else:
    for epoch in range(args.num_epochs):
      full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler)


if __name__ == "__main__":
  train_bert()

TorchAcc定制優(yōu)化

IO優(yōu)化

Data Prefetcher

支持預先讀取訓練數(shù)據(jù),且提供preprocess_fn參數(shù)支持數(shù)據(jù)預處理。

+ from torchacc.runtime.io.prefetcher import Prefetcher

data_loader = build_data_loader()
model = build_model()
optimizer = build_optimizer()

# define preprocess function
preprocess_fn = None

+ prefetcher = Prefetcher(data_loader, preprocess_fn)

for iter, samples in enumerate(prefetcher):
    loss = model(samples)
    loss.backward()

    # Prefetch to CPU first. Call after backward and before update.
    # At this point we are waiting for kernels launched by cuda graph
    #  to finish, so CPU is idle. Take advantage of this by loading next
    #  input batch before calling step.
+    prefetcher.prefetch_CPU()

    optimizer.step()
    
    # Prefetch to GPU. Call after optimizer step.
+	prefetcher.prefetch_GPU()

Pack Dataset

語言數(shù)據(jù)集都存在變長的情況,例如文本句子、語音等。為了提高計算效率,利用樣本的長短不一致的問題,將幾個樣本打包到一起,組成一個固定shape的batch,減少padding的0值占比和batch data的動態(tài)性,從而提高EPOCH的(分布式)訓練效率。

pin memory

在dataloader定義時增加pin_memory參數(shù),并適量增加num_workers。image..png

計算優(yōu)化

Kernel Fusion優(yōu)化

支持以下幾種優(yōu)化方式:

  • FusedLayerNorm

    # LayerNorm的等價替換kernel
    from torchacc.runtime import hooks
    # add before import torch
    hooks.enable_fused_layer_norm()
  • FusedAdam

    # Adam/AdamW的等價替換kernel
    from torchacc.runtime import hooks
    # add before import torch
    hooks.enable_fused_adam()
  • QuickGelu

    # 用QuickGelu替換nn.GELU
    from torchacc.runtime.nn.quick_gelu import QuickGelu
  • fused_bias_dropout_add

    # from torchacc.runtime.nn import dropout_add_fused_train, 
    #將Dropout和element-wise的bias add等操作fuse起來
    if self.training:
        # train mode
        with torch.enable_grad():
            x = dropout_add_fused_train(x, to_add, drop_rate)
    else:
        # inference mode
        x = dropout_add_fused(x, to_add, drop_rate)
  • WindowProcess

    # WindowProcess優(yōu)化kernel 融合了SwinTransformer中關于shift window及window劃分的操作,包括 - window cyclic shift和window partition - window merge和reverse cyclic shift。
    from torchacc.runtime.nn.window_process import WindowProcess
    
    if not fused:
        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
    else:
        x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
    
    
    
    from torchacc.runtime.nn.window_process import WindowProcessReverse
    
    if not fused:
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    else:
        x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
  • FusedSwinFmha

    # 融合了SwinTransformer中MHA的qk_result + relative_position_bias + mask + softmax部分
    from torchacc.runtime.nn.fmha import FusedSwinFmha
    
    FusedSwinFmha.apply(attn, relative_pos_bias, attn_mask, batch_size, window_num,
                  num_head, window_len)
  • nms/nms_normal/soft_nms/batched_soft_nms

    # 融合了nms/nms_normal/soft_nms/batched_soft_nms等四類算子cuda kernel實現(xiàn)。
    
    from torchacc.runtime.nn.nms import nms, nms_normal
    from torchacc.runtime.nn.nms import soft_nms, batched_soft_nms
    

Parallelized Kernel優(yōu)化

DCN/DCNv2:

# 對dcn_v2_cuda后向進行了并行計算優(yōu)化。
from torchacc.runtime.nn.dcn_v2 import DCN, DCNv2

self.conv = DCN(chi, cho, kernel_size, stride, padding, dilation, deformable_groups)

Multi-stream Kernel優(yōu)化

利用多個stream來并發(fā)計算函數(shù)的一組輸入,計算邏輯同mmdet.core.multi_apply函數(shù)。

from torchacc.runtime.utils.misc import multi_apply_multi_stream
from mmdet.core import multi_apply

def test_func(t1, t2, t3):
  t1 = t1 * 2.0
  t2 = t2 + 2.0
  t3 = t3 - 2.0
  return (t1, t2, t3)

cuda = torch.device('cuda')
t1 = torch.empty((100, 1000), device=cuda).normal_(0.0, 1.0)
t2 = torch.empty((100, 1000), device=cuda).normal_(0.0, 2.0)
t3 = torch.empty((100, 1000), device=cuda).normal_(0.0, 3.0)

if enable_torchacc:
    result = multi_apply_multi_stream(test_func, 2, t1, t2, t3)
else:
    result = multi_apply(test_func, t1, t2, t3)

顯存優(yōu)化

Gradient Checkpointing

import torchacc

model = torchacc.auto_checkpoint(model)