본문 바로가기
AI/메타 러닝

[Image Classification] One-Shot Learning

by 박서현 2022. 12. 23.

One-Shot Learning

이미지 분류 문제에서는 Few-Shot Learning을 K-way N-shot Learning이라고도 합니다. way는 학습 데이터로 사용한 클래스의 수입니다. 예를 들어 남자와 여자 사진을 하나씩 갖고 학습하는 것은 2-way 1-shot Learning입니다.

클래스별로 학습 데이터가 1개씩만 있는 경우 One-Shot Learning이라고 합니다. 이미지 분류에서는 클래스별로 이미지 한 장씩 있는 경우입니다. One-Shot Learning이 작동하는 방식은 클래스 수 K만큼의 이미지로 데이터 베이스를 구성한 다음, 새로운 이미지가 주어졌을 때 데이터 베이스에서 가장 유사한 이미지의 클래스로 분류하는 것입니다. 따라서 우리는 두 이미지 사이에 유사도를 계산하면 됩니다.

Train Siamese Network

One-Shot Learning에서는 유사도를 구하기 위해서 별도의 Neural Network를 이용합니다. 이를 Siamese Network라고 하는데 학습 방식은 다음과 같습니다.

호랑이 이미지 2개(x1, x2)를 CNN(f)에 입력해 Feature Vector (h1, h2)를 얻습니다. 그리고 두 Feature Vector 차의 절대값(z)을 구합니다. 

z를 Dense Layers에 한번 더 입력해 나온 값에 Sigmoid 함수를 취하면 우리가 구하고자 했던 두 이미지 사이의 유사도sim(x1, x2)를 구할 수 있습니다. sim(x1, x2)는 Sigmoid 함수값이기 때문에 0과 1사이의 값을 가집니다.

x1과 x2는 같은 클래스(호랑이)이기 때문에 유사도는 1이여야 합니다. 따라서 Target 값 1과 sim(x1, x2)와 1 사이에 Cross Entropy Loss 값을 구할 수 있습니다.

Loss 값으로 Update하는 것은 CNN(f)와 Dense Layers입니다.

두 이미지가 다른 클래스라면 Target 값을 0으로 두고 Loss 값을 계산하면 됩니다.

One-Shot Prediction

Siamese Network를 활용해서 6-way 1-shot preciction을 어떻게 하는지 살펴보겠습니다. 6-way 1-shot이니깐 이미지 분류를 하기 위해 필요한 이미지 데이터는 6개입니다. 이렇게 미리 준비하는 데이터셋을 Support Set이라고 합니다. 위 이미지를 보면 여우, 다람쥐, 토끼, 햄스터, 수달, 비버를 분류하는 문제를 풀 수 있도록 Support Set을 구성했습니다. 참고로 Siamese Network 학습은 Support Set과 관계없이 진행합니다. 즉, Siamese Network를 학습하는 데이터는 Support Set의 클래스를 포함하지 않을 수 있습니다. 

클래스를 분류하고 싶은 데이터를 Query라고 합니다. Query가 주어지면 아래와 같이 Query와 Support Set 사이에 유사도를 구합니다.

  • sim(Query, Fox) = 0.2
  • sim(Query, Squirrel) = 0.9
  • sim(Query, Rabbit) = 0.7
  • sim(Query, Hamster) = 0.5
  • sim(Query, Otter) = 0.3
  • sim(Query, Beaver) = 0.4

Support Set 중 Query와 유사도가 가장 높은 이미지는 Squirrel이기 때문에, 최종적으로 "Query는 Squirrel이다"라고 예측합니다.

Train Siamese Network with Triplet Loss

Siamese Network를 학습할 수 있는 다른 방법이 있습니다. 바로 Triplet Loss를 활용하는 것입니다. Triplet Loss를 이용하기 위해서는 데이터를 3개씩 묶어서 사용해야 합니다. 

데이터 구성

위와 같이 Siamese Network를 학습할 데이터셋이 있으면, 학습을 위해서는 아래와 같이 재구성해야 합니다.

  1. Anchor: 랜덤하게 하나의 데이터를 선택합니다.
  2. Positive: Anchor와 같은 클래스의 데이터 중 무작위로 하나를 선택합니다.
  3. Negative: Anchor와 다른 클래스 중 무작위로 하나를 선택합니다.

학습

Anchor와 Positive는 호랑이 이미지, Negative는 코끼리 이미지로 구성했습니다. Triplet Loss는 같은 클래스인 Positive와 Anchor는 유사도가 높게, 다른 클래스인 Negative와 Anchor는 유사도가 낮게 나오게끔 하는 것이 목적입니다.

위에서 언급했던 것처럼 유사도를 구하기 위해서는 먼저 이미지를 벡터로 만들어야 합니다. Siamese Network가 바로 이 기능을 합니다. 세 이미지를 Siamese Network(f)에 입력해 Feature Vector(f(x+), f(xa), f(x-)를 계산합니다. 유사도 계산은 L2 Distance를 활용합니다. 위의 오른쪽 이미지에는 Feature Space 상에서 f(x+), f(xa) 사이의 거리 d+와 f(x-), f(xa) 사이의 거리 d-를 직관적으로 표현했습니다.

L2 Distance는 "거리" 개념이기 때문에 두 벡터가 유사할 수록 거리 값이 작아야 하고, 다를 수록 거리 값이 커야 합니다. 따라서 d+는 d-보다 작아야 합니다. Triplet Loss는 이러한 사실을 이용합니다. d+가 d-보다 작으면 수정해야할 것이 없기 때문에 이 경우 Loss는 0입니다. 그리고 d+가 d-보다 작으면 그 차이만큼 Loss로 사용합니다.

  • if d- >= d+, then loss = 0
  • otherwise, loss = (d+) - (d-)

그런데 d+가 0.5, d-가 0.5001이면 "d+가 d-보다 작으니 문제없어" 라고 말할 순 없습니다. 이런 경우 d+가 d-보다 적어도 α만큼은 더 작아야 한다고 조건을 달면 해결할 수 있습니다. 예를들어 d+가 d-보다 0.1보다 더 작아야 한다면, 0.5와 0.5001은 Loss값을 가져야 합니다. 이때 α는 margin이라고 합니다.

  • Loss = max(0, (d+) + α - (d-))
    • if d- >= (d+) + α, then loss = 0
    • otherwise, loss = (d+) + α - (d-)

학습 데이터 구성 시 유의사항

위 데이터 구성에서 Anchor, Positive, Negative 데이터를 선정할 때 무작위로 추출했습니다. 하지만 호랑이와 코끼리 데이터를 보면 알 수 있듯이 무작위로 추출한 데이터는 Positive와 Negative가 구분하기 너무 쉬운 경우가 많습니다. 이 경우 당연히 d-가 d+보다 훨씬 작을 것입니다. 효율적이고 효과적으로 Siamese Network를 학습하기 위해서는 Anchor와 구분하기 어려운 Negative 데이터를 선택해야 합니다. https://www.youtube.com/watch?v=d2XB5-tuCWU를 참고하면 구분하기 어려운 데이터를 선택하는 방법은 2015년에 발표된 FaceNet 논문에 잘 설명되어 있다고 합니다. 나중에 기회되면 이 내용도 공부해서 포스팅하겠습니다.

One-Shot Prediction

코사인 유사도를 이용할 때와 Prediction과정은 동일합니다. 다만 Query 이미지와 Support Set 사이에 거리(dist)를 계산하기 때문에 값이 가장 작은 이미지의 클래스(Squirrel)로 최종 예측합니다.

 

참고 : https://www.youtube.com/watch?v=4S-XDefSjTM

'AI > 메타 러닝' 카테고리의 다른 글

[Image Classification] Fine Tune Few-Shot Learning  (0) 2023.01.05
Transfer Learning VS Few-Shot Learning  (0) 2022.12.22
메타 러닝 (Meta Learning)  (0) 2022.12.19