| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from sentence_transformers import CrossEncoder |
| import fugashi |
| import numpy as np |
|
|
| app = FastAPI() |
|
|
| |
| nli_model = CrossEncoder("akiFQC/bert-base-japanese-v3_nli-jsnli-jnli-jsick") |
| tagger = fugashi.Tagger() |
|
|
| ALLOWED_POS_GROUPS = [ |
| {"名詞", "形容詞", "形容動詞"}, |
| {"動詞", "名詞"}, |
| {"副詞", "名詞"}, |
| ] |
|
|
| HONORIFIC_PREFIXES = {"お", "ご", "御"} |
|
|
| VALID_POS = {"名詞", "動詞", "形容詞", "形容動詞", "副詞", "接頭辞"} |
|
|
| class StringInput(BaseModel): |
| user_input: str |
| correct_word: str |
|
|
| def is_valid_japanese(word: str) -> tuple[bool, str]: |
| if not word.strip(): |
| return False, "入力が空です" |
| if len(word.strip()) > 10: |
| return False, "10文字以内で入力してください" |
| tokens = list(tagger(word)) |
| if not tokens: |
| return False, "日本語として認識できません" |
| has_valid_token = any(t.feature.pos1 in VALID_POS for t in tokens) |
| if not has_valid_token: |
| return False, "単語として成立していません" |
| return True, "OK" |
|
|
| def get_word_info(word: str) -> dict: |
| tokens = list(tagger(word)) |
| pos = "不明" |
| for token in tokens: |
| if token.feature.pos1 == "接頭辞" and token.surface in HONORIFIC_PREFIXES: |
| continue |
| pos = token.feature.pos1 |
| break |
| surfaces = [t.surface for t in tokens] |
| last_token = tokens[-1] if tokens else None |
| last_surface = last_token.surface if last_token else "" |
| last_pos = last_token.feature.pos1 if last_token else "" |
| if last_surface in ["た", "だ"] and last_pos == "助動詞": |
| tense = "過去" |
| elif last_surface in ["いる", "いた", "います"] and last_pos == "動詞": |
| tense = "進行・状態" |
| elif last_surface in ["る", "す", "く", "ぬ"] and last_pos == "動詞": |
| tense = "現在" |
| else: |
| tense = "その他" |
| return {"pos": pos, "tense": tense, "surfaces": surfaces} |
|
|
| def pos_compatible(pos1: str, pos2: str) -> bool: |
| if pos1 == pos2: |
| return True |
| for group in ALLOWED_POS_GROUPS: |
| if pos1 in group and pos2 in group: |
| return True |
| return False |
|
|
| def check_answer(user_input: str, correct_word: str) -> dict: |
| valid, reason = is_valid_japanese(user_input) |
| if not valid: |
| return {"result": "invalid", "reason": reason, "score": 0.0} |
|
|
| user_info = get_word_info(user_input) |
| correct_info = get_word_info(correct_word) |
|
|
| if not pos_compatible(user_info["pos"], correct_info["pos"]): |
| return { |
| "result": "不正解", |
| "reason": f"品詞が違います({user_info['pos']} vs {correct_info['pos']})", |
| "score": 0.0 |
| } |
|
|
| if user_info["pos"] == "動詞" and correct_info["pos"] == "動詞": |
| if user_info["tense"] != correct_info["tense"]: |
| return { |
| "result": "不正解", |
| "reason": f"時制が違います({user_info['tense']} vs {correct_info['tense']})", |
| "score": 0.0 |
| } |
|
|
| labels = ["含意", "中立", "矛盾"] |
| scores_ab = nli_model.predict([(user_input, correct_word)])[0] |
| scores_ba = nli_model.predict([(correct_word, user_input)])[0] |
|
|
| res_ab = np.argmax(scores_ab) |
| res_ba = np.argmax(scores_ba) |
|
|
| if res_ab == 0 and res_ba == 0: |
| return {"result": "正解", "reason": "同義語", "score": 1.0} |
| elif res_ab == 0 or res_ba == 0: |
| return {"result": "不正解", "reason": "包含関係(上位/下位概念)", "score": 0.5} |
| else: |
| return {"result": "不正解", "reason": labels[res_ab], "score": 0.0} |
|
|
| @app.get("/") |
| def root(): |
| return {"status": "ok"} |
|
|
| @app.post("/check") |
| def check(data: StringInput): |
| |
| print(f"受信: user_input='{data.user_input}' / correct_word='{data.correct_word}'") |
| result = check_answer(data.user_input, data.correct_word) |
| print(f"結果: {result}") |
| return result |