hwding commited on
Commit
b9bae32
·
verified ·
1 Parent(s): 793a0f4

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +83 -21
handler.py CHANGED
@@ -1,12 +1,17 @@
1
- from typing import Dict, Any
2
  import torch
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  from peft import PeftModel
5
 
 
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, path: str = ""):
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
 
11
  bnb_config = BitsAndBytesConfig(
12
  load_in_4bit=True,
@@ -27,36 +32,38 @@ class EndpointHandler:
27
  self.model = PeftModel.from_pretrained(self.model, path)
28
  self.model.eval()
29
 
30
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
31
- inputs = data.get("inputs", "")
32
- parameters = data.get("parameters", {})
33
 
34
- max_new_tokens = parameters.get("max_new_tokens", 512)
35
- temperature = parameters.get("temperature", 0.7)
36
- top_p = parameters.get("top_p", 0.95)
37
- do_sample = parameters.get("do_sample", True)
 
 
 
 
 
 
 
38
 
39
- if not inputs.startswith("### System:"):
40
- prompt = f"""### System:
41
- You are an expert Minecraft Forge mod developer for version 1.21.11. Write clean, efficient, and well-structured Java code.
42
-
43
- ### User:
44
- {inputs}
45
-
46
- ### Assistant:
47
- """
48
- else:
49
- prompt = inputs
50
 
 
 
 
 
51
  input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
52
 
53
  with torch.no_grad():
54
  outputs = self.model.generate(
55
  **input_ids,
56
  max_new_tokens=max_new_tokens,
57
- temperature=temperature,
58
  top_p=top_p,
59
- do_sample=do_sample,
60
  pad_token_id=self.tokenizer.eos_token_id,
61
  )
62
 
@@ -65,4 +72,59 @@ You are an expert Minecraft Forge mod developer for version 1.21.11. Write clean
65
  if "### Assistant:" in generated_text:
66
  generated_text = generated_text.split("### Assistant:")[-1].strip()
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return {"generated_text": generated_text}
 
1
+ from typing import Dict, Any, List
2
  import torch
3
+ import time
4
+ import uuid
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
  from peft import PeftModel
7
 
8
+ DEFAULT_SYSTEM_PROMPT = "You are an expert Minecraft Forge mod developer for version 1.21.11. Write clean, efficient, and well-structured Java code."
9
+
10
 
11
  class EndpointHandler:
12
  def __init__(self, path: str = ""):
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ self.model_id = "hwding/forge-coder-v1.21.11"
15
 
16
  bnb_config = BitsAndBytesConfig(
17
  load_in_4bit=True,
 
32
  self.model = PeftModel.from_pretrained(self.model, path)
33
  self.model.eval()
34
 
35
+ def _format_messages(self, messages: List[Dict[str, str]]) -> str:
36
+ prompt_parts = []
37
+ has_system = False
38
 
39
+ for msg in messages:
40
+ role = msg.get("role", "")
41
+ content = msg.get("content", "")
42
+
43
+ if role == "system":
44
+ prompt_parts.append(f"### System:\n{content}")
45
+ has_system = True
46
+ elif role == "user":
47
+ prompt_parts.append(f"### User:\n{content}")
48
+ elif role == "assistant":
49
+ prompt_parts.append(f"### Assistant:\n{content}")
50
 
51
+ if not has_system:
52
+ prompt_parts.insert(0, f"### System:\n{DEFAULT_SYSTEM_PROMPT}")
 
 
 
 
 
 
 
 
 
53
 
54
+ prompt_parts.append("### Assistant:\n")
55
+ return "\n\n".join(prompt_parts)
56
+
57
+ def _generate(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
58
  input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
59
 
60
  with torch.no_grad():
61
  outputs = self.model.generate(
62
  **input_ids,
63
  max_new_tokens=max_new_tokens,
64
+ temperature=temperature if temperature > 0 else 1.0,
65
  top_p=top_p,
66
+ do_sample=temperature > 0,
67
  pad_token_id=self.tokenizer.eos_token_id,
68
  )
69
 
 
72
  if "### Assistant:" in generated_text:
73
  generated_text = generated_text.split("### Assistant:")[-1].strip()
74
 
75
+ return generated_text
76
+
77
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
78
+ messages = data.get("messages")
79
+ if messages:
80
+ return self._handle_openai_format(data)
81
+ return self._handle_simple_format(data)
82
+
83
+ def _handle_openai_format(self, data: Dict[str, Any]) -> Dict[str, Any]:
84
+ messages = data.get("messages", [])
85
+ max_tokens = data.get("max_tokens", 512)
86
+ temperature = data.get("temperature", 0.7)
87
+ top_p = data.get("top_p", 0.95)
88
+
89
+ prompt = self._format_messages(messages)
90
+ generated_text = self._generate(prompt, max_tokens, temperature, top_p)
91
+
92
+ prompt_tokens = len(self.tokenizer.encode(prompt))
93
+ completion_tokens = len(self.tokenizer.encode(generated_text))
94
+
95
+ return {
96
+ "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
97
+ "object": "chat.completion",
98
+ "created": int(time.time()),
99
+ "model": self.model_id,
100
+ "choices": [{
101
+ "index": 0,
102
+ "message": {
103
+ "role": "assistant",
104
+ "content": generated_text,
105
+ },
106
+ "finish_reason": "stop",
107
+ }],
108
+ "usage": {
109
+ "prompt_tokens": prompt_tokens,
110
+ "completion_tokens": completion_tokens,
111
+ "total_tokens": prompt_tokens + completion_tokens,
112
+ }
113
+ }
114
+
115
+ def _handle_simple_format(self, data: Dict[str, Any]) -> Dict[str, Any]:
116
+ inputs = data.get("inputs", "")
117
+ parameters = data.get("parameters", {})
118
+
119
+ max_new_tokens = parameters.get("max_new_tokens", 512)
120
+ temperature = parameters.get("temperature", 0.7)
121
+ top_p = parameters.get("top_p", 0.95)
122
+
123
+ if not inputs.startswith("### System:"):
124
+ prompt = f"### System:\n{DEFAULT_SYSTEM_PROMPT}\n\n### User:\n{inputs}\n\n### Assistant:\n"
125
+ else:
126
+ prompt = inputs
127
+
128
+ generated_text = self._generate(prompt, max_new_tokens, temperature, top_p)
129
+
130
  return {"generated_text": generated_text}