I am running https://github.com/jadore801120/attention-is-all-you-need-pytorch this code that has implemented the paper Attention-is-all-you-need with PyTorch with Gigaword dataset, so the task is changed into text summarization from the original paper's intent, machine translation (dataset name is WST2016).
The dataset is much bigger in Gigaword comparing to WST2016, total size of training dataset is 3,800,000 and 29,056 respectively.
The problem is, when I train Gigaword, exactly when the graph trains the 31% of the entire data, it gets dramatically slows down, up to 30 times become slower linearly with GPU util become 0-15% from 99-100%.
I can't figure out why the training gets slow down exactly at 31% of a batch.
What would be the possible start point to debug this problem?