JVM上如何进行高效的矩阵相乘
02 Oct 2015在神经网络中,矩阵相乘往往占据了70%以上的时间。矩阵相乘不仅仅被用在全连接层,在这篇文章中,还介绍了怎样利用矩阵相乘进行卷积运算。可以说矩阵相乘是深度学习的核心。
对于矩阵相乘进行性能优化,可以采用包括使用SIMD指令、多线程和Cache访问优化等方法。
在不同的平台上,业界早已经有了比较成熟的高效矩阵相乘的实现,例如MKL,OpenBLAS,clBLAS,cuBLAS等。
从性能的角度,目前JVM上还没有与以上这些相媲美的实现。所以,一种玩法是通过JNI调用这些native的库。
netlib-java就是这么玩的。它提供一组标准的线性代数运算接口(BLAS, LAPACK, ARPACK),如果本地安装了支持这些接口的native库,它会直接使用这些高效的native库进行计算,否则使用一个JVM的版本进行计算。
目前Spark就是使用netlib-java进行线性代数运算。
矩阵相乘的方法名称是gemm(General Matrix to Matrix Multiplication)。分成双精度(dgemm)和单精度(sgemm)两个版本,这两个版本的参数是一致的,只不过在一些参数类型上是double和float的区别。
这里以mkl为例,介绍在JVM上如何使用netlib-java。MKL 是Intel开发的在x86 CPU上最快和使用最为广泛的数学运算库。假设MKL的安装在/opt/intel
下面
- netlib-java会寻找blas和lapack的动态库,所以要在/usr/lib/下面创建名为libblas.so和liblapack.so的软连接。
sudo ln -sf /opt/intel/mkl/lib/intel64/libmkl_rt.so /usr/lib/libblas.so
sudo ln -sf /opt/intel/mkl/lib/intel64/libmkl_rt.so /usr/lib/libblas.so.3
sudo ln -sf /opt/intel/mkl/lib/intel64/libmkl_rt.so /usr/lib/liblapack.so
sudo ln -sf /opt/intel/mkl/lib/intel64/libmkl_rt.so /usr/lib/liblapack.so.3
- 将mkl的lib路径加入ld.conf,使得操作系统载入mkl的动态库
sudo echo -e '/opt/intel/mkl/lib/intel64\n/opt/intel/lib/intel64'>/etc/ld.so.conf.d/libblas.conf
sudo ldconfig
- 在maven项目的pom.xml文件中加入netlib-java的依赖
- 代码示例(一个2 x 3的矩阵乘以3 x 2的矩阵)
当第一次看到gemm的参数时,感到有点云里雾里。这里对这几个参数做一些简单的说明:
-
这个API进行的是 C = alpha * A * B + beta * C的运算,A是个m x k的矩阵,B是一个k x n的矩阵。
-
A,B,C的数据放在数组里。这里要注意的是,矩阵默认是以column major(先列后行)的形式存储的。这是因为最早gemm是在Fortran上提出的,Fortran的二维数组是先列后行,这和现在的C和Java都不一样。
-
第一个和第二个字符串参数表示A,B是否需要要转置(“T”)还是不转置(“N”)成为指定形状(m x k,k x n)的矩阵。当为“T”时,相当于矩阵按照row marjor(先行后列)来存了。
-
lda,ldb,ldc是指存储不连续的维度上两个相邻的元素的存储间隔。举个栗子,对于一个m x n的矩阵A,当不转置的时候,由于默认是column-major,数组中的元素会按照列顺序去填,填满一列再填下一列。所以存储不连续的维度是行,这样行上面两个相邻元素实际间隔lda就是一个m。如果是转置的,即A转置后才是m x k的矩阵,数组中的元素先填行,填完一行再填下一行。这样存储不连续的维度是列,lda也就是k了。
netlib-java也可以和GPU整合。可以参考这篇文章。简单的说,只要配置环境不需要改代码
- 安装cublas和blas,在debian / ubuntu上
sudo apt-get install cublas blas
-
创建NVBLAS配置文件,指定矩阵在GPU上如何运算(见示例)
-
修改环境变量
export LD_LIBRARY_PATH=PATH_TO_CUBLAS/lib64:PATH_TO_SYSTEM_BLAS
export LD_PRELOAD=libnvblas.so