TesorFlow: Pythonで学習したデータをAndroidで実行

1 Star2 Stars3 Stars4 Stars5 Stars (まだ評価されていません)
Loading...

TensorFlowのチュートリアル(手書き画像データの画像認識)を元に、DeepLearningのネットワークデータを書き出し、Androidで手書き認識をするデモを作成してみました。

tensorflow_mnist_screen0.png

学習データの書き出し

TensorFlowのチュートリアルの”MNIST For ML Beginners” のモデルを元にまずは、学習データをPythonにてPC上で書き出します。

“MNIST For ML Beginners”
https://www.tensorflow.org/versions/master/tutorials/mnist/beginners/index.html

ここのチュートリアルを元に、グラフデータ書き出し用にスクリプトを変更しました。

https://github.com/miyosuda/TensorFlowAndroidMNIST/blob/master/trainer-script/beginner.py

ネットワークデータを書き出すには、グラフ情報とVariableに入ったテンソルデータ(学習された内容)をいっしょにして書き出す必要があるのですが、現時点ではTensorFlowは、グラフ情報とVariableをいっしょに保存することができない様でした。

なので、学習後に、Viriablesの中身を評価して、一度ndarrayに変換し、

# Store variable
_W = W.eval(sess)
_b = b.eval(sess)

ndarrayをConstantに変換し、Variablesの代わりとして利用しグラフを再構成することでグラフと学習データをまとめて書き出しました。

# グラフを再生成
g_2 = tf.Graph()
with g_2.as_default():
# 入力部分に"input"と名前付けしておく
x_2 = tf.placeholder("float", [None, 784], name="input")
# VariablesをConstantで置き換え
W_2 = tf.constant(_W, name="constant_W")
b_2 = tf.constant(_b, name="constant_b")
# 出力部分に"output"と名前付けしておく
y_2 = tf.nn.softmax(tf.matmul(x_2, W_2) + b_2, name="output")
sess_2 = tf.Session()
init_2 = tf.initialize_all_variables()
sess_2.run(init_2)
# グラフを書き出し
graph_def = g_2.as_graph_def()
tf.train.write_graph(graph_def, './tmp/beginner-export',
'beginner-graph.pb', as_text=False)

Android側で呼び出すために、入力部分と出力部分のノードにそれぞれ、”input”, “output”という名前をつけておきました。

このモデルの学習と書き出しは数秒ですみました。

Android側

TensorFlowにもともと入っているAndroidデモは、Bazel環境でしかビルドできなかったので、AndroidStudioとNDKだけでAndroidアプリがビルドできる環境を作りました。

TensorFlowのAndroidサンプルをBazelでビルドした時にできるライブラリファイル(.aファイル)を保存しておき、NDKだけでビルドできる様にしました

https://github.com/miyosuda/TensorFlowAndroidMNIST/tree/master/jni-build/jni

Android.mkはこの様になりました。

Makefile

LOCAL_PATH := $(call my-dir)
include $(CLEAR_VARS)
TENSORFLOW_CFLAGS     := -frtti \
-fstack-protector-strong \
-fpic \
-ffunction-sections \
-funwind-tables \
-no-canonical-prefixes \
'-march=armv7-a' \
'-mfpu=vfpv3-d16' \
'-mfloat-abi=softfp' \
'-std=c++11' '-mfpu=neon' -O2 \
TENSORFLOW_SRC_FILES := ./tensorflow_jni.cc \
./jni_utils.cc \
LOCAL_MODULE    := tensorflow_mnist
LOCAL_ARM_MODE  := arm
LOCAL_SRC_FILES := $(TENSORFLOW_SRC_FILES)
LOCAL_CFLAGS    := $(TENSORFLOW_CFLAGS)
LOCAL_LDLIBS    := \
-Wl,-whole-archive \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libandroid_tensorflow_lib.a \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libre2.a \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libprotos_all_cc.a \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libprotobuf.a \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libprotobuf_lite.a \
-Wl,-no-whole-archive \
$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/$(TARGET_ARCH_ABI)/libgnustl_static.a \
$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/$(TARGET_ARCH_ABI)/libsupc++.a \
-llog -landroid -lm -ljnigraphics -pthread -no-canonical-prefixes '-march=armv7-a' -Wl,--fix-cortex-a8 -Wl,-S \
LOCAL_C_INCLUDES += $(LOCAL_PATH)/include $(LOCAL_PATH)/genfiles $(LOCAL_PATH)/include/third_party/eigen3
NDK_MODULE_PATH := $(call my-dir)
include $(BUILD_SHARED_LIBRARY)

コンパイラオプションや、リンカオプションを上記の様に設定しないと、ビルドが通っても、学習データ(ProtocolBufferデータ)が正しく読み込めませんでした。

Java側で28×28の手書きのピクセルデータを用意して、JNIでc++側に渡し、グラフデータを元に作成したグラフに入力します。

https://github.com/miyosuda/TensorFlowAndroidMNIST/blob/master/jni-build/jni/tensorflow_jni.cc

tensorflow_mnist_screen0.png

無事に認識できました。

学習データの置き換え

上記のモデルだと、認識率が91%くらいなので、TensorFlowの”Deep MNIST for Experts”にあるDeep Learningをつかったモデル(認識率99.2%)に置き換えてみました。

“Deep MNIST for Experts”
https://www.tensorflow.org/versions/master/tutorials/mnist/pros/index.html

学習データ書き出し用スクリプト
https://github.com/miyosuda/TensorFlowAndroidMNIST/blob/master/trainer-script/expert.py

DropOutのノードが入っていると、Android側での実行時に何故かエラーが出てしまったのと、元々DropOutノードは学習時にしか必要ないものなので、書き出し時にはグラフから外しました。

自分の環境(MacBook Pro)で学習には1時間くらいかかりました。

inputノードとoutputノードの名前を同じにしておいたので、書き出し後は、学習データを差し替えるだけで、Android側のコードはすべてそのままで実行できます。

手書き認識を試してみたところ、かなり正確に0〜9の数字を認識するのが確認できました。

ソース

上記のソース一式ははこちら
https://github.com/miyosuda/TensorFlowAndroidMNIST


1 Star2 Stars3 Stars4 Stars5 Stars (まだ評価されていません)
Loading...
      この投稿は審査処理中  | 元のサイトへ