Upload sample_llama3-8B.py
Browse files- sample_llama3-8B.py +270 -0
sample_llama3-8B.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
| 3 |
+
import os
|
| 4 |
+
import fire
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from typing import Optional, List
|
| 10 |
+
from llama import Llama
|
| 11 |
+
from peft import PeftModel
|
| 12 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def extract_svg_from_text(text: str) -> Optional[str]:
|
| 18 |
+
"""
|
| 19 |
+
从包含SVG的文本中提取出完整的<svg>...</svg>结构。
|
| 20 |
+
如果未匹配到,则返回一个默认的空SVG。
|
| 21 |
+
"""
|
| 22 |
+
pattern = r"<svg\b[^>]*>.*?</svg>"
|
| 23 |
+
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
|
| 24 |
+
if matches:
|
| 25 |
+
return matches[0]
|
| 26 |
+
else:
|
| 27 |
+
return """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>"""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def code_style_prompt(desc: str) -> str:
|
| 31 |
+
return f"""\
|
| 32 |
+
// SVG CODE GENERATION TASK FOR CODELLAMA
|
| 33 |
+
// OBJECTIVE: Create simple yet accurate SVG contour drawing
|
| 34 |
+
// DESCRIPTION: {desc}
|
| 35 |
+
|
| 36 |
+
// SVG Example(DESCRIPTION=wheelchair)(you do not need to generate an example as well):
|
| 37 |
+
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
|
| 38 |
+
<!-- Wheelchair seat -->
|
| 39 |
+
<path d="M30,40 L50,40 L50,60 L30,60 Z" fill="#555"/>
|
| 40 |
+
|
| 41 |
+
<!-- Wheelchair back -->
|
| 42 |
+
<path d="M30,40 L20,30 L20,20 L30,20 L30,40" fill="#555"/>
|
| 43 |
+
|
| 44 |
+
<!-- Large wheel -->
|
| 45 |
+
<circle cx="65" cy="65" r="25" stroke="#333" stroke-width="3" fill="none"/>
|
| 46 |
+
<circle cx="65" cy="65" r="5" fill="#333"/>
|
| 47 |
+
|
| 48 |
+
<!-- Small wheel -->
|
| 49 |
+
<circle cx="30" cy="70" r="10" stroke="#333" stroke-width="3" fill="none"/>
|
| 50 |
+
<circle cx="30" cy="70" r="3" fill="#333"/>
|
| 51 |
+
|
| 52 |
+
<!-- Wheel spokes (large wheel) -->
|
| 53 |
+
<line x1="65" y1="65" x2="80" y2="65" stroke="#333" stroke-width="2"/>
|
| 54 |
+
<line x1="65" y1="65" x2="65" y2="80" stroke="#333" stroke-width="2"/>
|
| 55 |
+
<line x1="65" y1="65" x2="55" y2="75" stroke="#333" stroke-width="2"/>
|
| 56 |
+
<line x1="65" y1="65" x2="55" y2="55" stroke="#333" stroke-width="2"/>
|
| 57 |
+
|
| 58 |
+
<!-- Wheel spokes (small wheel) -->
|
| 59 |
+
<line x1="30" y1="70" x2="38" y2="70" stroke="#333" stroke-width="2"/>
|
| 60 |
+
<line x1="30" y1="70" x2="30" y2="78" stroke="#333" stroke-width="2"/>
|
| 61 |
+
</svg>
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
// CODE GENERATION INSTRUCTIONS:
|
| 65 |
+
1. Figure out the main parts of the object(animal) according to the DESCRIPTION
|
| 66 |
+
1. Fill path data for main-outline using basic commands
|
| 67 |
+
2. Position eye element at logical position
|
| 68 |
+
3. Keep all coordinates within viewBox
|
| 69 |
+
4. Use 2 decimal precision for coordinates
|
| 70 |
+
5. Close all path elements properly
|
| 71 |
+
|
| 72 |
+
// {desc} GENERATION START FROM HERE:
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def post_process(code: str) -> str:
|
| 77 |
+
"""针对代码模型的输出优化后处理"""
|
| 78 |
+
# 提取闭合的SVG代码块
|
| 79 |
+
svg_match = re.search(r'<svg.*?</svg>', code, re.DOTALL)
|
| 80 |
+
if svg_match:
|
| 81 |
+
code = svg_match.group(0)
|
| 82 |
+
|
| 83 |
+
# 确保XML声明
|
| 84 |
+
if '<?xml' not in code:
|
| 85 |
+
code = '<?xml version="1.0" encoding="UTF-8"?>\n' + code
|
| 86 |
+
|
| 87 |
+
# 验证必要元素
|
| 88 |
+
required_elements = {
|
| 89 |
+
'<svg': 1,
|
| 90 |
+
'</svg>': 1,
|
| 91 |
+
'<path': 1,
|
| 92 |
+
'<circle': 1
|
| 93 |
+
}
|
| 94 |
+
for elem, count in required_elements.items():
|
| 95 |
+
if code.count(elem) < count:
|
| 96 |
+
code = code.replace('</svg>',
|
| 97 |
+
f'<!-- Auto-added {elem} -->\n<{elem} />\n</svg>')
|
| 98 |
+
|
| 99 |
+
return code.strip()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def strict_svg_postprocess(raw_code: str) -> str:
|
| 103 |
+
"""
|
| 104 |
+
严格按照需求设计的SVG后处理器
|
| 105 |
+
|
| 106 |
+
处理逻辑:
|
| 107 |
+
1. 按行处理,找到第一个不以<svg开头的行作为内容起点
|
| 108 |
+
2. 逐行检查:去重(最多3次)、完整性、排除</svg>
|
| 109 |
+
3. 自动添加标准头尾
|
| 110 |
+
"""
|
| 111 |
+
# 预处理:清理前后空白,分割为行
|
| 112 |
+
lines = [line.strip() for line in raw_code.strip().split('\n')]
|
| 113 |
+
|
| 114 |
+
# 阶段1:找到有效内容起始行
|
| 115 |
+
start_index = 0
|
| 116 |
+
for i, line in enumerate(lines):
|
| 117 |
+
if not line.lower().startswith("<svg"):
|
| 118 |
+
start_index = i
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
# 阶段2:逐行处理有效内容
|
| 122 |
+
valid_lines = []
|
| 123 |
+
line_counter = {}
|
| 124 |
+
|
| 125 |
+
for line in lines[start_index:]:
|
| 126 |
+
# 排除</svg>标签
|
| 127 |
+
if re.match(r'</\s*svg\s*>', line, re.IGNORECASE):
|
| 128 |
+
continue
|
| 129 |
+
if re.match(r'<\s*svg', line, re.IGNORECASE):
|
| 130 |
+
continue
|
| 131 |
+
# 检查完整性(匹配XML标签语法)
|
| 132 |
+
is_valid_tag = re.fullmatch(
|
| 133 |
+
r'\s*<[^>]+/?>\s*',
|
| 134 |
+
line,
|
| 135 |
+
re.IGNORECASE
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# 检查是否已存在3次
|
| 139 |
+
count = line_counter.get(line, 0)
|
| 140 |
+
|
| 141 |
+
if is_valid_tag and count < 1:
|
| 142 |
+
valid_lines.append(line)
|
| 143 |
+
line_counter[line] = count + 1
|
| 144 |
+
|
| 145 |
+
# 阶段3:组装最终结果
|
| 146 |
+
core_content = '\n'.join(valid_lines)
|
| 147 |
+
|
| 148 |
+
return f'''<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
|
| 149 |
+
{core_content}
|
| 150 |
+
</svg>'''
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def load_label_names(json_path: str) -> dict:
|
| 154 |
+
"""加载标签映射表"""
|
| 155 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 156 |
+
data = json.load(f)
|
| 157 |
+
return data['dataset_info']['features'][0]['dtype']['class_label']['names']
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def main_infer( ):
|
| 161 |
+
# 初始化代码模型
|
| 162 |
+
# generator = Llama.build(
|
| 163 |
+
# ckpt_dir=ckpt_dir,
|
| 164 |
+
# tokenizer_path=tokenizer_path,
|
| 165 |
+
# max_seq_len=max_seq_len,
|
| 166 |
+
# max_batch_size=max_batch_size,
|
| 167 |
+
# )
|
| 168 |
+
# 加载基础模型(请根据具体模型名称或路径调整)
|
| 169 |
+
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",token="")
|
| 170 |
+
# 加载 LoRA 模型,加载 LoRA 权重(此处使用“steve329/llama3-8B-edit-lora-12k”)
|
| 171 |
+
model = PeftModel.from_pretrained(base_model, "steve329/llama3-8B-edit-lora-12k")
|
| 172 |
+
# 设置评估模式
|
| 173 |
+
model.eval()
|
| 174 |
+
# 加载对应的分词器(确保与基础模型匹配)
|
| 175 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
| 176 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 177 |
+
model.to(device)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
root = "/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean_llama8b"
|
| 181 |
+
for dir in os.listdir(root):
|
| 182 |
+
print(dir)
|
| 183 |
+
if (dir == "LICENSE-CODE") or (dir == "LICENSE-IMAGES") or (dir == "README.md") or (dir == "CaseGenerator.py") : continue
|
| 184 |
+
output_dir = os.path.jon(root+dir+'generated_svg') # 替换为实际的目标文件夹路径
|
| 185 |
+
|
| 186 |
+
# 确保目标文件夹存在,如果不存在则创建
|
| 187 |
+
if not os.path.exists(output_dir):
|
| 188 |
+
os.makedirs(output_dir)
|
| 189 |
+
|
| 190 |
+
file_dir = os.path.jon(root+dir+'query')
|
| 191 |
+
i=0
|
| 192 |
+
for file in os.listdir(file_dir):
|
| 193 |
+
|
| 194 |
+
file_name = os.path.splitext(file)[0]
|
| 195 |
+
file_path = os.path.join(file_dir, file)
|
| 196 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 197 |
+
content = file.read()
|
| 198 |
+
print(content)
|
| 199 |
+
if len(content) > 4383:
|
| 200 |
+
file_path = os.path.join(output_dir, file_name + '.svg')
|
| 201 |
+
|
| 202 |
+
# 将final_code写入到description.svg文件中
|
| 203 |
+
with open(file_path, 'w', encoding='utf-8') as svg_file:
|
| 204 |
+
svg_file.write("""<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>""")
|
| 205 |
+
|
| 206 |
+
print(f"SVG文件已保存至: {file_path}")
|
| 207 |
+
with open('/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean/'+dir+'/skipped_file.txt', 'w', encoding='utf-8') as f:
|
| 208 |
+
f.write(f"{file_name}"+"\n")
|
| 209 |
+
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
# test_input = (
|
| 213 |
+
# '{"instruction": "You are an expert SVG graphics generator. You generate clean, valid SVG code according to user instructions.", '
|
| 214 |
+
# f'"input": {content}'
|
| 215 |
+
# )
|
| 216 |
+
|
| 217 |
+
inputs = tokenizer(content, return_tensors="pt")
|
| 218 |
+
input_ids = inputs.input_ids.to(device)
|
| 219 |
+
attention_mask = inputs.attention_mask.to(device)
|
| 220 |
+
# 使用模型生成文本(可以根据需要调整生成参数)
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
generated_ids = model.generate(
|
| 223 |
+
input_ids,
|
| 224 |
+
attention_mask=attention_mask,
|
| 225 |
+
max_length=4096, # 指定生成文本的最大长度
|
| 226 |
+
do_sample=True, # 是否使用采样,True 可生成更多样化结果
|
| 227 |
+
top_k=50, # Top-K 采样参数
|
| 228 |
+
top_p=0.95 # Top-p (nucleus) 采样参数
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# 解码生成的 token 成为文本
|
| 232 |
+
generated_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
|
| 233 |
+
# print("Prompt:", prompt)
|
| 234 |
+
# print("Generated text:", generated_text)
|
| 235 |
+
# print("-" * 80)
|
| 236 |
+
# results = generator.text_completion(
|
| 237 |
+
# prompts=[content],
|
| 238 |
+
# max_gen_len=max_gen_len,
|
| 239 |
+
# temperature=temperature,
|
| 240 |
+
# top_p=top_p,
|
| 241 |
+
# )
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# 后处理
|
| 248 |
+
# raw_code = results[0]['generation']aa
|
| 249 |
+
print("raw_code:")
|
| 250 |
+
print(generated_text)
|
| 251 |
+
final_code = extract_svg_from_text(generated_text)
|
| 252 |
+
|
| 253 |
+
# 输出结果
|
| 254 |
+
print(f"\n=== Input: {file_name} ===")
|
| 255 |
+
print(f"// Generated SVG Code:")
|
| 256 |
+
print(final_code)
|
| 257 |
+
print("\n" + "=" * 40 + "\n")
|
| 258 |
+
|
| 259 |
+
# 定义SVG文件的完整路径
|
| 260 |
+
file_path = os.path.join(output_dir, file_name + '.svg')
|
| 261 |
+
|
| 262 |
+
# 将final_code写入到description.svg文件中
|
| 263 |
+
with open(file_path, 'w', encoding='utf-8') as svg_file:
|
| 264 |
+
svg_file.write(final_code)
|
| 265 |
+
|
| 266 |
+
print(f"SVG文件已保存至: {file_path}")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
main_infer()
|