| """ |
| objective.py — Objective Layer untuk AKSARA. |
| |
| Komponen: |
| embedding_relation_loss : cek apakah verb dan objek "nyambung" secara semantik |
| berbasis semantic_slots dari BSU — BUKAN embedding baru |
| cooccurrence_loss : cek apakah kombinasi token masuk akal secara statistik |
| berbasis PMI matrix yang diprecompute dari corpus |
| build_cooccurrence_matrix: precompute PMI(a, b) dari corpus, domain-agnostic |
| negative_sample : buat kalimat negatif via token shuffle |
| |
| Prinsip implementasi: |
| - Tidak ada hardcode domain atau aturan semantik manual |
| - Semua berbasis representasi model sendiri (BSU semantic_slots) |
| - Co-occurrence berbasis data corpus, bukan library NLP eksternal |
| - Tidak ada layer baru — hanya operasi di atas output BSU yang sudah ada |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| import random |
| from collections import defaultdict |
| from typing import Dict, List, Optional, Set, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def build_cooccurrence_matrix( |
| corpus: List[str], |
| root_vocab: Dict[str, int], |
| window: int = 4, |
| min_count: int = 2, |
| ) -> Dict[Tuple[int, int], float]: |
| token_count: Dict[int, int] = defaultdict(int) |
| pair_count: Dict[Tuple[int, int], int] = defaultdict(int) |
| total_tokens = 0 |
|
|
| for text in corpus: |
| tokens = text.lower().split() |
| ids = [root_vocab[t] for t in tokens if t in root_vocab] |
| for i, a in enumerate(ids): |
| token_count[a] += 1 |
| total_tokens += 1 |
| for j in range(i + 1, min(i + 1 + window, len(ids))): |
| b = ids[j] |
| if a != b: |
| pair_count[(min(a, b), max(a, b))] += 1 |
|
|
| if total_tokens == 0: |
| return {} |
|
|
| co_matrix: Dict[Tuple[int, int], float] = {} |
| for (a, b), co_c in pair_count.items(): |
| if co_c < min_count: |
| continue |
| p_ab = co_c / total_tokens |
| p_a = token_count[a] / total_tokens |
| p_b = token_count[b] / total_tokens |
| pmi = math.log(p_ab / (p_a * p_b + 1e-9) + 1e-9) |
| if pmi > 0: |
| co_matrix[(a, b)] = pmi |
|
|
| return co_matrix |
|
|
|
|
| def embedding_relation_loss( |
| semantic_slots: torch.Tensor, |
| morpheme_ids: torch.Tensor, |
| verb_token_ids: Set[int], |
| min_sim_target: float = 0.3, |
| ) -> torch.Tensor: |
| device = semantic_slots.device |
| B, L, d = semantic_slots.shape |
|
|
| if L < 2: |
| return torch.tensor(0.0, device=device) |
|
|
| total_loss = torch.tensor(0.0, device=device) |
| n_valid = 0 |
|
|
| for b in range(B): |
| ids = morpheme_ids[b] |
| sem = semantic_slots[b] |
|
|
| verb_pos = -1 |
| for pos in range(L): |
| if ids[pos].item() in verb_token_ids: |
| verb_pos = pos |
| break |
|
|
| if verb_pos == -1 or verb_pos >= L - 1: |
| continue |
|
|
| v_sem = sem[verb_pos] |
| after_verb = sem[verb_pos + 1:] |
| if after_verb.size(0) == 0: |
| continue |
|
|
| v_norm = F.normalize(v_sem.unsqueeze(0), dim=-1) |
| o_norm = F.normalize(after_verb, dim=-1) |
| sims = (v_norm * o_norm).sum(dim=-1) |
|
|
| avg_sim = sims.mean() |
| deficit = (min_sim_target - avg_sim).clamp(min=0.0) |
| total_loss = total_loss + deficit |
| n_valid += 1 |
|
|
| if n_valid == 0: |
| return torch.tensor(0.0, device=device) |
|
|
| return total_loss / n_valid |
|
|
|
|
| def cooccurrence_loss( |
| morpheme_ids: torch.Tensor, |
| co_matrix: Dict[Tuple[int, int], float], |
| window: int = 3, |
| ) -> torch.Tensor: |
| device = morpheme_ids.device |
| B, L = morpheme_ids.shape |
|
|
| total_pmi = 0.0 |
| n_pairs = 0 |
|
|
| for b in range(B): |
| ids = morpheme_ids[b].tolist() |
| for i in range(L): |
| a = ids[i] |
| if a == 0: |
| continue |
| for j in range(i + 1, min(i + 1 + window, L)): |
| bb = ids[j] |
| if bb == 0: |
| continue |
| key = (min(a, bb), max(a, bb)) |
| if key in co_matrix: |
| total_pmi += co_matrix[key] |
| n_pairs += 1 |
|
|
| if n_pairs == 0: |
| return torch.tensor(0.0, device=device) |
|
|
| avg_pmi = total_pmi / n_pairs |
| return torch.tensor(-avg_pmi, dtype=torch.float32, device=device) |
|
|
|
|
| def make_negative_batch( |
| morpheme_ids: torch.Tensor, |
| n_neg: int = 1, |
| ) -> torch.Tensor: |
| B, L = morpheme_ids.shape |
| device = morpheme_ids.device |
| negs = [] |
|
|
| for _ in range(n_neg): |
| neg = morpheme_ids.clone() |
| for b in range(B): |
| non_pad = (neg[b] != 0).nonzero(as_tuple=True)[0] |
| if len(non_pad) > 2: |
| perm = non_pad[torch.randperm(len(non_pad), device=device)] |
| neg[b, non_pad] = neg[b, perm] |
| negs.append(neg) |
|
|
| return torch.cat(negs, dim=0) |
|
|
|
|
| class CompositeLoss: |
| def __init__( |
| self, |
| lambda_rel: float = 0.3, |
| lambda_co: float = 0.1, |
| verb_ids: Optional[Set[int]] = None, |
| co_matrix: Optional[Dict[Tuple[int, int], float]] = None, |
| ): |
| self.lambda_rel = lambda_rel |
| self.lambda_co = lambda_co |
| self.verb_ids = verb_ids or set() |
| self.co_matrix = co_matrix or {} |
|
|
| def __call__( |
| self, |
| base_loss: torch.Tensor, |
| model_output: Dict, |
| morpheme_ids: torch.Tensor, |
| ) -> Tuple[torch.Tensor, Dict[str, float]]: |
| breakdown: Dict[str, float] = {"base": base_loss.item()} |
| total = base_loss |
|
|
| semantic_slots = model_output.get("semantic_slots") |
| if semantic_slots is None: |
| semantic_slots = model_output.get("bsu_original") |
|
|
| if semantic_slots is not None and self.verb_ids: |
| l_rel = embedding_relation_loss( |
| semantic_slots, morpheme_ids, self.verb_ids, |
| ) |
| total = total + self.lambda_rel * l_rel |
| breakdown["rel"] = l_rel.item() |
| else: |
| breakdown["rel"] = 0.0 |
|
|
| if self.co_matrix: |
| l_co = cooccurrence_loss(morpheme_ids, self.co_matrix) |
| total = total + self.lambda_co * l_co |
| breakdown["co"] = l_co.item() |
| else: |
| breakdown["co"] = 0.0 |
|
|
| breakdown["total"] = total.item() |
| return total, breakdown |