本文中含有需要您注意的重要提示信息,忽略該信息可能對您的業務造成影響,請務必仔細閱讀。
在大規模分布式異步訓練中,您可以使用WorkQueue進行彈性數據切分,以緩解長尾效應,從而降低模型訓練所需的時間。本文介紹WorkQueue的調用格式、參數及其提供的方法。同時,以文件數據源和MaxCompute表數據源為例,介紹實現數據切分的經典示例。
公共云GPU服務器即將過保下線,您可以繼續提交CPU版本的TensorFlow任務。如需使用GPU進行模型訓練,請前往DLC提交任務,具體操作請參見創建訓練任務。
背景信息
在大規模分布式異步訓練中,如果每個Worker讀取相同數量的樣本,則慢節點的訓練時長會遠大于其他節點,造成長尾效應。并且隨著訓練規模擴大,長尾效應會越來越嚴重,導致訓練的整體數據吞吐降低,進而增加訓練時間。
為解決該問題,PAI提供了pai.data.WorkQueue
類,支持對多種數據源進行彈性數據切分,讓慢節點獲取較少的訓練數據,快節點獲取更多的訓練數據,以緩解長尾效應,從而降低模型訓練所需的時間。
版本配套關系
Python版本:Python 2.7
PAI-TensorFlow版本:PAI-TensorFlow 1.12
pai.data.WorkQueue
功能
工作項隊列類,用于統一管理所有Worker上的工作項。每個Worker的當前剩余工作項被消費完后,會從同一個WorkQueue獲得新的工作項,并將其作為數據源進行訓練,從而使得訓練快的Worker獲得更多的工作項進行訓練,以減少長尾效應。
格式
class pai.data.WorkQueue(works, num_epochs=1, shuffle=True, seed=None, prefix=None, num_slices=None, name='work_queue')
參數
參數名
描述
類型
是否必選
默認值
works
文件名或表名列表。
LIST of STRING
是
無
num_epochs
讀取全部數據的次數。
INT
否
1
shuffle
是否每個Epoch都隨機重洗數據,取值如下:
True:每個Epoch都隨機重洗數據。
False:不進行數據重洗。
BOOL
否
True
seed
重洗數據的隨機種子。取值為None時,表示系統自動選取隨機種子。
INT
否
None
prefix
工作項(文件名或表名)的前綴。取值為None時,表示無前綴。
STRING
否
None
num_slices
工作項的總數量。集群越不穩定,需要將工作項總數量配置的越大,通常為Worker數量的10倍以上。取值為None時,表示不分片。
INT
否
None
num_clients
工作隊列支持的最大工作搶占并發數。
INT
否
1
name
工作隊列的名稱。
STRING
否
work_queue
返回值
返回WorkQueue對象,您可以使用該對象調用
pai.data.WorkQueue
類提供的方法。
pai.data.WorkQueue提供的方法
pai.data.WorkQueue
類提供以下方法:
take
功能
從全局工作隊列獲取一個工作項,并下載至本地。
格式
WorkQueue.take()
參數
無
返回值
返回值類型為
tensorflow.Tensor
。
input_dataset
功能
返回一個Dataset,其每個元素為一個工作項。
格式
WorkQueue.input_dataset()
參數
無
返回值
返回值類型為
tensorflow.data.Dataset
。
input_producer
功能
返回全局工作隊列在本地的代理隊列,為Reader類Op使用。
格式
WorkQueue.input_producer()
參數
無
返回值
返回值類型為
tensorflow.FIFOQueue
。
add_summary
功能
在Tensorboard中顯示WorkQueue的資源水位信息。
格式
WorkQueue.add_summary()
參數
無
返回值
無
典型示例
pai.data.WorkQueue
類支持對多種數據源進行彈性數據切分,以下分別以文件數據源和MaxCompute表數據源為例,介紹如何使用pai.data.WorkQueue
類實現彈性數據切分(僅提供核心代碼片段):
文件數據源
import pai # ... # path1、path2及path3表示需要讀取的文件列表。 # shuffle取值為True,表示每個Epoch都隨機化打散文件路徑。 work_queue = pai.data.WorkQueue([path1, path2, path3], shuffle=True) # 讓WorkQueue支持TensorBoard。 work_queue.add_summary() # 創建文件讀取器。 reader = tf.TextLineReader() # 從文件列表中讀取2條記錄。 keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2) with tf.train.MonitoredTrainingSession() as sess: sess.run(...)
MaxCompute表數據源
TableRecordDataset數據源
import pai #... # odps_path1、odps_path2及odps_path3表示需要讀取的MaxCompute表列表。 # shuffle取值為True,表示每個Epoch都隨機化打散表路徑。 # num_slices為工作項總數量。 # FLAGS.num_workers為訓練中的Worker數量。 work_queue = pai.data.WorkQueue([odps_path1, odps_path2, odps_path3],shuffle=True, num_slices=FLAGS.num_workers * 10) # 創建文件名Dataset。 filenames_dataset = work_queue.input_dataset() # 將dataset作為文件名傳入TableRecordDataset。 dataset = tf.data.TableRecordDataset(filenames_dataset, record_defaults=...)
關于
tf.data.TableRecordDataset
接口的調用,請參見TableRecordDataset。TableRecordReader數據源
import pai # ... # odps_path1、odps_path2及odps_path3表示需要讀取的MaxCompute表列表。 # shuffle取值為True,表示每個Epoch都隨機化打散表路徑。 # num_slices為工作項總數量。 # FLAGS.num_workers為訓練中的Worker數量。 work_queue = pai.data.WorkQueue( [odps_path1, odps_path2, odps_path3], shuffle=True, num_slices=FLAGS.num_workers * 10) # 創建表讀取器。 reader = tf.TableRecordReader() # 從表中讀取2條記錄。 keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2)