import torch from torchvision import models, transforms import gradio as gr if torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define the class labels (same as provided) class_labels = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)_Powdery_mildew', 'Cherry_(including_sour)_healthy', 'Corn_(maize)_Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)_Common_rust', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)_healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy', 'blast', 'blight', 'tungro'] # Step 1: Load the model from the saved .pth file def load_model(model_path): print("Device: ", device) model = models.resnet101(pretrained=False, num_classes=len(class_labels)) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() # Set model to evaluation mode return model # Step 2: Define the image preprocessing function def preprocess_image(image): transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize the image to match ResNet input size transforms.ToTensor(), # Convert the image to a tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet statistics ]) return transform(image).unsqueeze(0) # Add a batch dimension # Step 3: Define the prediction function def predict(image): input_tensor = preprocess_image(image).to(device) # Run the model on the preprocessed image with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) # Get top 5 predictions with probabilities top5_prob, top5_catid = torch.topk(probabilities, 5) predictions = {class_labels[i]: float(top5_prob[j]) for j, i in enumerate(top5_catid)} return predictions # Load your trained model model_path = 'model.pth' model = load_model(model_path) # Step 4: Define the Gradio interface interface = gr.Interface( fn=predict, # The prediction function inputs=gr.Image(type="pil", label="Upload Image"), # Accept an image input outputs=gr.Label(num_top_classes=5, label="Top 5 Predicted Classes"), # Display top 5 predictions title="Plant Disease Detection", description="Upload an image and the model will predict the plant disease." ) # Step 5: Launch the Gradio interface interface.launch()