| import torch, random |
| from torch import nn |
| from einops import rearrange |
|
|
| from stldm.submodules import * |
|
|
| class Down_Block(nn.Module): |
| def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32): |
| super(Down_Block, self).__init__() |
| self.block1 = ResnetBlock(dim=in_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups) |
| self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head))) |
| self.block2 = ResnetBlock(dim=hid_ch, dim_out=hid_ch, groups=num_groups) |
| |
| self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) |
| self.last = Downsample2D(dim_in=hid_ch, dim_out=out_ch) if not is_last else ChannelConversion(hid_ch, out_ch) |
| |
| def forward(self, x, time_emb, cond=None, relative_pos=None): |
| assert x.ndim==5 |
| B, T, C, H, W = x.shape |
|
|
| x = x.reshape(B*T, C, H, W) |
| if cond is None: |
| cond = torch.zeros_like(x) |
|
|
| time_emb = time_emb.unsqueeze(1) |
| time_emb = time_emb.repeat(1, T, 1) |
| time_emb = time_emb.reshape(B*T, -1) |
| |
| out = torch.cat((x, cond), dim=1) |
| out = self.block1(out, time_emb) |
| |
| spatial_attn = self.attn_spatial(out) |
| out = self.block2(spatial_attn, time_emb) |
| *_, c, h, w = out.shape |
| out = out.reshape(B,T,c,h,w) |
| |
| |
| temporal_attn = self.attn_temporal(out) |
| temporal_attn = temporal_attn.reshape(B*T,c,h,w) |
|
|
| out = self.last(temporal_attn) |
| *_, c, h, w = out.shape |
|
|
| return out.reshape(B, T, c, h, w), spatial_attn, temporal_attn |
|
|
| class MidBlock(nn.Module): |
| def __init__(self, in_ch, time_dim, num_groups=8, heads=4, dim_head=32): |
| super(MidBlock, self).__init__() |
| self.block1 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups) |
| self.qattn_spatial = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head))) |
| self.block2 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups) |
| |
| self.qattn_time = Residual(PreNorm(in_ch, TemporalAttention(dim=in_ch, heads=heads, dim_head=dim_head))) |
| self.block3 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups) |
|
|
| def forward(self, x, time_emb, relative_pos=None): |
| assert x.ndim==5 |
| B, T, C, H, W = x.shape |
| x = x.reshape(B*T, C, H, W) |
|
|
| time_emb = time_emb.unsqueeze(1) |
| time_emb = time_emb.repeat(1, T, 1) |
| time_emb = time_emb.reshape(B*T, -1) |
|
|
| out = self.block1(x, time_emb) |
| out = self.qattn_spatial(out) |
| out = self.block2(out, time_emb) |
|
|
| out = out.reshape((B, T, C, H, W)) |
| |
| out = self.qattn_time(out).reshape(B*T, C, H, W) |
| out = self.block3(out, time_emb) |
|
|
| *_, c, h, w = out.shape |
| return out.reshape(B, T, c, h, w) |
|
|
| class Up_Block(nn.Module): |
| def __init__(self, in_chs, hid_ch, out_ch, is_last, time_dim, patch_size=None, num_groups=8, heads=4, dim_head=32): |
| super(Up_Block, self).__init__() |
| in_ch, skip_ch = in_chs |
| self.up = Upsample2D(dim_in=in_ch, dim_out=hid_ch) if not is_last else ChannelConversion(in_ch, hid_ch) |
| self.attn_spatial = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head))) |
| self.block1 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups) |
| |
| self.attn_temporal = Residual(PreNorm(hid_ch, TemporalAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) |
| self.block2 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=out_ch, time_emb_dim=time_dim, groups=num_groups) |
| |
| def forward(self, x, time_emb, spatialattn_skip, tempattn_skip, relative_pos=None): |
| assert x.ndim==5 |
| B, T, C, H, W = x.shape |
| x = x.reshape(B*T, C, H, W) |
|
|
| time_emb = time_emb.unsqueeze(1) |
| time_emb = time_emb.repeat(1, T, 1) |
| time_emb = time_emb.reshape(B*T, -1) |
|
|
| out = self.up(x) |
| *_, c, h, w = out.shape |
| out = out.reshape(-1, T, c, h, w) |
| |
| |
| out = self.attn_temporal(out).reshape(B*T, c, h, w) |
| |
| out = torch.cat((out, tempattn_skip), dim=1) |
| out = self.block1(out, time_emb) |
| |
| out = self.attn_spatial(out) |
| |
| out = torch.cat((out, spatialattn_skip), dim=1) |
| out = self.block2(out, time_emb) |
| *_, c, h, w = out.shape |
| return out.reshape(B, T, c, h, w) |
|
|
| class LDM(nn.Module): |
| def __init__(self, in_ch, chs_mult:tuple, patch_size=None, num_groups=8, heads=4, dim_head=32, base_ch=64): |
| super(LDM, self).__init__() |
| |
| time_dim = 4*base_ch |
| fourier_dim = base_ch |
| self.time_mlp = Time_MLP(dim=base_ch, time_dim=time_dim, fourier_dim=fourier_dim) |
|
|
| ups, downs = [], [] |
| conditions = [] |
|
|
| layer_no = len(chs_mult) |
| chs = [in_ch, *map(lambda m: base_ch*m, chs_mult)] |
| ch_in, ch_out = chs[:-1], chs[1:] |
| up_in, up_out = list(reversed(ch_out)), list(reversed(ch_in)) |
| |
| patches = None if patch_size is None else [patch_size//(2**n) for n in range(layer_no)] |
| for n in range(layer_no): |
| downs.append( |
| Down_Block(in_ch=2*ch_in[n], hid_ch=ch_in[n], out_ch=ch_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[n], is_last=(n==layer_no-1), num_groups=num_groups, heads=heads, dim_head=dim_head) |
| ) |
| ups.append( |
| Up_Block(in_chs=(up_in[n], ch_in[-n-1]), hid_ch=up_in[n], out_ch=up_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[layer_no-n-1], is_last=(n==0), num_groups=num_groups, heads=heads, dim_head=dim_head) |
| ) |
| if n != -1: |
| conditions.append( |
| Downsample2D(dim_in=ch_in[n], dim_out=ch_out[n]) |
| ) |
|
|
| self.downs = nn.ModuleList(downs) |
| self.ups = nn.ModuleList(ups) |
| self.conditions = nn.ModuleList(conditions) |
| self.mid = MidBlock(in_ch=ch_out[-1], time_dim=time_dim, num_groups=num_groups, heads=heads, dim_head=dim_head) |
| |
| |
| def forward(self, x, time, conds=None): |
| t = self.time_mlp(time) |
|
|
| hid_spatial = [] |
| hid_temporal = [] |
|
|
| |
|
|
| for n, down_block in enumerate(self.downs): |
| |
| |
| x, spatial_attn, time_attn = down_block(x, t, conds) |
| hid_spatial.append(spatial_attn) |
| hid_temporal.append(time_attn) |
| if conds is not None: |
| conds = self.conditions[n](conds) |
| |
| |
| out = self.mid(x, t) |
|
|
| for up_block in self.ups: |
| |
| out = up_block(out, t, hid_spatial.pop(), hid_temporal.pop()) |
| |
| return out |
|
|
| |
| from collections import namedtuple |
| from torch.cuda.amp import autocast |
| import torch.nn.functional as F |
| from einops import reduce |
| from tqdm.auto import tqdm |
|
|
| ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) |
|
|
| def identity(t, *args, **kwargs): |
| return t |
|
|
| def extract(a, t, x_shape): |
| b, *_ = t.shape |
| out = a.gather(-1, t) |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
| def default(val, d): |
| if exists(val): |
| return val |
| return d() if callable(d) else d |
|
|
| def exists(x): |
| return x is not None |
|
|
| def guidance_scheduler(sampling_step: int, const: float): |
| return const*torch.ones(sampling_step) |
|
|
| from huggingface_hub import PyTorchModelHubMixin |
|
|
| class GaussianDiffusion( |
| nn.Module, |
| PyTorchModelHubMixin, |
| |
| repo_url="https://github.com/sqfoo/stldm_official", |
| pipeline_tag="Precipitation_Nowcasting", |
| license="mit"): |
| def __init__( |
| self, |
| vp_param: dict, |
| stldm_param: dict, |
| timesteps = 1000, |
| sampling_timesteps = None, |
| objective = 'pred_v', |
| beta_schedule = 'sigmoid', |
| schedule_fn_kwargs = dict(), |
| ddim_sampling_eta = 0., |
| offset_noise_strength = 0., |
| min_snr_loss_weight = False, |
| min_snr_gamma = 5 |
| ): |
| super(GaussianDiffusion, self).__init__() |
|
|
| self.backbone = SimVPV2_Model(**vp_param) |
| self.diff_unet = LDM(**stldm_param) |
|
|
| self.objective = objective |
| assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' |
|
|
| if beta_schedule == 'linear': |
| beta_schedule_fn = linear_beta_schedule |
| elif beta_schedule == 'cosine': |
| beta_schedule_fn = cosine_beta_schedule |
| elif beta_schedule == 'sigmoid': |
| beta_schedule_fn = sigmoid_beta_schedule |
| else: |
| raise ValueError(f'unknown beta schedule {beta_schedule}') |
|
|
| betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) |
|
|
| alphas = 1. - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) |
|
|
| timesteps, = betas.shape |
| self.num_timesteps = int(timesteps) |
|
|
| |
|
|
| self.sampling_timesteps = default(sampling_timesteps, timesteps) |
|
|
| assert self.sampling_timesteps <= timesteps |
| self.is_ddim_sampling = self.sampling_timesteps < timesteps |
| self.ddim_sampling_eta = ddim_sampling_eta |
|
|
| |
|
|
| register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) |
|
|
| register_buffer('betas', betas) |
| register_buffer('alphas_cumprod', alphas_cumprod) |
| register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) |
|
|
| |
|
|
| register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) |
| register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) |
| register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) |
| register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) |
| register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) |
|
|
| |
|
|
| posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) |
|
|
| |
|
|
| register_buffer('posterior_variance', posterior_variance) |
|
|
| |
|
|
| register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) |
| register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) |
| register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) |
|
|
| |
|
|
| self.offset_noise_strength = offset_noise_strength |
|
|
| |
| |
|
|
| snr = alphas_cumprod / (1 - alphas_cumprod) |
|
|
| |
|
|
| maybe_clipped_snr = snr.clone() |
| if min_snr_loss_weight: |
| maybe_clipped_snr.clamp_(max = min_snr_gamma) |
|
|
| if objective == 'pred_noise': |
| register_buffer('loss_weight', maybe_clipped_snr / snr) |
| elif objective == 'pred_x0': |
| register_buffer('loss_weight', maybe_clipped_snr) |
| elif objective == 'pred_v': |
| register_buffer('loss_weight', maybe_clipped_snr / (snr + 1)) |
|
|
| @property |
| def device(self): |
| return self.betas.device |
|
|
| |
| def setup_guidance(self, scheduler): |
| if exists(scheduler): |
| self.CFG_sch = scheduler.to(self.device) |
| else: |
| self.CFG_sch = scheduler |
|
|
| def predict_start_from_noise(self, x_t, t, noise): |
| return ( |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - |
| extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise |
| ) |
|
|
| def predict_noise_from_start(self, x_t, t, x0): |
| return ( |
| (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ |
| extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) |
| ) |
|
|
| def predict_v(self, x_start, t, noise): |
| return ( |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - |
| extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start |
| ) |
|
|
| def predict_start_from_v(self, x_t, t, v): |
| return ( |
| extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - |
| extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v |
| ) |
|
|
| def q_posterior(self, x_start, x_t, t): |
| posterior_mean = ( |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + |
| extract(self.posterior_mean_coef2, t, x_t.shape) * x_t |
| ) |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) |
| posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped |
|
|
| def model_predictions(self, x, t, cond, clip_x_start = False, rederive_pred_noise = False): |
| |
| if exists(self.CFG_sch): |
| uncond = self.diff_unet(x, t, conds=None) |
| model_output = self.diff_unet(x, t, conds=cond) |
| time = int(t[0]) |
| model_output = model_output - self.CFG_sch[time] * (uncond - model_output) |
| else: |
| model_output = self.diff_unet(x, t, conds=cond) |
| maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity |
|
|
| if self.objective == 'pred_noise': |
| pred_noise = model_output |
| x_start = self.predict_start_from_noise(x, t, pred_noise) |
| x_start = maybe_clip(x_start) |
|
|
| if clip_x_start and rederive_pred_noise: |
| pred_noise = self.predict_noise_from_start(x, t, x_start) |
|
|
| elif self.objective == 'pred_x0': |
| x_start = model_output |
| x_start = maybe_clip(x_start) |
| pred_noise = self.predict_noise_from_start(x, t, x_start) |
|
|
| elif self.objective == 'pred_v': |
| v = model_output |
| x_start = self.predict_start_from_v(x, t, v) |
| x_start = maybe_clip(x_start) |
| pred_noise = self.predict_noise_from_start(x, t, x_start) |
|
|
| return ModelPrediction(pred_noise, x_start) |
| |
| def p_mean_variance(self, x, t, cond=None, clip_denoised = True): |
| preds = self.model_predictions(x, t, cond=cond, clip_x_start=False,) |
| x_start = preds.pred_x_start |
|
|
| if clip_denoised: |
| x_start.clamp_(-1., 1.) |
|
|
| model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) |
| return model_mean, posterior_variance, posterior_log_variance, x_start |
|
|
| @torch.no_grad() |
| def p_sample(self, x, t: int, cond=None): |
| b, *_, device = *x.shape, self.device |
| batched_times = torch.full((b,), t, device = device, dtype = torch.long) |
| model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, cond=cond, clip_denoised = False) |
| noise = torch.randn_like(x) if t > 0 else 0. |
| pred_img = model_mean + (0.5 * model_log_variance).exp() * noise |
| return pred_img, x_start |
|
|
| @torch.no_grad() |
| def p_sample_loop(self, shape, cond=None, return_all_timesteps = False): |
| batch, device = shape[0], self.device |
|
|
| frames_pred = torch.randn(shape, device = device) |
| imgs = [frames_pred] |
|
|
| for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable=True): |
| frames_pred, _ = self.p_sample(frames_pred, t, cond=cond) |
| imgs.append(frames_pred) |
|
|
| ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1) |
| return ret |
|
|
| @torch.no_grad() |
| def ddim_sample(self, shape, cond=None, return_all_timesteps = False): |
| batch, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective |
| device = self.device |
| times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) |
| times = list(reversed(times.int().tolist())) |
| time_pairs = list(zip(times[:-1], times[1:])) |
|
|
| frames_pred = torch.randn(shape, device = device) |
| imgs = [frames_pred] |
|
|
| for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable=True): |
| time_cond = torch.full((batch,), time, device = device, dtype = torch.long) |
| pred_noise, x_start, *_ = self.model_predictions( |
| frames_pred, |
| time_cond, |
| cond = cond, |
| clip_x_start = False, |
| rederive_pred_noise = True |
| ) |
|
|
| if time_next < 0: |
| frames_pred = x_start |
| imgs.append(frames_pred) |
| continue |
|
|
| alpha = self.alphas_cumprod[time] |
| alpha_next = self.alphas_cumprod[time_next] |
|
|
| sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() |
| c = (1 - alpha_next - sigma ** 2).sqrt() |
|
|
| noise = torch.randn_like(frames_pred) |
|
|
| frames_pred = x_start * alpha_next.sqrt() + \ |
| c * pred_noise + \ |
| sigma * noise |
|
|
| imgs.append(frames_pred) |
|
|
| ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1) |
| return ret |
|
|
| @torch.no_grad() |
| def sample(self, frames_in, return_all_timesteps = False): |
| assert frames_in.ndim == 5 |
| B, T_in, C, H, W = frames_in.shape |
| device = self.device |
|
|
| backbone_output, conds, *_ = self.backbone(frames_in) |
| sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample |
|
|
| *_, c, h, w = conds.shape |
| tgt_shape = conds.reshape(B, -1, c, h, w).shape |
| ldm_pred = sample_fn( |
| tgt_shape, |
| cond=conds, |
| return_all_timesteps = return_all_timesteps |
| ) |
| |
| ldm_pred = rearrange(ldm_pred, 'b t c h w -> (b t) c h w') |
| frames_pred = self.backbone.vae.decode(ldm_pred) |
| frames_pred = rearrange(frames_pred, '(b t) c h w -> b t c h w', b=B) |
| return frames_pred, backbone_output |
|
|
| def predict(self, frames_in, compute_loss=False, **kwargs): |
| pred, mu = self.sample(frames_in=frames_in) |
| return pred, mu |
|
|
| def compute_loss(self, frames_in, frames_gt, validate=False): |
| compute_loss = True and (not validate) |
| B, T_in, C, H, W = frames_in.shape |
| T_out = frames_gt.shape[1] |
| device = frames_in.device |
|
|
| """ |
| Diffusion Loss |
| """ |
| backbone_output, conds = self.backbone(frames_in) |
| hid_gt, _ = self.backbone.vae.encode( |
| rearrange(frames_gt, 'b t c h w -> (b t) c h w') |
| ) |
| hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B) |
| t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long() |
| if random.random() > 0.85: |
| conds = None |
| diff_loss = self.p_losses(hid_gt.detach(), t, cond=conds) |
|
|
| """ |
| Backbone Loss |
| """ |
| mu_loss = self.backbone._losses_(frames_in, frames_gt) |
|
|
| """ |
| VAE Loss |
| """ |
| ae_loss, kl_loss = self.backbone.vae._losses_( |
| rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w'), |
| rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w') |
| ) |
| kl_weight = 1E-6 |
| recon_loss = ae_loss + kl_weight*kl_loss |
|
|
| """ |
| Prior Loss at t=T [Noisy] |
| """ |
| hid_gt, _ = self.backbone.vae.encode( |
| rearrange(frames_gt, 'b t c h w -> (b t) c h w') |
| ) |
| hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B) |
| T = torch.ones((B,), device=self.device).long() * (self.num_timesteps - 1) |
| mu_noisy = extract(self.sqrt_alphas_cumprod, T, hid_gt.shape) * hid_gt |
| sigma_noisy = extract(self.sqrt_one_minus_alphas_cumprod, T, hid_gt.shape) |
| log_var_noisy = 2*torch.log(sigma_noisy) |
| prior_loss = self.kl_from_standard_normal(mu_noisy, log_var_noisy) |
| |
| return recon_loss, mu_loss, diff_loss, prior_loss |
| |
|
|
| def kl_from_standard_normal(self, mean, log_var): |
| kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var) |
| return kl.mean() |
|
|
| @autocast(enabled = False) |
| def q_sample(self, x_start, t, noise = None): |
| noise = default(noise, lambda: torch.randn_like(x_start)) |
|
|
| return ( |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + |
| extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise |
| ) |
|
|
| def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, cond=None): |
| b, T, c, h, w = x_start.shape |
|
|
| noise = default(noise, lambda: torch.randn_like(x_start)) |
|
|
| |
| offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength) |
|
|
| if offset_noise_strength > 0.: |
| offset_noise = torch.randn(x_start.shape[:2], device = self.device) |
| noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1') |
|
|
| |
| x = self.q_sample(x_start=x_start, t=t, noise=noise) |
|
|
| model_out = self.diff_unet(x, t, conds=cond) |
|
|
| if self.objective == 'pred_noise': |
| target = noise |
| elif self.objective == 'pred_x0': |
| target = x_start |
| elif self.objective == 'pred_v': |
| v = self.predict_v(x_start, t, noise) |
| target = v |
| else: |
| raise ValueError(f'unknown objective {self.objective}') |
|
|
| loss = F.mse_loss(model_out, target, reduction = 'none') |
| loss = reduce(loss, 'b ... -> b', 'mean') |
|
|
| loss = loss * extract(self.loss_weight, t, loss.shape) |
| return loss.mean() |
|
|
| @torch.no_grad() |
| def forward(self, input_x, include_mu=False, **kwargs): |
| pred, mu = self.predict(input_x, compute_loss=False) |
| if include_mu: |
| return pred, mu |
| else: |
| return pred |
|
|
| from stldm.modules import SimVPV2_Model, VAE |
| def model_setup(model_config, print_info=False, cfg_str=None): |
| if print_info: |
| print('Setup the model with considering temporal attention be (BHW, T, C) ... ...') |
| print('Train it from end to end') |
| vp_config = model_config['vp_param'] |
| ldm_config = model_config['stldm_param'] |
|
|
| vpm = SimVPV2_Model(**vp_config) |
| ldm = LDM(**ldm_config) |
| model = GaussianDiffusion(vp_model=vpm, diffusion=ldm, **model_config['param']) |
| |
| scheduler = guidance_scheduler(sampling_step=model_config['param']['timesteps'], const=cfg_str) if cfg_str is not None else None |
| model.setup_guidance(scheduler) |
|
|
| return model |
|
|
| def ae_setup(model_config): |
| vp_config = model_config['vp_param'] |
| vpm = SimVPV2_Model(**vp_config) |
| ae = vpm.vae |
| return ae |
|
|
| def backbone_setup(model_config): |
| vp_config = model_config['vp_param'] |
| vpm = SimVPV2_Model(**vp_config) |
| return vpm |