Ventali commited on
Commit
2fdf0fc
·
verified ·
1 Parent(s): 413cd2c

Upload llama31-8b-ade-sft-v2 adapter (exact_match 0.715, positive_f1 0.860)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: llama3.1
3
+ base_model: meta-llama/Llama-3.1-8B-Instruct
4
+ library_name: peft
5
+ pipeline_tag: text-generation
6
+ language:
7
+ - en
8
+ tags:
9
+ - medical
10
+ - biomedical
11
+ - adverse-drug-events
12
+ - ade
13
+ - pharmacovigilance
14
+ - distillation
15
+ - lora
16
+ - peft
17
+ - llama-3.1
18
+ datasets:
19
+ - ade-benchmark-corpus/ade_corpus_v2
20
+ model-index:
21
+ - name: llama31-8b-ade-sft-v2
22
+ results:
23
+ - task:
24
+ type: text-generation
25
+ name: ADE Binary QA + span extraction
26
+ dataset:
27
+ type: ade-benchmark-corpus/ade_corpus_v2
28
+ name: ade_corpus_v2 (200 held-out)
29
+ metrics:
30
+ - type: exact_match
31
+ value: 0.715
32
+ name: exact_match (answer ∈ {yes,no,abstain})
33
+ - type: f1
34
+ value: 0.860
35
+ name: positive_f1 (answer=yes)
36
+ - type: precision
37
+ value: 0.785
38
+ name: positive_precision
39
+ - type: recall
40
+ value: 0.950
41
+ name: positive_recall
42
+ - type: f1
43
+ value: 0.883
44
+ name: span_drug_token_f1 (positives only)
45
+ - type: f1
46
+ value: 0.866
47
+ name: span_event_token_f1 (positives only)
48
+ ---
49
+
50
+ # llama31-8b-ade-sft-v2
51
+
52
+ A LoRA adapter for [`meta-llama/Llama-3.1-8B-Instruct`](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) that answers adverse drug event (ADE) questions on single-sentence clinical text and extracts the implicated drug and event as structured JSON. Distilled from a Vertex-hosted Llama 3.3 70B teacher; trained with QLoRA on ~3k teacher-labeled sentences from `ade_corpus_v2`.
53
+
54
+ **⚠️ Not clinical grade.** This is a research / educational artifact. Do not use for patient-care decisions.
55
+
56
+ ## Intended use
57
+
58
+ Given a short clinical vignette (one or a few sentences), produce a JSON object:
59
+
60
+ ```json
61
+ {
62
+ "answer": "yes | no | abstain",
63
+ "drug": "<drug name or empty>",
64
+ "event": "<adverse event or empty>",
65
+ "evidence": "<quoted or closely paraphrased text>",
66
+ "short_justification": "<one short sentence>",
67
+ "confidence": 0.0
68
+ }
69
+ ```
70
+
71
+ - `answer` is `yes` only when the text supports a causally plausible drug-event relationship.
72
+ - `abstain` is reserved for cases where the text names no plausible drug or no plausible event. Temporal co-occurrence with a clear external cause (e.g., "on metformin, slipped and fractured ankle") should be `no`, not `abstain`.
73
+
74
+ ## Evaluation
75
+
76
+ Held-out split (200 rows, balanced 100 positive / 100 negative) sampled from `ade_corpus_v2` and never seen during training. Compared against a v1 baseline that did not use few-shots or hard negatives.
77
+
78
+ | Metric | v1 | **v2 (this model)** |
79
+ |---|---|---|
80
+ | exact_match (yes/no/abstain) | 0.555 | **0.715** |
81
+ | abstain_rate | 0.315 | **0.135** |
82
+ | positive_f1 | 0.884 | 0.860 |
83
+ | positive_precision | 0.798 | 0.785 |
84
+ | positive_recall | 0.990 | 0.950 |
85
+ | span_drug_exact_match (pos) | 0.940 | 0.840 |
86
+ | span_drug_token_f1 (pos) | 0.952 | 0.883 |
87
+ | span_event_exact_match (pos) | 0.660 | 0.710 |
88
+ | span_event_token_f1 (pos) | 0.816 | 0.866 |
89
+
90
+ **Tradeoff to know.** v2 adds 600 "hard negatives" (drug mentioned, answer=no) to teach calibrated abstention. This halved the abstain rate and added 16 pts of exact_match, but cost ~10 pts of drug-span exact match vs v1 — the model learned to be more cautious about emitting a drug name. If your use case needs drug extraction on positives above all else, the earlier v1 checkpoint may be preferable.
91
+
92
+ ## Usage
93
+
94
+ ```python
95
+ from peft import PeftModel
96
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
97
+ import torch
98
+
99
+ base_id = "meta-llama/Llama-3.1-8B-Instruct"
100
+ adapter_id = "Ventali/llama31-8b-ade-sft-v2"
101
+
102
+ bnb = BitsAndBytesConfig(
103
+ load_in_4bit=True,
104
+ bnb_4bit_quant_type="nf4",
105
+ bnb_4bit_use_double_quant=True,
106
+ bnb_4bit_compute_dtype=torch.bfloat16,
107
+ )
108
+ tokenizer = AutoTokenizer.from_pretrained(base_id)
109
+ model = AutoModelForCausalLM.from_pretrained(base_id, quantization_config=bnb, device_map="auto")
110
+ model = PeftModel.from_pretrained(model, adapter_id)
111
+ model.eval()
112
+
113
+ messages = [
114
+ {"role": "system", "content": "You are a careful biomedical assistant. For each case, return a compact JSON answer grounded in the provided evidence. If the evidence is insufficient, abstain."},
115
+ {"role": "user", "content": "Case: The patient developed diffuse urticaria three days after starting amoxicillin.\n\nIs this consistent with a possible adverse drug event? Identify the drug and event if so, or abstain if the evidence is insufficient."},
116
+ ]
117
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
118
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
119
+ with torch.no_grad():
120
+ out = model.generate(**inputs, max_new_tokens=256, do_sample=False, pad_token_id=tokenizer.eos_token_id)
121
+ print(tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
122
+ ```
123
+
124
+ For Apple Silicon you can fuse the adapter into the base and run via `mlx-lm`:
125
+
126
+ ```bash
127
+ pip install mlx-lm
128
+ mlx_lm.fuse --model meta-llama/Llama-3.1-8B-Instruct \
129
+ --adapter-path <local-adapter-dir> \
130
+ --save-path ~/models/llama31-ade-mlx
131
+ mlx_lm.generate --model ~/models/llama31-ade-mlx --prompt "..."
132
+ ```
133
+
134
+ ## Training
135
+
136
+ - Base: `meta-llama/Llama-3.1-8B-Instruct`, loaded in 4-bit (NF4, double-quant, bf16 compute).
137
+ - LoRA: r=32, alpha=64, dropout=0.05, target modules {q,k,v,o,gate,up,down}_proj. 41.9M trainable params (0.52% of base).
138
+ - Data: 2,999 (prompt, teacher JSON) pairs. Prompts drawn from `ade_corpus_v2` as 1,200 positive (from `drug_ade_relation`) + 1,200 easy-negative + 600 hard-negative (classification label=0 rows whose text mentions a drug from the positive-split vocabulary). Teacher: Vertex AI managed `llama-3.3-70b-instruct-maas` (temperature 0.2), seeded with 3 yes/no/abstain few-shots and prompted to reserve abstention for cases with no plausible drug or no plausible event.
139
+ - Filter: required non-empty `answer` and `evidence`, `confidence ≥ 0.65`, evidence-source word overlap ≥ 0.6. 2,999/3,000 retained.
140
+ - Optimizer: AdamW, lr=2e-4, warmup_ratio=0.03, weight_decay=0.01, bf16, gradient_checkpointing on.
141
+ - 3 epochs with `load_best_model_at_end=True` on `eval_loss`; the epoch-1 checkpoint (eval_loss 0.506) was restored, eclipsing the overfit epochs 2–3 (0.547, 0.676).
142
+ - Hardware: single A100 40GB on GCP `a2-highgpu-1g`. Training wall time ~94 min.
143
+
144
+ ## Limitations
145
+
146
+ - Trained on single-sentence, literature-style clinical text. Longer narratives (discharge summaries, EHR free-text) are out of distribution and will likely perform worse.
147
+ - Teacher labels are synthetic. A clinician-reviewed eval set was not used; regressions against human judgment have not been measured.
148
+ - The model occasionally produces an empty `drug` or `event` field on positive cases, which is a regression from v1 on drug-span extraction. See the tradeoff note above.
149
+ - English only.
150
+
151
+ ## Reproducibility
152
+
153
+ Full pipeline (seed building, teacher generation config, filter, SFT prep, training, evaluation) lives at https://github.com/ventali/medical-distill. Commit [`547629f`](https://github.com/ventali/medical-distill/commit/547629f) records this adapter's metrics.
154
+
155
+ ## License
156
+
157
+ Inherits the [Llama 3.1 Community License](https://llama.meta.com/llama3_1/license/) from the base model.
adapter_config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alora_invocation_tokens": null,
3
+ "alpha_pattern": {},
4
+ "arrow_config": null,
5
+ "auto_mapping": null,
6
+ "base_model_name_or_path": "meta-llama/Llama-3.1-8B-Instruct",
7
+ "bias": "none",
8
+ "corda_config": null,
9
+ "ensure_weight_tying": false,
10
+ "eva_config": null,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_lora_weights": true,
15
+ "layer_replication": null,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "loftq_config": {},
19
+ "lora_alpha": 64,
20
+ "lora_bias": false,
21
+ "lora_dropout": 0.05,
22
+ "lora_ga_config": null,
23
+ "megatron_config": null,
24
+ "megatron_core": "megatron.core",
25
+ "modules_to_save": null,
26
+ "peft_type": "LORA",
27
+ "peft_version": "0.19.1",
28
+ "qalora_group_size": 16,
29
+ "r": 32,
30
+ "rank_pattern": {},
31
+ "revision": null,
32
+ "target_modules": [
33
+ "v_proj",
34
+ "up_proj",
35
+ "o_proj",
36
+ "q_proj",
37
+ "k_proj",
38
+ "down_proj",
39
+ "gate_proj"
40
+ ],
41
+ "target_parameters": null,
42
+ "task_type": "CAUSAL_LM",
43
+ "trainable_token_indices": null,
44
+ "use_bdlora": null,
45
+ "use_dora": false,
46
+ "use_qalora": false,
47
+ "use_rslora": false
48
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a846d9693f4bc75f02e0b9e7b846e50d37bca3656051eae0d1faf125bb0b9ee
3
+ size 335604696
chat_template.jinja ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{- bos_token }}
2
+ {%- if custom_tools is defined %}
3
+ {%- set tools = custom_tools %}
4
+ {%- endif %}
5
+ {%- if not tools_in_user_message is defined %}
6
+ {%- set tools_in_user_message = true %}
7
+ {%- endif %}
8
+ {%- if not date_string is defined %}
9
+ {%- set date_string = "26 Jul 2024" %}
10
+ {%- endif %}
11
+ {%- if not tools is defined %}
12
+ {%- set tools = none %}
13
+ {%- endif %}
14
+
15
+ {#- This block extracts the system message, so we can slot it into the right place. #}
16
+ {%- if messages[0]['role'] == 'system' %}
17
+ {%- set system_message = messages[0]['content']|trim %}
18
+ {%- set messages = messages[1:] %}
19
+ {%- else %}
20
+ {%- set system_message = "" %}
21
+ {%- endif %}
22
+
23
+ {#- System message + builtin tools #}
24
+ {{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
25
+ {%- if builtin_tools is defined or tools is not none %}
26
+ {{- "Environment: ipython\n" }}
27
+ {%- endif %}
28
+ {%- if builtin_tools is defined %}
29
+ {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
30
+ {%- endif %}
31
+ {{- "Cutting Knowledge Date: December 2023\n" }}
32
+ {{- "Today Date: " + date_string + "\n\n" }}
33
+ {%- if tools is not none and not tools_in_user_message %}
34
+ {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
35
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
36
+ {{- "Do not use variables.\n\n" }}
37
+ {%- for t in tools %}
38
+ {{- t | tojson(indent=4) }}
39
+ {{- "\n\n" }}
40
+ {%- endfor %}
41
+ {%- endif %}
42
+ {{- system_message }}
43
+ {{- "<|eot_id|>" }}
44
+
45
+ {#- Custom tools are passed in a user message with some extra guidance #}
46
+ {%- if tools_in_user_message and not tools is none %}
47
+ {#- Extract the first user message so we can plug it in here #}
48
+ {%- if messages | length != 0 %}
49
+ {%- set first_user_message = messages[0]['content']|trim %}
50
+ {%- set messages = messages[1:] %}
51
+ {%- else %}
52
+ {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
53
+ {%- endif %}
54
+ {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
55
+ {{- "Given the following functions, please respond with a JSON for a function call " }}
56
+ {{- "with its proper arguments that best answers the given prompt.\n\n" }}
57
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
58
+ {{- "Do not use variables.\n\n" }}
59
+ {%- for t in tools %}
60
+ {{- t | tojson(indent=4) }}
61
+ {{- "\n\n" }}
62
+ {%- endfor %}
63
+ {{- first_user_message + "<|eot_id|>"}}
64
+ {%- endif %}
65
+
66
+ {%- for message in messages %}
67
+ {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
68
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
69
+ {%- elif 'tool_calls' in message %}
70
+ {%- if not message.tool_calls|length == 1 %}
71
+ {{- raise_exception("This model only supports single tool-calls at once!") }}
72
+ {%- endif %}
73
+ {%- set tool_call = message.tool_calls[0].function %}
74
+ {%- if builtin_tools is defined and tool_call.name in builtin_tools %}
75
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
76
+ {{- "<|python_tag|>" + tool_call.name + ".call(" }}
77
+ {%- for arg_name, arg_val in tool_call.arguments | items %}
78
+ {{- arg_name + '="' + arg_val + '"' }}
79
+ {%- if not loop.last %}
80
+ {{- ", " }}
81
+ {%- endif %}
82
+ {%- endfor %}
83
+ {{- ")" }}
84
+ {%- else %}
85
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
86
+ {{- '{"name": "' + tool_call.name + '", ' }}
87
+ {{- '"parameters": ' }}
88
+ {{- tool_call.arguments | tojson }}
89
+ {{- "}" }}
90
+ {%- endif %}
91
+ {%- if builtin_tools is defined %}
92
+ {#- This means we're in ipython mode #}
93
+ {{- "<|eom_id|>" }}
94
+ {%- else %}
95
+ {{- "<|eot_id|>" }}
96
+ {%- endif %}
97
+ {%- elif message.role == "tool" or message.role == "ipython" %}
98
+ {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
99
+ {%- if message.content is mapping or message.content is iterable %}
100
+ {{- message.content | tojson }}
101
+ {%- else %}
102
+ {{- message.content }}
103
+ {%- endif %}
104
+ {{- "<|eot_id|>" }}
105
+ {%- endif %}
106
+ {%- endfor %}
107
+ {%- if add_generation_prompt %}
108
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
109
+ {%- endif %}
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:384a7e7c676f7be2e5d2e8449c508be9b00e5b18c5b3c39ebc626e96b3f4b988
3
+ size 17210019
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<|begin_of_text|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|eot_id|>",
6
+ "is_local": false,
7
+ "model_input_names": [
8
+ "input_ids",
9
+ "attention_mask"
10
+ ],
11
+ "model_max_length": 131072,
12
+ "pad_token": "<|eot_id|>",
13
+ "tokenizer_class": "TokenizersBackend"
14
+ }