JVM上如何进行高效的矩阵相乘

在神经网络中,矩阵相乘往往占据了70%以上的时间。矩阵相乘不仅仅被用在全连接层,在这篇文章中,还介绍了怎样利用矩阵相乘进行卷积运算。可以说矩阵相乘是深度学习的核心。

对于矩阵相乘进行性能优化,可以采用包括使用SIMD指令、多线程和Cache访问优化等方法。

在不同的平台上,业界早已经有了比较成熟的高效矩阵相乘的实现,例如MKL,OpenBLAS,clBLAS,cuBLAS等。

从性能的角度,目前JVM上还没有与以上这些相媲美的实现。所以,一种玩法是通过JNI调用这些native的库。

netlib-java就是这么玩的。它提供一组标准的线性代数运算接口(BLAS, LAPACK, ARPACK),如果本地安装了支持这些接口的native库,它会直接使用这些高效的native库进行计算,否则使用一个JVM的版本进行计算。

目前Spark就是使用netlib-java进行线性代数运算。

矩阵相乘的方法名称是gemm(General Matrix to Matrix Multiplication)。分成双精度(dgemm)和单精度(sgemm)两个版本,这两个版本的参数是一致的,只不过在一些参数类型上是double和float的区别。

这里以mkl为例,介绍在JVM上如何使用netlib-java。MKL 是Intel开发的在x86 CPU上最快和使用最为广泛的数学运算库。假设MKL的安装在/opt/intel下面

sudo ln -sf /opt/intel/mkl/lib/intel64/libmkl_rt.so /usr/lib/libblas.so
sudo ln -sf /opt/intel/mkl/lib/intel64/libmkl_rt.so /usr/lib/libblas.so.3
sudo ln -sf /opt/intel/mkl/lib/intel64/libmkl_rt.so /usr/lib/liblapack.so
sudo ln -sf /opt/intel/mkl/lib/intel64/libmkl_rt.so /usr/lib/liblapack.so.3
sudo echo -e '/opt/intel/mkl/lib/intel64\n/opt/intel/lib/intel64'>/etc/ld.so.conf.d/libblas.conf
sudo ldconfig
<dependency>
    <groupId>com.github.fommil.netlib</groupId>
    <artifactId>all</artifactId>
    <version>1.1.2</version>
    <type>pom</type>
</dependency>
import com.github.fommil.netlib.BLAS;
import static com.github.fommil.netlib.BLAS.getInstance;
public class BlasDemo {
    public static void main(String[] args) {
        BLAS blas = getInstance();

        // 2 x 3 matrix
        //   1.0 3.0 5.0
        //   2.0 4.0 6.0
        double[] A = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};

        // 3 x 2 matrix
        //   6.0 3.0
        //   5.0 2.0
        //   4.0 1.0
        double[] B = {6.0, 5.0, 4.0, 3.0, 2.0, 1.0};

        double[] C = new double[4];

        int m = 2;
        int n = 2;
        int k = 3;
        double alpha = 1.0;
        int lda = 2;
        int ldb = 3;
        double beta = 0.0;
        int ldc = 2;

        blas.dgemm(
            "N", "N",
            m, n, k,
            alpha, A, lda, B, ldb,
            beta, C, ldc
        );

        for(int y = 0; y < 2; y++) {
            for(int x = 0; x < 2; x++) {
                System.out.print(C[x + y * ldc] + " ");
            }
            System.out.println();
        }
    }
}

当第一次看到gemm的参数时,感到有点云里雾里。这里对这几个参数做一些简单的说明:

netlib-java也可以和GPU整合。可以参考这篇文章。简单的说,只要配置环境不需要改代码

sudo apt-get install cublas blas
export LD_LIBRARY_PATH=PATH_TO_CUBLAS/lib64:PATH_TO_SYSTEM_BLAS
export LD_PRELOAD=libnvblas.so