Wooks_learning

[논문리뷰] Zero shot text-to-image generation 본문

딥러닝/논문 리뷰

[논문리뷰] Zero shot text-to-image generation

Wooks_ 2023. 9. 6. 22:23

[발표자료] Zero-Shot Text-to-Image Generation.pdf
2.78MB

오늘 리뷰할 논문은 DALL.E라고 잘 알려져 있는 text-to-image generation 모델에 대해 집중적으로 알아볼 예정이다.

리뷰 순서는 아래와 같이 리뷰할 예정이다.

 

1. Abstract

2. Introduction

3. Method

    - Stage 1.

    - Stage 2.

4. Experiments

 

1. Abstract

전통적으로 Text-to-image generation의 경우 특정 데이터 셋에 잘 작동하는 모델링 기법을 찾는데에 집중되어 있었다. 해당 과정에서 모델의 복잡한 구조, 추가적인 손실함수, 추가적인 라벨 작업이 필수적인 요소로써 요구되었다. DALLE의 경우 text-to-image generation을 text와 image 토큰들의 single stream 형태로 받아 transformer의 특징인 autoregressive한 성질을 이용한 접근법을 제시하였다.

 

2. Introduction

"Generative adversarial text to image synthesis" 논문에서처럼 GAN을 이용하여 text를 주었을 때 고품질의 이미지를 생성할 수 있도록 하였는데, 해당 논문에서 zero-shot 일반화가 가능하다는 것도 실험적으로 보여주었다.

 

그렇다면 생성 관점에서의 zero-shot이란 무엇일까?

그림 1. Zero-shot의 예시

그림 1에서 보이는 것처럼 학습시 보지 못했던 text에 대해 의미론적으로 이질감이 없는 이미지를 잘 생성하는 것. 즉, 문맥과 객체를 잘 이해하여 zero-shot(학습때 보지못한 데이터) 상황에서 일반화가 잘 된 것을 zero-shot 일반화가 가능하다라고 이해할 수 있을것 같다.

 

Introduction을 이어서 설명하면 transformer의 성공적인 데뷔에 대해 언급하고 있다. 모델의 크기와 데이터의 양을 방대한 양으로 늘려서 autoregressive한 방법을 이용하여 인상적인 결과를 얻었다. 이를 텍스트에 국한하지 않고 text-to-image generation에 적용해보면 좋은 결과가 나오지 않을까? 하는 의문에서 해당 논문이 출발하게 된다.

 

3. Method

결국 본 논문은 앞에서 설명한 것처럼 transformer를 이용하여 autoregressive한 특성을 잘 살려볼것인데, 이미지 pixel 하나하나를 autoregressive하게 생성하게 된다면 연산량이 상당할 것이다. 따라서 해당 문제를 완화시키기 위해 transformer는 pixel이 quantized된 image token들을 예측하는 문제로 바꾸게 되었다. 또한 pixel CNN의 단점도 언급하고 있는데, pixel CNN의 경우 픽셀간의 짧은 범위의 종속성을 모델링하는 것을 likelihood objective function에 의해 우선시 되는 경향이 있다고 이야기한다. 즉 모델이 고주파 성분을 캡처하는데에 보다 더 많은 노력을 하겠다는 이야기로 해석할 수 있다.

그림 2. 픽셀간 종속성 설명을 위한 예시 그림

어떤 이야기인지 조금 더 풀어서 설명하자면, 그림 2에서 빨간 박스를 통해 초록색 박스를 예측하는 구조로 pixel CNN이 동작하게 된다. 해당 이미지에서 저주파 성분이라고 하면 고양이의 전체적인 실루엣을 의미할 것이고, 고주파 성분은 고양이의 디테일한 털, 눈동자의 디테일 등을 의미하게 될 것이다. 따라서 pixel CNN을 이용했을 때 고주파 성분을 주로 학습하기 때문에 물체의 전반적인 엣지, 실루엣 등을 잘 표현하지 못하는 문제가 생긴다.

 

따라서 본 논문에서는 image를 표현할 수 있는 8192개의 code book을 학습하고, 해당 code book과 transformer의 autoregressive한 성질로 input text가 들어왔을 때 code의 나열을 이쁘게하여 고품질의 이미지를 생성하는 것이 목적이라고 할 수 있다.

 

3.1 Stage 1

그림 3. VQ-VAE 모델 구조

stage 1에서는 image만을 이용하여 image의 code book을 사용한다. VQ-VAE의 경우 워낙 유명하고 정리되어 있는 글이 많아 해당 글에서는 디테일하게 다루진 않겠지만, 간략하게 흐름 정도만 다루어 보겠다.

 

먼저, input image를 CNN encoder에 넣어 code book의 index가 배열에 잘 박힐 수 있도록 학습이 진행된다. 이 과정에서 code book 또한 학습되고, 생성된 q(z|x)를 decoder에 넣어 reconstruction을 잘 할 수 있도록 encoder, code book, decoder를 학습하게 된다.

 

DALLE에서 사용한 VQ-VAE 2(dVAE)는 위에서 설명한 내용을 조금 확장하여 hierarchical한 모델 구조를 띄게 된다.

그림 4. VQ-VAE(dVAE)의 구조

일반적으로 DNN은 하위 레벨로 갈수록 세부적인 특징을 학습하고, 상위 레벨로 갈수록 더 복잡한 특징(형태, 객체) 등을 학습하는 것으로 알려져 있다. 따라서 해당 구조를 사용하면 당연하게도 scale에 따른 표현을 더 세세하게 학습할 수 있으므로 고주파, 저주파 성분을 학습할 때 큰 이점을 살릴 수 있다. 

 

그림 5. VQ-VAE2의 multi resolution에 대한 고찰

그림 5에서 표현한 것 처럼 각각의 layer에서 추출한 feature들이 이미지를 복원할 때 저주파, 고주파의 의미들을 잘 담고있는 예시로써 표현하고 있다.

 

결국, stage 1에서는 이미지의 코드북을 잘 생성하고, 매핑할 수 있는 것을 목표로 한다. 하지만 여기서 하나의 문제점이 존재한다. 우리가 코드북을 매핑할 때 argmax등을 이용할 수 있는데, 이렇게 되면 연속적이지 않게 되므로 미분을 할 수 없다. 즉 학습을 할때 방해가 되는 요소로써 작용하게 되는데, 이를 gumbel-softmax를 이용하여 미분이 가능하도록 reparameterization trick을 사용하였다.

 

그림 6. Gumbel-softmax 수식

 

우리가 흔히 아는 softmax의 형태를 띄고있다. 다만, gi와 타우가 추가되어있는 것을 확인할 수 있는데, 각각의 역할은 다음과 같다.

 

gi : 해당 값을 이용하여 약간의 랜덤성을 추가해줌.

타우 : temperature의 역할을 하여 output distribution을 완만하거나, categorical하게 만들어주는 역할을 함.

해당 설명으로 축약할 수 있다.

 

그림 6. 타우에 따른 output distribution

조금 더 자세하게 설명하자면, 타우를 이용하여 매핑되는 코드북의 확률 분포가 완만하게 나올지, confidence가 높은 것을 주력으로 뽑아낼지를 결정할 수 있고, gi를 이용하여 이전과는 다른 코드북이 매핑될 수 있다는 것을 의미한다.

 

이렇게 stage 1.에서는 VQ-VAE를 이용하여 코드북을 잘 표현하고, 매핑하는 단계를 거치게 된다.

 

3.2 Stage 2

 

stage 2의 경우 transformer 구조를 이용하여 text를 input으로 받고, 해당 text를 바탕으로 codebook의 token index를 잘 예측하는 것이 목적인 단계이다. 즉, stage1에서 이미지를 이용해서 codebook으로 잘 표현할 수 있는 dictionary를 만들었다면, 해당 dictionary를 순서대로 잘 나열할 수 있게 만들어주는 것이 stage 2에서 하는 목적이자, DALLE의 최종적인 지향점이라고 할 수 있다. 이게 어떻게 학습되느냐는 굉장히 간단하다.

 

1. transformer(혹은 GPT)를 이용해서 text를 input으로 받는다.

2. 해당 input을 바탕으로 token들을 autoregressive하게 예측한다.

3. 32x32개의 token을 모두 예측할때까지 지속.

4. VQ-VAE decoder를 이용하여 이미지를 복원한다.

 

위의 flow를 따라 학습이 진행되는데 여기서 transformer가 뱉은 output이 정확하게 예측되었는지 측정할 수 있는 비교군이 필요하다. 따라서 stage 1에서 학습한 va-vae encoder를 이용하여 정답 label을 구축한다. 그림으로 표현하자면 다음과 같다.

그림 7. Stage 2의 전체적인 학습 flow

조금 더 자세하게 설명하면, transformer가 뱉은 32x32개의 token을 VA-VAE가 뱉은 32x32개의 codebook 나열과 cross-entropy를 이용하여 정확한 token index를 출력하도록 유도한다.

 

 

4. Experiments

그림 8. Qualitative results

DALL-E를 통해 생성된 결과들은 다른 모델들에 비해 굉장히 사실적이고, 디테일하게 생성되는 것을 확인할 수 있다. 그림 8에서 우측읜 빨간 박스가 DALL-E가 생성한 결과이고, 나머지가 타 모델이 생성한 결과인것을 보면, text-to-image generation task에서 왜 유명한 논문이 되었는지 결과로써 바로 확인할 수 있다.

 

그림 9. Qualitative Findings

또한, DALL-E의 새로운 특성을 발견하였는데, 색다른 개념을 조합하는 일반화 능력을 DALL-E가 가지고 있다는 것을 알 수 있었다. 해당 특성으로 인해 zero-shot 상황에서도 잘 생성할 수 있는 기반이 되지 않았을까 생각이 든다.

 

여러가지 실험 내용이 존재하지만, 첨부한 발표자료를 통해 확인하면 좋을 것으로 예상된다.

Comments