"""Small ResNet that scores 512x512 grayscale handwriting on [0, 1]. Inputs follow MLX's NHWC convention: `[batch, height, width, channels]`. The forward pass returns raw logits; apply `mx.sigmoid` (or use BCE with `with_logits=True`) for a probability. """ from __future__ import annotations import mlx.core as mx import mlx.nn as nn class BasicBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, stride: int = 1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm(out_channels) if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False), nn.BatchNorm(out_channels), ) else: self.shortcut = nn.Identity() def __call__(self, x: mx.array) -> mx.array: residual = self.shortcut(x) out = nn.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) return nn.relu(out + residual) class QualityClassifier(nn.Module): def __init__(self): super().__init__() self.stem_conv = nn.Conv2d(1, 4, 7, stride=2, padding=3, bias=False) self.stem_bn = nn.BatchNorm(4) self.stem_pool = nn.MaxPool2d(3, stride=2, padding=1) self.stage1 = nn.Sequential(BasicBlock(4, 8, stride=2), BasicBlock(8, 8)) self.stage2 = nn.Sequential(BasicBlock(8, 16, stride=2), BasicBlock(16, 16)) self.stage3 = nn.Sequential(BasicBlock(16, 32, stride=2), BasicBlock(32, 32)) self.stage4 = nn.Sequential(BasicBlock(32, 32, stride=2), BasicBlock(32, 32)) self.head = nn.Linear(32, 1) def __call__(self, x: mx.array) -> mx.array: x = nn.relu(self.stem_bn(self.stem_conv(x))) x = self.stem_pool(x) x = self.stage1(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = x.mean(axis=(1, 2)) return self.head(x).squeeze(-1) def score(self, x: mx.array) -> mx.array: return mx.sigmoid(self(x))