| | from io import BytesIO |
| | from urllib.request import urlopen |
| | import soundfile |
| | import torch |
| | from datasets import load_dataset, Audio |
| | import numpy as np |
| | from transformers import AutoModel, AutoProcessor, BatchFeature,Gemma3ForCausalLM,Gemma3Processor |
| | from tqdm import tqdm |
| | import json |
| | import os |
| | import time |
| | from datetime import datetime |
| | from whisper_normalizer.english import EnglishTextNormalizer |
| | from whisper_normalizer.basic import BasicTextNormalizer |
| | import sacrebleu |
| | from jiwer import cer, wer |
| | from torch.utils.data import Dataset, DataLoader |
| | import soundfile as sf |
| | import re |
| | from pathlib import Path |
| | import opencc |
| | from ASRDataset import * |
| |
|
| | |
| |
|
| | model_id = "./" |
| | revision = "main" |
| |
|
| | processor = AutoProcessor.from_pretrained( |
| | model_id, revision = revision, trust_remote_code=True |
| | ) |
| |
|
| | results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| | |
| |
|
| | def eval_text(model,dataloader,with_input_mode=False,save_path="",start_idx=0): |
| | res = {'label':[],"pred":[],'cer':[]} |
| | func_error = 0 |
| | total_func_call = 0 |
| | total_error = 0 |
| | all_output_text = [] |
| | remove_sign = lambda x:x.replace('User transcribe is','').replace('GPT output is','').replace('\n','').\ |
| | replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\ |
| | replace('.','').replace('!','') |
| | for batch_idx, batch in enumerate(tqdm(dataloader)): |
| | if batch_idx<=start_idx:continue |
| | batch = {k: v.to("cuda") for k, v in batch.items() if type(v)!=type(None)} |
| | try: |
| | with torch.inference_mode(): |
| | if not with_input_mode: batch.pop('input_modes') |
| | generate_ids = model.generate(**batch, |
| | max_new_tokens=256, |
| | temperature = 0.001, top_p = 0.95, top_k = 64, do_sample=True |
| | ) |
| | batch_inputs = processor.batch_decode( |
| | generate_ids[:, :batch['input_ids'].shape[1]], skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False |
| | ) |
| | batch_predictions = processor.batch_decode( |
| | generate_ids[:, batch['input_ids'].shape[1]:], skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False |
| | ) |
| | batch_references = processor.batch_decode( |
| | batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=False |
| | ) |
| | for inp,label,output in zip(batch_inputs,batch_references,batch_predictions): |
| | |
| | cer_o = min(100,round(cer(re.sub(r"\s+", "", label), re.sub(r"\s+", "", output)) * 100, 2)) |
| | res['label'].append(batch_references) |
| | res['pred'].append(batch_predictions) |
| | res['cer'].append(cer_o) |
| | all_output_text.append({ |
| | 'input':inp, |
| | 'label':label, |
| | 'output':output, |
| | 'cer':cer_o, |
| | }) |
| | if 'Action:' in label: |
| | func_error+=(remove_sign(label)!=remove_sign(output)) |
| | total_func_call+=1 |
| | if batch_idx%100==0: |
| | with open(save_path,'w', encoding='utf-8') as f: |
| | json.dump(all_output_text,f, ensure_ascii=False, indent=4) |
| | avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text) |
| | total_error = sum(a['cer']!=0 for a in all_output_text) |
| | print('total',len(all_output_text)) |
| | print('total_error & rate',total_error,total_error/len(all_output_text)) |
| | print('avg_cer',avg_cer) |
| | print('total_func_call',total_func_call) |
| | print('func_error & rate',func_error,',',func_error/total_func_call) |
| | except: |
| | print("error at ",batch_idx) |
| | time.sleep(2) |
| | avg_cer = sum(a['cer'] for a in all_output_text)/len(all_output_text) |
| | total_error = sum(a['cer']!=0 for a in all_output_text) |
| | print('total',len(all_output_text)) |
| | print('total_error & rate',total_error,total_error/len(all_output_text)) |
| | print('avg_cer',avg_cer) |
| | print('total_func_call',total_func_call) |
| | print('func_error & rate',func_error,',',func_error/total_func_call) |
| | with open(save_path,'w', encoding='utf-8') as f: |
| | json.dump(all_output_text,f, ensure_ascii=False, indent=4) |
| | return res,all_output_text |
| |
|
| |
|
| | nav_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/nav_toolcall_train.json') |
| | ctrl_data = MultiturnAudioDataset(split='eval',text_only=True,processor=processor,json_path='/mnt/data-2t/jeff/codes/LLaMA-Factory/data/ctrl_toolcall_train.json') |
| | ctrl_dataloader = DataLoader(ctrl_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn) |
| | nav_dataloader = DataLoader(nav_data, batch_size=1, shuffle=False, collate_fn=covost_collate_fn) |
| |
|
| |
|
| |
|
| | from transformers import AutoProcessor, Gemma3ForConditionalGeneration |
| | from PIL import Image |
| | import requests |
| | import torch |
| |
|
| | model_id_org = "google/gemma-3-4b-it" |
| |
|
| | model_org = Gemma3ForConditionalGeneration.from_pretrained( |
| | model_id_org, device_map="auto",attn_implementation="eager" |
| | ).eval() |
| |
|
| | from peft import PeftModel |
| | model_org = PeftModel.from_pretrained(model_org, '/mnt/data-2t/jeff/codes/LLaMA-Factory/saves/Gemma-3-4B-Instruct/lora/train_123/checkpoint-3270') |
| |
|
| |
|
| |
|
| | res_org_nav,output_org_nav = eval_text(model_org,nav_dataloader,save_path='./output_org_nav_{}.json'.format(str(datetime.now())[:16])) |
| | res_org_ctrl,output_org_ctrl = eval_text(model_org,ctrl_dataloader,save_path='./output_org_ctrl_{}.json'.format(str(datetime.now())[:16])) |
| |
|
| |
|