Daniellesry commited on
Commit
3ba6a3b
·
1 Parent(s): 45ed92a
Files changed (1) hide show
  1. dkt/pipelines/pipeline.py +13 -3
dkt/pipelines/pipeline.py CHANGED
@@ -983,7 +983,17 @@ class DKTPipeline:
983
 
984
 
985
  # 5. 放到 GPU
986
- input_image = input_image.to(device=device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
987
 
988
 
989
  model_device = next(self.moge_pipe.parameters()).device
@@ -1013,9 +1023,9 @@ class DKTPipeline:
1013
  if iidx == 0:
1014
  # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
1015
  if input_image_np.max() > 1:
1016
- input_image = torch.tensor(input_image_np / 255, dtype=torch.bfloat16, device=moge_device).permute(2, 0, 1)
1017
  else:
1018
- input_image = torch.tensor(input_image_np, dtype=torch.bfloat16, device=moge_device).permute(2, 0, 1)
1019
 
1020
  print(f'moge devices: {moge_device}') #* why cpu?
1021
 
 
983
 
984
 
985
  # 5. 放到 GPU
986
+ input_image = input_image.to(device=device, dtype=torch.float32)
987
+
988
+
989
+
990
+
991
+ # 🔴 必须补 batch 维度
992
+ if input_image.dim() == 3:
993
+ input_image = input_image.unsqueeze(0) # (1, 3, H, W)
994
+
995
+
996
+
997
 
998
 
999
  model_device = next(self.moge_pipe.parameters()).device
 
1023
  if iidx == 0:
1024
  # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
1025
  if input_image_np.max() > 1:
1026
+ input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
1027
  else:
1028
+ input_image = torch.tensor(input_image_np, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
1029
 
1030
  print(f'moge devices: {moge_device}') #* why cpu?
1031