yehaochen commited on
Commit
bbe1e5b
·
1 Parent(s): 262600b

[test] update time series test scripts

Browse files
Files changed (2) hide show
  1. 0092638_seism.npy +3 -0
  2. test_inference_ts.py +78 -0
0092638_seism.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2b94653c6964b630038897a27cb6d276ff866d9ecd1f6419358b9407f0df62e
3
+ size 72128
test_inference_ts.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor
4
+
5
+
6
+ model_path = Path(__file__).parent.resolve()
7
+ print(f"Loading model from: {model_path}")
8
+
9
+ # 加载模型配置
10
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
11
+ print(f"Model config: {config.model_type}")
12
+ print(f"Architecture: {config.architectures}")
13
+
14
+ # 加载处理器(tokenizer + image processor + ts processor)
15
+ print("\nLoading processor...")
16
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
17
+
18
+ # 加载模型(使用 bfloat16 精度和自动设备映射)
19
+ print("\nLoading model...")
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_path,
22
+ dtype=torch.bfloat16,
23
+ device_map="auto",
24
+ # attn_implementation="flash_attention_2", #时序暂不支持flash_attn,load加这行会报错
25
+ trust_remote_code=True
26
+ )
27
+
28
+ print(f"✓ Model loaded successfully!")
29
+ print(f"Model type: {type(model).__name__}")
30
+ print(f"Model device: {model.device}")
31
+
32
+ # ============================================================================
33
+ # 测试 3: 时序对话
34
+ # ============================================================================
35
+ print("\n" + "=" * 80)
36
+ print("测试 3: 时序对话")
37
+ print("=" * 80)
38
+
39
+ messages = [
40
+ {
41
+ "role": "user",
42
+ "content": [
43
+ {"type": "time_series", "data": "./0092638_seism.npy", "sampling_rate": 100},
44
+ {"type": "text", "text": "Please determine whether an Earthquake event has occurred in the provided time-series data. If so, please specify the starting time point indices of the P-wave and S-wave in the event."},
45
+ ],
46
+ }
47
+ ]
48
+
49
+ time_series_inputs = processor.time_series_preprocessor(messages)
50
+ multimodal_inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", enable_thinking=False, **time_series_inputs).to(model.device, dtype=torch.bfloat16)
51
+
52
+ print("\n生成时序回复...")
53
+ with torch.inference_mode():
54
+ multimodal_generated_ids = model.generate(
55
+ **multimodal_inputs,
56
+ max_new_tokens=200,
57
+ do_sample=False,
58
+ temperature=1.0,
59
+ )
60
+
61
+ # 提取生成的 token(去除输入部分)
62
+ multimodal_generated_ids_trimmed = [
63
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(multimodal_inputs.input_ids, multimodal_generated_ids)
64
+ ]
65
+
66
+ # 解码为文本
67
+ multimodal_output = processor.batch_decode(
68
+ multimodal_generated_ids_trimmed,
69
+ skip_special_tokens=True,
70
+ clean_up_tokenization_spaces=False
71
+ )
72
+
73
+ print("\n" + "-" * 80)
74
+ print("时序输出:")
75
+ print("-" * 80)
76
+ print(multimodal_output[0])
77
+ print("-" * 80)
78
+ print("\n✅ 时序功能测试完成!")