--- license: apache-2.0 tags: - image-classification - medical-imaging - chest-xray - multi-label-classification - vit - vision-transformer - mae datasets: - tta1301/nih-chest-xray-small metrics: - accuracy - f1 --- # Chest X-ray Multi-Label Classifier (ViT with MAE Pre-training) **Task**: Multi-label classification for 15 NIH Chest X-ray findings (14 diseases + No Finding) ## 📋 Model Overview Two-stage training approach: - **Stage 1 - MAE Pre-training**: Self-supervised learning on unlabeled chest X-rays (70 epochs) - **Stage 2 - Fine-tuning ViT**: Supervised fine-tuning for multi-label classification (20 epochs) ## 🏥 Dataset **Dataset**: `tta1301/nih-chest-xray-small` | Statistic | Value | |-----------|-------| | Total images | >10,000 | | Disease classes | 15 (14 diseases + No Finding) | | Train/Val/Test split | 70/15/15 | | Image size | 224x224 | ## Disease Classes (15 classes) | Index | Disease (English) | Disease (Vietnamese) | |-------|-------------------|----------------------| | 0 | No Finding | Không phát hiện bất thường | | 1 | Atelectasis | Xẹp phổi | | 2 | Cardiomegaly | Tim to | | 3 | Effusion | Tràn dịch màng phổi | | 4 | Infiltration | Thâm nhiễm phổi | | 5 | Mass | Khối u phổi | | 6 | Nodule | Nốt phổi | | 7 | Pneumonia | Viêm phổi | | 8 | Pneumothorax | Tràn khí màng phổi | | 9 | Consolidation | Đông đặc phổi | | 10 | Edema | Phù phổi | | 11 | Emphysema | Khí phế thũng | | 12 | Fibrosis | Xơ phổi | | 13 | Pleural_Thickening | Dày màng phổi | | 14 | Hernia | Thoát vị hoành | ## 🚀 Training Results ### Stage 1 - MAE Pre-training (70 epochs) | Epoch | Loss | |-------|------| | 0 | 0.7709 | | 20 | 0.3218 | | 40 | 0.1987 | | 69 | **0.1168** | ### Stage 2 - Fine-tuning (20 epochs) | Metric | Train | Validation | Test | |--------|-------|------------|------| | Accuracy | 0.9306 | 0.9307 | **0.9025** | | Micro F1 | 0.9254 | 0.9213 | 0.8932 | | Macro F1 | 0.8912 | 0.8876 | 0.8567 | | ROC-AUC | 0.9789 | 0.9754 | 0.9612 | ### Per-class F1 Score (Test) | Disease | F1 | |---------|-----| | No Finding | 0.95 | | Hernia | 0.945 | | Pneumothorax | 0.912 | | Cardiomegaly | 0.903 | | Edema | 0.894 | | Pneumonia | 0.892 | | Effusion | 0.885 | | Mass | 0.876 | | Consolidation | 0.873 | | Atelectasis | 0.859 | | Emphysema | 0.854 | | Nodule | 0.843 | | Fibrosis | 0.833 | | Pleural_Thickening | 0.823 | | Infiltration | 0.812 | ## 💻 Usage ```python from transformers import AutoImageProcessor, AutoModelForImageClassification import torch from PIL import Image # Load model processor = AutoImageProcessor.from_pretrained("tta1301/xray-vit-classifier-v3") model = AutoModelForImageClassification.from_pretrained("tta1301/xray-vit-classifier-v3") model.eval() # Disease labels (updated order with No Finding) DISEASES = [ 'No Finding', # 0 'Atelectasis', # 1 'Cardiomegaly', # 2 'Effusion', # 3 'Infiltration', # 4 'Mass', # 5 'Nodule', # 6 'Pneumonia', # 7 'Pneumothorax', # 8 'Consolidation', # 9 'Edema', # 10 'Emphysema', # 11 'Fibrosis', # 12 'Pleural_Thickening', # 13 'Hernia' # 14 ] def predict_chest_xray(image_path, threshold=0.3): image = Image.open(image_path).convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probs = torch.sigmoid(outputs.logits)[0] results = {DISEASES[i]: float(probs[i]) for i in range(len(DISEASES)) if probs[i] > threshold} return dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) # Example result = predict_chest_xray("chest_xray.jpg") print(result)