AKSARA-CLM-v1 / aksara /training /objective.py
emylton's picture
Upload folder using huggingface_hub
9338a41 verified
Raw
History Blame Contribute Delete
6.3 kB
"""
objective.py — Objective Layer untuk AKSARA.
Komponen:
embedding_relation_loss : cek apakah verb dan objek "nyambung" secara semantik
berbasis semantic_slots dari BSU — BUKAN embedding baru
cooccurrence_loss : cek apakah kombinasi token masuk akal secara statistik
berbasis PMI matrix yang diprecompute dari corpus
build_cooccurrence_matrix: precompute PMI(a, b) dari corpus, domain-agnostic
negative_sample : buat kalimat negatif via token shuffle
Prinsip implementasi:
- Tidak ada hardcode domain atau aturan semantik manual
- Semua berbasis representasi model sendiri (BSU semantic_slots)
- Co-occurrence berbasis data corpus, bukan library NLP eksternal
- Tidak ada layer baru — hanya operasi di atas output BSU yang sudah ada
"""
from __future__ import annotations
import math
import random
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple
import torch
import torch.nn.functional as F
def build_cooccurrence_matrix(
corpus: List[str],
root_vocab: Dict[str, int],
window: int = 4,
min_count: int = 2,
) -> Dict[Tuple[int, int], float]:
token_count: Dict[int, int] = defaultdict(int)
pair_count: Dict[Tuple[int, int], int] = defaultdict(int)
total_tokens = 0
for text in corpus:
tokens = text.lower().split()
ids = [root_vocab[t] for t in tokens if t in root_vocab]
for i, a in enumerate(ids):
token_count[a] += 1
total_tokens += 1
for j in range(i + 1, min(i + 1 + window, len(ids))):
b = ids[j]
if a != b:
pair_count[(min(a, b), max(a, b))] += 1
if total_tokens == 0:
return {}
co_matrix: Dict[Tuple[int, int], float] = {}
for (a, b), co_c in pair_count.items():
if co_c < min_count:
continue
p_ab = co_c / total_tokens
p_a = token_count[a] / total_tokens
p_b = token_count[b] / total_tokens
pmi = math.log(p_ab / (p_a * p_b + 1e-9) + 1e-9)
if pmi > 0:
co_matrix[(a, b)] = pmi
return co_matrix
def embedding_relation_loss(
semantic_slots: torch.Tensor,
morpheme_ids: torch.Tensor,
verb_token_ids: Set[int],
min_sim_target: float = 0.3,
) -> torch.Tensor:
device = semantic_slots.device
B, L, d = semantic_slots.shape
if L < 2:
return torch.tensor(0.0, device=device)
total_loss = torch.tensor(0.0, device=device)
n_valid = 0
for b in range(B):
ids = morpheme_ids[b]
sem = semantic_slots[b]
verb_pos = -1
for pos in range(L):
if ids[pos].item() in verb_token_ids:
verb_pos = pos
break
if verb_pos == -1 or verb_pos >= L - 1:
continue
v_sem = sem[verb_pos]
after_verb = sem[verb_pos + 1:]
if after_verb.size(0) == 0:
continue
v_norm = F.normalize(v_sem.unsqueeze(0), dim=-1)
o_norm = F.normalize(after_verb, dim=-1)
sims = (v_norm * o_norm).sum(dim=-1)
avg_sim = sims.mean()
deficit = (min_sim_target - avg_sim).clamp(min=0.0)
total_loss = total_loss + deficit
n_valid += 1
if n_valid == 0:
return torch.tensor(0.0, device=device)
return total_loss / n_valid
def cooccurrence_loss(
morpheme_ids: torch.Tensor,
co_matrix: Dict[Tuple[int, int], float],
window: int = 3,
) -> torch.Tensor:
device = morpheme_ids.device
B, L = morpheme_ids.shape
total_pmi = 0.0
n_pairs = 0
for b in range(B):
ids = morpheme_ids[b].tolist()
for i in range(L):
a = ids[i]
if a == 0:
continue
for j in range(i + 1, min(i + 1 + window, L)):
bb = ids[j]
if bb == 0:
continue
key = (min(a, bb), max(a, bb))
if key in co_matrix:
total_pmi += co_matrix[key]
n_pairs += 1
if n_pairs == 0:
return torch.tensor(0.0, device=device)
avg_pmi = total_pmi / n_pairs
return torch.tensor(-avg_pmi, dtype=torch.float32, device=device)
def make_negative_batch(
morpheme_ids: torch.Tensor,
n_neg: int = 1,
) -> torch.Tensor:
B, L = morpheme_ids.shape
device = morpheme_ids.device
negs = []
for _ in range(n_neg):
neg = morpheme_ids.clone()
for b in range(B):
non_pad = (neg[b] != 0).nonzero(as_tuple=True)[0]
if len(non_pad) > 2:
perm = non_pad[torch.randperm(len(non_pad), device=device)]
neg[b, non_pad] = neg[b, perm]
negs.append(neg)
return torch.cat(negs, dim=0)
class CompositeLoss:
def __init__(
self,
lambda_rel: float = 0.3,
lambda_co: float = 0.1,
verb_ids: Optional[Set[int]] = None,
co_matrix: Optional[Dict[Tuple[int, int], float]] = None,
):
self.lambda_rel = lambda_rel
self.lambda_co = lambda_co
self.verb_ids = verb_ids or set()
self.co_matrix = co_matrix or {}
def __call__(
self,
base_loss: torch.Tensor,
model_output: Dict,
morpheme_ids: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, float]]:
breakdown: Dict[str, float] = {"base": base_loss.item()}
total = base_loss
semantic_slots = model_output.get("semantic_slots")
if semantic_slots is None:
semantic_slots = model_output.get("bsu_original")
if semantic_slots is not None and self.verb_ids:
l_rel = embedding_relation_loss(
semantic_slots, morpheme_ids, self.verb_ids,
)
total = total + self.lambda_rel * l_rel
breakdown["rel"] = l_rel.item()
else:
breakdown["rel"] = 0.0
if self.co_matrix:
l_co = cooccurrence_loss(morpheme_ids, self.co_matrix)
total = total + self.lambda_co * l_co
breakdown["co"] = l_co.item()
else:
breakdown["co"] = 0.0
breakdown["total"] = total.item()
return total, breakdown