Instructions to use daiweichen/pal-b-large-opt-350m with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use daiweichen/pal-b-large-opt-350m with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "summarization" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("summarization", model="daiweichen/pal-b-large-opt-350m", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("daiweichen/pal-b-large-opt-350m", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload PAL_B_RM_opt
Browse files- README.md +4 -4
- itemLearner.py +4 -1
- learner.py +9 -7
- pytorch_model.bin +1 -1
- userLearner.py +7 -2
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
datasets:
|
| 5 |
- CarperAI/openai_summarize_tldr
|
| 6 |
language:
|
| 7 |
- en
|
| 8 |
-
|
| 9 |
-
|
| 10 |
---
|
| 11 |
|
| 12 |
# Model Card for Model ID
|
|
|
|
| 1 |
---
|
| 2 |
+
base_model:
|
| 3 |
+
- facebook/opt-350m
|
| 4 |
datasets:
|
| 5 |
- CarperAI/openai_summarize_tldr
|
| 6 |
language:
|
| 7 |
- en
|
| 8 |
+
library_name: transformers
|
| 9 |
+
license: mit
|
| 10 |
---
|
| 11 |
|
| 12 |
# Model Card for Model ID
|
itemLearner.py
CHANGED
|
@@ -27,6 +27,8 @@ class ItemLearner(nn.Module):
|
|
| 27 |
'''
|
| 28 |
input_ids = x['input_ids']
|
| 29 |
attention_mask = x['attention_mask']
|
|
|
|
|
|
|
| 30 |
|
| 31 |
if rm_cached is None:
|
| 32 |
llm_res = self.llm(
|
|
@@ -37,11 +39,12 @@ class ItemLearner(nn.Module):
|
|
| 37 |
llm_res = self.llm(
|
| 38 |
input_ids=input_ids[:, -1:], # attention_mask=attention_mask,
|
| 39 |
past_key_values=rm_cached["item_learner"],
|
| 40 |
-
use_cache=
|
| 41 |
)
|
| 42 |
rm_cached["item_learner"] = llm_res.past_key_values
|
| 43 |
|
| 44 |
embeds = llm_res.last_hidden_state
|
|
|
|
| 45 |
# embeds shape: (bs, seq_len, hidden_size)
|
| 46 |
shape = embeds.shape
|
| 47 |
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hidden_size)
|
|
|
|
| 27 |
'''
|
| 28 |
input_ids = x['input_ids']
|
| 29 |
attention_mask = x['attention_mask']
|
| 30 |
+
# logger.critical(f"ItemLearner: {input_ids=}")
|
| 31 |
+
# logger.critical(f"ItemLearner: {attention_mask=}")
|
| 32 |
|
| 33 |
if rm_cached is None:
|
| 34 |
llm_res = self.llm(
|
|
|
|
| 39 |
llm_res = self.llm(
|
| 40 |
input_ids=input_ids[:, -1:], # attention_mask=attention_mask,
|
| 41 |
past_key_values=rm_cached["item_learner"],
|
| 42 |
+
use_cache=True
|
| 43 |
)
|
| 44 |
rm_cached["item_learner"] = llm_res.past_key_values
|
| 45 |
|
| 46 |
embeds = llm_res.last_hidden_state
|
| 47 |
+
# logger.critical(f"ItemLearner: {embeds=}")
|
| 48 |
# embeds shape: (bs, seq_len, hidden_size)
|
| 49 |
shape = embeds.shape
|
| 50 |
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hidden_size)
|
learner.py
CHANGED
|
@@ -113,16 +113,18 @@ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
|
|
| 113 |
|
| 114 |
def forward(self, x, rm_cached=None):
|
| 115 |
assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model"
|
| 116 |
-
|
| 117 |
if rm_cached is None:
|
| 118 |
items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items))
|
| 119 |
else:
|
| 120 |
items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached)
|
| 121 |
-
logger.
|
| 122 |
-
logger.
|
| 123 |
-
logger.
|
| 124 |
-
logger.
|
|
|
|
| 125 |
if self.pref_learner_type == 'angle':
|
|
|
|
| 126 |
prompt_last_prime = prompt_prime[:, -1, :]
|
| 127 |
prompt_last_prime = prompt_last_prime.unsqueeze(1)
|
| 128 |
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
|
|
@@ -131,8 +133,8 @@ class PrefLearner(BasePrefLearner): # <f(x),f(u)>
|
|
| 131 |
items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
|
| 132 |
logit_scale = self.logit_scale.exp()
|
| 133 |
clamped_logit_scale = torch.clamp(logit_scale, max=100)
|
| 134 |
-
logger.
|
| 135 |
-
logger.
|
| 136 |
sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
|
| 137 |
if rm_cached is None:
|
| 138 |
return sim_score
|
|
|
|
| 113 |
|
| 114 |
def forward(self, x, rm_cached=None):
|
| 115 |
assert self.uid is not None, "Please specify the user id first by calling specify_user_ids() to personalize the reward model"
|
| 116 |
+
prompt, items = x
|
| 117 |
if rm_cached is None:
|
| 118 |
items_prime, prompt_prime = self.map_to_pref_embedding_space((self.uid, prompt, items))
|
| 119 |
else:
|
| 120 |
items_prime, prompt_prime, rm_cached = self.map_to_pref_embedding_space((self.uid, prompt, items), rm_cached)
|
| 121 |
+
# logger.critical(f"{items_prime[0]=}")
|
| 122 |
+
# logger.critical(f"{prompt_prime[0]=}")
|
| 123 |
+
# logger.critical(f"{items_prime.shape=}")
|
| 124 |
+
# logger.critical(f"{prompt_prime.shape=}")
|
| 125 |
+
# FIXME: bug exist here
|
| 126 |
if self.pref_learner_type == 'angle':
|
| 127 |
+
# FIXME: do the cumulative evaluation!
|
| 128 |
prompt_last_prime = prompt_prime[:, -1, :]
|
| 129 |
prompt_last_prime = prompt_last_prime.unsqueeze(1)
|
| 130 |
prompt_last_prime = prompt_last_prime / torch.norm(prompt_last_prime, dim=-1, keepdim=True)
|
|
|
|
| 133 |
items_last_prime = items_last_prime / torch.norm(items_last_prime, dim=-1, keepdim=True)
|
| 134 |
logit_scale = self.logit_scale.exp()
|
| 135 |
clamped_logit_scale = torch.clamp(logit_scale, max=100)
|
| 136 |
+
# logger.critical(f"{prompt_last_prime.shape=}")
|
| 137 |
+
# logger.critical(f"{items_last_prime.shape=}")
|
| 138 |
sim_score = (prompt_last_prime * items_last_prime).sum(dim=-1) * clamped_logit_scale # (bs, max_token_length)
|
| 139 |
if rm_cached is None:
|
| 140 |
return sim_score
|
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1334487698
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c5e8e5083c6c333b9ba3284e989dab95ef00f70e1b97770df255acebada4388
|
| 3 |
size 1334487698
|
userLearner.py
CHANGED
|
@@ -92,6 +92,11 @@ class UserLearner(nn.Module):
|
|
| 92 |
|
| 93 |
# embeds shape: (bs, seq_len, hid_dim)
|
| 94 |
shape = embeds.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
|
| 96 |
# g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
|
| 97 |
logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
|
|
@@ -118,8 +123,8 @@ class UserLearner(nn.Module):
|
|
| 118 |
# assert sum(mix_weight) == 1
|
| 119 |
# w = self.softmax(mix_weight.repeat(bs, 1))
|
| 120 |
# w = mix_weight.repeat(bs, 1)
|
| 121 |
-
logger.info(f"{w=}")
|
| 122 |
-
logger.info(f"{w.shape=}")
|
| 123 |
w = w.unsqueeze(-1).unsqueeze(-1)
|
| 124 |
y_hat = (w * prompt_logits).sum(dim=1)
|
| 125 |
self.tmp_store_user_ideal_points = y_hat
|
|
|
|
| 92 |
|
| 93 |
# embeds shape: (bs, seq_len, hid_dim)
|
| 94 |
shape = embeds.shape
|
| 95 |
+
# only last hidden state start
|
| 96 |
+
embeds = embeds[:, -1, :] # (bs, seq_len, hid_dim) -> (bs, hid_dim)
|
| 97 |
+
embeds = embeds.unsqueeze(1).repeat(1, shape[1], 1) # (bs, hid_dim) -> (bs, seq_len, hid_dim)
|
| 98 |
+
# only last hidden state end
|
| 99 |
+
# logger.critical("using only last hidden state of prompt tokens")
|
| 100 |
embeds = embeds.view(-1, shape[-1]) # (bs*seq_len, hid_dim)
|
| 101 |
# g(embeds) shape: (bs*seq_len, hid_dim) -> (bs*seq_len, pref_dim)
|
| 102 |
logits = torch.stack([g(embeds).view(shape[0], shape[1], -1) for g in self.projectors.values()],dim=1)
|
|
|
|
| 123 |
# assert sum(mix_weight) == 1
|
| 124 |
# w = self.softmax(mix_weight.repeat(bs, 1))
|
| 125 |
# w = mix_weight.repeat(bs, 1)
|
| 126 |
+
# logger.info(f"{w=}")
|
| 127 |
+
# logger.info(f"{w.shape=}")
|
| 128 |
w = w.unsqueeze(-1).unsqueeze(-1)
|
| 129 |
y_hat = (w * prompt_logits).sum(dim=1)
|
| 130 |
self.tmp_store_user_ideal_points = y_hat
|