chainerでモデルを入れ子にしたら重みが更新されなかった話
概要
chainerのmodel(Chainクラス)を入れ子にして使っていたら重みが更新されなかった.
Chainクラスで重みの更新がされるのは self.init_scope()
内に書いている link
オブジェクトだけだったことが判明し,
with self.init_scope():
以下に書くとちゃんと更新された.
状況
version
chainer==3.0.0
やりたかったこと
あるmodelA
に layerN
を追加して,新たに modelB
を作成したかった.
だめなコード
計算グラフを出力すると,ちゃんとmodelA -> layerN という風に接続されていたので,これでうまく接続されているものだと思っていた.
が,実際に学習中に都度重みを出力してみると,modelA内の重み(l1, l2, l3の重み)が全く更新されていないことがわかった.
# example/train_mnist.pyから拝借 class modelA(chainer.Chain): def __init__(self, n_units, n_out): super(modelA, self).__init__() with self.init_scope(): self.l1 = L.Linear(None, n_units) # n_in -> n_units self.l2 = L.Linear(None, n_units) # n_units -> n_units self.l3 = L.Linear(None, n_out) # n_units -> n_out def __call__(self, x): h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) return self.l3(h2) class modelB(chainer.Chain): def __init__(self, n_out, modelA): super(modelB, self).__init__() self.modelA = modelA with self.init_scope(): self.layerN = L.Linear(None, n_out) def __call__(self, x): h1 = F.relu(self.modelA(x)) h2 = self.layerN(h1) return h2
よいコード
まあちゃんとドキュメント見ればそれっぽいことは書いてあるんだが,まったく気づかなかった..
init_scope内に書くと,context managerとやらに登録されるらしい.
chainer.Chain — Chainer 3.0.0 documentation
# example/train_mnist.pyから拝借 class modelA(chainer.Chain): def __init__(self, n_units, n_out): super(modelA, self).__init__() with self.init_scope(): self.l1 = L.Linear(None, n_units) # n_in -> n_units self.l2 = L.Linear(None, n_units) # n_units -> n_units self.l3 = L.Linear(None, n_out) # n_units -> n_out def __call__(self, x): h1 = F.relu(self.l1(x)) h2 = F.relu(self.l2(h1)) return self.l3(h2) class modelB(chainer.Chain): def __init__(self, n_out, modelA): super(modelB, self).__init__() # self.modelA = modelA with self.init_scope(): self.modelA = modelA # ここに書くのが正解 self.layerN = L.Linear(None, n_out) def __call__(self, x): h1 = F.relu(self.modelA(x)) h2 = self.layerN(h1) return h2
まとめ
重み更新したい link
オブジェクトは init_scope
内に書きましょう.
逆に,fine-tuningとかで重みを更新したくない場合は, init_scope
内に書かなければ更新されないようなので,便利だなーと思った.
ドキュメントはちゃんと読みましょう.