NightRaven109 commited on
Commit
789b504
·
verified ·
1 Parent(s): 4cf6eef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -8
app.py CHANGED
@@ -16,6 +16,7 @@ from gradio.utils import get_cache_folder
16
  from infer import lotus, lotus_video
17
  import transformers
18
  from huggingface_hub import login
 
19
 
20
  transformers.utils.move_cache()
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -27,18 +28,206 @@ def apply_gaussian_blur(image, radius=1.0):
27
  """Apply Gaussian blur to PIL Image with specified radius"""
28
  return image.filter(ImageFilter.GaussianBlur(radius=radius))
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def infer(path_input, seed=None):
31
  name_base, name_ext = os.path.splitext(os.path.basename(path_input))
32
  _, output_d = lotus(path_input, 'depth', seed, device)
33
-
34
  # Apply Gaussian blur with 0.75 radius
35
  output_d = apply_gaussian_blur(output_d, radius=0.75)
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  if not os.path.exists("files/output"):
38
  os.makedirs("files/output")
39
  d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}")
 
 
40
  output_d.save(d_save_path)
41
- return [path_input, d_save_path]
 
 
42
 
43
  def infer_video(path_input, seed=None):
44
  _, frames_d, fps = lotus_video(path_input, 'depth', seed, device)
@@ -70,7 +259,7 @@ def run_demo_server():
70
 
71
  with gr.Blocks(
72
  theme=gradio_theme,
73
- title="LOTUS (Depth - Discriminative)",
74
  css="""
75
  #download {
76
  height: 118px;
@@ -133,6 +322,11 @@ def run_demo_server():
133
  elem_classes="slider",
134
  position=0.25,
135
  )
 
 
 
 
 
136
 
137
  gr.Examples(
138
  fn=infer_gpu,
@@ -141,7 +335,7 @@ def run_demo_server():
141
  for name in os.listdir(os.path.join("files", "images"))
142
  ]),
143
  inputs=[image_input],
144
- outputs=[image_output_d],
145
  cache_examples=False,
146
  )
147
 
@@ -182,13 +376,13 @@ def run_demo_server():
182
  image_submit_btn.click(
183
  fn=infer_gpu,
184
  inputs=[image_input],
185
- outputs=[image_output_d],
186
  concurrency_limit=1,
187
  )
188
  image_reset_btn.click(
189
- fn=lambda: None,
190
  inputs=[],
191
- outputs=[image_output_d],
192
  queue=False,
193
  )
194
 
 
16
  from infer import lotus, lotus_video
17
  import transformers
18
  from huggingface_hub import login
19
+ import cv2
20
 
21
  transformers.utils.move_cache()
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
28
  """Apply Gaussian blur to PIL Image with specified radius"""
29
  return image.filter(ImageFilter.GaussianBlur(radius=radius))
30
 
31
+ class NormalMapSimple:
32
+ @classmethod
33
+ def INPUT_TYPES(s):
34
+ return {
35
+ "required": {
36
+ "images": ("IMAGE",),
37
+ "scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
38
+ },
39
+ }
40
+
41
+ RETURN_TYPES = ("IMAGE",)
42
+ FUNCTION = "normal_map"
43
+
44
+ CATEGORY = "image/filters"
45
+
46
+ def normal_map(self, images, scale_XY):
47
+ t = images.detach().clone().cpu().numpy().astype(np.float32)
48
+ L = np.mean(t[:,:,:,:3], axis=3)
49
+ for i in range(t.shape[0]):
50
+ t[i,:,:,0] = cv2.Scharr(L[i], -1, 1, 0, cv2.BORDER_REFLECT) * -1
51
+ t[i,:,:,1] = cv2.Scharr(L[i], -1, 0, 1, cv2.BORDER_REFLECT)
52
+ t[:,:,:,2] = 1
53
+ t = torch.from_numpy(t)
54
+ t[:,:,:,:2] *= scale_XY
55
+ t[:,:,:,:3] = torch.nn.functional.normalize(t[:,:,:,:3], dim=3) / 2 + 0.5
56
+ return (t,)
57
+
58
+ class ConvertNormals:
59
+ @classmethod
60
+ def INPUT_TYPES(s):
61
+ return {
62
+ "required": {
63
+ "normals": ("IMAGE",),
64
+ "input_mode": (["BAE", "MiDaS", "Standard", "DirectX"],),
65
+ "output_mode": (["BAE", "MiDaS", "Standard", "DirectX"],),
66
+ "scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
67
+ "normalize": ("BOOLEAN", {"default": True}),
68
+ "fix_black": ("BOOLEAN", {"default": True}),
69
+ },
70
+ "optional": {
71
+ "optional_fill": ("IMAGE",),
72
+ },
73
+ }
74
+
75
+ RETURN_TYPES = ("IMAGE",)
76
+ FUNCTION = "convert_normals"
77
+
78
+ CATEGORY = "image/filters"
79
+
80
+ def convert_normals(self, normals, input_mode, output_mode, scale_XY, normalize, fix_black, optional_fill=None):
81
+ try:
82
+ t = normals.detach().clone()
83
+
84
+ if input_mode == "BAE":
85
+ t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
86
+ elif input_mode == "MiDaS":
87
+ t[:,:,:,:3] = torch.stack([1 - t[:,:,:,2], t[:,:,:,1], t[:,:,:,0]], dim=3) # BGR -> RGB and invert R
88
+ elif input_mode == "DirectX":
89
+ t[:,:,:,1] = 1 - t[:,:,:,1] # invert G
90
+
91
+ if fix_black:
92
+ key = torch.clamp(1 - t[:,:,:,2] * 2, min=0, max=1)
93
+ if optional_fill is None:
94
+ t[:,:,:,0] += key * 0.5
95
+ t[:,:,:,1] += key * 0.5
96
+ t[:,:,:,2] += key
97
+ else:
98
+ fill = optional_fill.detach().clone()
99
+ if fill.shape[1:3] != t.shape[1:3]:
100
+ fill = torch.nn.functional.interpolate(fill.movedim(-1,1), size=(t.shape[1], t.shape[2]), mode='bilinear').movedim(1,-1)
101
+ if fill.shape[0] != t.shape[0]:
102
+ fill = fill[0].unsqueeze(0).expand(t.shape[0], -1, -1, -1)
103
+ t[:,:,:,:3] += fill[:,:,:,:3] * key.unsqueeze(3).expand(-1, -1, -1, 3)
104
+
105
+ t[:,:,:,:2] = (t[:,:,:,:2] - 0.5) * scale_XY + 0.5
106
+
107
+ if normalize:
108
+ # Transform to [-1, 1] range
109
+ t_norm = t[:,:,:,:3] * 2 - 1
110
+
111
+ # Calculate the length of each vector
112
+ lengths = torch.sqrt(torch.sum(t_norm**2, dim=3, keepdim=True))
113
+
114
+ # Avoid division by zero
115
+ lengths = torch.clamp(lengths, min=1e-6)
116
+
117
+ # Normalize each vector to unit length
118
+ t_norm = t_norm / lengths
119
+
120
+ # Transform back to [0, 1] range
121
+ t[:,:,:,:3] = (t_norm + 1) / 2
122
+
123
+ if output_mode == "BAE":
124
+ t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
125
+ elif output_mode == "MiDaS":
126
+ t[:,:,:,:3] = torch.stack([t[:,:,:,2], t[:,:,:,1], 1 - t[:,:,:,0]], dim=3) # invert R and BGR -> RGB
127
+ elif output_mode == "DirectX":
128
+ t[:,:,:,1] = 1 - t[:,:,:,1] # invert G
129
+
130
+ return (t,)
131
+ except Exception as e:
132
+ print(f"Error in convert_normals: {str(e)}")
133
+ return (normals,)
134
+
135
+ def get_image_intensity(img, gamma_correction=1.0):
136
+ """
137
+ Extract intensity map from an image using HSV color space
138
+ """
139
+ # Convert to HSV color space
140
+ result = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
141
+ # Extract Value channel (intensity)
142
+ result = result[:, :, 2].astype(np.float32) / 255.0
143
+ # Apply gamma correction
144
+ result = result ** gamma_correction
145
+ # Convert back to 0-255 range
146
+ result = (result * 255.0).clip(0, 255).astype(np.uint8)
147
+ # Convert to RGB (still grayscale but in RGB format)
148
+ result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
149
+ return result
150
+
151
+ def blend_numpy_images(image1, image2, blend_factor=0.25, mode="normal"):
152
+ """
153
+ Blend two numpy images using normal mode
154
+ """
155
+ # Convert to float32 and normalize to 0-1
156
+ img1 = image1.astype(np.float32) / 255.0
157
+ img2 = image2.astype(np.float32) / 255.0
158
+
159
+ # Normal blend mode
160
+ blended = img1 * (1 - blend_factor) + img2 * blend_factor
161
+
162
+ # Convert back to uint8
163
+ blended = (blended * 255.0).clip(0, 255).astype(np.uint8)
164
+ return blended
165
+
166
+ def process_normal_map(image):
167
+ """
168
+ Process image through NormalMapSimple and ConvertNormals
169
+ """
170
+ # Convert numpy image to torch tensor with batch dimension
171
+ image_tensor = torch.from_numpy(image).unsqueeze(0).float() / 255.0
172
+
173
+ # Create instances of the classes
174
+ normal_map_generator = NormalMapSimple()
175
+ normal_converter = ConvertNormals()
176
+
177
+ # Generate initial normal map
178
+ normal_map = normal_map_generator.normal_map(image_tensor, scale_XY=1.0)[0]
179
+
180
+ # Convert normal map from Standard to DirectX
181
+ converted_normal = normal_converter.convert_normals(
182
+ normal_map,
183
+ input_mode="Standard",
184
+ output_mode="DirectX",
185
+ scale_XY=1.0,
186
+ normalize=True,
187
+ fix_black=True
188
+ )[0]
189
+
190
+ # Convert back to numpy array
191
+ result = (converted_normal.squeeze(0).numpy() * 255).astype(np.uint8)
192
+ return result
193
+
194
  def infer(path_input, seed=None):
195
  name_base, name_ext = os.path.splitext(os.path.basename(path_input))
196
  _, output_d = lotus(path_input, 'depth', seed, device)
197
+
198
  # Apply Gaussian blur with 0.75 radius
199
  output_d = apply_gaussian_blur(output_d, radius=0.75)
200
+
201
+ # Convert depth to numpy for normal map processing
202
+ depth_array = np.array(output_d)
203
+
204
+ # Load original image for intensity blending
205
+ input_image = Image.open(path_input)
206
+ input_array = np.array(input_image)
207
+
208
+ # Get intensity map from original image
209
+ intensity_map = get_image_intensity(input_array, gamma_correction=1.0)
210
+
211
+ # Blend depth with intensity map
212
+ blended_result = blend_numpy_images(
213
+ cv2.cvtColor(depth_array, cv2.COLOR_RGB2BGR if len(depth_array.shape) == 3 else cv2.COLOR_GRAY2BGR),
214
+ intensity_map,
215
+ blend_factor=0.25,
216
+ mode="normal"
217
+ )
218
+
219
+ # Generate normal map from blended result
220
+ normal_map = process_normal_map(blended_result)
221
+
222
  if not os.path.exists("files/output"):
223
  os.makedirs("files/output")
224
  d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}")
225
+ n_save_path = os.path.join("files/output", f"{name_base}_n{name_ext}")
226
+
227
  output_d.save(d_save_path)
228
+ Image.fromarray(normal_map).save(n_save_path)
229
+
230
+ return [path_input, d_save_path, n_save_path]
231
 
232
  def infer_video(path_input, seed=None):
233
  _, frames_d, fps = lotus_video(path_input, 'depth', seed, device)
 
259
 
260
  with gr.Blocks(
261
  theme=gradio_theme,
262
+ title="LOTUS (Depth & Normal Maps - Discriminative)",
263
  css="""
264
  #download {
265
  height: 118px;
 
322
  elem_classes="slider",
323
  position=0.25,
324
  )
325
+ image_output_n = gr.Image(
326
+ label="Normal Map Output",
327
+ type="filepath",
328
+ interactive=False,
329
+ )
330
 
331
  gr.Examples(
332
  fn=infer_gpu,
 
335
  for name in os.listdir(os.path.join("files", "images"))
336
  ]),
337
  inputs=[image_input],
338
+ outputs=[image_output_d, image_output_n],
339
  cache_examples=False,
340
  )
341
 
 
376
  image_submit_btn.click(
377
  fn=infer_gpu,
378
  inputs=[image_input],
379
+ outputs=[image_output_d, image_output_n],
380
  concurrency_limit=1,
381
  )
382
  image_reset_btn.click(
383
+ fn=lambda: [None, None],
384
  inputs=[],
385
+ outputs=[image_output_d, image_output_n],
386
  queue=False,
387
  )
388