2021年08月05日

Lua で行列計算:LAPACK を使う

 思うところがあって、Lua で行列計算をすることになりました。特異値分解 (singular value decomposition) をやりたい。Lua でコーディングすることもできなくはないだろうけど、車輪の再発明は避けたいので、BLAS/LAPACK をリンクすることにした。

 昔作った C のライブラリがあったので、それを流用する。Lua 上では、LMatrix という型のユーザーデータを定義する。作戦としては以下の通り。

  • 行列の本体は、下のような構造型として実装する。data として実際に確保されるメモリサイズは、「行数×列数×sizeof(double)」とする。
    typedef struct LAMatrix {
      int row, column;
      double data[1];
    } LAMatrix;
  • Lua のユーザーデータは、LAMatrix へのポインタを保持する。ユーザーデータの中身を書き換えたい時、LAMatrix 自体をアロケートし直すことがあるので、LAMatrix 本体をユーザーデータとして扱うのは無理筋。
    typedef struct LMatrixData {
      LAMatrix *m;
    } LMatrixData;
  • LAMatrix から Lua のユーザーデータを作る関数は下の通り。ユーザーデータが確実に LAMatrix を含んでいることを確認するため、LAMatrix へのポインタを内部のテーブルに登録しておく。
    /*  Registry key for metatable  */
    #define LUA_LMATRIX_INTERNAL "LMatrix**"
    static int
    lm_LMatrixNewFromLAMatrix(lua_State *L, LAMatrix *m)
    {
      /* Stk: ... */
      LMatrixData *p = (LMatrixData *)lua_newuserdata(L, sizeof(LMatrixData));
        // 新しいユーザーデータを作成
      if (p == NULL)
        return luaL_error(L, "cannot allocate LMatrixData userdata");
      p->m = m;
      /* Stk: ... p */
      lua_getglobal(L, "LMatrix");
      /* Stk: ... p metatable */
      lua_setmetatable(L, -2);  // ユーザーデータのメタテーブルを LMatrix にする
      /* Stk: ... p */
      lua_pushstring(L, LUA_LMATRIX_INTERNAL);
      lua_rawget(L, LUA_REGISTRYINDEX);
      /* Stk: ... p table */
      /* Register m in the internal table */
      lua_pushlightuserdata(L, p->m); // LAMatrix を内部のテーブルに登録する
      lua_pushboolean(L, 1);
      /* Stk: ... p table m 1 */
      lua_rawset(L, -3);  /*  table[m] = 1  */
      lua_pop(L, 1);
      /* Stk: ... p */
      return 1;
    }
    
  • Lua のユーザーデータが LAMatrix を保持しているかどうかを調べる関数。
    static LMatrixData *
    lm_LMatrixCheckLMatrix(lua_State *L, int stackPos, int throwError)
    {
      LMatrixData *p = NULL;
      if (lua_isuserdata(L, stackPos)) {
        p = (LMatrixData *)lua_touserdata(L, stackPos);  // LMatrixData へのポインタにキャスト
        /* Stk: ...  */
        lua_pushstring(L, LUA_LMATRIX_INTERNAL);  // LAMatrix は内部のテーブルに登録されているか?
        lua_rawget(L, LUA_REGISTRYINDEX);
        /* Stk: ... table */
        lua_pushlightuserdata(L, p->m);
        lua_rawget(L, -2);
        /* Stk: ... table isregistered? */
        if (!lua_toboolean(L, -1))  // 登録されていれば OK, いなければ NG
          p = NULL;
        lua_pop(L, 2);
        /* Stk: ... */
      }
      if (throwError) {
        luaL_argcheck(L, p != NULL, stackPos, "expecting LMatrix");
      }
      return p;
    }
    
  • ユーザーデータをガーベジとして回収するときは、ポインタを内部のテーブルから削除してから、LAMatrix を解放する。
    static int
    lm_LMatrixRelease(lua_State *L)
    {
      /* Stk: self */
      LMatrixData *p = lm_LMatrixCheckLMatrix(L, 1, false);
      if (p != NULL) {
        /*  Unregister m from the internal table  */
        lua_pushstring(L, LUA_LMATRIX_INTERNAL);  // 内部のテーブルから p->m を削除
        lua_rawget(L, LUA_REGISTRYINDEX);
        /* Stk: self table */
        lua_pushlightuserdata(L, p->m);
        lua_pushnil(L);
        /* Stk: self table m nil */
        lua_rawset(L, -3);  /*  table[m] = nil  */
        /* Stk: self table */
        lua_pop(L, 1);
        LAMatrixRelease(p->m);  // LAMatrix のメモリを解放
        p->m = NULL;
      }
      return 0;
    }
    
  • LMatrix という Lua ライブラリを実装する。
    int luaopen_LMatrix(lua_State *L)
    {
      static const luaL_Reg LMatrixMap[] = {
        {"new", lm_LMatrixNew},
        // ... 中略 ...
        {"__gc", lm_LMatrixRelease},
        {NULL, NULL}
      };  /* LAMatrixMap */
      
      /*  Create an empty table for recording LMatrix pointers  */
      //  内部用のテーブルを初期化
      /* Stk: */
      lua_pushstring(L, LUA_LMATRIX_INTERNAL);
      lua_newtable(L);
      /* Stk: LUA_LMATRIX_INTERNAL newtable */
      lua_rawset(L, LUA_REGISTRYINDEX);
      /* Stk: */
      
    #if LUA_VERSION_NUM >= 502
      luaL_newlib(L, LMatrixMap);
    #else
      luaL_register(L, "LMatrix", LMatrixMap);
    #endif
      
      //  LMatrix テーブルをメタテーブルとして使えるように設定
      /* Stk: table  */
      lua_pushstring(L, "__index");
      /* Stk: table "__index" */
      lua_pushvalue(L, -2);
      /* Stk: table "__index" table */
      lua_settable(L, -3); /* table.__index = table */
      /* Stk: table */
      
      //  LMatrix(n, m) が LMatrix:new(n, m) に読み替えられるように
      //  LMatrix 自身にもメタテーブルを設定
      /* set metatable for LMatrix (to implement "__call" metamethod)  */
      lua_newtable(L);
      /* Stk: table newtable */
      lua_pushstring(L, "__call");
      lua_pushcfunction(L, lm_LMatrixNew);
      lua_settable(L, -3); /* newtable.__call = (lm_LMatrixNew) */
      /* Stk: table newtable */
      lua_setmetatable(L, -2);
      /* Stk: table */
      
      return 1;
    }
    

 コンパイルして LMatrix.dylib または LMatrix.dll を作ればよい。ここで少し面倒なのは、BLAS/LAPACK をどうやってリンクするか。MacOS の場合は、Accelerate.framework の中に clapack があるけど、Windows の場合は標準ライブラリには入ってない(よね?)ので、自前で用意しないといけない。CLAPACK のビルドは結構面倒なので、LAPACK を gfortran でビルドしてリンクする方が簡単。下のようなコマンドラインを使う。../lapack-3.10.0/libw64 に LAPACK と REFBLAS が入っている。また、../LuaAppMaker/build-win/build/liblua51.dll が入っている。-static-Wl,-Bdynamic,-llua51 の組み合わせが要注意。なるべくスタティックリンクしたいのだけど、lua51.dll だけはダイナミックリンクが必須なのです。

$ x86_64-w64-mingw32-gcc -I../LuaAppMaker/LuaJIT-2.0.5/src \
  -Wno-discarded-qualifiers -Wall -O2 -c -o luamatrix_w64.o luamatrix.c
$ x86_64-w64-mingw32-gcc -shared -fPIC -o LMatrix.dll luamatrix_w64.o \
  -L../LuaAppMaker/build-win/build/lib -L../lapack-3.10.0/libw64 \
  -static -llapack -lrefblas -lm -lgfortran -lquadmath -lpthread \
  -Wl,-Bdynamic,-llua51

 下のようなことができるようになりました。一応目標達成かな。

% require "LMatrix"
% a = LMatrix{{1,2,3},{4,5,6}}
% u,s,v = a:svd()  # (NxM)行列の特異値分解
% u   # 左直交行列 (NxN)
-->LMatrix{{-0.42866713354863,-0.56630691884804,-0.70394670414744},
  {0.8059639085893,0.11238241409659,-0.58119908039611},
  {0.40824829046386,-0.81649658092773,0.40824829046386}}
% s   # 特異値
-->LMatrix{{9.5080320006957,0.77286963567348}}
% v   # 右直交行列 (MxM)
-->LMatrix{{-0.38631770311861,-0.92236578007706},
  {-0.92236578007706,0.38631770311861}}
% ss = LMatrix{{9.5080320006957,0,0}, {0,0.77286963567348,0}}
% ss  # 特異値を対角要素に持つ(NxM)行列
-->LMatrix{{9.5080320006957,0,0},
  {0,0.77286963567348,0}}
% c = u:multiply(ss, v)  # u x ss x v は a になるはず
% c
-->LMatrix{{1,2,3},
  {4,5,6}}
Posted at 2021年08月05日 20:24:28
email.png