wordVec / app.py
wwefih's picture
Update app.py
b268ba6 verified
Raw
History Blame
4.07 kB
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import CrossEncoder
import fugashi
import numpy as np
app = FastAPI()
# モデルのロード(起動時に1回だけ)
nli_model = CrossEncoder("akiFQC/bert-base-japanese-v3_nli-jsnli-jnli-jsick")
tagger = fugashi.Tagger()
ALLOWED_POS_GROUPS = [
{"名詞", "形容詞", "形容動詞"},
{"動詞", "名詞"},
{"副詞", "名詞"},
]
HONORIFIC_PREFIXES = {"お", "ご", "御"}
VALID_POS = {"名詞", "動詞", "形容詞", "形容動詞", "副詞", "接頭辞"}
class StringInput(BaseModel):
user_input: str
correct_word: str
def is_valid_japanese(word: str) -> tuple[bool, str]:
if not word.strip():
return False, "入力が空です"
if len(word.strip()) > 10:
return False, "10文字以内で入力してください"
tokens = list(tagger(word))
if not tokens:
return False, "日本語として認識できません"
has_valid_token = any(t.feature.pos1 in VALID_POS for t in tokens)
if not has_valid_token:
return False, "単語として成立していません"
return True, "OK"
def get_word_info(word: str) -> dict:
tokens = list(tagger(word))
pos = "不明"
for token in tokens:
if token.feature.pos1 == "接頭辞" and token.surface in HONORIFIC_PREFIXES:
continue
pos = token.feature.pos1
break
surfaces = [t.surface for t in tokens]
last_token = tokens[-1] if tokens else None
last_surface = last_token.surface if last_token else ""
last_pos = last_token.feature.pos1 if last_token else ""
if last_surface in ["た", "だ"] and last_pos == "助動詞":
tense = "過去"
elif last_surface in ["いる", "いた", "います"] and last_pos == "動詞":
tense = "進行・状態"
elif last_surface in ["る", "す", "く", "ぬ"] and last_pos == "動詞":
tense = "現在"
else:
tense = "その他"
return {"pos": pos, "tense": tense, "surfaces": surfaces}
def pos_compatible(pos1: str, pos2: str) -> bool:
if pos1 == pos2:
return True
for group in ALLOWED_POS_GROUPS:
if pos1 in group and pos2 in group:
return True
return False
def check_answer(user_input: str, correct_word: str) -> dict:
valid, reason = is_valid_japanese(user_input)
if not valid:
return {"result": "invalid", "reason": reason, "score": 0.0}
user_info = get_word_info(user_input)
correct_info = get_word_info(correct_word)
if not pos_compatible(user_info["pos"], correct_info["pos"]):
return {
"result": "不正解",
"reason": f"品詞が違います({user_info['pos']} vs {correct_info['pos']})",
"score": 0.0
}
if user_info["pos"] == "動詞" and correct_info["pos"] == "動詞":
if user_info["tense"] != correct_info["tense"]:
return {
"result": "不正解",
"reason": f"時制が違います({user_info['tense']} vs {correct_info['tense']})",
"score": 0.0
}
labels = ["含意", "中立", "矛盾"]
scores_ab = nli_model.predict([(user_input, correct_word)])[0]
scores_ba = nli_model.predict([(correct_word, user_input)])[0]
res_ab = np.argmax(scores_ab)
res_ba = np.argmax(scores_ba)
if res_ab == 0 and res_ba == 0:
return {"result": "正解", "reason": "同義語", "score": 1.0}
elif res_ab == 0 or res_ba == 0:
return {"result": "不正解", "reason": "包含関係(上位/下位概念)", "score": 0.5}
else:
return {"result": "不正解", "reason": labels[res_ab], "score": 0.0}
@app.get("/")
def root():
return {"status": "ok"}
@app.post("/check")
def check(data: StringInput):
# ログに記録
print(f"受信: user_input='{data.user_input}' / correct_word='{data.correct_word}'")
result = check_answer(data.user_input, data.correct_word)
print(f"結果: {result}")
return result