| from transformers import PretrainedConfig |
|
|
| class DiffAEConfig(PretrainedConfig): |
| model_type = "DiffAE" |
| def __init__(self, |
| is3D=True, |
| in_channels=1, |
| out_channels=1, |
| latent_dim=128, |
| net_ch=32, |
| sample_every_batches=1000, |
| sample_size=4, |
| test_with_TEval=True, |
| ampmode="16-mixed", |
| grey2RGB=-1, |
| test_emb_only=True, |
| test_ema=True, |
| batch_size=9, |
| |
| |
| data_name="ukbb", |
| diffusion_type = 'beatgans', |
| |
| |
| lr=0.0001, |
| |
| |
| |
| |
| |
| |
| seed=1701, |
| input_shape=(50, 128, 128), |
| |
| **kwargs): |
| self.is3D = is3D |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.latent_dim = latent_dim |
| self.net_ch = net_ch |
| self.sample_every_batches = sample_every_batches |
| self.sample_size = sample_size |
| self.test_with_TEval = test_with_TEval |
| self.ampmode = ampmode |
| self.grey2RGB = grey2RGB |
| self.test_emb_only = test_emb_only |
| self.test_ema = test_ema |
| self.batch_size = batch_size |
| self.data_name = data_name |
| self.diffusion_type = diffusion_type |
| self.lr = lr |
| self.seed = seed |
| self.input_shape = input_shape |
| super().__init__(**kwargs) |