Focal loss란 무엇일까?

Object detection에서는 class imbalance 문제가 있었다.
이미지를 보면 객체가 이미지를 꽉 차지하기 보다 배경이 훨씬 이미지를 많이 차지하는 것을 알 수 있다.
One-stage detector에서 훈련할 때 이러한 배경 샘플이 너무 많고 객체 샘플이 적은 class imbalance 문제가 Cross entropy loss에 영향을 주는 것으로 나타났다. 쉽게 분류되는 배경 즉, negative sample이 손실의 대부분을 차지하며 기울기까지 영향을 주게 되는 것이다.
왜 클래스 불균형 문제가 문제인 것일까?


대부분의 배경 샘플은 너무 쉽게 분류가 된다. 위 cross entropy loss 식을 보면 p 값이 클수록 (= 정확한 예측일수록) 손실이 작아지고 p 값이 작을수록 (= 잘못된 예측일수록) 손실이 커진다.
결국 배경 샘플은 쉽게 분류되어 p 값이 높아지고 손실이 작아지는 것이다. 하지만 Cross entropy loss는 작은 손실도 합산하기 때문에 많은 쉬운 샘플이 총 손실을 지배하게 된다. 결국 모델이 어려운 하드 샘플보다 쉬운 샘플을 더 신경쓰게 되고 결과적으로 진짜 중요한 객체 탐지 성능을 제대로 학습하지 못하게 되는 것이다.
특히 모델이 훈련할 때 손실 함수의 Gradient를 기반으로 학습이 진행되는데 즉, 손실이 클수록 (잘못된 예측일수록) Gradient 업데이트가 크고 손실이 작을수록 영향이 적다. 때문에 클래스 불균형 상태에선 배경 샘플이 너무 많기 때문에 쉽게 분류된 Background 샘플이 총 Gradient 업데이트를 차지하게 되어 정작 중요한 어려운 샘플은 학습되지 않게 된다.
좀 더 쉽게 설명한 사이트가 있어서 해당 사이트의 그림을 첨부해보았다.



위에서 계속 설명한 바와 같이 모델은 어려운 문제 하나를 맞춰서 학습을 진행하기 보다 쉬운 문제 여러 개를 맞춰서 loss를 줄이고자 하는 것이다.
이러한 문제를 해결하기 위해 RetinaNet이라는 모델을 제안한 논문에서 Focal loss를 제안했다.
Focal loss는 쉽게 분류되는 샘플의 손실을 자동으로 낮추고 어려운 샘플에 집중하도록 설계되었다.

Cross entropy loss에 조정 계수(modulating factor) (1-p_t)^γ 를 추가했다. p_t 가 클수록 (= 쉽게 분류될수록) (1-p_t)^γ 값이 작아지고 손실이 줄어든다. 반대로 p_t가 작을수록 (= 어려운 샘플일수록) (1-p_t)^γ 값이 커지고 손실이 유지된다. 이러한 메커니즘을 통해 모델이 하드 샘플 (= 탐지가 어려운 작은 객체)에 더 집중하도록 유도한다.

조정 강도 γ 의 역할은 무엇일까?
γ = 0이면 Focal loss는 그냥 Cross entropy loss와 동일하다. 즉, γ 값이 클수록 쉬운 샘플의 손실 기여도를 더 줄일 수 있는 것이다. 논문에서 실험한 결과, γ = 2일 때 가장 좋은 성능을 보였다.
- p_t = 0.9 (거의 맞춘 샘플) → 손실이 Cross entropy loss 대비 100배 감소
- p_t = 0.968 (완전히 맞춘 샘플) → 손실이 Cross entropy loss 대비 1000배 감소
- p_t <= 0.5 (잘못된 샘플) → 손실이 크게 줄어들지 않음 (4배 정도만 감소) → 여전히 학습 가능
Focal loss는 object detection에만 사용될 수 있을까?
Focal loss는 object detection 전용 손실 함수가 아니다. 즉, Cross-entropy 손실을 변형한 형태로 잘 분류된 샘플의 손실 기여도를 줄이고 어려운 샘플의 손실 기여도를 높이는 방식이다.
때문에 class imbalance 문제가 있는 어떤 분류 문제에도 응용할 수 있다.
실제로 classification에 focal loss를 적용하고자 하는 글도 있고, 논문도 존재한다.. 그렇다..
https://discuss.pytorch.org/t/focal-loss-for-imbalanced-multi-class-classification-in-pytorch/61289
'헷개정 - 헷갈리는 개념 정리' 카테고리의 다른 글
| Word2Vec (0) | 2025.10.05 |
|---|---|
| 왜 RNN, LSTM, GRU에선 ReLU를 안 쓰고 Sigmoid, Tanh를 사용할까? (0) | 2025.10.05 |
| U-Net이 왜 성능이 좋을까? (0) | 2025.09.26 |
| CNN은 시각 데이터 처리에만 사용이 가능한 네트워크일까? (0) | 2025.09.26 |
| CNN이 시각 데이터를 처리하기 위한 가장 우수한 인공 신경망일까? (0) | 2025.09.26 |