--- base_model: Qwen/Qwen2.5-3B-Instruct library_name: transformers model_name: Qwen2.5-3B-Inst-SQL-Reasoning-GRPO tags: - trl - grpo licence: license license: apache-2.0 language: - en --- # Qwen-2.5-3B-Instruct Based Text-to-SQL Generation Model Aligned with Multiple Reward Functions via GRPO This model is RL-tuned using GRPO to produce Reasoning based SQL Queries as an output. You can use the same `system` prompt or modify as needed. Just by entering the `SCHEMAS` and `QUESTION` in the format below as part of the `user` prompt, you'll be able to generate the required SQL Query that answers the `question` along with the model's reasoning traces. ## Quick start ```python import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length=2560) model = PeftModel.from_pretrained(model, "DeathReaper0965/Qwen2.5-3B-Inst-SQL-Reasoning-GRPO", is_trainable=False) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct", max_length = 2560) def create_prompt(schemas, question): prompt = [ { 'role': 'system', 'content': """\ You are an expert SQL Query Writer. Given relevant Schemas and the Question, you first understand the problem entirely and then reason about the best possible approach to come up with an answer. Once, you are confident in your reasoning, you will then start generating the SQL Query as the answer that accurately solves the given question leveraging some or all schemas. Remember that you should place all your reasoning between and tags. Also, you should provide your solution between and tags. An example generation is as follows: This is a sample reasoning that solves the question based on the schema. SELECT COLUMN FROM TABLE_NAME WHERE CONDITION """ }, { 'role': 'user', 'content': f"""\ SCHEMAS: --------------- {schemas} --------------- QUESTION: "{question}"\ """ } ] return prompt schemas = """\ CREATE TABLE lab ( subject_id text, hadm_id text, itemid int, charttime date, flag bool, value_unit int, label text, fluid text ) CREATE TABLE diagnoses ( subject_id text, hadm_id text, icd9_code text, short_title text, long_title text ) CREATE TABLE procedures ( subject_id text, hadm_id text, icd9_code text, short_title text, long_title text ) CREATE TABLE demographic ( subject_id text, hadm_id text, name text, marital_status text, age int, dob date, gender text, language text, religion text, admission_type text, days_stay text, insurance text, ethnicity text, expire_flag bool, admission_location text, discharge_location text, diagnosis text, dod date, dob_year date, dod_year date, admittime date, dischtime date, admityear int ) CREATE TABLE prescriptions ( subject_id text, hadm_id text, icustay_id text, drug_type text, drug text, formulary_drug_cd text, route text, drug_dose text )\ """ question = "How many patients whose admission type is emergency and diagnoses icd9 code is 56210?" example_prompt = create_prompt(schemas, question) streamer = TextStreamer(tokenizer, skip_prompt=True) inputs = tokenizer.apply_chat_template(example_prompt, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt") with torch.inference_mode(): outputs = model.generate(**inputs, max_new_tokens=1024, streamer=streamer) outputs = tokenizer.batch_decode(outputs) print(outputs[0].split("<|im_start|>assistant")[-1]) ###########OUTPUT########### To answer this question, we need to perform the following steps: 1. Identify patients who have an 'emergency' admission type from the `demographic` table. 2. Identify patients who have the ICD-9 code '56210' in their `diagnosis` field from the same `demographic` table. 3. Find the intersection of these two groups by joining the results of the above queries. 4. Count the number of unique patients who meet both criteria. We can achieve this using a combination of JOIN operations in our SQL query. SELECT COUNT(DISTINCT d.subject_id) FROM demographic AS d JOIN diagnoses AS di ON d.subject_id = di.subject_id AND d.hadm_id = di.hadm_id WHERE d.admission_type = 'Emergency' AND di.icd9_code = '56210' ``` > Designed and Developed with by [Praneet](https://deathreaper0965.github.io/) | [LinkedIn](http://linkedin.com/in/deathreaper0965) | [GitHub](https://github.com/DeathReaper0965/)