Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| import io | |
| from PIL import Image | |
| import os | |
| from cryptography.fernet import Fernet | |
| from google.cloud import storage | |
| import pinecone | |
| import json | |
| import uuid | |
| import pandas as pd | |
| # decrypt Storage Cloud credentials | |
| fernet = Fernet(os.environ['DECRYPTION_KEY']) | |
| with open('cloud-storage.encrypted', 'rb') as fp: | |
| encrypted = fp.read() | |
| creds = json.loads(fernet.decrypt(encrypted).decode()) | |
| # then save creds to file | |
| with open('cloud-storage.json', 'w', encoding='utf-8') as fp: | |
| fp.write(json.dumps(creds, indent=4)) | |
| # connect to Cloud Storage | |
| os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json' | |
| storage_client = storage.Client() | |
| bucket = storage_client.get_bucket('hf-diffusion-images') | |
| # get api key for pinecone auth | |
| PINECONE_KEY = os.environ['PINECONE_KEY'] | |
| index_id = "hf-diffusion" | |
| # init connection to pinecone | |
| pinecone.init( | |
| api_key=PINECONE_KEY, | |
| environment="us-west1-gcp" | |
| ) | |
| if index_id not in pinecone.list_indexes(): | |
| raise ValueError(f"Index '{index_id}' not found") | |
| index = pinecone.Index(index_id) | |
| device = 'cpu' | |
| # init all of the models and move them to a given GPU | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "CompVis/stable-diffusion-v1-4", use_auth_token=os.environ['HF_AUTH'] | |
| ) | |
| pipe.to(device) | |
| missing_im = Image.open('missing.png') | |
| threshold = 0.85 | |
| def encode_text(text: str): | |
| text_inputs = pipe.tokenizer( | |
| text, return_tensors='pt' | |
| ).to(device) | |
| text_embeds = pipe.text_encoder(**text_inputs) | |
| text_embeds = text_embeds.pooler_output.cpu().tolist()[0] | |
| return text_embeds | |
| def prompt_query(text: str): | |
| embeds = encode_text(text) | |
| try: | |
| xc = index.query(embeds, top_k=30, include_metadata=True) | |
| except Exception as e: | |
| print(f"Error during query: {e}") | |
| # reinitialize connection | |
| pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') | |
| index2 = pinecone.Index(index_id) | |
| try: | |
| xc = index2.query(embeds, top_k=30, include_metadata=True) | |
| except Exception as e: | |
| raise ValueError(e) | |
| prompts = [ | |
| match['metadata']['prompt'] for match in xc['matches'] | |
| ] | |
| scores = [round(match['score'], 2) for match in xc['matches']] | |
| # deduplicate while preserving order | |
| df = pd.DataFrame({'Similarity': scores, 'Prompt': prompts}) | |
| df = df.drop_duplicates(subset='Prompt', keep='first') | |
| df = df[df['Prompt'].str.len() > 7].head() | |
| return df | |
| def diffuse(text: str): | |
| # diffuse | |
| out = pipe(text) | |
| if any(out.nsfw_content_detected): | |
| return {} | |
| else: | |
| _id = str(uuid.uuid4()) | |
| # add image to Cloud Storage | |
| im = out.images[0] | |
| im.save(f'{_id}.png', format='png') | |
| # push to storage | |
| blob = bucket.blob(f'images/{_id}.png') | |
| blob.upload_from_filename(f'{_id}.png') | |
| # delete local file | |
| os.remove(f'{_id}.png') | |
| # add embedding and metadata to Pinecone | |
| embeds = encode_text(text) | |
| meta = { | |
| 'prompt': text, | |
| 'image_url': f'images/{_id}.png' | |
| } | |
| index.upsert([(_id, embeds, meta)]) | |
| return out.images[0] | |
| def get_image(url: str): | |
| blob = bucket.blob(url).download_as_string() | |
| blob_bytes = io.BytesIO(blob) | |
| im = Image.open(blob_bytes) | |
| return im | |
| def test_image(_id, image): | |
| try: | |
| image.save('tmp.png') | |
| return True | |
| except OSError: | |
| # delete corrupted file from pinecone and cloud | |
| index.delete(ids=[_id]) | |
| bucket.blob(f"images/{_id}.png").delete() | |
| print(f"DELETED '{_id}'") | |
| return False | |
| def prompt_image(text: str): | |
| embeds = encode_text(text) | |
| try: | |
| xc = index.query(embeds, top_k=9, include_metadata=True) | |
| except Exception as e: | |
| print(f"Error during query: {e}") | |
| # reinitialize connection | |
| pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') | |
| index2 = pinecone.Index(index_id) | |
| try: | |
| xc = index2.query(embeds, top_k=9, include_metadata=True) | |
| except Exception as e: | |
| raise ValueError(e) | |
| image_urls = [ | |
| match['metadata']['image_url'] for match in xc['matches'] | |
| ] | |
| scores = [match['score'] for match in xc['matches']] | |
| ids = [match['id'] for match in xc['matches']] | |
| images = [] | |
| for _id, image_url in zip(ids, image_urls): | |
| try: | |
| blob = bucket.blob(image_url).download_as_string() | |
| blob_bytes = io.BytesIO(blob) | |
| im = Image.open(blob_bytes) | |
| if test_image(_id, im): | |
| images.append(im) | |
| else: | |
| images.append(missing_im) | |
| except ValueError: | |
| print(f"ValueError: '{image_url}'") | |
| return images, scores | |
| # __APP FUNCTIONS__ | |
| def set_suggestion(text: str): | |
| return gr.TextArea.update(value=text[0]) | |
| def set_images(text: str): | |
| images, scores = prompt_image(text) | |
| match_found = False | |
| for score in scores: | |
| if score > threshold: | |
| match_found = True | |
| if match_found: | |
| print("MATCH FOUND") | |
| return gr.Gallery.update(value=images) | |
| else: | |
| print("NO MATCH FOUND") | |
| diffuse(text) | |
| images, scores = prompt_image(text) | |
| return gr.Gallery.update(value=images) | |
| # __CREATE APP__ | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown( | |
| """ | |
| # Dream Cacher | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.TextArea( | |
| value="A person surfing", | |
| placeholder="Enter a prompt to dream about", | |
| interactive=True | |
| ) | |
| search = gr.Button(value="Search!") | |
| suggestions = gr.Dataframe( | |
| values=[], | |
| headers=['Similarity', 'Prompt'] | |
| ) | |
| # event listener for change in prompt | |
| prompt.change( | |
| prompt_query, prompt, suggestions, | |
| show_progress=False | |
| ) | |
| # results column | |
| with gr.Column(): | |
| pics = gr.Gallery() | |
| pics.style(grid=3) | |
| # search event listening | |
| try: | |
| search.click(set_images, prompt, pics) | |
| except OSError: | |
| print("OSError") | |
| demo.launch() |