import gradio as gr import torch import urllib.request import urllib.parse import xml.etree.ElementTree as ET from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import os # Base model to use BASE_MODEL_NAME = "google/gemma-2-2b" # Check if a custom adapter is uploaded or default to base ADAPTER_MODEL_NAME = os.environ.get("ADAPTER_MODEL_ID", "") HF_TOKEN = os.environ.get("HF_TOKEN") print("Initializing tokenizer...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=HF_TOKEN) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Initializing model...") # Free spaces run on CPU. We load in float32 and use low_cpu_mem_usage to fit in CPU RAM. model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_NAME, torch_dtype=torch.float32, low_cpu_mem_usage=True, token=HF_TOKEN ) if ADAPTER_MODEL_NAME: print(f"Loading custom PEFT adapter: {ADAPTER_MODEL_NAME}...") try: model = PeftModel.from_pretrained(model, ADAPTER_MODEL_NAME) except Exception as e: print(f"Error loading PEFT adapter: {e}. Using base model instead.") def fetch_arxiv_papers(topic, max_results=3): try: query = urllib.parse.quote(topic) url = f"http://export.arxiv.org/api/query?search_query=all:{query}&max_results={max_results}" req = urllib.request.Request( url, headers={'User-Agent': 'Mozilla/5.0'} ) with urllib.request.urlopen(req, timeout=10) as response: xml_data = response.read() root = ET.fromstring(xml_data) ns = {'atom': 'http://www.w3.org/2005/Atom'} papers = [] for entry in root.findall('atom:entry', ns): title_node = entry.find('atom:title', ns) summary_node = entry.find('atom:summary', ns) id_node = entry.find('atom:id', ns) title = title_node.text.strip().replace('\n', ' ') if title_node is not None else "Unknown Title" summary = summary_node.text.strip().replace('\n', ' ') if summary_node is not None else "" id_url = id_node.text.strip() if id_node is not None else "" authors = [] for author in entry.findall('atom:author', ns): name_node = author.find('atom:name', ns) if name_node is not None: authors.append(name_node.text.strip()) papers.append({ 'title': title, 'authors': authors, 'summary': summary, 'url': id_url }) return papers except Exception as e: print(f"Error fetching from arXiv: {e}") return [] def summarize_topic(topic): if not topic.strip(): return "Please enter a valid topic.", "" papers = fetch_arxiv_papers(topic, max_results=3) if papers: context_str = "" for i, paper in enumerate(papers, 1): context_str += f"Paper {i}: {paper['title']} by {', '.join(paper['authors'])}\nAbstract: {paper['summary']}\n\n" formatted_prompt = f"Document:\nTopic: {topic}\n\nRelevant Literature:\n{context_str}Based on the above papers, provide key points and an overview of the research on this topic.\n\nSummary:\n" else: formatted_prompt = f"Document:\nTopic: {topic}\nProvide key points and overview of research papers associated with this topic.\n\nSummary:\n" inputs = tokenizer(formatted_prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) if "Summary:" in decoded: response = decoded.split("Summary:")[-1].strip() else: response = decoded[len(formatted_prompt):].strip() # Format References refs_output = "" if papers: for i, paper in enumerate(papers, 1): authors_str = ", ".join(paper['authors']) if paper['authors'] else "Unknown Authors" refs_output += f"**[{i}] {paper['title']}**\n" refs_output += f"*Authors:* {authors_str}\n" refs_output += f"*URL:* [{paper['url']}]({paper['url']})\n\n" else: refs_output = "No external papers could be retrieved." return response, refs_output # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔬 Scientific Research Paper Summarizer") gr.Markdown( "Enter any scientific topic or research question. The assistant will fetch relevant papers from " "arXiv in real-time, synthesize the core findings using our fine-tuned Gemma model, and provide direct references." ) with gr.Row(): with gr.Column(): topic_input = gr.Textbox( label="Research Topic / Keywords", placeholder="e.g. quantum machine learning, artificial intelligence in healthcare..." ) submit_btn = gr.Button("Generate Overview", variant="primary") with gr.Column(): answer_output = gr.Textbox(label="Model Overview & Points", interactive=False) refs_output = gr.Markdown(label="References Cited") submit_btn.click( fn=summarize_topic, inputs=[topic_input], outputs=[answer_output, refs_output] ) if __name__ == "__main__": demo.launch()