使用JojoGAN创建风格化的面部图

磐创AI

    
    介绍
    风格迁移是神经网络的一个发展领域,它是一个非常有用的功能,可以集成到社交媒体和人工智能应用程序中。几个神经网络可以根据训练数据将图像样式映射和传输到输入图像。在本文中,我们将研究 JojoGAN,以及仅使用一种参考样式来训练和生成具有该样式的任何图像的过程。
    JoJoGAN:One Shot Face Stylization
    One Shot Face Stylization(一次性面部风格化)可用于 AI 应用程序、社交媒体过滤器、有趣的应用程序和业务用例。随着 AI 生成的图像和视频滤镜的日益普及,以及它们在社交媒体和短视频、图像中的使用,一次性面部风格化是一个有用的功能,应用程序和社交媒体公司可以将其集成到最终产品中。
    因此,让我们来看看用于一次性生成人脸样式的流行 GAN 架构——JojoGAN。
    JojoGAN 架构
    JojoGAN 是一种风格迁移程序,可让将人脸图像的风格迁移为另一种风格。它通过GAN将参考风格图像反转为近似的配对训练数据,根据风格化代码生成真实的人脸图像,并与参考风格图像相匹配。然后将该数据集用于微调 StyleGAN,并且可以使用新的输入图像,JojoGAN 将根据 GAN 反转(inversion)将其转换为该特定样式。
    
    JojoGAN 架构和工作流程
    JojoGAN 只需一种参考风格即可在很短的时间内(不到 1 分钟)进行训练,并生成高质量的风格化图像。
    JojoGan 的一些例子
    JojoGAN 生成的风格化图像的一些示例:
    
    风格化的图像可以在各种不同的输入风格上生成并且可以修改。
    JojoGan 代码深潜
    让我们看看 JojoGAN 生成风格化人像的实现。有几个预训练模型可用,它们可以在我们的风格图像上进行训练,或者可以修改模型以在几分钟内更改风格。
    JojoGAN 的设置和导入
    克隆 JojoGAN 存储库并导入必要的库。在 Google Colab 存储中创建一些文件夹,用于存储反转代码、样式图像和模型。
    !git clone https://github.com/mchong6/JoJoGAN.git
    %cd JoJoGAN
    !pip install tqdm gdown scikit-learn==0.22 scipy lpips dlib opencv-python wandb
    !wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
    !sudo unzip ninja-linux.zip -d /usr/local/bin/
    import torch
    torch.backends.cudnn.benchmark = True
    from torchvision import transforms, utils
    from util import *
    from PIL import Image
    import math
    import random
    import os
    import numpy
    from torch import nn, autograd, optim
    from torch.nn import functional
    from tqdm import tqdm
    import wandb
    from model import *
    from e4e_projection import projection
    from google.colab import files
    from copy import deepcopy
    from pydrive.auth import GoogleAuth
    from pydrive.drive import GoogleDrive
    from google.colab import auth
    from oauth2client.client import GoogleCredentials
    模型文件
    使用 Pydrive 下载模型文件。一组驱动器 ID 可用于预训练模型。这些预训练模型可用于随时随地生成风格化图像,并具有不同的准确度。之后,可以训练用户创建的模型。
    #Download models
    #optionally enable downloads with pydrive in order to authenticate and avoid drive download limits.
    download_with_pydrive = True  
    device = 'cuda' #['cuda', 'cpu']
    !wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
    !bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2
    !mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat
    %matplotlib inline
    drive_ids = {
       "stylegan2-ffhq-config-f.pt": "1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK",
       "e4e_ffhq_encode.pt": "1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7",
       "restyle_psp_ffhq_encode.pt": "1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd",
       "arcane_caitlyn.pt": "1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc",
       "arcane_caitlyn_preserve_color.pt": "1cUTyjU-q98P75a8THCaO545RTwpVV-aH",
       "arcane_jinx_preserve_color.pt": "1jElwHxaYPod5Itdy18izJk49K1nl4ney",
       "arcane_jinx.pt": "1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_",
       "arcane_multi_preserve_color.pt": "1enJgrC08NpWpx2XGBmLt1laimjpGCyfl",
       "arcane_multi.pt": "15V9s09sgaw-zhKp116VHigf5FowAy43f",
       "sketch_multi.pt": "1GdaeHGBGjBAFsWipTL0y-ssUiAqk8AxD",
       "disney.pt": "1zbE2upakFUAx8ximYnLofFwfT8MilqJA",
       "disney_preserve_color.pt": "1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi",
       "jojo.pt": "13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4",
       "jojo_preserve_color.pt": "1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2",
       "jojo_yasuho.pt": "1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_",
       "jojo_yasuho_preserve_color.pt": "1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L",
       "art.pt": "1a0QDEHwXQ6hE_FcYEyNMuv5r5UnRQLKT",
    }
    # from StyelGAN-NADA
    class Downloader(object):
       def __init__(self, use_pydrive):
           self.use_pydrive = use_pydrive
           if self.use_pydrive:
               self.authenticate()
       def authenticate(self):
           auth.authenticate_user()
           gauth = GoogleAuth()
           gauth.credentials = GoogleCredentials.get_application_default()
           self.drive = GoogleDrive(gauth)
       def download_file(self, file_name):
           file_dst = os.path.join('models', file_name)
           file_id = drive_ids[file_name]
           if not os.path.exists(file_dst):
               print(f'Downloading {file_name}')
               if self.use_pydrive:
                   downloaded = self.drive.CreateFile({'id':file_id})
                   downloaded.FetchMetadata(fetch_all=True)
                   downloaded.GetContentFile(file_dst)
               else:
                   !gdown --id $file_id -O $file_dst
    downloader = Downloader(download_with_pydrive)
    downloader.download_file('stylegan2-ffhq-config-f.pt')
    downloader.download_file('e4e_ffhq_encode.pt')
    加载生成器
    加载原始和微调生成器。设置用于调整图像大小和规范化图像的 transforms。
    latent_dim = 512
    # Load original generator
    original_generator = Generator(1024, latent_dim, 8, 2).to(device)
    ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
    original_generator.load_state_dict(ckpt["g_ema"], strict=False)
    mean_latent = original_generator.mean_latent(10000)
    # to be finetuned generator
    generator = deepcopy(original_generator)
    transform = transforms.Compose(
       [
           transforms.Resize((1024, 1024)),
           transforms.ToTensor(),
           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
       ]
    )
    输入图像
    设置输入图像位置。对齐和裁剪面并重新设置映射的样式。
    #image to the test_input directory and put the name here
    filename = 'face.jpeg' #@param {type:"string"}
    filepath = f'test_input/{filename}'
    name = strip_path_extension(filepath)+'.pt'
    # aligns and crops face
    aligned_face = align_face(filepath)
    # my_w = restyle_projection(aligned_face, name, device, n_iters=1).unsqueeze(0)
    my_w = projection(aligned_face, name, device).unsqueeze(0)
    
    预训练图
    选择预训练好的图类型,选择不保留颜色的检查点,效果更好。
    plt.rcParams['figure.dpi'] = 150
    pretrained = 'sketch_multi' #['art', 'arcane_multi', 'sketch_multi', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']
    #Preserve color tries to preserve color of original image by limiting family of allowable transformations.
    if preserve_color:
       ckpt = f'{pretrained}_preserve_color.pt'
    else:
       ckpt = f'{pretrained}.pt'
    生成结果
    加载检查点和生成器并设置种子值,然后开始生成风格化图像。用于 Elon Musk 的输入图像将根据图类型进行风格化。
    #Generate results
    n_sample =  5#{type:"number"}
    seed = 3000 #{type:"number"}
    torch.manual_seed(seed)
    with torch.no_grad():
       generator.eval()
       z = torch.randn(n_sample, latent_dim, device=device)
       original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)
       sample = generator([z], truncation=0.7, truncation_latent=mean_latent)
       original_my_sample = original_generator(my_w, input_is_latent=True)
       my_sample = generator(my_w, input_is_latent=True)
    # display reference images
    if pretrained == 'arcane_multi':
       style_path = f'style_images_aligned/arcane_jinx.png'
    elif pretrained == 'sketch_multi':
       style_path = f'style_images_aligned/sketch.png'
    else:   
       style_path = f'style_images_aligned/{pretrained}.png'
    style_image = transform(Image.open(style_path)).unsqueeze(0).to(device)
    face = transform(aligned_face).unsqueeze(0).to(device)
    my_output = torch.cat([style_image, face, my_sample], 0)
    
    生成的结果
    结果生成为预先训练的类型“Jojo”,看起来相当准确。
    现在让我们看一下在自创样式上训练 GAN。
    使用你的风格图像进行训练
    选择一些面部图,甚至创建一些自己的面部图并加载这些图像以训练 GAN,并设置路径。裁剪和对齐人脸并执行 GAN 反转。
    names = ['1.jpg', '2.jpg', '3.jpg']
    targets = []
    latents = []
    for name in names:
       style_path = os.path.join('style_images', name)
       assert os.path.exists(style_path), f"{style_path} does not exist!"
       name = strip_path_extension(name)
       # crop and align the face
       style_aligned_path = os.path.join('style_images_aligned', f'{name}.png')
       if not os.path.exists(style_aligned_path):
           style_aligned = align_face(style_path)
           style_aligned.save(style_aligned_path)
       else:
           style_aligned = Image.open(style_aligned_path).convert('RGB')
       # GAN invert
       style_code_path = os.path.join('inversion_codes', f'{name}.pt')
       if not os.path.exists(style_code_path):
           latent = projection(style_aligned, style_code_path, device)
       else:
           latent = torch.load(style_code_path)['latent']
       latents.append(latent.to(device))
    targets = torch.stack(targets, 0)
    latents = torch.stack(latents, 0)
    
    微调 StyleGAN
    通过调整 alpha、颜色保留和设置迭代次数来微调 StyleGAN。加载感知损失的鉴别器并重置生成器。
    #Finetune StyleGAN
    #alpha controls the strength of the style
    alpha =  1.0 # min:0, max:1, step:0.1
    alpha = 1-alpha
    #preserve color of original image by limiting family of allowable transformations
    preserve_color = False 
    #Number of finetuning steps.
    num_iter = 300
    #Log training on wandb and interval for image logging
    use_wandb = False 
    log_interval = 50
    if use_wandb:
       wandb.init(project="JoJoGAN")
       config = wandb.config
       config.num_iter = num_iter
       config.preserve_color = preserve_color
       wandb.log(
       {"Style reference": [wandb.Image(transforms.ToPILImage()(target_im))]},
       step=0)
    # load discriminator for perceptual loss
    discriminator = Discriminator(1024, 2).eval().to(device)
    ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
    discriminator.load_state_dict(ckpt["d"], strict=False)
    # reset generator
    del generator
    generator = deepcopy(original_generator)
    g_optim = optim.Adam(generator.parameters(), lr=2e-3, betas=(0, 0.99))
    训练生成器从潜在空间生成图像,并优化损失。
    if preserve_color:
       id_swap = [9,11,15,16,17]
    z = range(numiter)
    for idx in tqdm( z):
       mean_w = generator.get_latent(torch.randn([latents.size(0), latent_dim]).to(device)).unsqueeze(1).repeat(1, generator.n_latent, 1)
       
    in_latent = latents.clone()
       in_latent[:, id_swap] = alpha*latents[:, id_swap] + (1-alpha*mean_w[:, id_swap]
       img = generator(in_latent, input_is_latent=True)
       with torch.no_grad():
           real_feat = discriminator(targets)
        
       fake_feat = discriminator(img)
       loss = sum([functional.l1_loss(a, b) for a, b in zip(fake_feat, real_feat)])/len(fake_feat)  
        
       if use_wandb:
           wandb.log({"loss": loss}, step=idx)
           if idx % log_interval == 0:
               generator.eval()
               my_sample = generator(my_w, input_is_latent=True)
               generator.train()
               wandb.log(
               {"Current stylization": [wandb.Image(my_sample)]},
               step=idx)
       g_optim.zero_grad()
       loss.backward()
       g_optim.step()
    使用 JojoGAN 生成结果
    现在生成结果。下面已经为原始图像和示例图像生成了结果以进行比较。
    #Generate resultsn_sample =  5
    seed = 3000
    torch.manual_seed(seed)
    with torch.no_grad():
       generator.eval()
       z = torch.randn(n_sample, latent_dim, device=device)
       original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)
       sample = generator([z], truncation=0.7, truncation_latent=mean_latent)
       original_my_sample = original_generator(my_w, input_is_latent=True)
       my_sample = generator(my_w, input_is_latent=True)
    # display reference images
    style_images = []
    for name in names:
       style_path = f'style_images_aligned/{strip_path_extension(name)}.png'
       style_image = transform(Image.open(style_path))
       style_images.append(style_image)
    face = transform(aligned_face).to(device).unsqueeze(0)
    style_images = torch.stack(style_images, 0).to(device)
    my_output = torch.cat([face, my_sample], 0)
    output = torch.cat([original_sample, sample], 0)
    
    
    生成的结果
    现在,你可以使用 JojoGAN 生成你自己风格的图像。结果令人印象深刻,但可以通过调整训练方法和训练图像中的更多特征来进一步改进。
    结论
    JojoGAN 能够以快速有效的方式准确地映射和迁移用户定义的样式。关键要点是:
    · JojoGAN 可以只用一种风格进行训练,以轻松映射并创建任何面部的风格化图
    · JojoGAN 非常快速有效,可以在不到一分钟的时间内完成训练
    · 结果非常准确,类似于逼真的肖像
    · JojoGAN 可以轻松微调和修改,使其适用于 AI 应用程序
    因此,无论风格类型、形状和颜色如何,JojoGAN 都是用于风格转移的理想神经网络,因此可以成为各种社交媒体应用程序和 AI 应用程序中非常有用的功能。