| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from sentence_transformers import CrossEncoder |
| import fugashi |
| import numpy as np |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| nli_model = CrossEncoder("akiFQC/bert-base-japanese-v3_nli-jsnli-jnli-jsick") |
| tagger = fugashi.Tagger() |
|
|
| |
| SIMILARITY_THRESHOLD = 0.85 |
|
|
| ALLOWED_POS_GROUPS = [ |
| {"名詞"}, |
| {"形容詞", "形容動詞", "形状詞"}, |
| {"動詞", "名詞"}, |
| {"副詞", "名詞"}, |
| ] |
|
|
| HONORIFIC_PREFIXES = {"お", "ご", "御"} |
|
|
| |
| VALID_POS = {"名詞", "動詞", "形容詞", "形容動詞", "副詞", "接頭辞", "形状詞"} |
|
|
| CUSTOM_SYNONYMS = [ |
| {"スマホ", "スマートフォン"}, |
| {"チャリ", "自転車"}, |
| {"おでこ", "額"}, |
| {"ビリ", "最下位"}, |
| {"おにぎり", "おむすび"}, |
| {"めっちゃ", "とても"} |
| ] |
|
|
| class StringInput(BaseModel): |
| user_input: str |
| correct_word: str |
|
|
| def is_valid_japanese(word: str) -> tuple[bool, str]: |
| if not word.strip(): |
| return False, "入力が空です" |
| if len(word.strip()) > 10: |
| return False, "10文字以内で入力してください" |
| tokens = list(tagger(word)) |
| if not tokens: |
| return False, "日本語として認識できません" |
| has_valid_token = any(t.feature.pos1 in VALID_POS for t in tokens) |
| if not has_valid_token: |
| return False, "単語として成立していません" |
| return True, "OK" |
|
|
| def get_word_info(word: str) -> dict: |
| tokens = list(tagger(word)) |
| pos = "不明" |
| for token in tokens: |
| |
| if token.feature.pos1 == "接頭辞" and token.surface in HONORIFIC_PREFIXES: |
| continue |
| |
| if token.feature.pos1 in ["助詞", "助動詞", "補助記号"]: |
| continue |
| |
| pos = token.feature.pos1 |
| |
| surfaces = [t.surface for t in tokens] |
| last_surface = surfaces[-1] if surfaces else "" |
| |
| if last_surface in ["た", "だ"]: |
| tense = "過去" |
| elif last_surface in ["いる", "いた", "います"]: |
| tense = "進行・状態" |
| elif last_surface in ["る", "す", "く", "ぬ"]: |
| tense = "現在" |
| else: |
| tense = "その他" |
| |
| return {"pos": pos, "tense": tense, "surfaces": surfaces} |
|
|
| def pos_compatible(pos1: str, pos2: str) -> bool: |
| if pos1 == pos2: |
| return True |
| for group in ALLOWED_POS_GROUPS: |
| if pos1 in group and pos2 in group: |
| return True |
| return False |
|
|
| |
| def softmax(x): |
| e_x = np.exp(x - np.max(x)) |
| return e_x / e_x.sum() |
|
|
| def check_answer(user_input: str, correct_word: str) -> dict: |
| |
| valid, reason = is_valid_japanese(user_input) |
| if not valid: |
| return {"result": "invalid", "reason": reason, "score": 0.0} |
|
|
| |
| |
| if user_input.strip() == correct_word.strip(): |
| return {"result": "正解", "reason": "完全一致", "score": 1.0} |
| |
|
|
| |
| for syn_set in CUSTOM_SYNONYMS: |
| if user_input in syn_set and correct_word in syn_set: |
| return {"result": "正解", "reason": "同義語(辞書一致)", "score": 1.0} |
|
|
| user_info = get_word_info(user_input) |
| correct_info = get_word_info(correct_word) |
|
|
| |
| if not pos_compatible(user_info["pos"], correct_info["pos"]): |
| return { |
| "result": "不正解", |
| "reason": f"品詞が違います({user_info['pos']} vs {correct_info['pos']})", |
| "score": 0.0 |
| } |
|
|
| |
| if user_info["pos"] == "動詞" and correct_info["pos"] == "動詞": |
| if user_info["tense"] != correct_info["tense"]: |
| return { |
| "result": "不正解", |
| "reason": f"時制が違います({user_info['tense']} vs {correct_info['tense']})", |
| "score": 0.0 |
| } |
|
|
| |
| labels = ["含意", "中立", "矛盾"] |
| scores_ab = nli_model.predict([(user_input, correct_word)])[0] |
| scores_ba = nli_model.predict([(correct_word, user_input)])[0] |
|
|
| res_ab = np.argmax(scores_ab) |
| res_ba = np.argmax(scores_ba) |
| |
| probs_ab = softmax(scores_ab) |
| probs_ba = softmax(scores_ba) |
|
|
| |
| if res_ab == 0 and res_ba == 0: |
| avg_prob = (probs_ab[0] + probs_ba[0]) / 2 |
| if avg_prob <= SIMILARITY_THRESHOLD: |
| return {"result": "不正解", "reason": "似ているが違う", "score": 0.0} |
| return {"result": "正解", "reason": "同義語", "score": 1.0} |
| |
| elif res_ab == 0 or res_ba == 0: |
| return {"result": "不正解", "reason": "包含関係(上位/下位概念)", "score": 0.5} |
| else: |
| return {"result": "不正解", "reason": labels[res_ab], "score": 0.0} |
|
|
| @app.get("/") |
| def root(): |
| return {"status": "ok"} |
|
|
| @app.post("/check") |
| def check(data: StringInput): |
| |
| print(f"受信: user_input='{data.user_input}' / correct_word='{data.correct_word}'") |
| result = check_answer(data.user_input, data.correct_word) |
| print(f"結果: {result}") |
| return result |