Skip to content

PixelCNN: Pixel Pixel Görüntü Üretme Sanatı

Published: at 04:18 PMSuggest an edit

Selam! Bugün biraz farklı bir image generation yaklaşımından bahsedeceğiz: PixelCNN. GANs’lerden farklı olarak, PixelCNN tamamen farklı bir felsefe benimsiyor - her pixel’i bir öncekine bağlı olarak üretiyor. Autoregressive modellerin görüntü dünyasındaki karşılığı diyebiliriz.

PixelCNN Nedir?

PixelCNN, 2016 yılında van den Oord ve arkadaşları tarafından önerilen bir autoregressive image generation modeli. Temel fikir çok basit: bir görüntüyü pixel pixel üret, her yeni pixel’i daha önce ürettiğin pixel’lere göre belirle.

Tıpkı bir cümleyi kelime kelime yazmak gibi - her kelime bir öncekine bağlı. Ama burada “kelimeler” pixel’ler ve “cümle” de görüntü.

Autoregressive Modellerin Mantığı

Bir görüntüyü p(x) probability distribution’ı olarak düşünelim. Bu distribution’ı chain rule ile parçalayabiliriz:

p(x) = p(x₁) × p(x₂|x₁) × p(x₃|x₁,x₂) × ... × p(xₙ|x₁,...,xₙ₋₁)

Yani her pixel’in probability’si, kendinden önceki tüm pixel’lere bağlı!

Masked Convolution: Geleceği Görmeme Sanatı

PixelCNN’in kilit özelliği masked convolution. Normal convolution’da filter tüm neighboring pixel’lara bakabilir, ama PixelCNN’de sadece daha önce üretilmiş pixel’lara bakabilir.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class MaskedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, mask_type='A'):
        super(MaskedConv2d, self).__init__()
        self.mask_type = mask_type

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
                             padding=kernel_size//2, bias=False)

        # Create mask
        mask = torch.ones(kernel_size, kernel_size)
        mask[kernel_size//2, kernel_size//2+1:] = 0  # Right side
        mask[kernel_size//2+1:, :] = 0  # Bottom rows

        if mask_type == 'A':
            # First layer: exclude current pixel
            mask[kernel_size//2, kernel_size//2] = 0

        self.register_buffer('mask', mask)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        return self.conv(x)

# Mask visualization
def visualize_mask(kernel_size=5, mask_type='A'):
    mask = torch.ones(kernel_size, kernel_size)
    mask[kernel_size//2, kernel_size//2+1:] = 0
    mask[kernel_size//2+1:, :] = 0

    if mask_type == 'A':
        mask[kernel_size//2, kernel_size//2] = 0

    print(f"Mask type {mask_type}:")
    print(mask.numpy())

# Görselleştirelim
visualize_mask(5, 'A')
visualize_mask(5, 'B')

PixelCNN Architecture

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = MaskedConv2d(channels, channels//2, 1, mask_type='B')
        self.conv2 = MaskedConv2d(channels//2, channels//2, 3, mask_type='B')
        self.conv3 = MaskedConv2d(channels//2, channels, 1, mask_type='B')

    def forward(self, x):
        residual = x
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        out = self.conv3(out)
        return F.relu(out + residual)

class PixelCNN(nn.Module):
    def __init__(self, num_layers=7, num_feature_maps=128, num_colors=256):
        super(PixelCNN, self).__init__()

        # First layer uses mask type A
        self.first_layer = MaskedConv2d(3, num_feature_maps, 7, mask_type='A')

        # Hidden layers use mask type B
        self.hidden_layers = nn.ModuleList([
            ResidualBlock(num_feature_maps) for _ in range(num_layers)
        ])

        # Output layers
        self.output_conv1 = MaskedConv2d(num_feature_maps, num_feature_maps, 1, mask_type='B')
        self.output_conv2 = MaskedConv2d(num_feature_maps, num_colors, 1, mask_type='B')

    def forward(self, x):
        x = F.relu(self.first_layer(x))

        for layer in self.hidden_layers:
            x = layer(x)

        x = F.relu(self.output_conv1(x))
        x = self.output_conv2(x)

        return x

Training Süreci

PixelCNN’i train etmek GANs’e göre çok daha straightforward:

def train_pixelcnn(model, dataloader, num_epochs=100, learning_rate=0.001):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0

        for batch_idx, (data, _) in enumerate(dataloader):
            data = data.to(device)

            # Convert to discrete values (0-255)
            data_discrete = (data * 255).long()

            optimizer.zero_grad()

            # Forward pass
            logits = model(data)

            # Calculate loss for each pixel
            batch_size, channels, height, width = data_discrete.shape

            loss = 0
            for c in range(channels):
                loss += criterion(
                    logits[:, c*256:(c+1)*256, :, :].contiguous().view(-1, 256),
                    data_discrete[:, c, :, :].contiguous().view(-1)
                )

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')

        avg_loss = total_loss / num_batches
        print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')

# Training için data preprocessing
def preprocess_for_pixelcnn(images):
    # [0, 1] range'den [0, 255] discrete values'a
    discrete_images = (images * 255).long()
    return discrete_images

Sampling: Görüntü Üretme

PixelCNN’den sampling yapmak sequential bir süreç:

def sample_from_pixelcnn(model, shape=(1, 3, 32, 32), device='cuda'):
    model.eval()

    with torch.no_grad():
        # Start with empty image
        sample = torch.zeros(shape).to(device)

        batch_size, channels, height, width = shape

        for i in range(height):
            for j in range(width):
                for c in range(channels):
                    # Get logits for current pixel
                    logits = model(sample)

                    # Get probability distribution for current pixel
                    pixel_logits = logits[:, c*256:(c+1)*256, i, j]
                    pixel_probs = F.softmax(pixel_logits, dim=1)

                    # Sample from distribution
                    pixel_sample = torch.multinomial(pixel_probs, 1)

                    # Set pixel value (normalize to [0, 1])
                    sample[:, c, i, j] = pixel_sample.float() / 255.0

    return sample

# Batch sampling için optimize edilmiş versiyon
def fast_sample_batch(model, num_samples=10, image_size=32):
    model.eval()
    device = next(model.parameters()).device

    samples = torch.zeros(num_samples, 3, image_size, image_size).to(device)

    with torch.no_grad():
        for i in range(image_size):
            for j in range(image_size):
                # Predict all channels at once
                logits = model(samples)

                for c in range(3):
                    channel_logits = logits[:, c*256:(c+1)*256, i, j]
                    probs = F.softmax(channel_logits, dim=1)
                    pixel_values = torch.multinomial(probs, 1).squeeze()
                    samples[:, c, i, j] = pixel_values.float() / 255.0

    return samples

PixelCNN ve GANs Karşılaştırması

Her iki yaklaşımın da güçlü ve zayıf yanları var:

PixelCNN’in Avantajları:

# 1. Stable training
def stable_training_demo():
    # No adversarial training needed
    # No mode collapse
    # Tractable likelihood calculation
    pass

# 2. Exact likelihood calculation
def calculate_likelihood(model, image):
    """Calculate exact likelihood of an image"""
    with torch.no_grad():
        logits = model(image.unsqueeze(0))
        log_likelihood = 0

        for i in range(image.shape[1]):
            for j in range(image.shape[2]):
                for c in range(image.shape[0]):
                    pixel_logits = logits[0, c*256:(c+1)*256, i, j]
                    pixel_value = int(image[c, i, j] * 255)
                    log_likelihood += F.log_softmax(pixel_logits, dim=0)[pixel_value]

        return log_likelihood.item()

PixelCNN’in Dezavantajları:

# Slow sampling
import time

def compare_generation_speed():
    start_time = time.time()

    # PixelCNN: Sequential generation
    # 32x32 image = 1024 forward passes
    pixelcnn_sample = sample_from_pixelcnn(model, (1, 3, 32, 32))
    pixelcnn_time = time.time() - start_time

    start_time = time.time()

    # GAN: Parallel generation
    # Single forward pass
    noise = torch.randn(1, 100)
    gan_sample = generator(noise)
    gan_time = time.time() - start_time

    print(f"PixelCNN generation time: {pixelcnn_time:.2f}s")
    print(f"GAN generation time: {gan_time:.4f}s")

Gated PixelCNN İyileştirmesi

Orijinal PixelCNN’in “blind spot” problemi vardı. Gated PixelCNN bunu çözüyor:

class GatedBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(GatedBlock, self).__init__()

        # Vertical stack (captures dependencies from above)
        self.vertical_conv = MaskedConv2d(in_channels, 2*out_channels,
                                        (kernel_size//2 + 1, kernel_size),
                                        mask_type='B')

        # Horizontal stack (captures dependencies from left)
        self.horizontal_conv = MaskedConv2d(in_channels, 2*out_channels,
                                          (1, kernel_size//2 + 1),
                                          mask_type='B')

        # Connection from vertical to horizontal
        self.v_to_h = nn.Conv2d(2*out_channels, 2*out_channels, 1)

        # Residual connection
        self.residual = nn.Conv2d(out_channels, in_channels, 1)

    def forward(self, v_input, h_input):
        # Vertical stack
        v_out = self.vertical_conv(v_input)
        v_out_tanh, v_out_sigmoid = torch.chunk(v_out, 2, dim=1)
        v_out = torch.tanh(v_out_tanh) * torch.sigmoid(v_out_sigmoid)

        # Horizontal stack
        h_out = self.horizontal_conv(h_input)

        # Add vertical information to horizontal
        v_to_h = self.v_to_h(v_out)
        h_out = h_out + v_to_h

        h_out_tanh, h_out_sigmoid = torch.chunk(h_out, 2, dim=1)
        h_out = torch.tanh(h_out_tanh) * torch.sigmoid(h_out_sigmoid)

        # Residual connection
        h_residual = self.residual(h_out)
        h_out = h_input + h_residual

        return v_out, h_out

Conditional PixelCNN

Class-conditional generation için:

class ConditionalPixelCNN(nn.Module):
    def __init__(self, num_classes=10, num_colors=256, num_filters=128):
        super(ConditionalPixelCNN, self).__init__()

        # Class embedding
        self.class_embedding = nn.Embedding(num_classes, num_filters)

        # Rest of the network
        self.pixelcnn = PixelCNN(num_colors=num_colors)

    def forward(self, x, class_labels):
        # Inject class information
        class_emb = self.class_embedding(class_labels)
        class_emb = class_emb.unsqueeze(-1).unsqueeze(-1)
        class_emb = class_emb.expand(-1, -1, x.size(2), x.size(3))

        # Concatenate with input
        x_conditional = torch.cat([x, class_emb], dim=1)

        return self.pixelcnn(x_conditional)

# Conditional sampling
def conditional_sample(model, class_label, shape=(1, 3, 32, 32)):
    model.eval()
    device = next(model.parameters()).device

    sample = torch.zeros(shape).to(device)
    class_tensor = torch.tensor([class_label]).to(device)

    with torch.no_grad():
        for i in range(shape[2]):
            for j in range(shape[3]):
                logits = model(sample, class_tensor)
                # Rest of sampling logic...

    return sample

Performans Optimizasyonları

PixelCNN’i hızlandırmak için:

# 1. Caching intermediate results
class CachedPixelCNN(nn.Module):
    def __init__(self, base_model):
        super(CachedPixelCNN, self).__init__()
        self.base_model = base_model
        self.cache = {}

    def forward_cached(self, x, i, j):
        # Cache activations for faster sampling
        cache_key = (i, j)
        if cache_key in self.cache:
            return self.cache[cache_key]

        result = self.base_model(x)
        self.cache[cache_key] = result
        return result

# 2. Quantization
def quantize_model(model):
    quantized_model = torch.quantization.quantize_dynamic(
        model, {nn.Conv2d}, dtype=torch.qint8
    )
    return quantized_model

# 3. Model distillation
def distill_pixelcnn(teacher_model, student_model, dataloader):
    criterion = nn.KLDivLoss()
    optimizer = torch.optim.Adam(student_model.parameters())

    for data, _ in dataloader:
        with torch.no_grad():
            teacher_logits = teacher_model(data)

        student_logits = student_model(data)

        loss = criterion(
            F.log_softmax(student_logits, dim=1),
            F.softmax(teacher_logits, dim=1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Evaluation ve Metrics

def evaluate_pixelcnn(model, test_loader):
    model.eval()
    total_likelihood = 0
    num_samples = 0

    with torch.no_grad():
        for data, _ in test_loader:
            batch_likelihood = 0

            for img in data:
                img_likelihood = calculate_likelihood(model, img)
                batch_likelihood += img_likelihood

            total_likelihood += batch_likelihood
            num_samples += len(data)

    average_likelihood = total_likelihood / num_samples
    perplexity = torch.exp(-average_likelihood)

    return {
        'likelihood': average_likelihood,
        'perplexity': perplexity,
        'bits_per_dim': -average_likelihood / (3 * 32 * 32 * np.log(2))
    }

Son Sözler

PixelCNN, autoregressive generation’ın görüntü dünyasındaki güzel bir örneği. GANs kadar hızlı değil ama matematiksel olarak çok daha sağlam temellere sahip. Kesin likelihood hesaplaması yapabilmesi ve kararlı training süreci önemli avantajları.

Günümüzde Diffusion Models hem PixelCNN’in prensipli yaklaşımını hem de GANs’in kalitesini birleştirerek yeni bir standart oluşturdu. Ama PixelCNN’i anlamak autoregressive modeling’in temellerini öğrenmek için harika bir başlangıç.

Özellikle text generation’da kullandığımız transformer’lar da aynı autoregressive prensibini kullanıyor. Yani PixelCNN’i anlarsanız, günümüz AI’ın birçok alanında karşınıza çıkacak desenleri de anlamış olursunuz.

Pixel by pixel, step by step! 🎨📊



Previous Post
FLUX.1: Black Forest Labs'ın Yeni Text-to-Image AI Modeli
Next Post
GANs: Generator ve Discriminator ile Veri Üretimi