python の高速化ツール numba を使う

Python3の64bit環境へ移行してプログラムの高速化への準備を整える

まずは、python2.7で使っていたプログラムをpython3へと移行する。
これはディストリビューションに付属していたツール「2to3」を使って変換する。このプログラムの
オプションがよくわかなかったので何度かつまずいたがとりあえず変換に成功。

次に高速化にトライ。いろいろと見て回ったうちでもっとも簡単に
高速化できるツールがnumbaではないか?

Python: Numbaによる高速化

バージョンがまだ1.0になっていないようだけどなにしろやり方が簡単だ。関数の前に一行加えるだけで高速化できる。これは簡単さっそく上記のサイトのプログラムを書き換えてトライ。

import numpy as np
from numba.decorators import jit,autojit

def pairwise_python(X, D):
  M = X.shape[0]
  N = X.shape[1]
  for i in range(M):
    for j in range(M):
      d = 0.0
      for k in range(N):
        tmp = X[i, k] - X[j, k]
        d += tmp * tmp
      D[i, j] = np.sqrt(d)

@jit
def pairwise_numba(X, D):
  M = X.shape[0]
  N = X.shape[1]
  for i in range(M):
  for j in range(M):
  d = 0.0
      for k in range(N):
        tmp = X[i, k] - X[j, k]
        d += tmp * tmp
      D[i, j] = np.sqrt(d)

@autojit
def pairwise_numba_auto(X, D):
  M = X.shape[0]
  N = X.shape[1]
  for i in range(M):
    for j in range(M):
      d = 0.0
      for k in range(N):
        tmp = X[i, k] - X[j, k]
        d += tmp * tmp
      D[i, j] = np.sqrt(d)

@jit("void(f8[:,:],f8[:,:])")
def pairwise_numba_jit2(X, D):
  M = X.shape[0]
  N = X.shape[1]
  for i in range(M):
    for j in range(M):
      d = 0.0
      for k in range(N):
        tmp = X[i, k] - X[j, k]
        d += tmp * tmp
      D[i, j] = np.sqrt(d)

X = np.random.random((1000, 3))
D = np.empty((1000, 1000))

#%timeit pairwise_python(X, D)
#%timeit pairwise_numba(X, D)
#%timeit pairwise_numba_auto(X, D)
#%timeit pairwise_numba_jit2(X, D)


このブログラムで実行時間は
1 loops, best of 3: 7.33 s per loop
1 loops, best of 3: 12.5 ms per loop
1 loops, best of 3: 12.9 ms per loop
100 loops, best of 3: 12.8 ms per loop

いきなり100倍ぐらいになった!しかし、@autojitと型指定した場合の効果はなかった。
通常の@jitオプションでもコンパイル結果に差がない場合は、スピードは変わらないということか。

上記記事の、np.sum関数を中で使うと速度が低下するのはよくわからないが、np.関数によってはnumbaの高速化できないものがあるということか?

numbaでサポートするnumpy関数。
Numpy Support in numba
ここをみるとsqrtは入っているがsumは入ってないね。つまりサポートしていないということかも。

ここでも、numbaで高速化できてない例が挙げられているが、np.maxは上記のサポートの中に入っていないみたいだ。なので、
for i in arr:
  if i > MAX:
    MAX = i
という手順でmax関数を代用している回答が寄せられている。
http://stackoverflow.com/questions/20679380/optimizing-access-on-numpy-arrays-for-numba 

numbaでサポートしているかどうかチェックして使う必要がありそうだ


この記事へのコメント


この記事へのトラックバック