gamaly's picture
Update app.py
3c76e95 verified
Raw
History Blame
10.2 kB
"""Gradio app for Maritime Intelligence Classifier."""
import gradio as gr
from setfit import SetFitModel
from pathlib import Path
import os
# Try to load model from Hugging Face Hub first, then fall back to local
# Set MODEL_PATH environment variable or update this line with your Hugging Face repo ID
MODEL_PATH = os.getenv("MODEL_PATH", "gamaly/maritime-intelligence-classifier")
LOCAL_MODEL_PATH = "./maritime_classifier"
# Load model
print("Loading model...")
print(f"MODEL_PATH: {MODEL_PATH}")
print(f"LOCAL_MODEL_PATH: {LOCAL_MODEL_PATH}")
model = None
try:
# Check if MODEL_PATH is a Hugging Face repo (contains "/" and doesn't exist locally)
if "/" in MODEL_PATH and not Path(MODEL_PATH).exists():
print(f"Loading from Hugging Face Hub: {MODEL_PATH}")
model = SetFitModel.from_pretrained(MODEL_PATH)
print(f"✓ Successfully loaded model from Hugging Face: {MODEL_PATH}")
# Check if local model path exists
elif Path(LOCAL_MODEL_PATH).exists():
print(f"Loading from local path: {LOCAL_MODEL_PATH}")
model = SetFitModel.from_pretrained(LOCAL_MODEL_PATH)
print(f"✓ Successfully loaded model from local path: {LOCAL_MODEL_PATH}")
# If MODEL_PATH is a local path that exists
elif Path(MODEL_PATH).exists():
print(f"Loading from local path: {MODEL_PATH}")
model = SetFitModel.from_pretrained(MODEL_PATH)
print(f"✓ Successfully loaded model from local path: {MODEL_PATH}")
# Default: try MODEL_PATH as Hugging Face repo
else:
print(f"Attempting to load from Hugging Face Hub: {MODEL_PATH}")
model = SetFitModel.from_pretrained(MODEL_PATH)
print(f"✓ Successfully loaded model from Hugging Face: {MODEL_PATH}")
except Exception as e:
print(f"❌ Error loading model: {e}")
print(f" Attempted paths:")
print(f" - Hugging Face: {MODEL_PATH}")
print(f" - Local: {LOCAL_MODEL_PATH}")
import traceback
print("\nFull traceback:")
traceback.print_exc()
model = None
if model is None:
print("\n⚠️ WARNING: Model failed to load. The app will not work correctly.")
print(" Please check:")
print(f" 1. Model exists at: https://huggingface.co/{MODEL_PATH}")
print(" 2. Internet connection is available")
print(" 3. All dependencies are installed (setfit, sentence-transformers, etc.)")
else:
print("\n✅ Model loaded successfully! Ready for inference.")
def truncate_text(text, max_tokens=256):
"""
Truncate text to approximately max_tokens.
Uses a simple word-based approximation (roughly 1 token = 0.75 words).
"""
if not text:
return text
# Rough approximation: 1 token ≈ 0.75 words (conservative estimate)
max_words = int(max_tokens * 0.75)
words = text.split()
if len(words) <= max_words:
return text
# Truncate and add ellipsis
truncated = " ".join(words[:max_words])
return truncated + "... [truncated]"
def predict_text(text):
"""Predict whether text is actionable (YES) or not (NO)."""
if model is None:
return "Error: Model not loaded. Please check the console logs.", 0.0, "error"
if not text or not text.strip():
return "Please enter some text to classify.", 0.0, "neutral"
try:
# Note: SetFit uses the base model's max_length (256 tokens for all-MiniLM-L6-v2)
# The model will automatically truncate longer texts, but we can pre-truncate
# to ensure we're using the most relevant part (beginning of text)
# For longer articles, the beginning usually contains the most important info
# Check approximate length (rough estimate: 1 token ≈ 0.75 words)
word_count = len(text.split())
token_estimate = int(word_count / 0.75)
# If text is significantly longer than 256 tokens, truncate intelligently
# (SetFit will truncate anyway, but we can control which part)
if token_estimate > 300: # Give some buffer
# For news articles, the beginning usually has the key info
# But we could also try: beginning + end, or just beginning
processed_text = truncate_text(text, max_tokens=256)
print(f"⚠️ Text truncated from ~{token_estimate} tokens to ~256 tokens")
else:
processed_text = text
# Make prediction
prediction = model.predict([processed_text])[0]
# Get probabilities (handle version compatibility)
try:
probabilities = model.predict_proba([processed_text])[0]
confidence = probabilities[prediction] * 100
except AttributeError as e:
# Fallback if predict_proba fails due to version mismatch
# Use a simple confidence estimate based on prediction
print(f"Warning: predict_proba failed ({e}), using fallback confidence")
# For binary classification, we can estimate confidence from the decision function
# or just use a default high confidence
confidence = 85.0 # Default confidence when we can't get probabilities
# Convert to labels
label = "YES (Actionable)" if prediction == 1 else "NO (Not Actionable)"
# Determine status for styling
status = "actionable" if prediction == 1 else "not_actionable"
return label, confidence, status
except Exception as e:
error_msg = f"Error during prediction: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return error_msg, 0.0, "error"
def get_explanation(status):
"""Get explanation based on prediction status."""
explanations = {
"actionable": "✓ This text contains actionable vessel-specific evidence (e.g., specific vessel names, crimes, incidents).",
"not_actionable": "✗ This text does not contain actionable vessel-specific evidence (e.g., general maritime news, non-specific information).",
"error": "⚠️ An error occurred. Please check the model is properly loaded.",
"neutral": ""
}
return explanations.get(status, "")
# Create Gradio interface
# Note: theme parameter moved to launch() in Gradio 6.0+
with gr.Blocks(title="Maritime Intelligence Classifier") as app:
gr.Markdown(
"""
# 🚢 Maritime Intelligence Classifier
Classify maritime news articles as containing **actionable vessel-specific evidence** (YES) or not (NO).
**Actionable articles** typically include:
- Specific vessel names
- Specific crimes or incidents
- Evidence that can be used for investigation
**Non-actionable articles** are general maritime news without specific vessel details.
"""
)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Article Text",
placeholder="Paste or type the maritime news article text here...",
lines=10,
max_lines=20
)
submit_btn = gr.Button("Classify", variant="primary", size="lg")
with gr.Column(scale=1):
prediction_output = gr.Label(
label="Prediction",
value={"YES (Actionable)": 0.0, "NO (Not Actionable)": 0.0}
)
confidence_output = gr.Number(
label="Confidence",
value=0.0,
precision=1
)
explanation_output = gr.Markdown()
# Example texts
gr.Markdown("### 📝 Example Texts")
with gr.Row():
example_yes = gr.Examples(
examples=[
["The fishing vessel Marine 707 was involved in the disappearance of fisheries observer Samuel Abayateye in Ghanaian waters. The observer's decapitated body was found weeks later."],
["Authorities detained the Meng Xin 15 after discovering evidence of illegal saiko transshipment and threats against fisheries observers."],
],
inputs=text_input,
label="YES Examples (Actionable)"
)
example_no = gr.Examples(
examples=[
["A new maritime museum opened in the port city, showcasing historical ships and ocean exploration artifacts."],
["Marine scientists are studying the effects of ocean acidification on coral reefs in tropical waters."],
],
inputs=text_input,
label="NO Examples (Not Actionable)"
)
# Connect the prediction function
def update_prediction(text):
label, confidence, status = predict_text(text)
# Create label dict for gradio Label component
if status == "actionable":
label_dict = {"YES (Actionable)": confidence / 100, "NO (Not Actionable)": (100 - confidence) / 100}
elif status == "not_actionable":
label_dict = {"YES (Actionable)": (100 - confidence) / 100, "NO (Not Actionable)": confidence / 100}
else:
label_dict = {"YES (Actionable)": 0.0, "NO (Not Actionable)": 0.0}
explanation = get_explanation(status)
return label_dict, confidence, explanation
submit_btn.click(
fn=update_prediction,
inputs=text_input,
outputs=[prediction_output, confidence_output, explanation_output]
)
text_input.submit(
fn=update_prediction,
inputs=text_input,
outputs=[prediction_output, confidence_output, explanation_output]
)
gr.Markdown(
"""
---
### ℹ️ About
This classifier uses SetFit to identify maritime news articles containing actionable vessel-specific evidence.
Built for The Outlaw Ocean Project.
**Model**: SetFit (sentence-transformers/all-MiniLM-L6-v2 base)
"""
)
if __name__ == "__main__":
app.launch(share=False, theme=gr.themes.Soft())