{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: /opt/conda did not contain libcudart.so as expected! Searching further paths...\n", " warn(msg)\n", "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n", "The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. \n", "The class this function is called from is 'LlamaTokenizer'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n", "CUDA SETUP: Detected CUDA version 113\n", "CUDA SETUP: Loading binary /home/matt/hf/sqllama-V0/.venv/lib/python3.7/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f8ad2d1a5de842bcb6b7e3c6972d9074", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/33 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers import LlamaTokenizer, LlamaForCausalLM\n", "from peft import prepare_model_for_int8_training\n", "tokenizer = LlamaTokenizer.from_pretrained(\n", " \"decapoda-research/llama-7b-hf\")\n", " \n", "tokenizer.pad_token_id = 0\n", "tokenizer.padding_side = 'left'\n", "\n", "model = LlamaForCausalLM.from_pretrained(\n", " \"decapoda-research/llama-7b-hf\",\n", " load_in_8bit=True,\n", " device_map=\"auto\",\n", " torch_dtype=torch.float16\n", ")\n", "\n", "model = prepare_model_for_int8_training(model)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "table: 2-17672470-19\n", "columns: Stage,Winner,General Classification,Mountains Classification,Points Classification,Sprints classification,Team Classification\n", "Q: What is the stage of Gerolsteiner?\n", "A: SELECT Stage FROM 2-17672470-19 WHERE Team Classification = 'gerolsteiner'\n", "END\n", "\n", "\n", "table: 2-12518301-2\n", "columns: Rider,Matches,Rides,Bonus Pts,Total Points\n", "Q: What was the average number of points with bonus pts less than 31 with the rider dennis gavros?\n", "A: SELECT AVG Total Points FROM 2-12518301-2 WHERE Rider = 'dennis gavros' AND Bonus Pts < 31\n", "END\n", "\n", "\n", "table: 1-27961684-1\n", "columns: Institution,City,State,Team Name,Affiliation,Enrollment,Home Conference\n", "Q: How many states were there when there was an enrollment of 2789?\n", "A: SELECT COUNT State FROM 1-27961684-1 WHERE Enrollment = 2789\n", "END\n", "\n", "\n", "table: 2-17441442-2\n", "columns: Res.,Record,Opponent,Method,Event,Round,Time,Location\n", "Q: What is the round number when the record is 15–7–1?\n", "A: SELECT COUNT Round FROM 2-17441442-2 WHERE Record = '15–7–1'\n", "END\n", "\n", "\n", "table: 2-17406982-1\n", "columns: Round,Pick,Player,Position,School/Club Team\n", "Q: What pick in round 5 did the 49ers pick Jim Pilot?\n", "A: SELECT SUM Pick FROM 2-17406982-1 WHERE Player = 'jim pilot' AND Round > 5\n", "END\n", "\n" ] } ], "source": [ "import random\n", "import json\n", "\n", "# defined by WikiSQL\n", "\n", "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n", "cond_ops = ['=', '>', '<', 'OP']\n", "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n", "\n", "def fix_repr(d,cols,types,tid):\n", " sel_index=d['sel'] \n", " agg_index=d['agg']\n", " conditions=d['conds']\n", " col = cols[sel_index]\n", " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n", " agg=agg_ops[agg_index],\n", " sel=col,\n", " tid=tid\n", " )\n", " if conditions:\n", " cs = []\n", " for i, o, v in conditions:\n", " #print(i,cols)\n", " nm = cols[i]\n", " op = cond_ops[o]\n", " \n", " if types[i] in ['text']:\n", " val = f\"\\'{v}\\'\"\n", " else:\n", " val = v\n", " cs.append(f'{nm} {op} {val}')\n", " #print(cs)\n", "\n", " rep += ' WHERE ' + ' AND '.join(cs)\n", " \n", " return rep\n", "\n", "tbl_cols = {}\n", "tbl_types = {}\n", "tbl_str = {}\n", "\n", "prefix = 'Below is a question that describes a data request, paired with an input that describes a SQL table. Write a SQL query that retrieves the data.'\n", "\n", "def tbl_def_to_string(id, header, types):\n", " s = f'table: {id}\\ncolumns: ' + ','.join(header)\n", " return s\n", "\n", "with open('data/train.tables.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['id']\n", " hdr = js['header']\n", " ts = js['types']\n", " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n", " tbl_cols[id] = hdr\n", " tbl_types[id] = ts\n", "\n", "q_s = []\n", "a_s = []\n", "\n", "with open('data/train.jsonl') as f:\n", " for line in f:\n", " js = json.loads(line)\n", " id = js['table_id']\n", " s = tbl_str[id]\n", " qst = js['question']\n", " nl = s + '\\nQ: ' + qst + '\\nA: '\n", " q_s.append(nl)\n", "\n", " sql = js['sql']\n", " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n", " a = a + \"\\nEND\\n\"\n", " a_s.append(a)\n", "\n", "M = len(q_s)\n", "\n", "data_txt = [q_s[i] + a_s[i] for i in range(M)]\n", "\n", "for i in range(5):\n", " j = random.randint(0,M-1)\n", " print()\n", " print(data_txt[j]) \n", " \n", " " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "toks = [tokenizer(s) for s in data_txt]\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "92\n", " 0\n", "count 56355.000000\n", "mean 101.219519\n", "std 21.740325\n", "min 63.000000\n", "25% 87.500000\n", "50% 97.000000\n", "75% 109.000000\n", "max 461.000000\n", "32084\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "print(len(toks[0].input_ids))\n", "lens = np.array([len(tok.input_ids) for tok in toks])\n", "print(pd.DataFrame(lens).describe())\n", "\n", "z = zip(q_s,lens)\n", "q_red = [a for a,b in z if b < 100]\n", "z = zip(a_s,lens)\n", "a_red = [a for a,b in z if b < 100]\n", "\n", "data_red = [q_red[i] + a_red[i] for i in range(len(q_red))]\n", "print(len(data_red))\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "708e075933754c6c940eeae9e3d3abc9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/32084 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import random, datasets\n", "#d = {'prompt': random.sample(data_red, 1000)}\n", "d = {'prompt': data_red}\n", "\n", "data = datasets.Dataset.from_dict(d)\n", "data = data.map(lambda x:\n", " tokenizer(\n", " x['prompt'],\n", " truncation=True,\n", " max_length=100,\n", " padding=\"max_length\"\n", " ))\n", "\n", "data = data.remove_columns('prompt')\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from peft import LoraConfig, get_peft_model\n", "import transformers\n", "import datasets\n", "\n", "LORA_R = 4\n", "LORA_ALPHA = 16\n", "LORA_DROPOUT = .1\n", "BATCH = 128\n", "MICRO_BATCH = 4\n", "N_GAS = BATCH//MICRO_BATCH\n", "EPOCHS = 2\n", "LR = 1e-5\n", "\n", "lora_cfg = LoraConfig(\n", " r = LORA_R,\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT,\n", " task_type='CASUAL_LM',\n", " target_modules=['q_proj','v_proj']\n", ")\n", "\n", "model = get_peft_model(model,lora_cfg)\n", "\n", "targs = transformers.TrainingArguments(\n", " per_device_train_batch_size=MICRO_BATCH,\n", " gradient_accumulation_steps=N_GAS,\n", " warmup_steps=0,\n", " num_train_epochs=EPOCHS,\n", " learning_rate=LR,\n", " fp16=True,\n", " logging_steps=1,\n", " output_dir='sqllama-out3',\n", " save_total_limit=3,\n", " remove_unused_columns=False\n", ")\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
| Step | \n", "Training Loss | \n", "
|---|---|
| 1 | \n", "2.748800 | \n", "
| 2 | \n", "2.725100 | \n", "
"
],
"text/plain": [
"