chainerでモデルを入れ子にしたら重みが更新されなかった話

概要

chainerのmodel(Chainクラス)を入れ子にして使っていたら重みが更新されなかった.
Chainクラスで重みの更新がされるのは self.init_scope()内に書いている linkオブジェクトだけだったことが判明し,
with self.init_scope():以下に書くとちゃんと更新された.

状況

version

chainer==3.0.0

やりたかったこと

あるmodelAlayerNを追加して,新たに modelBを作成したかった.

だめなコード

計算グラフを出力すると,ちゃんとmodelA -> layerN という風に接続されていたので,これでうまく接続されているものだと思っていた.
が,実際に学習中に都度重みを出力してみると,modelA内の重み(l1, l2, l3の重み)が全く更新されていないことがわかった.

# example/train_mnist.pyから拝借
class modelA(chainer.Chain):
    def __init__(self, n_units, n_out):
        super(modelA, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

class modelB(chainer.Chain):
    def __init__(self, n_out, modelA):
        super(modelB, self).__init__()
        self.modelA = modelA
        with self.init_scope():
            self.layerN = L.Linear(None, n_out)
    
    def __call__(self, x):
        h1 = F.relu(self.modelA(x))
        h2 = self.layerN(h1)
        return h2

よいコード

まあちゃんとドキュメント見ればそれっぽいことは書いてあるんだが,まったく気づかなかった..
init_scope内に書くと,context managerとやらに登録されるらしい.
chainer.Chain — Chainer 3.0.0 documentation

# example/train_mnist.pyから拝借
class modelA(chainer.Chain):
    def __init__(self, n_units, n_out):
        super(modelA, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

class modelB(chainer.Chain):
    def __init__(self, n_out, modelA):
        super(modelB, self).__init__()
        # self.modelA = modelA
        with self.init_scope():
            self.modelA = modelA # ここに書くのが正解
            self.layerN = L.Linear(None, n_out)
    
    def __call__(self, x):
        h1 = F.relu(self.modelA(x))
        h2 = self.layerN(h1)
        return h2

まとめ

重み更新したい linkオブジェクトは init_scope内に書きましょう.
逆に,fine-tuningとかで重みを更新したくない場合は, init_scope内に書かなければ更新されないようなので,便利だなーと思った.

ドキュメントはちゃんと読みましょう.