kaggle

pytorchでSpecAugmentationの実装!

この記事では、SpecAugmentationの実装を紹介します。

kaggle、研究、インターン先でも使う場面が訪れたので、そろそろ記事にしてみます!

紹介といっても、torchlibrosaを用いるので実装は超簡単です。

SpecAugmentationとは?

SpecAugmentationとは、音におけるデータ拡張手法の一つです。

2019年にGoogle Brainによって提案されました。

SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition

複雑なモデルよりも、シンプルなモデル+SpecAugmentationの方が音声認識タスクにおいて良い性能が出たと報告されています!

また、音のImageNetのようなポジションにあるAudio-setの事前学習モデルを作成した、

PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern RecognitionでもSpecAugmentationは適応されています。

この論文の中でもSpecAugmentationは学習に効果があったと報告されています。

 

特徴として、

  • ログメルスペクトログラムに適応!(データセット内で完結できる!)
  • 計算が軽い!
  • 実装が簡単!
  • 汎化性能UP(過学習対策に有効!)

が挙げられます。

使いどころとしては、

これ、どう見ても過学習しているんだよな。でも、表現力的にまだこのモデルで行きたい。

データ拡張したいけど、ノイズとかはもう試したんだよな。ピッチを変更したり生音をいじるのは計算コスト的に億劫だな。

とりあえず、SpecAugmentation適応するか。

という流れです。それくらい簡単に実装できます。

これを用いるリスクは、音イベントが単一の周波数もしくは、瞬間的な音イベントの場合SpecAugmentationを行うことでその部分がマスクされてしまい意味のないデータを生成してしまう可能性があるという点です。

なので、スペクトログラムの一部がマスクされても意味のないデータにならないことが、使用する際の前提条件です。

以下、実装をします。

SpecAugmentationの実装

今回は、PANNsの実装(https://github.com/qiuqiangkong/audioset_tagging_cnn)にならって書きます。

PANNsではモデルの中に、torchlibrosaで実装をしていました。

torchlibrosaはpytorchで音をいじるときに何かと便利なので、最近使い始めました。

インストールもpipでいけます。

pip install torchlibrosa

お勧めです。笑

では、実装です。

from torchlibrosa.augmentation import SpecAugmentation
import torch
logmel = torch.tensor(mel_norm).unsqueeze(0).unsqueeze(0)
print(f"mel_norm:{mel_norm.shape}, logmel:{logmel.shape}")
spec_augmenter = SpecAugmentation(
            time_drop_width=32,
            time_stripes_num=2,
            freq_drop_width=32,
            freq_stripes_num=4,
        )
augmented = spec_augmenter(logmel)
augmented = augmented.squeeze(0).squeeze(0).numpy()
plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)
plt.imshow(mel_norm, aspect="auto")
plt.xlabel("frequency", size=16)
plt.ylabel("time", size=16)
plt.title("original", size=16)
plt.subplot(1, 2, 2)
plt.imshow(augmented, aspect="auto")
plt.colorbar()
plt.xlabel("frequency", size=16)
plt.ylabel("time", size=16)
plt.title("SpecAug", size=16)
plt.savefig('spec_aug.png')

# mel_norm:(128, 627), logmel:torch.Size([1, 1, 128, 627])

 

注意点としては、spec_augmenterが受け付けるデータがtorch.tensorの4dimだということです。

これは、入力がx : (batch_size, channel, time, freq) を前提としているからです。

実装は以上です!

最後に

たったこれだけの実装でモデルの性能を爆上げすることができます!

SpecAugmentation恐るべしです。

音を扱っていて、過学習が気になったらその対策候補の一つに入れてみるのもいいかもですね!

では!

 

オススメのプログラミングスクールをご紹介

タイピングもままならない完全にプログラミング初心者から

アホいぶきんぐ
アホいぶきんぐ
プログラミングってどこの国の言語なの~?

たった二ヶ月で

いぶきんぐ
いぶきんぐ
え!?人工知能めっちゃ簡単にできるじゃん!

応用も簡単にできる…!!

という状態になるまで、一気に成長させてくれたオススメのプログラミングスクールをご紹介します!

テックアカデミーのPython+AIコースを受講した僕が本音のレビュー・割引あり! というプログラミング完全初心者だった僕が Tech Academy(テックアカデミー)のPython×AIコース を二ヶ月間...

COMMENT

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です