| import os |
| import base64 |
| import torch |
| import gradio as gr |
| from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor |
| from qwen_vl_utils import process_vision_info |
| import cv2 |
|
|
|
|
| |
| model_directory = "THP2903/erax_llm" |
|
|
| |
| model = Qwen2VLForConditionalGeneration.from_pretrained( |
| model_directory, |
| torch_dtype=torch.bfloat16, |
| device_map="auto" |
| ) |
|
|
| |
| tokenizer = AutoTokenizer( |
| vocab_file=f"{model_directory}/vocab.json", |
| merges_file=f"{model_directory}/merges.txt", |
| tokenizer_file=f"{model_directory}/tokenizer.json" |
| ) |
|
|
| |
| processor_config = torch.load(f"{model_directory}/preprocessor_config.json") |
| processor = AutoProcessor.from_config(processor_config) |
|
|
| |
| generation_config = torch.load(f"{model_directory}/generation_config.json") |
| generation_config.do_sample = True |
| generation_config.temperature = 0.2 |
| generation_config.top_k = 1 |
| generation_config.top_p = 0.001 |
| generation_config.max_new_tokens = 2048 |
| generation_config.repetition_penalty = 1.1 |
|
|
| |
| def generate_description(image, prompt): |
| |
| _, encoded_image = cv2.imencode('.jpg', image) |
| encoded_image = base64.b64encode(encoded_image).decode('utf-8') |
| base64_data = f"data:image;base64,{encoded_image}" |
|
|
| |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "image", |
| "image": base64_data, |
| }, |
| { |
| "type": "text", |
| "text": prompt |
| }, |
| ], |
| } |
| ] |
|
|
| |
| tokenized_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| image_inputs, video_inputs = process_vision_info(messages) |
|
|
| |
| inputs = processor( |
| text=[tokenized_text], |
| images=image_inputs, |
| videos=video_inputs, |
| padding=True, |
| return_tensors="pt", |
| ).to("cuda") |
|
|
| |
| generated_ids = model.generate(**inputs, generation_config=generation_config) |
| generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] |
| output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
| return output_text[0] |
|
|
| |
| iface = gr.Interface( |
| fn=generate_description, |
| inputs=[ |
| gr.Image(type="numpy", label="Upload Image"), |
| gr.Textbox(lines=2, placeholder="Enter your prompt/question here", label="Prompt") |
| ], |
| outputs="text", |
| title="Image Description Generator", |
| description="Upload an image and enter a prompt/question to get a detailed description or answer based on the image." |
| ) |
|
|
| iface.launch() |
|
|