File size: 4,066 Bytes
ff1dc9c
 
42121b5
 
 
ff1dc9c
 
 
42121b5
 
 
 
 
 
 
 
 
 
 
 
 
 
2769cd5
 
 
ff1dc9c
42121b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff1dc9c
 
 
 
2769cd5
84119fa
b268ba6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import CrossEncoder
import fugashi
import numpy as np

app = FastAPI()

# モデルのロード(起動時に1回だけ)
nli_model = CrossEncoder("akiFQC/bert-base-japanese-v3_nli-jsnli-jnli-jsick")
tagger = fugashi.Tagger()

ALLOWED_POS_GROUPS = [
    {"名詞", "形容詞", "形容動詞"},
    {"動詞", "名詞"},
    {"副詞", "名詞"},
]

HONORIFIC_PREFIXES = {"お", "ご", "御"}

VALID_POS = {"名詞", "動詞", "形容詞", "形容動詞", "副詞", "接頭辞"}

class StringInput(BaseModel):
    user_input: str
    correct_word: str

def is_valid_japanese(word: str) -> tuple[bool, str]:
    if not word.strip():
        return False, "入力が空です"
    if len(word.strip()) > 10:
        return False, "10文字以内で入力してください"
    tokens = list(tagger(word))
    if not tokens:
        return False, "日本語として認識できません"
    has_valid_token = any(t.feature.pos1 in VALID_POS for t in tokens)
    if not has_valid_token:
        return False, "単語として成立していません"
    return True, "OK"

def get_word_info(word: str) -> dict:
    tokens = list(tagger(word))
    pos = "不明"
    for token in tokens:
        if token.feature.pos1 == "接頭辞" and token.surface in HONORIFIC_PREFIXES:
            continue
        pos = token.feature.pos1
        break
    surfaces = [t.surface for t in tokens]
    last_token = tokens[-1] if tokens else None
    last_surface = last_token.surface if last_token else ""
    last_pos = last_token.feature.pos1 if last_token else ""
    if last_surface in ["た", "だ"] and last_pos == "助動詞":
        tense = "過去"
    elif last_surface in ["いる", "いた", "います"] and last_pos == "動詞":
        tense = "進行・状態"
    elif last_surface in ["る", "す", "く", "ぬ"] and last_pos == "動詞":
        tense = "現在"
    else:
        tense = "その他"
    return {"pos": pos, "tense": tense, "surfaces": surfaces}

def pos_compatible(pos1: str, pos2: str) -> bool:
    if pos1 == pos2:
        return True
    for group in ALLOWED_POS_GROUPS:
        if pos1 in group and pos2 in group:
            return True
    return False

def check_answer(user_input: str, correct_word: str) -> dict:
    valid, reason = is_valid_japanese(user_input)
    if not valid:
        return {"result": "invalid", "reason": reason, "score": 0.0}

    user_info = get_word_info(user_input)
    correct_info = get_word_info(correct_word)

    if not pos_compatible(user_info["pos"], correct_info["pos"]):
        return {
            "result": "不正解",
            "reason": f"品詞が違います({user_info['pos']} vs {correct_info['pos']})",
            "score": 0.0
        }

    if user_info["pos"] == "動詞" and correct_info["pos"] == "動詞":
        if user_info["tense"] != correct_info["tense"]:
            return {
                "result": "不正解",
                "reason": f"時制が違います({user_info['tense']} vs {correct_info['tense']})",
                "score": 0.0
            }

    labels = ["含意", "中立", "矛盾"]
    scores_ab = nli_model.predict([(user_input, correct_word)])[0]
    scores_ba = nli_model.predict([(correct_word, user_input)])[0]

    res_ab = np.argmax(scores_ab)
    res_ba = np.argmax(scores_ba)

    if res_ab == 0 and res_ba == 0:
        return {"result": "正解", "reason": "同義語", "score": 1.0}
    elif res_ab == 0 or res_ba == 0:
        return {"result": "不正解", "reason": "包含関係(上位/下位概念)", "score": 0.5}
    else:
        return {"result": "不正解", "reason": labels[res_ab], "score": 0.0}

@app.get("/")
def root():
    return {"status": "ok"}

@app.post("/check")
def check(data: StringInput):
    # ログに記録
    print(f"受信: user_input='{data.user_input}' / correct_word='{data.correct_word}'")
    result = check_answer(data.user_input, data.correct_word)
    print(f"結果: {result}")
    return result