from torch import nn
import torch
from modules import devices, shared
import ldm.models.diffusion.ddpm


class VectorAdjustPrior(nn.Module):
    def __init__(self, hidden_size, inter_dim=64):
        super().__init__()
        self.vector_proj = nn.Linear(hidden_size * 2, inter_dim, bias=True)
        self.out_proj = nn.Linear(hidden_size + inter_dim, hidden_size, bias=True)

    def forward(self, z):
        b, s = z.shape[0:2]
        x1 = torch.mean(z, dim=1).repeat(s, 1)
        x2 = z.reshape(b * s, -1)
        x = torch.cat((x1, x2), dim=1)
        x = self.vector_proj(x)
        x = torch.cat((x2, x), dim=1)
        x = self.out_proj(x)
        x = x.reshape(b, s, -1)
        return x

    @classmethod
    def load_model(cls, model_path, hidden_size=768, inter_dim=64):
        model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
        model.load_state_dict(torch.load(model_path)["state_dict"])

        return model


vap = VectorAdjustPrior.load_model('v2.pt').to(devices.device)


def get_learned_conditioning_with_prior(self, c):
    cond = ldm.models.diffusion.ddpm.LatentDiffusion.get_learned_conditioning_original(self, c)

    if shared.opts.use_prior:
        cond = vap(cond)

    return cond


if not hasattr(ldm.models.diffusion.ddpm.LatentDiffusion, 'get_learned_conditioning_original'):
    ldm.models.diffusion.ddpm.LatentDiffusion.get_learned_conditioning_original = ldm.models.diffusion.ddpm.LatentDiffusion.get_learned_conditioning
ldm.models.diffusion.ddpm.LatentDiffusion.get_learned_conditioning = get_learned_conditioning_with_prior

shared.options_templates.update(shared.options_section(('nai', "NAI"), {
    "use_prior": shared.OptionInfo(True, "use v2.pt"),
}))
Edit
Pub: 08 Oct 2022 21:59 UTC
Edit: 09 Oct 2022 06:48 UTC
Views: 12919