Alzheimer_UI / flask_app.py
ak0601's picture
Update flask_app.py
513d71c verified
raw
history blame
2.64 kB
from flask import Flask, render_template, request
import requests
import json
app = Flask(__name__)
# API URL
API_URL = "https://ak0601-et-alzheimer.hf.space/predict"
# Feature names in order
FEATURE_NAMES = [
"ROI",
"nFixations",
"nTobiiSaccades",
"regSaccades",
"longSaccades",
"tinySaccades",
"saccadeTotLength",
"totalFixTime",
"totalSpokenTime",
"speechDelay",
"endSpeechDelay",
"startPupL",
"startPupR",
"endPupL",
"endPupR",
"diffPupL",
"diffPupR"
]
@app.route('/', methods=['GET', 'POST'])
def index():
prediction_result = None
input_values = {}
error_message = None
if request.method == 'POST' and 'predict_btn' in request.form:
try:
# Collect and convert features
features = []
for name in FEATURE_NAMES:
val = request.form.get(name)
input_values[name] = val # Keep for re-populating form
features.append(float(val))
# Prepare payload
payload = {"features": features}
# Call API
response = requests.post(API_URL, json=payload)
if response.status_code == 200:
result = response.json()
# Determine class label (optional mapping)
class_map = {0: "Control/Healthy", 1: "MCI (Mild Cognitive Impairment)", 2: "Alzheimer's Disease"}
predicted_class_idx = result.get("predicted_class")
predicted_label = class_map.get(predicted_class_idx, f"Class {predicted_class_idx}")
prediction_result = {
"class_index": predicted_class_idx,
"label": predicted_label,
"confidence": f"{result.get('confidence', 0):.2%}",
"probabilities": result.get("probabilities")
}
else:
error_message = f"API Error: {response.status_code} - {response.text}"
except ValueError:
error_message = "Invalid input: Please ensure all fields contain numeric values."
except requests.exceptions.ConnectionError:
error_message = "Connection Error: Could not connect to the prediction API."
except Exception as e:
error_message = f"An error occurred: {str(e)}"
return render_template('index.html', feature_names=FEATURE_NAMES, result=prediction_result, inputs=input_values, error=error_message)
if __name__ == '__main__':
app.run(debug=True, port=7860) # Use 8080 to avoid conflict with default 5000 if needed