Knowledge distillation for better convergence in multitask learning

机器学习
自然语言处理
海外精选
海外精选的内容汇集了全球优质的亚马逊云科技相关技术内容。同时,内容中提到的“AWS” 是 “Amazon Web Services” 的缩写,在此网站不作为商标展示。
0
0
{"value":"![下载.jpg](https://dev-media.amazoncloud.cn/e886e99e541f44b7bb8acfdb790d4b5e_%E4%B8%8B%E8%BD%BD.jpg)\n\nEach task in a multitask model (left) typically has its own loss function, and during training, the functions converge at different rates (right). During training, a new method attempts to preserve gains (dotted lines) on tasks whose performance has peaked.\n\n![下载 1.jpg](https://dev-media.amazoncloud.cn/2d2de2624964407bab70e18944efcc00_%E4%B8%8B%E8%BD%BD%20%281%29.jpg)\n\nValidation curves in a five-task multitask learning setup, where training minimizes the sum of the task losses. The tasks corresponding to the blue, purple, and red curves show signs of overfitting, while the tasks corresponding to the orange and green curves are underfitted at the end of training.\n\nMultitask learning (MTL) typically involves jointly optimizing the losses of a set of tasks. One naive approach is to simply minimize the sum of the losses. However, the convergence speeds of the tasks can differ according to task difficulty. This naive training approach is usually suboptimal, because the model can end up overfitting some tasks and underfitting others.\n\nTo address this issue, many existing methods aim to balance the learning speed across tasks, by facilitating or inhibiting the learning of each individual task, such that all tasks have roughly the same convergence rate. These methods include applying static loss weights, dynamically adjusting loss weights during training, and manipulating the gradients of different tasks.\n\n![下载 2.jpg](https://dev-media.amazoncloud.cn/051b1ed8c43e4c3a889d23641fb26cae_%E4%B8%8B%E8%BD%BD%20%282%29.jpg)\n\nIllustration of the idea of the proposed approach. As the validation curve of each task reaches its peak point, we switch to a knowledge distillaion loss for that task from that point going forward, in the hope that we will be able to achieve the dotted lines, where the performance of each task is kept at its peak level until the end of training.\n\nIn a ++[paper](https://www.amazon.science/publications/asynchronous-convergence-in-multi-task-learning-via-knowledge-distillation-from-converged-tasks)++ we presented in the NAACL 2022 ++[industry track](https://www.amazon.science/blog/naacl-industry-track-offers-reality-checks-new-directions)++, we propose a method for achieving convergence in MTL that improves on approaches that artificially enforce the same convergence rate across tasks. Instead, we let the tasks converge on their own schedules, and when a task converges, we switch to a knowledge distillation (KD) loss in order to keep the task's performance at the best level while the model learns the remaining tasks. The figure below illustrates the idea.\n\nWe evaluate the proposed method in two five-task MTL setups consisting of proprietary e-commerce datasets. The results show that our method consistently outperforms existing loss-weighting and gradient-balancing approaches, achieving average improvements of 0.9% and 1.5%, respectively over the best-performing baseline model in the two setups.\n\n#### **Asynchronous convergence via knowledge distillation**\n\nOur proposed method works as follows:\n- After the model converges on a task, we use its best-performing parameter values and run inference on the task’s training set, recording the predictions.\n- For the remaining training steps, we use these predictions as soft labels to train the model on the converged task, while using real labels to train on the remaining tasks.\n- We repeat this until all tasks converge. \n\nWith this method, we experiment with two different training settings:\n\n- Joint setting: Train on all tasks together and switch to KD losses as the model converges on different tasks.\n- Joint setting: Train on all tasks together and switch to KD losses as the model converges on different tasks.\n\n#### **Experiments and results**\n\nWe evaluate our approach using two five-task setups, where the tasks are proprietary e-commerce tasks. The tasks in the first setup are more similar to each other and are all classification tasks, while the ones in the second setup are more diverse in terms of application and task type. We evaluate on these two benchmarks to test the effectiveness and robustness of our method in different MTL scenarios.\n\nIn both setups, both the joint and sequential settings substantially outperform the baseline methods. Our best results are, on average, higher by 0.9% and by 1.5% than the best-performing baseline, respectively, in the two setups.\n\nBelow are the validation curves of the baseline that simply minimizes the sum of task losses and of our proposed joint and sequential settings in the first five-task setup. We can observe that none of the tasks in the joint and sequential settings shows a downward trend, suggesting that the method is indeed effective in maintaining the performance of converged tasks at the best level while the model learns the remaining tasks.\n\n![下载 3.jpg](https://dev-media.amazoncloud.cn/7c2fff5012194ef0bdada3118ce29684_%E4%B8%8B%E8%BD%BD%20%283%29.jpg)\n\nValidation curves of the baseline that simply minimizes the sum of task losses and of our proposed joint and sequential settings.\n\nABOUT THE AUTHOR\n#### **[Weiyi Lu](https://www.amazon.science/author/weiyi-lu)**\nWeiyi Lu is an applied scientist at Amazon.\n\n","render":"<p><img src=\\"https://dev-media.amazoncloud.cn/e886e99e541f44b7bb8acfdb790d4b5e_%E4%B8%8B%E8%BD%BD.jpg\\" alt=\\"下载.jpg\\" /></p>\n<p>Each task in a multitask model (left) typically has its own loss function, and during training, the functions converge at different rates (right). During training, a new method attempts to preserve gains (dotted lines) on tasks whose performance has peaked.</p>\n<p><img src=\\"https://dev-media.amazoncloud.cn/2d2de2624964407bab70e18944efcc00_%E4%B8%8B%E8%BD%BD%20%281%29.jpg\\" alt=\\"下载 1.jpg\\" /></p>\n<p>Validation curves in a five-task multitask learning setup, where training minimizes the sum of the task losses. The tasks corresponding to the blue, purple, and red curves show signs of overfitting, while the tasks corresponding to the orange and green curves are underfitted at the end of training.</p>\n<p>Multitask learning (MTL) typically involves jointly optimizing the losses of a set of tasks. One naive approach is to simply minimize the sum of the losses. However, the convergence speeds of the tasks can differ according to task difficulty. This naive training approach is usually suboptimal, because the model can end up overfitting some tasks and underfitting others.</p>\n<p>To address this issue, many existing methods aim to balance the learning speed across tasks, by facilitating or inhibiting the learning of each individual task, such that all tasks have roughly the same convergence rate. These methods include applying static loss weights, dynamically adjusting loss weights during training, and manipulating the gradients of different tasks.</p>\n<p><img src=\\"https://dev-media.amazoncloud.cn/051b1ed8c43e4c3a889d23641fb26cae_%E4%B8%8B%E8%BD%BD%20%282%29.jpg\\" alt=\\"下载 2.jpg\\" /></p>\n<p>Illustration of the idea of the proposed approach. As the validation curve of each task reaches its peak point, we switch to a knowledge distillaion loss for that task from that point going forward, in the hope that we will be able to achieve the dotted lines, where the performance of each task is kept at its peak level until the end of training.</p>\n<p>In a <ins><a href=\\"https://www.amazon.science/publications/asynchronous-convergence-in-multi-task-learning-via-knowledge-distillation-from-converged-tasks\\" target=\\"_blank\\">paper</a></ins> we presented in the NAACL 2022 <ins><a href=\\"https://www.amazon.science/blog/naacl-industry-track-offers-reality-checks-new-directions\\" target=\\"_blank\\">industry track</a></ins>, we propose a method for achieving convergence in MTL that improves on approaches that artificially enforce the same convergence rate across tasks. Instead, we let the tasks converge on their own schedules, and when a task converges, we switch to a knowledge distillation (KD) loss in order to keep the task’s performance at the best level while the model learns the remaining tasks. The figure below illustrates the idea.</p>\n<p>We evaluate the proposed method in two five-task MTL setups consisting of proprietary e-commerce datasets. The results show that our method consistently outperforms existing loss-weighting and gradient-balancing approaches, achieving average improvements of 0.9% and 1.5%, respectively over the best-performing baseline model in the two setups.</p>\n<h4><a id=\\"Asynchronous_convergence_via_knowledge_distillation_20\\"></a><strong>Asynchronous convergence via knowledge distillation</strong></h4>\\n<p>Our proposed method works as follows:</p>\n<ul>\\n<li>After the model converges on a task, we use its best-performing parameter values and run inference on the task’s training set, recording the predictions.</li>\n<li>For the remaining training steps, we use these predictions as soft labels to train the model on the converged task, while using real labels to train on the remaining tasks.</li>\n<li>We repeat this until all tasks converge.</li>\n</ul>\\n<p>With this method, we experiment with two different training settings:</p>\n<ul>\\n<li>Joint setting: Train on all tasks together and switch to KD losses as the model converges on different tasks.</li>\n<li>Joint setting: Train on all tasks together and switch to KD losses as the model converges on different tasks.</li>\n</ul>\\n<h4><a id=\\"Experiments_and_results_32\\"></a><strong>Experiments and results</strong></h4>\\n<p>We evaluate our approach using two five-task setups, where the tasks are proprietary e-commerce tasks. The tasks in the first setup are more similar to each other and are all classification tasks, while the ones in the second setup are more diverse in terms of application and task type. We evaluate on these two benchmarks to test the effectiveness and robustness of our method in different MTL scenarios.</p>\n<p>In both setups, both the joint and sequential settings substantially outperform the baseline methods. Our best results are, on average, higher by 0.9% and by 1.5% than the best-performing baseline, respectively, in the two setups.</p>\n<p>Below are the validation curves of the baseline that simply minimizes the sum of task losses and of our proposed joint and sequential settings in the first five-task setup. We can observe that none of the tasks in the joint and sequential settings shows a downward trend, suggesting that the method is indeed effective in maintaining the performance of converged tasks at the best level while the model learns the remaining tasks.</p>\n<p><img src=\\"https://dev-media.amazoncloud.cn/7c2fff5012194ef0bdada3118ce29684_%E4%B8%8B%E8%BD%BD%20%283%29.jpg\\" alt=\\"下载 3.jpg\\" /></p>\n<p>Validation curves of the baseline that simply minimizes the sum of task losses and of our proposed joint and sequential settings.</p>\n<p>ABOUT THE AUTHOR</p>\n<h4><a id=\\"Weiyi_Luhttpswwwamazonscienceauthorweiyilu_45\\"></a><strong><a href=\\"https://www.amazon.science/author/weiyi-lu\\" target=\\"_blank\\">Weiyi Lu</a></strong></h4>\n<p>Weiyi Lu is an applied scientist at Amazon.</p>\n"}
目录
亚马逊云科技解决方案 基于行业客户应用场景及技术领域的解决方案
联系亚马逊云科技专家
亚马逊云科技解决方案
基于行业客户应用场景及技术领域的解决方案
联系专家
0
目录
关闭