Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
3ba6a3b
1
Parent(s):
45ed92a
- dkt/pipelines/pipeline.py +13 -3
dkt/pipelines/pipeline.py
CHANGED
|
@@ -983,7 +983,17 @@ class DKTPipeline:
|
|
| 983 |
|
| 984 |
|
| 985 |
# 5. 放到 GPU
|
| 986 |
-
input_image = input_image.to(device=device, dtype=torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 987 |
|
| 988 |
|
| 989 |
model_device = next(self.moge_pipe.parameters()).device
|
|
@@ -1013,9 +1023,9 @@ class DKTPipeline:
|
|
| 1013 |
if iidx == 0:
|
| 1014 |
# Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
|
| 1015 |
if input_image_np.max() > 1:
|
| 1016 |
-
input_image = torch.tensor(input_image_np / 255, dtype=torch.
|
| 1017 |
else:
|
| 1018 |
-
input_image = torch.tensor(input_image_np, dtype=torch.
|
| 1019 |
|
| 1020 |
print(f'moge devices: {moge_device}') #* why cpu?
|
| 1021 |
|
|
|
|
| 983 |
|
| 984 |
|
| 985 |
# 5. 放到 GPU
|
| 986 |
+
input_image = input_image.to(device=device, dtype=torch.float32)
|
| 987 |
+
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
|
| 991 |
+
# 🔴 必须补 batch 维度
|
| 992 |
+
if input_image.dim() == 3:
|
| 993 |
+
input_image = input_image.unsqueeze(0) # (1, 3, H, W)
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
|
| 997 |
|
| 998 |
|
| 999 |
model_device = next(self.moge_pipe.parameters()).device
|
|
|
|
| 1023 |
if iidx == 0:
|
| 1024 |
# Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
|
| 1025 |
if input_image_np.max() > 1:
|
| 1026 |
+
input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
|
| 1027 |
else:
|
| 1028 |
+
input_image = torch.tensor(input_image_np, dtype=torch.float32, device=moge_device).permute(2, 0, 1)
|
| 1029 |
|
| 1030 |
print(f'moge devices: {moge_device}') #* why cpu?
|
| 1031 |
|