机器学习洞察 | JAX,机器学习领域的“新面孔”

Python
深度学习
Numpy
机器学习
0
0
在之前的《[机器学习](https://aws.amazon.com/cn/machine-learning/?trk=cndc-detail)洞察》系列文章中,我们分别针对于多模态[机器学习](https://aws.amazon.com/cn/machine-learning/?trk=cndc-detail)和分布式训练、[无服务器](https://aws.amazon.com/cn/serverless/?trk=cndc-detail)推理进行了解读,本文将为您重点介绍 JAX 的发展并剖析其演变和动机。下面,就让我们来认识一下 JAX 这一新崛起的深度学习框架—— ## **开源机器学习框架的演进** ![image.png](https://dev-media.amazoncloud.cn/06f39917375a4876852e8ba571834619_image.png "image.png") 从这张 GitHub Star 趋势图可以看到,自 2019 年 JAX 出现到如今保持着一个向上的抛物线走势。 在考察一个开源[机器学习](https://aws.amazon.com/cn/machine-learning/?trk=cndc-detail)框架时,例如开发者熟知的 PyTorch, TensorFlow, MXNet 等,往往会从支持模型的广泛性、部署的成熟性、生态系统的丰富性来对它做一个评估:包括是否支持 Hugging Face 等主流模型,以及其框架相关研究论文的数量,还有它可提供复现代码的论文数量等等。 ## **JAX 的源起** - 为什么 Eager 模式是在 TensorFlow 1.4 版本之后引入的? - Eager 模式在 TensorFlow 2.0 之后变成了一个默认的执行模式,和原有的 Graph 模式的区别是什么? - … 回归并理清这些历史问题有助于开发者了解[机器学习](https://aws.amazon.com/cn/machine-learning/?trk=cndc-detail)的演变逻辑,并了解 JAX 是如何吸取之前的教训,帮助开发者更方便地实践深度学习或[机器学习](https://aws.amazon.com/cn/machine-learning/?trk=cndc-detail)应用。 ### **Eager 模式 V.S. Graph 模式** 在 TF 引进了 Eager 模式之后,它会采用更直观的界面,使用自然的 Python 代码和数据结构,而且享受更加便携的调试,在 Eager 模式中可以通过直接调用操作来检查和测试模型,而之前 Graph 这种模式有点类似于 C 和 C++,它的编程是写好程序之后要先进行编译才能运行。 Eager 模式有自然控制的流程,使用 Python 而不是图控制流,以及支持 GPU 和 TPU 的加速。做为开发者,我们希望可以客观地看待不同的框架,而不是比较他们的优劣。值得思考的一个问题是:通过了解 TF 的 Eager 模式对于 Graph 模式的改进,它的改进逻辑和思路在 JAX 中都有身影。 ## **什么是 JAX** ![image.png](https://dev-media.amazoncloud.cn/2ae38ac7608a462588b4b2819bc7e86d_image.png "image.png") JAX 作为现在越来越流行的库,是一种类似于 NumPy(使用 Python 开源的数值计算扩展库)的轻量级用于阵列的计算。JAX 最开始的设计不仅仅是为了深度学习而设计的,深度学习只是它的一小部分,它提供了编写 NumPy 程序的能力,这些程序可以使用 GPU/TPU 自动拆分和加速。 JAX 用于基于阵列的计算时,开发者无需修改代码就可以在 CPU/GPU/ASIC 上同时运行,并支持原生 Python 和 NumPy 函数的四种可组合函数转换: - 自动微分 (Autodiff) - 即时编译 (JIT compilation) - 自动向量化 (Vectorization) - 代码并行化 (Parallelization) ### **JAX 初体验** 我们可以通过下面这个简单的测试对比 JAX 和 NumPy 的计算性能。 输入一个 100 X 100 的二维数组 X,选取 ml.g4dn.12xlarge 计算实例通过 NumPy 和 JAX 分别对矩阵的前三次幂求和: ``` def fn(x): return x + x*x + x*x*x x = np.random.randn(10000, 10000).astype(dtype='float32') %timeit -n5 fn(x) 436 ms ± 206 µs per loop (mean ± std. dev. of 7 runs, 5 loops each) ``` 我们发现此计算大约需要 436 毫秒。接下来,我们使用 JAX 实现以下计算: ``` jax_fn = jit(fn) x = jnp.array(x) %timeit jax_fn(x).block_until_ready() 3.67 ms ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` JAX 仅在 3.67 毫秒内执行此计算,比 NumPy 快 118 倍以上。可见,JAX 有可能比 NumPy 快几个数量级(注意,JAX 使用 TPU 而 NumPy 正在使用 CPU)。 **以上为个人测试结果,非官方提供的数据,仅供研究参考* 对比测试结果可得,NumPy 完成计算需要 436 毫秒,而 JAX 仅需要 3.67 毫秒,计算速度相差 100 多倍。这个测试也说明了为什么很多开发者对它的性能赞不绝口。 ## **JAX 的动机剖析** 我们希望通过回答这个问题来解读 JAX 的动机: 如何使用 Python 从头开始实现高性能和可扩展的深度神经网络? ### **在 NumPy 中创建深度学习系统** 通常,Python 程序员会从 NumPy 之类的东西开始,因为它是一种熟悉的、基于数组的数据处理语言,在 Python 社区中已经使用了几十年。如果你想在 NumPy 中创建深度学习系统,你可以从预测方法开始。 这里可以用一个详细的例子说明问题,从 NumPy 上的深度学习的场景说起: ![image.png](https://dev-media.amazoncloud.cn/a6245f7bb8a24d449daeabd27ed48c6a_image.png "image.png") 上述代码展示了订阅一个前馈的神经网络,它执行了一系列的点积和激活函数,然后将输入转化为某种可以学习的输出。一旦定义了这样的一个模型,接下来需要做就是要定义损失函数,这个函数将为你提供正在尝试优化的那些指标,来适应最佳的[机器学习](https://aws.amazon.com/cn/machine-learning/?trk=cndc-detail)模型。例如以上代码的损失函数是以均方误差损失函数 MSE 为例。 现在我们来分析下:在深度学习场景使用 NumPy 还缺少什么? 硬件加速 (GPU/TPU) 自动微分 (autodiff) 快速优化 添加编译 (Compilation) 融合操作 向量化操作批处理 (batching) 大型数据集并行化 (Parallelization) 1)硬件加速 (GPU/TPU):首先深度学习需要大量的计算,我们想在加速的硬件上运行它。所以我们想在 GPU 和 TPU/ASIC 上运行这个模型,这对于经典的 NumPy 来说有点困难; 2)自动微分 (autodiff) 快速优化:接下来我们想要做自动微分,这样就可以有效地拟合这个损失函数,而不必自己来实现数值微分; 3)然后我们需要添加编译 (Compilation):这样你就可以将这些操作融合在一起,使它们更加高效; 4)向量化操作批处理 (Batching):另外,当我们编写了某些函数后,可能希望将其应用于多个数据片段,而不再需要重写预测和损失函数来处理这些批量数据; 5)大型数据集并行化 (Parallelization):最后,如果我们正在处理大型数据集,会希望能够支持跨多个 cores 或多台 machines 做并行化操作。 ## **JAX 的动机剖析:XLA 和自动定位** JAX 非常重要的一个动机就是 XLA 和自动定位。让我们来看看 JAX 可以做些什么,来填补前面分析的在深度学习场景使用 NumPy 还缺少的功能。 ![image.png](https://dev-media.amazoncloud.cn/91daba687d0741048e60ff8c46601b43_image.png "image.png") 首先,用 jax.numpy 替换 numpy 导入模块。在许多情况下,jax.numpy 与经典的 NumPy 具有相同的 API,但 jax.numpy 可以完成前面分析时发现 NumPy 缺少,但是在深度学习场景却非常需要的的东西。 **JAX 可以通过 XLA 后端,来自动定位 CPU、GPU 和 TPU 或者 ASIC,以便快速计算模型和算法。** ### **JAX 动机剖析:Autograd** 第二个重要动机是 Autograd。开发者可以通过下面的代码调用 Autograd 版本: ![image.png](https://dev-media.amazoncloud.cn/881484df5c744cf891210bf10452b96a_image.png "image.png") 通过 from jax import grad 模块,使用 Autograd 的更新版本,**JAX 可以自动微分原生 Python 和 NumPy 函数**。它可以处理 Python 功能的大子集,包括循环、Ifs、递归等,甚至可以接受导数的导数。 JAX 提供了一组可组合的变换,其中之一是 **grad 变换**。 例子中,像 mse_loss 这样的损失函数,通过 grad (mse_loss) 将其转换为计算梯度的 Python 函数。 **Autograd 的主要预期应用是基于梯度的优化。** 有关更多信息,请查看 JAX 教程和示例:Https://github.com/hips/autograd ### **JAX 动机剖析:vmap** 在使用梯度函数时,开发者希望将其应用于多个数据片段,而在 JAX 中,你不再需要重写预测和损失函数来处理这些批量数据。 ![image.png](https://dev-media.amazoncloud.cn/39867227bda74f66906d88cf0dac397a_image.png "image.png") 如图中代码最后一行 (perexample_grads …) 所诠释的那样,如果你通过 vmap transform 传递它,这会自动向量化这个代码,这样就可以在多个批次中使用相同的代码。 ### **JAX 动机剖析:jit** JAX 还有一个重要的组合函数——jit,开发者可以使用 jit transform 实现即时编译。 ![image.png](https://dev-media.amazoncloud.cn/2e700d4381a14e4e9f92384e6f158cbb_image.png "image.png") jit 结合后台可以使用 XLA 后端编译器将操作融合在一起,来自动定位 CPU、GPU 和 TPU 或者 ASIC,加速计算模型和算法。 ### **JAX 动机剖析:pmap** 最后,如果想并行化你的代码,有一个和 vmap 非常相似得转换叫 pmap。 通过代码运行 pmap,开发者能够本地定位系统中的多个内核或你有权访问的 GPU、TPU 或 ASIC 集群。 ![image.png](https://dev-media.amazoncloud.cn/955248f46f2f481c9785481a2a11b8cf_image.png "image.png") 这最终成为一个非常强大的系统,**可以在没有太多额外代码的情况下构建我们用类似于 NumPy 的熟悉 API,做深度学习的快速计算等工作负载。** JAX 的关键设计思想 通过上述对比可以看到, JAX 不仅为开发者提供了和 NumPy 相似的 API,上述的五大函数转换组合也让 JAX 可以在不需要额外代码的情况下,帮助开发者构建深度学习应用进行快速计算。 这里的关键思想是: 1)首先,在 JAX 中,Python 代码被追溯到中间表示,JAX 知道如何转换这个中间表示。 2)在下篇文章中我们也将详细分析 JAX 的工作机制:同样的中间表示,通过允许 XLA 进行特定领域 (CPU/GPU 等) 的编译,如何来瞄准不同的后端; 3)另外,JAX 还有基于 NumPy 和 SciPy 的面向用户的 API,如果开发者一直使用 Python 的技术栈,应该会对 JAX 感觉相当熟悉; 4)最后,JAX 提供了**功能强大的变换:grad, git, vmap, pmap 等,来支持深度学习等计算**,因此 JAX 可以做到之前 NumPy 代码无法做到的事情。 通过前面的介绍,我们可以看到,开发者熟悉的 API 和语法以及四种强大的转换组合让开发者更加喜欢 JAX,并让深度学习场景或者科学计算变得非常简便。篇幅有限,我们将在下篇文章中通过实例为开发者深入介绍 JAX 的工作机制和生态场景。您也可以在 Build On Cloud 视频号观看这一部分的视频演讲: <video src="https://dev-media.amazoncloud.cn/bfecb1dbadc344f0a7d012903863f9f6_%E9%BB%84%E6%B5%A9%E6%96%87JAX1%20_%E6%96%B0%E9%9D%A2%E5%AD%94.mp4" class="bytemdVideo" controls="controls"></video> 欢迎回顾关于[机器学习](https://mp.weixin.qq.com/mp/appmsgalbum?__biz=Mzg5Mzg1NDc2NQ==&action=getalbum&album_id=2674381074877399043#wechat_redirect?trk=cndc-detail)的往期文章,以及更多面向开发者的技术分享。请持续关注 Build On Cloud 微信公众号! ## **往期推荐** - [机器学习洞察 | 挖掘多模态数据机器学习的价值](https://dev.amazoncloud.cn/column/article/63e32a58e5e05b6ff897ca0c) - [机器学习洞察 | 分布式训练让机器学习更加快速准确](https://dev.amazoncloud.cn/column/article/63e32dd06b109935d3b77259) - [机器学习洞察 | 降本增效,无服务器推理是怎么做到的?](https://dev.amazoncloud.cn/column/article/63e33010e5e05b6ff897ca0d) ![Build on cloud.gif](https://dev-media.amazoncloud.cn/c7279db01ec943f0a230b24e3642fff6_Build%20on%20cloud.gif "Build on cloud.gif")
目录
亚马逊云科技解决方案 基于行业客户应用场景及技术领域的解决方案
联系亚马逊云科技专家
亚马逊云科技解决方案
基于行业客户应用场景及技术领域的解决方案
联系专家
0
目录
关闭