from fastapi import FastAPI from pydantic import BaseModel from sentence_transformers import CrossEncoder import fugashi import numpy as np from fastapi.middleware.cors import CORSMiddleware app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # モデルのロード(起動時に1回だけ) nli_model = CrossEncoder("akiFQC/bert-base-japanese-v3_nli-jsnli-jnli-jsick") tagger = fugashi.Tagger() # --- 新規追加: 判定用の定数と辞書 --- SIMILARITY_THRESHOLD = 0.85 ALLOWED_POS_GROUPS = [ {"名詞"}, {"形容詞", "形容動詞", "形状詞"}, {"動詞", "名詞"}, {"副詞", "名詞"}, ] HONORIFIC_PREFIXES = {"お", "ご", "御"} # 形状詞も有効な品詞として追加 VALID_POS = {"名詞", "動詞", "形容詞", "形容動詞", "副詞", "接頭辞", "形状詞"} CUSTOM_SYNONYMS = [ {"スマホ", "スマートフォン"}, {"チャリ", "自転車"}, {"おでこ", "額"}, {"ビリ", "最下位"}, {"おにぎり", "おむすび"}, {"めっちゃ", "とても"} ] class StringInput(BaseModel): user_input: str correct_word: str def is_valid_japanese(word: str) -> tuple[bool, str]: if not word.strip(): return False, "入力が空です" if len(word.strip()) > 10: return False, "10文字以内で入力してください" tokens = list(tagger(word)) if not tokens: return False, "日本語として認識できません" has_valid_token = any(t.feature.pos1 in VALID_POS for t in tokens) if not has_valid_token: return False, "単語として成立していません" return True, "OK" def get_word_info(word: str) -> dict: tokens = list(tagger(word)) pos = "不明" for token in tokens: # お・ご・御 はスキップ if token.feature.pos1 == "接頭辞" and token.surface in HONORIFIC_PREFIXES: continue # 助詞や助動詞など、単語のメインにならない付属語は無視して上書きしない if token.feature.pos1 in ["助詞", "助動詞", "補助記号"]: continue # 複合語対策: ループを最後まで回して一番最後の主要な品詞を取得する pos = token.feature.pos1 surfaces = [t.surface for t in tokens] last_surface = surfaces[-1] if surfaces else "" if last_surface in ["た", "だ"]: tense = "過去" elif last_surface in ["いる", "いた", "います"]: tense = "進行・状態" elif last_surface in ["る", "す", "く", "ぬ"]: tense = "現在" else: tense = "その他" return {"pos": pos, "tense": tense, "surfaces": surfaces} def pos_compatible(pos1: str, pos2: str) -> bool: if pos1 == pos2: return True for group in ALLOWED_POS_GROUPS: if pos1 in group and pos2 in group: return True return False # --- 新規追加: Softmax関数 --- def softmax(x): e_x = np.exp(x - np.max(x)) return e_x / e_x.sum() def check_answer(user_input: str, correct_word: str) -> dict: # 0. バリデーションチェック valid, reason = is_valid_japanese(user_input) if not valid: return {"result": "invalid", "reason": reason, "score": 0.0} # --- 新規追加: 完全一致チェック(AI判定や品詞チェックをすべてバイパス) --- # .strip() を使って、目に見えない前後のスペースや改行を無視して比較します if user_input.strip() == correct_word.strip(): return {"result": "正解", "reason": "完全一致", "score": 1.0} # ------------------------------------------------------------------ # 1. カスタム辞書による絶対正解のチェック(AI判定をバイパス) for syn_set in CUSTOM_SYNONYMS: if user_input in syn_set and correct_word in syn_set: return {"result": "正解", "reason": "同義語(辞書一致)", "score": 1.0} user_info = get_word_info(user_input) correct_info = get_word_info(correct_word) # 2. 品詞チェック if not pos_compatible(user_info["pos"], correct_info["pos"]): return { "result": "不正解", "reason": f"品詞が違います({user_info['pos']} vs {correct_info['pos']})", "score": 0.0 } # 3. 時制チェック if user_info["pos"] == "動詞" and correct_info["pos"] == "動詞": if user_info["tense"] != correct_info["tense"]: return { "result": "不正解", "reason": f"時制が違います({user_info['tense']} vs {correct_info['tense']})", "score": 0.0 } # 4. AI (NLIモデル) によるスコア算出と Softmax 変換 labels = ["含意", "中立", "矛盾"] scores_ab = nli_model.predict([(user_input, correct_word)])[0] scores_ba = nli_model.predict([(correct_word, user_input)])[0] res_ab = np.argmax(scores_ab) res_ba = np.argmax(scores_ba) probs_ab = softmax(scores_ab) probs_ba = softmax(scores_ba) # 5. 最終判定ロジック if res_ab == 0 and res_ba == 0: avg_prob = (probs_ab[0] + probs_ba[0]) / 2 if avg_prob <= SIMILARITY_THRESHOLD: return {"result": "不正解", "reason": "似ているが違う", "score": 0.0} return {"result": "正解", "reason": "同義語", "score": 1.0} elif res_ab == 0 or res_ba == 0: return {"result": "不正解", "reason": "包含関係(上位/下位概念)", "score": 0.5} else: return {"result": "不正解", "reason": labels[res_ab], "score": 0.0} @app.get("/") def root(): return {"status": "ok"} @app.post("/check") def check(data: StringInput): # APIリクエスト・レスポンスの確認用ログは保持 print(f"受信: user_input='{data.user_input}' / correct_word='{data.correct_word}'") result = check_answer(data.user_input, data.correct_word) print(f"結果: {result}") return result