verilog書く人

自称ASIC設計者です。どなたかkaggle一緒に出ましょう。

scikit-learn 0.20からクロスバリデーションの使い方が変更される模様

機械学習をやるときに結構良く使うところだなと思っていた、クロスバリデーションのスプリッター系モジュールのインターフェースが変わることに気づいたのでメモ。

 

scikit-learnの従来のクロスバリデーション関係のモジュール(sklearn.cross_vlidation)は、scikit-learn 0.18で既にDeprecationWarningが表示されるようになっており、ver0.20で完全に廃止されると宣言されています。

 

詳しくはこちら↓

Release history — scikit-learn 0.18 documentation

 

 

まず、import元がsklearn.cross_varidationからsklearn.model_selectionに変わります。

 

これによって、例えば従来は

from sklearn.cross_varidation import StratifiedKFold

としていたところを

from sklearn.model_selection import StratifiedKFold

にするようになりました。

 

また、KFold系のクラスは、クラス自体をイテレータとして使うやり方から、splitメソッドを使うやり方になりました。

ver0.17まではn_folds=3の場合、

for train_idx, test_idx in KFold(3, iris.target):
xs_train = iris.data[train_idx]
y_train = iris.target[train_idx]
xs_test= iris.data[test_idx]
y_test = iris.target[test_idx]

としていたところを

for train_idx, test_idx in KFold(n_splits=3).split(iris.data):
xs_train = iris.data[train_idx]
y_train = iris.target[train_idx]
xs_test = iris.data[test_idx]
y_test = iris.target[test_idx]

とすることで、クロスバリデーションのテストデータと学習用データを分割できるようになりました。

今までn_foldsと呼ばれていた変数が、n_splitsという名前になることにも注意してください。

 

StratifiedKFoldでは、

for train_idx, test_idx in StratifiedKFold(3, iris.target):
xs_train = iris.data[train_idx]
y_train = iris.target[train_idx]
xs_test = iris.data[test_idx]
y_test = iris.target[test_idx]

としていたところを

for train_idx, test_idx in StratifiedKFold(n_splits=3).split(iris.data, iris.target):
xs_train = iris.data[train_idx]
y_train = iris.target[train_idx]
xs_test = iris.data[test_idx]
y_test = iris.target[test_idx]

とします。

 

 

今までのsklearn.cross_vlidationはそのうち使えなくなるので気をつけましょう。

今後公開用のスクリプトを書くときは新しい方式で書いたほうがよいと思われます。

 

それから、異なるFoldにおいて同じラベルが表れない

cross_validation.LabelKFold

cross_validation.LabelShuffleSplit

cross_validation.LeaveOneLabelOut

cross_validation.LeavePLabelOut

 

系のスプリッターは、

 

model_selection.GroupKFold

model_selection.GroupShuffleSplit

model_selection.LeaveOneGroupOut

model_selection.LeavePGroupsOut

 

になり、

これらのクラスにおいてlabelsと呼ばれていたパラメータがgroupsに、n_labelsがn_groupsになりました。