Critical batch size in ML training
Seunghyun Seo @SeunghyunSEO7 2025-01-09
The concept of critical batch size is quite simple.
-
Let’s assume we have a training dataset with 1M tokens. If we use a batch size of 10, we can update model param 100,000 times. On the other hand, if we increase the batch size to 100, the step size decreases to 10,000 (1/n).
-
With a larger batch size, we use more samples to estimate the averaged gradient for each update, which can produce more accurate gradients by reducing variance and bringing the sample mean closer to the true gradient. (2/n)
-
However, this comes at the cost of fewer opportunities to improve the model parameters (step size is reduced). To reach a certain level of convergence, the LR must be adjusted accordingly, often scaled by \sqrt(n) or n when increasing the batch size by a factor of n (3/n)
-
But what happens if we increase the batch size to an extreme value, like 10,000? In this scenario, the number of updates drops drastically to just 100. Beyond a certain point (e.g., 1,000 or 2,000), the improvement in the averaged gradient diminishes as batch size increases.(4/n)
-
This certain point is critical batch size. In such cases, more parallelization (increasing batch size) becomes ineffective, as achieving convergence would still require a minimum number of parameter updates (e.g. at least 1,000) because gradient will not be improved anymore (5/n)
-
In this case, increasing batch size becomes wastes of training tokens. and you can’t even scale LR to compensate decrease step size because of the overshoot problem. (6/n)
-
Critical batch size is not directly related to model size but is instead determined by the achievable loss. you know, modern large language models (LLMs) often employ batch size ramp-up strategies. (7/n)
-
the reason why is in the early phases of training, the critical batch size is much smaller because the model is primarily learning low-level features. (8/n)
-
However, since the target loss is often formulated by computing budget, C=6ND (as proposed by Kaplan et al.), it is indirectly influenced by model size, N and the number of training tokens D. Consequently, larger model often have larger critical batch size. (9/n)
-
Finally, someone have claimed that xAI could train the DeepSeek V3 model in just one day with their 100k H100 GPUs (Colossus) but this is unlikely due to critical batch size. Regardless of the computational power available, sufficient time is required for the model learn. (10/n)
references (11/n)