해당글은 medium 원문을 읽고, 이해하는 글입니다.
이해 못했으면 댓글달아주세요 +_+
A Loss Function Suitable for Class Imbalanced Data: “Focal Loss”
Deep Learning with Class Imbalanced Data
towardsdatascience.com
1단계 검출기는 빠르지만, 정확도가 2 단계 검출기의 10~40% 정도입니다.
이때, 1단계 검출기가 2단계 검출기만큼 정확도를 얻지 못하는 주요 요인은, 클래스 불균형입니다.
이미지에서 20000개의 bouding box(bb)가 있다면, 검출 하려는 객체는 7~8개에 불과하며, 나머지 bb는 배경입니다.
검출 target에 비해 negative case가 너무 많기 때문에, 클래스간 불균형이 발생합니다.
Cross entropy에서 잘못 분류된 객체가 올바르게 분류된 객체보다 패널티를 많이 받지만,
background에서 추출된 샘플(easy sample)이 수량이 너무 많기때문에, loss function은 제대로 학습할 수 없습니다.
(=> 예를들어 패널티 (5점 x 1개 ) 줘봐야 수량(0.1점 x 10000개) 으로 압도당합니다 .
이를 보완하기 위한 focal loss 수식은 아래와 같습니다.
(1-pt)^r 이면, pt 값 (예측값 낮 : 0.0 / 예측값 높: 1.0) ==> 따라서 (1-pt)값 : (1.0~ 0.0)^r 가 가중치로 부여됨.
위 식이 어떻게 적용되는지 아래 코드를 통해 확인 할 수 있습니다.
import tensorflow as tf
from tensorflow import keras
y_example_true = [1.0]
y_example_pred = [0.90]
y_example_pred0 = [0.95]
y_example_pred1 = [0.20]
binary_ce = tf.keras.losses.BinaryCrossentropy()
loss_close = binary_ce(y_example_true, y_example_pred).numpy()
loss_veryclose = binary_ce(y_example_true, y_example_pred0).numpy()
loss_far = binary_ce(y_example_true, y_example_pred1).numpy()
print ('CE loss when pred is close to true: ', loss_close)
print ('CE loss when pred is very close to true: ', loss_veryclose)
print ('CE loss when pred is far from true: ', loss_far)
focal_factor_close = (1.0-0.90)**2 ## (take gamma = 2, as in paper)
focal_factor_veryclose = (1.0-0.95)**2
focal_factor_far = (1.0-0.20)**2 ## ()
print ('\n')
print ('focal loss when pred is close to true: ', loss_close*focal_factor_close)
print ('focal loss when pred is very close to true: ', loss_veryclose*focal_factor_veryclose)
print ('focal loss when pred is far from true: ', loss_far*focal_factor_far)
y_example_true = [1.0] ==> 정답
y_example_pred = [0.90] ==> 잘예측
y_example_pred0 = [0.95] ==> 아주 잘예측
y_example_pred1 = [0.20] ==> 예측 잘 못함.
위 코드를 예측해보면,
객체를 해당 클래스로 잘 분류하면 prediction 값이 1에 가깞고, (p = 0.90, 0.95 등)
잘 못 분류하면 prediction 값이 낮습니다 ( p = 0.2)
y=1인 클래스를 잘 예측 한 경우 (1-p) 가 되고 (1-0.95) 이므로, 0.05 값이 곱해지므로, 반영이 적을 것이고
대신 y=1인 클래스에서 잘 못 예측했으면, (1-p ) -> (1-0.2) => 0.8 이므로 많이 반영됩니다.
이제 위 코드를 출력해보면, 아래와 같습니다.
CE loss when pred is close to true: 0.10536041
CE loss when pred is very close to true: 0.051293183
CE loss when pred is far from true: 1.6094373
focal loss when pred is close to true: 0.0010536041110754007
focal loss when pred is very close to true: 0.00012823295779526255
focal loss when pred is far from true: 1.0300399017333985
CE loss에서 정답을 맞췄을때의 패널티가, 틀렸을때의 패널티와 32배 (=1.6 / 0.05) 차이나지만,
focal loss에서 정답을 맞췄을때의 패널티가, 틀렸을때의 패널티와 10000배 (=1.03/ 0.0001) 차이남을 알수있습니다.
focal loss에서는 정답을 틀렸을때의 상당히 많은 패널티를 줌을 알 수 있습니다.
## 추가 예제
kaggle credit card fraud를 살펴보면, class imbalance data set 을 확인할수 있는데...
불균형한 데이터 셋이고, 각각 ce loss 랑, focal loss 살펴보면
recall이 0.73 -> 0.84 로 증가했고, 개수로 따지면,
CE 가 전체 148 case 중 69 사례 예측을 성공한 반면, focal loss가 101 개를 올바르게 감지하였습니다.