treadon's picture
Upload nucleus_image/scheduler.py with huggingface_hub
a8441ec verified
Raw
History Blame
883 Bytes
"""Flow Matching Euler Discrete Scheduler for Nucleus-Image."""
import mlx.core as mx
class FlowMatchEulerScheduler:
def __init__(self, shift: float = 1.0, num_train_timesteps: int = 1000):
self.shift = shift
self.num_train_timesteps = num_train_timesteps
self.sigmas = None
self.timesteps = None
def set_timesteps(self, num_inference_steps: int):
sigmas = mx.linspace(1.0, 0.0, num_inference_steps + 1)
if self.shift != 1.0:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas[:-1] * self.num_train_timesteps
def step(self, model_output, timestep_idx: int, sample):
sigma = self.sigmas[timestep_idx]
sigma_next = self.sigmas[timestep_idx + 1]
dt = sigma_next - sigma
return sample + dt * model_output