from dataclasses import dataclass from typing import Optional import math import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.utils import ModelOutput from .configuration_upscaler import UpscalerConfig # ------------------------- # Architecture alignée avec le script d'entraînement # ------------------------- class ResidualBlock(nn.Module): def __init__(self, channels: int, negative_slope: float = 0.1): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.act = nn.LeakyReLU(negative_slope=negative_slope, inplace=True) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) def forward(self, x): y = self.act(self.conv1(x)) y = self.conv2(y) return x + y class RestorationNet(nn.Module): def __init__(self, in_channels=3, width=32, num_blocks=3, negative_slope: float = 0.1): super().__init__() self.in_conv = nn.Conv2d(in_channels, width, 3, padding=1) self.in_act = nn.LeakyReLU(negative_slope=negative_slope, inplace=True) self.blocks = nn.Sequential( *[ResidualBlock(width, negative_slope=negative_slope) for _ in range(num_blocks)] ) self.out_conv = nn.Conv2d(width, in_channels, 3, padding=1) def forward(self, lr): y = self.in_act(self.in_conv(lr)) y = self.blocks(y) y = self.out_conv(y) return lr + y class ESPCNExactUpsampler(nn.Module): def __init__( self, scale_factor: int = 4, num_channels: int = 3, channels: int = 64, ): super().__init__() self.scale_factor = scale_factor hidden_channels = channels // 2 self.feature_maps = nn.Sequential( nn.Conv2d(num_channels, channels, kernel_size=5, stride=1, padding=2), nn.Tanh(), nn.Conv2d(channels, hidden_channels, kernel_size=3, stride=1, padding=1), nn.Tanh(), ) self.sub_pixel = nn.Sequential( nn.Conv2d( hidden_channels, num_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, ), nn.PixelShuffle(scale_factor), ) self._initialize_weights(num_channels) def _initialize_weights(self, num_channels: int): last_out_channels = num_channels * (self.scale_factor ** 2) for module in self.modules(): if isinstance(module, nn.Conv2d): if module.out_channels == last_out_channels: nn.init.normal_(module.weight.data, mean=0.0, std=0.001) else: fan = module.out_channels * module.kernel_size[0] * module.kernel_size[1] nn.init.normal_(module.weight.data, mean=0.0, std=math.sqrt(2.0 / fan)) if module.bias is not None: nn.init.zeros_(module.bias.data) def forward(self, x): x = self.feature_maps(x) x = self.sub_pixel(x) return x @dataclass class UpscalerOutput(ModelOutput): sr: torch.FloatTensor restored: Optional[torch.FloatTensor] = None class UpscalerModel(PreTrainedModel): config_class = UpscalerConfig main_input_name = "pixel_values" def __init__(self, config: UpscalerConfig): super().__init__(config) # IMPORTANT : # on expose directement les sous-modules pour que les clés # correspondent au checkpoint Lightning : # restoration.* / upsampler.* self.scale = config.scale self.restoration = RestorationNet( in_channels=config.in_channels, width=config.width, num_blocks=config.num_blocks, ) self.upsampler = ESPCNExactUpsampler( scale_factor=config.scale, num_channels=config.in_channels, channels=config.espcn_channels, ) self.post_init() def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput: restored = self.restoration(pixel_values) sr = self.upsampler(restored) return UpscalerOutput(sr=sr, restored=restored)