| import base64 |
| import logging |
| import os |
| import random |
| import sys |
|
|
| import comfy.model_management |
| import folder_paths |
| import numpy as np |
| import torch |
| import trimesh |
| from PIL import Image |
| from trimesh.exchange import gltf |
|
|
| sys.path.append(os.path.dirname(__file__)) |
| from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE |
| from spar3d.system import SPAR3D |
| from spar3d.utils import foreground_crop |
|
|
| SPAR3D_CATEGORY = "SPAR3D" |
| SPAR3D_MODEL_NAME = "stabilityai/spar3d" |
|
|
|
|
| class SPAR3DLoader: |
| CATEGORY = SPAR3D_CATEGORY |
| FUNCTION = "load" |
| RETURN_NAMES = ("spar3d_model",) |
| RETURN_TYPES = ("SPAR3D_MODEL",) |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "low_vram_mode": ("BOOLEAN", {"default": False}), |
| } |
| } |
|
|
| def load(self, low_vram_mode=False): |
| device = comfy.model_management.get_torch_device() |
| model = SPAR3D.from_pretrained( |
| SPAR3D_MODEL_NAME, |
| config_name="config.yaml", |
| weight_name="model.safetensors", |
| low_vram_mode=low_vram_mode, |
| ) |
| model.to(device) |
| model.eval() |
|
|
| return (model,) |
|
|
|
|
| class SPAR3DPreview: |
| CATEGORY = SPAR3D_CATEGORY |
| FUNCTION = "preview" |
| OUTPUT_NODE = True |
| RETURN_TYPES = () |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": {"mesh": ("MESH",)}} |
|
|
| def preview(self, mesh): |
| glbs = [] |
| for m in mesh: |
| scene = trimesh.Scene(m) |
| glb_data = gltf.export_glb(scene, include_normals=True) |
| glb_base64 = base64.b64encode(glb_data).decode("utf-8") |
| glbs.append(glb_base64) |
| return {"ui": {"glbs": glbs}} |
|
|
|
|
| class SPAR3DSampler: |
| CATEGORY = SPAR3D_CATEGORY |
| FUNCTION = "predict" |
| RETURN_NAMES = ("mesh", "pointcloud") |
| RETURN_TYPES = ("MESH", "POINTCLOUD") |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| remesh_choices = ["none"] |
| if TRIANGLE_REMESH_AVAILABLE: |
| remesh_choices.append("triangle") |
| if QUAD_REMESH_AVAILABLE: |
| remesh_choices.append("quad") |
|
|
| opt_dict = { |
| "mask": ("MASK",), |
| "pointcloud": ("POINTCLOUD",), |
| "target_type": (["none", "vertex", "face"],), |
| "target_count": ( |
| "INT", |
| {"default": 1000, "min": 3, "max": 20000, "step": 1}, |
| ), |
| "guidance_scale": ( |
| "FLOAT", |
| {"default": 3.0, "min": 1.0, "max": 5.0, "step": 0.05}, |
| ), |
| "seed": ( |
| "INT", |
| {"default": 42, "min": 0, "max": 2**32 - 1, "step": 1}, |
| ), |
| } |
| if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE: |
| opt_dict["remesh"] = (remesh_choices,) |
|
|
| return { |
| "required": { |
| "model": ("SPAR3D_MODEL",), |
| "image": ("IMAGE",), |
| "foreground_ratio": ( |
| "FLOAT", |
| {"default": 1.3, "min": 1.0, "max": 2.0, "step": 0.01}, |
| ), |
| "texture_resolution": ( |
| "INT", |
| {"default": 1024, "min": 512, "max": 2048, "step": 256}, |
| ), |
| }, |
| "optional": opt_dict, |
| } |
|
|
| def predict( |
| s, |
| model, |
| image, |
| mask, |
| foreground_ratio, |
| texture_resolution, |
| pointcloud=None, |
| remesh="none", |
| target_type="none", |
| target_count=1000, |
| guidance_scale=3.0, |
| seed=42, |
| ): |
| if image.shape[0] != 1: |
| raise ValueError("Only one image can be processed at a time") |
|
|
| vertex_count = ( |
| -1 |
| if target_type == "none" |
| else (target_count // 2 if target_type == "face" else target_count) |
| ) |
|
|
| pil_image = Image.fromarray( |
| torch.clamp(torch.round(255.0 * image[0]), 0, 255) |
| .type(torch.uint8) |
| .cpu() |
| .numpy() |
| ) |
|
|
| if mask is not None: |
| print("Using Mask") |
| mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype( |
| np.uint8 |
| ) |
| mask_pil = Image.fromarray(mask_np, mode="L") |
| pil_image.putalpha(mask_pil) |
| else: |
| if image.shape[3] != 4: |
| print("No mask or alpha channel detected, Converting to RGBA") |
| pil_image = pil_image.convert("RGBA") |
|
|
| pil_image = foreground_crop(pil_image, foreground_ratio) |
|
|
| model.cfg.guidance_scale = guidance_scale |
| random.seed(seed) |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| print(remesh) |
| with torch.no_grad(): |
| with torch.autocast(device_type="cuda", dtype=torch.float16): |
| if not TRIANGLE_REMESH_AVAILABLE and remesh == "triangle": |
| raise ImportError( |
| "Triangle remeshing requires gpytoolbox to be installed" |
| ) |
| if not QUAD_REMESH_AVAILABLE and remesh == "quad": |
| raise ImportError("Quad remeshing requires pynim to be installed") |
| mesh, glob_dict = model.run_image( |
| pil_image, |
| bake_resolution=texture_resolution, |
| pointcloud=pointcloud, |
| remesh=remesh, |
| vertex_count=vertex_count, |
| ) |
|
|
| if mesh.vertices.shape[0] == 0: |
| raise ValueError("No subject detected in the image") |
|
|
| return ( |
| [mesh], |
| glob_dict["pointcloud"].view(-1).detach().cpu().numpy().tolist(), |
| ) |
|
|
|
|
| class SPAR3DSave: |
| CATEGORY = SPAR3D_CATEGORY |
| FUNCTION = "save" |
| OUTPUT_NODE = True |
| RETURN_TYPES = () |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "mesh": ("MESH",), |
| "filename_prefix": ("STRING", {"default": "SPAR3D"}), |
| } |
| } |
|
|
| def __init__(self): |
| self.type = "output" |
|
|
| def save(self, mesh, filename_prefix): |
| output_dir = folder_paths.get_output_directory() |
| glbs = [] |
| for idx, m in enumerate(mesh): |
| scene = trimesh.Scene(m) |
| glb_data = gltf.export_glb(scene, include_normals=True) |
| logging.info(f"Generated GLB model with {len(glb_data)} bytes") |
|
|
| full_output_folder, filename, counter, subfolder, filename_prefix = ( |
| folder_paths.get_save_image_path(filename_prefix, output_dir) |
| ) |
| filename = filename.replace("%batch_num%", str(idx)) |
| out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb") |
| with open(out_path, "wb") as f: |
| f.write(glb_data) |
| glbs.append(base64.b64encode(glb_data).decode("utf-8")) |
| return {"ui": {"glbs": glbs}} |
|
|
|
|
| class SPAR3DPointCloudLoader: |
| CATEGORY = SPAR3D_CATEGORY |
| FUNCTION = "load_pointcloud" |
| RETURN_TYPES = ("POINTCLOUD",) |
| RETURN_NAMES = ("pointcloud",) |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "file": ("STRING", {"default": None}), |
| } |
| } |
|
|
| def load_pointcloud(self, file): |
| if file is None or file == "": |
| return (None,) |
| |
| mesh = trimesh.load(file) |
|
|
| |
| vertices = mesh.vertices |
|
|
| |
| if mesh.visual.vertex_colors is not None: |
| colors = ( |
| mesh.visual.vertex_colors[:, :3] / 255.0 |
| ) |
| else: |
| colors = np.ones((len(vertices), 3)) |
|
|
| |
| point_cloud = [] |
| for vertex, color in zip(vertices, colors): |
| point_cloud.extend( |
| [ |
| float(vertex[0]), |
| float(vertex[1]), |
| float(vertex[2]), |
| float(color[0]), |
| float(color[1]), |
| float(color[2]), |
| ] |
| ) |
|
|
| return (point_cloud,) |
|
|
|
|
| class SPAR3DPointCloudSaver: |
| CATEGORY = SPAR3D_CATEGORY |
| FUNCTION = "save_pointcloud" |
| OUTPUT_NODE = True |
| RETURN_TYPES = () |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "pointcloud": ("POINTCLOUD",), |
| "filename_prefix": ("STRING", {"default": "SPAR3D"}), |
| } |
| } |
|
|
| def save_pointcloud(self, pointcloud, filename_prefix): |
| if pointcloud is None: |
| return {"ui": {"text": "No point cloud data to save"}} |
|
|
| |
| points = np.array(pointcloud).reshape(-1, 6) |
|
|
| |
| vertex_array = np.zeros( |
| len(points), |
| dtype=[ |
| ("x", "f4"), |
| ("y", "f4"), |
| ("z", "f4"), |
| ("red", "u1"), |
| ("green", "u1"), |
| ("blue", "u1"), |
| ], |
| ) |
|
|
| |
| vertex_array["x"] = points[:, 0] |
| vertex_array["y"] = points[:, 1] |
| vertex_array["z"] = points[:, 2] |
| |
| vertex_array["red"] = (points[:, 3] * 255).astype(np.uint8) |
| vertex_array["green"] = (points[:, 4] * 255).astype(np.uint8) |
| vertex_array["blue"] = (points[:, 5] * 255).astype(np.uint8) |
|
|
| |
| ply_data = trimesh.PointCloud( |
| vertices=points[:, :3], colors=points[:, 3:] * 255 |
| ) |
|
|
| |
| output_dir = folder_paths.get_output_directory() |
| full_output_folder, filename, counter, subfolder, filename_prefix = ( |
| folder_paths.get_save_image_path(filename_prefix, output_dir) |
| ) |
| out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.ply") |
|
|
| ply_data.export(out_path) |
|
|
| return {"ui": {"text": f"Saved point cloud to {out_path}"}} |
|
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "SPAR3DLoader": "SPAR3D Loader", |
| "SPAR3DPreview": "SPAR3D Preview", |
| "SPAR3DSampler": "SPAR3D Sampler", |
| "SPAR3DSave": "SPAR3D Save", |
| "SPAR3DPointCloudLoader": "SPAR3D Point Cloud Loader", |
| "SPAR3DPointCloudSaver": "SPAR3D Point Cloud Saver", |
| } |
|
|
| NODE_CLASS_MAPPINGS = { |
| "SPAR3DLoader": SPAR3DLoader, |
| "SPAR3DPreview": SPAR3DPreview, |
| "SPAR3DSampler": SPAR3DSampler, |
| "SPAR3DSave": SPAR3DSave, |
| "SPAR3DPointCloudLoader": SPAR3DPointCloudLoader, |
| "SPAR3DPointCloudSaver": SPAR3DPointCloudSaver, |
| } |
|
|
| WEB_DIRECTORY = "./comfyui" |
|
|
| __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] |
|
|