この記事はChainer Advent Calendar 2017の23日目の記事です。
僕は普段、Chainerを使って研究開発しています。 このとき、クラスをどう分けるべきかよく悩みます。 いろいろやってみてある程度固まってきたので、自分なりにまとめてみました。
ChainerなどのDeepLearningフレームワークを使う理由は大きく分けて3段階ほどあります。
- 再現実験
- 試行錯誤を伴う実験
- 学習済みモデルを用いたシステムづくり
世の中に転がっているChainerサンプルプログラムは大体(1)のもので、こちらは綺麗にまとまっているものが多いです。 一方で、何か新規に実験していると、どうしても試行錯誤が発生してコードが煩雑になります(2)。 そしてさらには、(1)や(2)で学習したモデルを使ってサービス応用しようとすることもあります(3)。
今回は研究開発用のコード、つまり、サービス応用を考えつつ実験コードを書く際に、どうクラスを切っていくべきか考えをまとめます。
まず、細かいのも含めると、実験コードには次の構成要素があります。
- DataProcess : 入力・出力データの加工する
- Dataset : データをChainer用にまとめる
- Network : 汎用のニューラルネットワーク
- Loss : 損失の取り回し
- Model : ニューラルネットワーク全体
- Updater : モデルの更新(+データの取り回し)
- Trainer : 便利モジュールとの連携
規模や実験内容に応じてDataProcess
はDataset
に、Network
とLoss
はModel
にまとめることもあります。
このうち、学習済みモデルを用いたサービスを作る際に必要なのは、DataProcess
とModel
だけです。
それぞれに関して、なんなのか、なぜそれが必要か、どういうときに必要かを書きます。
DataProcess
入力データや出力データを加工する関数、もしくは呼び出し可能なオブジェクトです。
画像を読み出す、クロップする、線画化する、などなど。
これらのデータ処理は、DatasetMixin
オブジェクトのget_example
メソッドに書くこともできますが、こうしてしまうとあとで流用する際にそのオブジェクトの構造を意識する必要が出てきます。
例えば1枚の画像を加工したいだけでも、DatasetMixin
オブジェクトを作成し、get_example(0)
しなければいけません。
最初からデータを加工する関数を切り出しておけば、後で簡単に流用できます。
Dataset
データをChainer用にまとめるクラスです。DatasetMixin
を継承して作るのが一般的です。
DataProcess
にも書いたとおり、ここに記述した処理は後で流用しづらいので、なるべく簡単なことしか書かないほうが良いと思います。
僕はDataProcess
を1つだけ受け取ってデータ加工するDataset
クラスをよく使っています。
class Dataset(chainer.dataset.DatasetMixin):
def __init__(self, inputs, data_process):
self._inputs = inputs
self._data_process = data_process
def __len__(self):
return len(self._inputs)
def get_example(self, i):
return self._data_process(self._inputs[i])
Network
汎用のネットワークを書きます。簡単なモデルの場合はなくても良いと思います。 僕はよくBatchNormalizationとConvolution2Dをまとめたのを流用しています。
class BNConvolution2D(chainer.link.Chain):
def __init__(self, in_channels, out_channels, ksize, stride=1, pad=0, **kwargs):
super().__init__()
with super().init_scope():
self.conv = chainer.links.Convolution2D(in_channels, out_channels, ksize, stride, pad, nobias=True, **kwargs)
self.bn = chainer.links.BatchNormalization(out_channels)
def __call__(self, x):
return chainer.functions.relu(self.bn(self.conv(x)))
Loss
損失関数を実装します。簡単なモデルの場合はなくても良いと思います。
ChainerのTrainer
とloss周りの扱いはややこしく、chainer.report
を使ったりする必要があります。
Loss
クラスの書き方はchainer.links.Classifierがとても参考になります。
コンストラクタでModel
オブジェクトを受け取って__call__
でフォワードし、得られた出力を元にlossを作ってreturnする設計です。
Loss
クラスが必要になるのはモデルが2種類以上あるときです。DCGANなどのタスクではLoss
クラスを作って、生成器と判別器用のlossを返すと綺麗にコードが書けます。
Model
ニューラルネットワークをまとめたクラスです。
Optimizer
1つにつきModel
1つと考えると理解しやすいです。
chainer.link.Chain
やchainer.link.ChainList
を継承して書くのが一般的です。
Updater
こいつがむちゃくちゃしんどいです。
Model
が1つしかなければchainer.training.StandardUpdater
を使うと大体うまく行きます。
Model
が複数ある場合、StandardUpdater
を継承したUpdater
クラスを自分で定義し、データの流れとモデルの更新を自分で書く必要があります。
DCGANのサンプル実装でちょっと雰囲気がつかめると思います。
Loss
クラスをうまく切り出せてさえいればある程度綺麗に書けます。
Trainer
Chainerが用意した学習用のクラスです。
Updater
やModel
を与えるとよしなに色々やってくれます。
これに関してはいろんな記事があるので説明は割愛します。
これらの方式で実験コードを書くと、ある程度煩雑になってきても大規模な改修は発生しづらくなります。
また、DataProcess
とModel
をライブラリ化すれば、サービス応用も比較的簡単に行なえます。
Chainerは柔軟でいろんなクラス設計が可能です。 試行錯誤を伴う実験をしていてもコードが散らばらないような設計があれば、ぜひ教えてください。 開発方針を自分の中で持っておいて、どんどん研究開発していきたいものです。