Few Shot Learning – Siamese Network

https://miro.medium.com/max/1200/0*AuZLKEVsQpsTD3Dt

Original Source Here

Few Shot Learning – Siamese Network

透過few shot learning 來打造生物識別(人臉、聲紋、手寫…)模型!

Photo by Green Chameleon on Unsplash

由於疫情,近期都很少出門宅在家,就來記錄一下最近有用到的技術吧~

一般的分類問題往往都是屬於類別不多且每個類別資料量很多任務,比如MNIST 手寫資料集或 imagenet 的影像分類問題。但在生物識別的task 上我們往往沒辦法收集到那麼多資料,比如說我要建立一個人臉識別的模型,我應該不太可能跟每個我要識別的人都收集大量的照片,況且世界上的人那麽多,我也不大可能收集所有人的照片來建立分類模型。

Few shot learning 算是 meta-learning 的其中一塊,核心概念是讓模型學會學習(learn to learn)。這樣說有點懸,我們可以把它理解成: few shot learning 是要讓模型學會區分事物的差異。一個學會區分事物差異的模型,我們可以把它用在訓練集從未見過的新類別,並且可以只透過很少的樣本(few shot) 就學會區別此事物。

Siamese Network

Siamese 這個詞是孿生、連體嬰的意思,表示兩個人身體相連且共享部分的器官。而siamese network 是只有兩個架構權重都相同的類神經網路組合在一起(如下右圖)

可以看到這個網路的input 是一個image pairs,而我們的目標是要訓練一個能夠區分事物差異的網路。想必聰明的你已經想到要如何使用這個網路結構了!

首先我們要準備很多positive samples 以及 negative samples,分別表示相同類別的 image pairs 以及不同類別的 image pairs:

https://www.youtube.com/watch?v=UkQ2FVpDxHg

而 siamese network 就是要預測 input 的 image pairs 是否為相同類別,所以說就是一個binary classification 的問題!

實際要搭建siamese network 也非常的簡單,以pytorch 做一個範例:

class siameseNet(nn.Module):
def __init__(self, embedding_net):
super(siameseNet, self).__init__()
self.embedding_net = embedding_net
def forward(self, x1, x2):
output1 = self.embedding_net(x1)
output2 = self.embedding_net(x2)
return output1, output2
def get_embedding(self, x):
return self.embedding_net(x)

其中 embedding_net 是任何你自己搭建的CNN 網路,最後再把output1, output2 喂給 loss function backward 即可。

預測時可以調用 get_embedding() 就不用每次都 forward 兩張image了~

Contrastive loss

目標: 使相同類別的embedding 越接近越好,不同類別的embedding 越遠越好,用這個觀點來看下面的式子就會非常直觀了

Dw 表示兩embedding 之距離(歐式距離)
class ContrastiveLoss(nn.Module):   def __init__(self, margin):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.eps = 1e-9
def forward(self, output1, output2, target, size_average=True):
distances = (output2 - output1).pow(2).sum(1) # squared distances
losses = 0.5 * (target.float() * distances +
(1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
return losses.mean() if size_average else losses.sum()

Triplet loss

Triplet Loss是Google 在 2015 年發表的 FaceNet 論文中提出。可視為Contrastive loss 的改良。

triplet loss 必須建構在三元的image pair 下才能計算,搭配的網路架構如下

可以看到triplet loss 的做法直接喂給模型一個positive 以及一個 negative sample 來訓練,目標一樣是期望positive 能越接近anchor 而 negative 能越遠離anchor:

triplet loss
class TripletLoss(nn.Module):   def __init__(self, margin):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative, size_average=True):
distance_positive = (anchor - positive).pow(2).sum(1)
distance_negative = (anchor - negative).pow(2).sum(1)
losses = F.relu(distance_positive - distance_negative + self.margin)
return losses.mean() if size_average else losses.sum()

要搭建triplet 的 network 其實和siamese 幾乎一模一樣,只是input, output 變成三個而已

class TripletNet(nn.Module):
def __init__(self, embedding_net):
super(TripletNet, self).__init__()
self.embedding_net = embedding_net
def forward(self, x1, x2, x3):
output1 = self.embedding_net(x1)
output2 = self.embedding_net(x2)
output3 = self.embedding_net(x3)
return output1, output2, output3
def get_embedding(self, x):
return self.embedding_net(x)

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: