wordVec / app.py
wwefih's picture
Update app.py
30f2866 verified
Raw
History Blame
6.13 kB
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import CrossEncoder
import fugashi
import numpy as np
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# モデルのロード(起動時に1回だけ)
nli_model = CrossEncoder("akiFQC/bert-base-japanese-v3_nli-jsnli-jnli-jsick")
tagger = fugashi.Tagger()
# --- 新規追加: 判定用の定数と辞書 ---
SIMILARITY_THRESHOLD = 0.85
ALLOWED_POS_GROUPS = [
{"名詞"},
{"形容詞", "形容動詞", "形状詞"},
{"動詞", "名詞"},
{"副詞", "名詞"},
]
HONORIFIC_PREFIXES = {"お", "ご", "御"}
# 形状詞も有効な品詞として追加
VALID_POS = {"名詞", "動詞", "形容詞", "形容動詞", "副詞", "接頭辞", "形状詞"}
CUSTOM_SYNONYMS = [
{"スマホ", "スマートフォン"},
{"チャリ", "自転車"},
{"おでこ", "額"},
{"ビリ", "最下位"},
{"おにぎり", "おむすび"},
{"めっちゃ", "とても"}
]
class StringInput(BaseModel):
user_input: str
correct_word: str
def is_valid_japanese(word: str) -> tuple[bool, str]:
if not word.strip():
return False, "入力が空です"
if len(word.strip()) > 10:
return False, "10文字以内で入力してください"
tokens = list(tagger(word))
if not tokens:
return False, "日本語として認識できません"
has_valid_token = any(t.feature.pos1 in VALID_POS for t in tokens)
if not has_valid_token:
return False, "単語として成立していません"
return True, "OK"
def get_word_info(word: str) -> dict:
tokens = list(tagger(word))
pos = "不明"
for token in tokens:
# お・ご・御 はスキップ
if token.feature.pos1 == "接頭辞" and token.surface in HONORIFIC_PREFIXES:
continue
# 助詞や助動詞など、単語のメインにならない付属語は無視して上書きしない
if token.feature.pos1 in ["助詞", "助動詞", "補助記号"]:
continue
# 複合語対策: ループを最後まで回して一番最後の主要な品詞を取得する
pos = token.feature.pos1
surfaces = [t.surface for t in tokens]
last_surface = surfaces[-1] if surfaces else ""
if last_surface in ["た", "だ"]:
tense = "過去"
elif last_surface in ["いる", "いた", "います"]:
tense = "進行・状態"
elif last_surface in ["る", "す", "く", "ぬ"]:
tense = "現在"
else:
tense = "その他"
return {"pos": pos, "tense": tense, "surfaces": surfaces}
def pos_compatible(pos1: str, pos2: str) -> bool:
if pos1 == pos2:
return True
for group in ALLOWED_POS_GROUPS:
if pos1 in group and pos2 in group:
return True
return False
# --- 新規追加: Softmax関数 ---
def softmax(x):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def check_answer(user_input: str, correct_word: str) -> dict:
# 0. バリデーションチェック
valid, reason = is_valid_japanese(user_input)
if not valid:
return {"result": "invalid", "reason": reason, "score": 0.0}
# --- 新規追加: 完全一致チェック(AI判定や品詞チェックをすべてバイパス) ---
# .strip() を使って、目に見えない前後のスペースや改行を無視して比較します
if user_input.strip() == correct_word.strip():
return {"result": "正解", "reason": "完全一致", "score": 1.0}
# ------------------------------------------------------------------
# 1. カスタム辞書による絶対正解のチェック(AI判定をバイパス)
for syn_set in CUSTOM_SYNONYMS:
if user_input in syn_set and correct_word in syn_set:
return {"result": "正解", "reason": "同義語(辞書一致)", "score": 1.0}
user_info = get_word_info(user_input)
correct_info = get_word_info(correct_word)
# 2. 品詞チェック
if not pos_compatible(user_info["pos"], correct_info["pos"]):
return {
"result": "不正解",
"reason": f"品詞が違います({user_info['pos']} vs {correct_info['pos']})",
"score": 0.0
}
# 3. 時制チェック
if user_info["pos"] == "動詞" and correct_info["pos"] == "動詞":
if user_info["tense"] != correct_info["tense"]:
return {
"result": "不正解",
"reason": f"時制が違います({user_info['tense']} vs {correct_info['tense']})",
"score": 0.0
}
# 4. AI (NLIモデル) によるスコア算出と Softmax 変換
labels = ["含意", "中立", "矛盾"]
scores_ab = nli_model.predict([(user_input, correct_word)])[0]
scores_ba = nli_model.predict([(correct_word, user_input)])[0]
res_ab = np.argmax(scores_ab)
res_ba = np.argmax(scores_ba)
probs_ab = softmax(scores_ab)
probs_ba = softmax(scores_ba)
# 5. 最終判定ロジック
if res_ab == 0 and res_ba == 0:
avg_prob = (probs_ab[0] + probs_ba[0]) / 2
if avg_prob <= SIMILARITY_THRESHOLD:
return {"result": "不正解", "reason": "似ているが違う", "score": 0.0}
return {"result": "正解", "reason": "同義語", "score": 1.0}
elif res_ab == 0 or res_ba == 0:
return {"result": "不正解", "reason": "包含関係(上位/下位概念)", "score": 0.5}
else:
return {"result": "不正解", "reason": labels[res_ab], "score": 0.0}
@app.get("/")
def root():
return {"status": "ok"}
@app.post("/check")
def check(data: StringInput):
# APIリクエスト・レスポンスの確認用ログは保持
print(f"受信: user_input='{data.user_input}' / correct_word='{data.correct_word}'")
result = check_answer(data.user_input, data.correct_word)
print(f"結果: {result}")
return result