Meta-Learning Mini-Batch Risk Functionals
Supervised learning typically optimizes the expected value risk functional of the loss, but in many cases, we want to optimize for other risk functionals. In full-batch gradient descent, this is done by taking gradients of a risk functional of interest, such as the Conditional Value at Risk (CVaR) which ignores some quantile of extreme losses. However, deep learning must almost always use mini-batch gradient descent, and lack of unbiased estimators of various risk functionals make the right optimization procedure unclear. In this work, we introduce a meta-learning-based method of learning an interpretable mini-batch risk functional during model training, in a single shot. When optimizing for various risk functionals, the learned mini-batch risk functions lead to risk reduction of up to 10 functionals. Then in a setting where the right risk functional is unknown a priori, our method improves over baseline by 14 analyze the learned mini-batch risk functionals at different points through training, and find that they learn a curriculum (including warm-up periods), and that their final form can be surprisingly different from the underlying risk functional that they optimize for.
READ FULL TEXT