日本熟妇hd丰满老熟妇,中文字幕一区二区三区在线不卡 ,亚洲成片在线观看,免费女同在线一区二区

AIGC:TorchAcc提速Stable Diffusion分布式訓練

阿里云PAI為您提供了部分典型場景下的示例模型,便于您便捷地接入TorchAcc進行訓練加速。本文為您介紹如何在Stable Diffusion分布式訓練中接入TorchAcc并實現訓練加速。

測試環境配置

測試環境配置方法,請參見配置測試環境

本案例以DSW環境V100M16卡型為例,例如:節點規格選擇ecs.gn6v-c8g1.16xlarge-64c256gNVIDIA V100 * 8

接入TorchAcc加速Stable Diffusion分布式訓練

DSW環境為例:

  1. 進入DSW實例頁面下載并解壓測試代碼及腳本文件。

    1. 交互式建模(DSW)頁面,單擊DSW實例操作列下的打開

    2. Notebook頁簽的Launcher頁面,單擊快速開始區域Notebook下的Python3

    3. 執行以下命令下載并解壓測試代碼及腳本文件。

      !wget http://odps-release.cn-hangzhou.oss.aliyun-inc.com/torchacc/accbench/gallery/stable_diffusion.tar.gz && tar -zxvf stable_diffusion.tar.gz
  2. 進入stable-diffusion目錄,雙擊打開stable_diffusion.ipynb文件。

    后續,您可以直接在該文件中運行下述步驟中的命令,當成功運行結束一個步驟命令后,再順次運行下個步驟的命令。image..png

  3. 執行以下命令下載類Imagenet-1k的mock數據集并安裝Stable Diffusion模型依賴的第三方包。

    !bash prepare.sh
  4. 分別使用普通訓練方法(baseline)和接入TorchAcc進行Stable Diffusion模型分布式訓練,來驗證TorchAcc的性能提升效果。

    說明
    • 在測試不同GPU卡型(例如V100、A10等)時,可以通過調整batch_size來適配不同卡型的顯存大小。

    • 在測試不同機器實例時,由于單機GPU卡數不同(假設為N),因此可以通過設置nproc_per_node來啟動單卡或多卡的任務,其中:1<=nproc_per_node<=N。

    • Pytorch Eager單卡(baseline訓練)

      !#!/bin/bash
      
      !set -ex
      
      !python launch_single_task.py --batch_size=4 --nproc_per_node=1
    • Pytorch Eager八卡(baseline訓練)

      !#!/bin/bash
      
      !set -ex
      
      !python launch_single_task.py --batch_size=4 --nproc_per_node=8
    • TorchAcc單卡(PAI-OPT)

      !#!/bin/bash
      
      !set -ex
      
      !python launch_single_task.py --batch_size=4 --nproc_per_node=1 --compiler-opt
    • TorchAcc八卡(PAI-OPT)

      !#!/bin/bash
      
      !set -ex
      
      !python launch_single_task.py --batch_size=4 --nproc_per_node=8 --compiler-opt

    其中:普通訓練方法和接入TorchAcc訓練方法的優化配置如下:

    • baseline:Torch112+DDP+AMPO1

    • PAI-Opt:Torch112+TorchAcc+AMPO1

  5. 執行以下命令,獲取性能數據結果。

    import os
    from plot import plot, traverse
    from parser import parse_file
    #import seaborn as sns
    
    if __name__ == '__main__':
        path = "output"
        file_names = {}
        traverse(path, file_names)
    
        for model, tags in file_names.items():
            for tag, suffixes in tags.items():
                title = model + "_" + tag
                label = []
                api_data = []
                for suffix, o_suffixes in suffixes.items():
                    label.append(suffix)
                    for output_suffix, node_ranks in o_suffixes.items():
                        assert "0" in node_ranks
                        assert "log" in node_ranks["0"]
                        parse_data = parse_file(node_ranks["0"]["log"])
                        api_data.append(parse_data)
                plot(title, label, api_data)

    對于V100M16卡型,由于顯存有限,batch_size設置的值比較小,無法獲得較大程度的加速效果。但在實際場景中,經過在A10上的測試驗證,使用TorchAcc在單卡和多卡上均能夠獲得40%以上的提速效果。關于接入TorchAcc更詳細的代碼實現原理,請參見代碼實現原理

代碼實現原理

基于StableDiffusion使用三方包pytorch-lighting==1.8.6版本時,可以直接導入stable-diffusion目錄下的pl_hooks.py和logger.py完成TorchAcc接入。

Import TorchAcc API

main函數import處添加以下代碼,具體請參考main.py文件中35-45行代碼:

from logger import create_logger, enable_torchacc_compiler, enable_torchacc_kernel, log_params, log_metrics

+if enable_torchacc_compiler():
+    from torchacc.torch_xla.amp import GradScaler
+    import torchacc.torch_xla.distributed.xla_backend
+    import torchacc.torch_xla.core.xla_model as xm
+    import torchacc.torch_xla.distributed.parallel_loader as ploader
+    dist.get_rank = xm.get_ordinal
+    dist.get_world_size = xm.xrt_world_size
+    device = xm.xla_device()
+    xm.set_replication(device, [device])
+else:
    from torch.cuda.amp import GradScaler

Enable Pytorch-lightning hook

使用pl_hooks.py文件的enable_pl_hooks.py完成TorchAcc接入,具體請參考main.py文件中588行代碼:

from pl_hooks import enable_pl_hooks

+if enable_torchacc_compiler():
+    from torchacc.torch_xla.amp import syncfree
+    torch.optim.Adam = syncfree.Adam
+    torch.optim.AdamW = syncfree.AdamW
+    torch.optim.SGD = syncfree.SGD
+if opt.use_pl_logger:
+    os.environ["USE_PL_LOGGER"] = "1"
+if opt.log_freq is not None:
+    os.environ["LOG_FREQ"] = str(opt.log_freq)
    
+enable_pl_hooks() # call hook of acclerate