VAE설명

2023. 4. 29. 12:56VAE에 대한 정리

VAE란?

 

VAE는 Input image X를 잘 설명하는 feature를 추출하여 Latent vector z에 담고, 이 Latent vector z를 통해 X와 유사하지만 완전히 새로운 데이터를 생성해내는 것을 목표로 함.

 

이때 각 feature는 가우시안 분포를 따른다고 가정하며 latent z는 각 feature의 평균과 분산값을 나타냄 마치 AE의 decoder처럼 latent vector로부터 이미지를 생성해낸다고 보면 된다.

 

latent vector z는 정보를 담고 있음

 

p(z): latent vector z의 확률밀도 함수, 가우시안 분포를 따른다고 가정함. 즉, latent vector z를 sampling 할 수 있는 확률밀도함수

 

p(x|z): 주어진 z에서 특정 x가 나올 확률에 대한 확률밀도함수

 

θ: 모델의 파라미터

 

VAE는 Input Image X를 Encoder에 통과시켜 Latent vector z를 구하고, Latent vector z를 다시 Decoder에 통과시켜 기존 input image X와 비슷하지만 새로운 이미지 X를 찾아내는 구조를 가지고 있다.

 

Encoder

 

 

 

Reparameterization Trick (Sampling)

Decoder

Loss Fucntion

Encoder Code

img_shape = (28,28,1)
batch_size = 16
latent_dim = 2

input_img = keras.Input(shape = img_shape)
x = layers.Conv2D(32,3,padding='same',activation='relu')(input_img)
x = layers.Conv2D(64,3,padding='same',activation='relu',strides=(2,2))(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
x = layers.Conv2D(64,3,padding='same',activation='relu')(x)

shape_before_flattening = K.int_shape(x) # return tuple of integers of shape of x

x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)

z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)

Conv2D: 이미지를 다루는 딥러닝에서 핵심적인 요소로 쓰이며, CNN 구조이다.

 

Activation function: 입력된 데이터의 가중 합을 출력 신호로 변환하는 함수이다.

 

ReLU:  x 가 0 이하일 때를 차단하여 아무 값도 출력하지 않고 0 을 출력하며, 따라서 ReLU 함수를 '정류된 선형 함수' 라고 할 수 있다.

 

Reparameterization Trick (Sampling) Code

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0],latent_dim),mean=0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon

z = layers.Lambda(sampling)([z_mean, z_log_var])

Decoder Code

decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flattening[1:]),activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flattening[1:])(x)
x = layers.Conv2DTranspose(32,3,padding='same',activation='relu',strides=(2,2))(x)
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)

decoder = Model(decoder_input, x)
z_decoded = decoder(z)

reshape: numpy.ndarray의 차원과 모양을 바꿔준다.

 

Loss Fucntion Code

 def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)
        kl_loss   = -5e-4*K.mean(1+z_log_var-K.square(z_mean)-K.exp(z_log_var),axis=-1)
        return K.mean(xent_loss + kl_loss)