File size: 883 Bytes
a8441ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""Flow Matching Euler Discrete Scheduler for Nucleus-Image."""

import mlx.core as mx


class FlowMatchEulerScheduler:
    def __init__(self, shift: float = 1.0, num_train_timesteps: int = 1000):
        self.shift = shift
        self.num_train_timesteps = num_train_timesteps
        self.sigmas = None
        self.timesteps = None

    def set_timesteps(self, num_inference_steps: int):
        sigmas = mx.linspace(1.0, 0.0, num_inference_steps + 1)
        if self.shift != 1.0:
            sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
        self.sigmas = sigmas
        self.timesteps = sigmas[:-1] * self.num_train_timesteps

    def step(self, model_output, timestep_idx: int, sample):
        sigma = self.sigmas[timestep_idx]
        sigma_next = self.sigmas[timestep_idx + 1]
        dt = sigma_next - sigma
        return sample + dt * model_output