g1-moves / viewer.js
exptech's picture
Fix Pull Over pipeline availability and 154D viewer
21739c9 verified
import * as THREE from 'three';
import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
// ============================================================
// Constants
// ============================================================
const N_JOINTS = 29;
const DEFAULT_OBS_DIM = 160;
const DATASET_BASE_URL = 'https://huggingface.co/datasets/exptech/g1-moves/resolve/main';
const CONTROL_DT = 0.02; // 50 Hz policy, matching training (decimation=4, dt=0.005)
const MOTION_ANCHOR_IDX = 15; // torso_link in NPZ body array (0=pelvis, 15=torso)
const MUJOCO_ANCHOR_BODY = 16; // torso_link in MuJoCo xpos (0=world, 1=pelvis, 16=torso)
// Default joint positions (knees-bent keyframe) — XML joint order
const DEFAULT_JOINT_POS = new Float32Array([
-0.312, 0, 0, 0.669, -0.363, 0, // left leg: hip_pitch, hip_roll, hip_yaw, knee, ankle_pitch, ankle_roll
-0.312, 0, 0, 0.669, -0.363, 0, // right leg
0, 0, 0, // waist: yaw, roll, pitch
0.200, 0.200, 0, 0.600, 0, 0, 0, // left arm: shoulder_pitch, shoulder_roll, shoulder_yaw, elbow, wrist_roll, wrist_pitch, wrist_yaw
0.200, -0.200, 0, 0.600, 0, 0, 0 // right arm
]);
// Action scales per joint — XML joint order
const ACTION_SCALES = new Float32Array([
0.5475, 0.3507, 0.5475, 0.3507, 0.4386, 0.4386, // left leg
0.5475, 0.3507, 0.5475, 0.3507, 0.4386, 0.4386, // right leg
0.5475, 0.4386, 0.4386, // waist
0.4386, 0.4386, 0.4386, 0.4386, 0.4386, 0.0745, 0.0745, // left arm
0.4386, 0.4386, 0.4386, 0.4386, 0.4386, 0.0745, 0.0745 // right arm
]);
// ============================================================
// Coordinate conversions (MuJoCo Z-up → Three.js Y-up)
// ============================================================
function getPosition(buffer, index, target) {
return target.set(
buffer[index * 3 + 0],
buffer[index * 3 + 2],
-buffer[index * 3 + 1]
);
}
function getQuaternion(buffer, index, target) {
return target.set(
-buffer[index * 4 + 1],
-buffer[index * 4 + 3],
buffer[index * 4 + 2],
-buffer[index * 4 + 0]
);
}
// ============================================================
// Quaternion math (wxyz convention — MuJoCo native, no swizzle)
// ============================================================
function quatMul(a, b) {
return [
a[0]*b[0] - a[1]*b[1] - a[2]*b[2] - a[3]*b[3],
a[0]*b[1] + a[1]*b[0] + a[2]*b[3] - a[3]*b[2],
a[0]*b[2] - a[1]*b[3] + a[2]*b[0] + a[3]*b[1],
a[0]*b[3] + a[1]*b[2] - a[2]*b[1] + a[3]*b[0]
];
}
function quatInv(q) {
return [q[0], -q[1], -q[2], -q[3]];
}
function quatRotate(q, v) {
const qv = quatMul(q, [0, v[0], v[1], v[2]]);
const r = quatMul(qv, quatInv(q));
return [r[1], r[2], r[3]];
}
function quatRotateInv(q, v) {
return quatRotate(quatInv(q), v);
}
function quatToMat(q) {
// Convert wxyz quaternion to 6D rotation representation matching training:
// matrix_from_quat(q)[:,:,:2].reshape(-1,6) = row-major of (3,2) submatrix
// Returns [R00, R01, R10, R11, R20, R21]
const w = q[0], x = q[1], y = q[2], z = q[3];
const x2 = x*x, y2 = y*y, z2 = z*z;
const xy = x*y, xz = x*z, yz = y*z;
const wx = w*x, wy = w*y, wz = w*z;
return [
1 - 2*(y2+z2), 2*(xy-wz), // row 0: R[0,0], R[0,1]
2*(xy+wz), 1 - 2*(x2+z2), // row 1: R[1,0], R[1,1]
2*(xz-wy), 2*(yz+wx) // row 2: R[2,0], R[2,1]
];
}
// ============================================================
// NPZ parser (ZIP + numpy)
// ============================================================
async function loadNPZ(url) {
const buf = await fetch(url).then(r => {
if (!r.ok) throw new Error(`Failed to fetch ${url}: ${r.status}`);
return r.arrayBuffer();
});
const bytes = new Uint8Array(buf);
const entries = await parseZipEntries(bytes);
const arrays = {};
for (const [name, data] of Object.entries(entries)) {
arrays[name.replace('.npy', '')] = parseNpy(data);
}
return arrays;
}
async function parseZipEntries(bytes) {
const view = new DataView(bytes.buffer, bytes.byteOffset, bytes.byteLength);
const entries = {};
let offset = 0;
while (offset < bytes.length - 4) {
const sig = view.getUint32(offset, true);
if (sig !== 0x04034b50) break;
const method = view.getUint16(offset + 8, true);
let compSize = view.getUint32(offset + 18, true);
let uncompSize = view.getUint32(offset + 22, true);
const nameLen = view.getUint16(offset + 26, true);
const extraLen = view.getUint16(offset + 28, true);
const name = new TextDecoder().decode(bytes.slice(offset + 30, offset + 30 + nameLen));
// Handle ZIP64 extra field (sizes = 0xFFFFFFFF)
if (compSize === 0xFFFFFFFF || uncompSize === 0xFFFFFFFF) {
const extraStart = offset + 30 + nameLen;
let eOff = 0;
while (eOff + 4 <= extraLen) {
const hid = view.getUint16(extraStart + eOff, true);
const hsz = view.getUint16(extraStart + eOff + 2, true);
if (hid === 0x0001 && hsz >= 16) {
// ZIP64: 8-byte uncompressed, 8-byte compressed
const lo0 = view.getUint32(extraStart + eOff + 4, true);
const hi0 = view.getUint32(extraStart + eOff + 8, true);
const lo1 = view.getUint32(extraStart + eOff + 12, true);
const hi1 = view.getUint32(extraStart + eOff + 16, true);
uncompSize = lo0 + hi0 * 0x100000000;
compSize = lo1 + hi1 * 0x100000000;
break;
}
eOff += 4 + hsz;
}
}
const dataStart = offset + 30 + nameLen + extraLen;
let fileData;
if (method === 0) {
fileData = bytes.slice(dataStart, dataStart + uncompSize);
} else if (method === 8) {
fileData = await decompressRaw(bytes.slice(dataStart, dataStart + compSize));
} else {
throw new Error(`Unsupported ZIP method: ${method}`);
}
entries[name] = fileData;
offset = dataStart + compSize;
}
return entries;
}
function decompressRaw(compressed) {
const ds = new DecompressionStream('deflate-raw');
const writer = ds.writable.getWriter();
const reader = ds.readable.getReader();
writer.write(compressed);
writer.close();
const chunks = [];
return new Promise((resolve, reject) => {
(function pump() {
reader.read().then(({ done, value }) => {
if (done) {
const total = chunks.reduce((s, c) => s + c.length, 0);
const result = new Uint8Array(total);
let off = 0;
for (const c of chunks) { result.set(c, off); off += c.length; }
resolve(result);
} else {
chunks.push(value);
pump();
}
}).catch(reject);
})();
});
}
function parseNpy(bytes) {
const major = bytes[6];
let headerLen, headerOffset;
if (major <= 1) {
headerLen = new DataView(bytes.buffer, bytes.byteOffset + 8, 2).getUint16(0, true);
headerOffset = 10;
} else {
headerLen = new DataView(bytes.buffer, bytes.byteOffset + 8, 4).getUint32(0, true);
headerOffset = 12;
}
const headerStr = new TextDecoder().decode(bytes.slice(headerOffset, headerOffset + headerLen));
const shapeMatch = headerStr.match(/'shape':\s*\(([^)]*)\)/);
const descrMatch = headerStr.match(/'descr':\s*'([^']*)'/);
const shape = shapeMatch[1].split(',').filter(s => s.trim()).map(s => parseInt(s.trim()));
const dtype = descrMatch[1];
const dataStart = headerOffset + headerLen;
// Compute expected element count from shape
const numElements = shape.length > 0 ? shape.reduce((a, b) => a * b, 1) : 1;
const bytesPerElement = dtype === '<f8' ? 8 : 4;
const expectedBytes = numElements * bytesPerElement;
// Copy exactly the expected number of bytes for proper alignment
const rawBytes = bytes.slice(dataStart, dataStart + expectedBytes);
const ab = rawBytes.buffer.slice(rawBytes.byteOffset, rawBytes.byteOffset + rawBytes.byteLength);
if (dtype === '<f4') return { shape, data: new Float32Array(ab) };
if (dtype === '<f8') return { shape, data: new Float64Array(ab) };
if (dtype === '<i4') return { shape, data: new Int32Array(ab) };
throw new Error(`Unsupported numpy dtype: ${dtype}`);
}
// ============================================================
// Three.js scene builder from MuJoCo model
// ============================================================
function buildSceneFromModel(mujoco, model, data) {
const bodies = {};
const meshCache = {};
for (let g = 0; g < model.ngeom; g++) {
if (!(model.geom_group[g] < 3)) continue;
const b = model.geom_bodyid[g];
const type = model.geom_type[g];
const size = [
model.geom_size[g * 3 + 0],
model.geom_size[g * 3 + 1],
model.geom_size[g * 3 + 2]
];
if (!(b in bodies)) {
bodies[b] = new THREE.Group();
bodies[b].name = `body_${b}`;
bodies[b].bodyID = b;
}
let geometry;
if (type === 0) { // plane
geometry = new THREE.PlaneGeometry(100, 100);
} else if (type === 2) { // sphere
geometry = new THREE.SphereGeometry(size[0], 16, 16);
} else if (type === 3) { // capsule
geometry = new THREE.CapsuleGeometry(size[0], size[1] * 2, 12, 16);
} else if (type === 5) { // cylinder
geometry = new THREE.CylinderGeometry(size[0], size[0], size[1] * 2, 16);
} else if (type === 6) { // box
geometry = new THREE.BoxGeometry(size[0] * 2, size[2] * 2, size[1] * 2);
} else if (type === 7) { // mesh
const meshID = model.geom_dataid[g];
if (meshID < 0) continue;
if (meshID in meshCache) {
geometry = meshCache[meshID];
} else {
geometry = new THREE.BufferGeometry();
const vertBuf = model.mesh_vert.subarray(
model.mesh_vertadr[meshID] * 3,
(model.mesh_vertadr[meshID] + model.mesh_vertnum[meshID]) * 3
);
// Swizzle vertices: (x, y, z) → (x, z, -y) for Three.js Y-up
for (let v = 0; v < vertBuf.length; v += 3) {
const tmp = vertBuf[v + 1];
vertBuf[v + 1] = vertBuf[v + 2];
vertBuf[v + 2] = -tmp;
}
const faceVerts = model.mesh_face.subarray(
model.mesh_faceadr[meshID] * 3,
(model.mesh_faceadr[meshID] + model.mesh_facenum[meshID]) * 3
);
geometry.setAttribute('position', new THREE.BufferAttribute(vertBuf, 3));
geometry.setIndex(Array.from(faceVerts));
geometry.computeVertexNormals();
meshCache[meshID] = geometry;
}
} else {
continue;
}
if (!geometry) continue;
const color = [
model.geom_rgba[g * 4 + 0],
model.geom_rgba[g * 4 + 1],
model.geom_rgba[g * 4 + 2],
model.geom_rgba[g * 4 + 3]
];
// Check for material
let matColor = color;
let texture;
if (model.geom_matid[g] !== -1) {
const matId = model.geom_matid[g];
matColor = [
model.mat_rgba[matId * 4 + 0],
model.mat_rgba[matId * 4 + 1],
model.mat_rgba[matId * 4 + 2],
model.mat_rgba[matId * 4 + 3]
];
// Check for texture
const mjNTEXROLE = 10;
const texId = model.mat_texid[matId * mjNTEXROLE + 1]; // mjTEXROLE_RGB = 1
if (texId !== -1) {
const width = model.tex_width[texId];
const height = model.tex_height[texId];
const offset = model.tex_adr[texId];
const channels = model.tex_nchannel[texId];
const texData = model.tex_data;
const rgbaArray = new Uint8Array(width * height * 4);
for (let p = 0; p < width * height; p++) {
rgbaArray[p * 4 + 0] = texData[offset + p * channels + 0];
rgbaArray[p * 4 + 1] = channels > 1 ? texData[offset + p * channels + 1] : rgbaArray[p * 4];
rgbaArray[p * 4 + 2] = channels > 2 ? texData[offset + p * channels + 2] : rgbaArray[p * 4];
rgbaArray[p * 4 + 3] = 255;
}
texture = new THREE.DataTexture(rgbaArray, width, height, THREE.RGBAFormat, THREE.UnsignedByteType);
texture.repeat.set(
model.mat_texrepeat[matId * 2 + 0] || 1,
model.mat_texrepeat[matId * 2 + 1] || 1
);
texture.wrapS = THREE.RepeatWrapping;
texture.wrapT = THREE.RepeatWrapping;
texture.needsUpdate = true;
}
}
const matOpts = {
color: new THREE.Color(matColor[0], matColor[1], matColor[2]),
transparent: matColor[3] < 1.0,
opacity: matColor[3],
roughness: 0.6,
metalness: 0.1
};
if (texture) matOpts.map = texture;
const material = new THREE.MeshStandardMaterial(matOpts);
let mesh;
if (type === 0) {
mesh = new THREE.Mesh(geometry, new THREE.MeshStandardMaterial({
color: 0x2a2a30, roughness: 0.95, metalness: 0.0, ...(texture ? {map: texture} : {})
}));
mesh.rotateX(-Math.PI / 2);
} else {
mesh = new THREE.Mesh(geometry, material);
}
mesh.castShadow = type !== 0;
mesh.receiveShadow = true;
// Set geom local transform with swizzle
getPosition(model.geom_pos, g, mesh.position);
if (type !== 0) getQuaternion(model.geom_quat, g, mesh.quaternion);
bodies[b].add(mesh);
}
return bodies;
}
// ============================================================
// Main viewer class
// ============================================================
class G1PolicyViewer {
constructor() {
this.mujoco = null; // The WASM module
this.model = null;
this.data = null;
this.policySession = null;
this.policyKind = 'policy';
this.obsDim = DEFAULT_OBS_DIM;
this.motionData = null;
this.motionFps = 60;
this.nFrames = 0;
this.anchorBodyIdx = -1;
this.lastAction = new Float32Array(N_JOINTS);
this.simTime = 0;
this.paused = false;
this.bodies = {}; // Three.js body groups, keyed by body ID
this.scene = null;
this.camera = null;
this.renderer = null;
this.controls = null;
this.fpsCounter = { frames: 0, lastTime: performance.now(), value: 0 };
}
async init(clipId, category, embed = false, policyKind = 'policy') {
this.embed = embed;
this.policyKind = policyKind || 'policy';
this.obsDim = this.policyKind === 'policy_154' ? 154 : DEFAULT_OBS_DIM;
const status = document.getElementById('load-status');
const progress = document.getElementById('progress-fill');
const errorMsg = document.getElementById('error-msg');
const notifyParent = (msg) => {
if (window.parent !== window) {
window.parent.postMessage({ type: 'viewer-status', clipId, status: msg }, '*');
}
};
try {
status.textContent = 'Loading MuJoCo WASM...';
notifyParent('Loading WASM...');
progress.style.width = '10%';
const mjModule = await import('https://cdn.jsdelivr.net/npm/mujoco-js@0.0.7/+esm');
this.mujoco = await mjModule.default();
status.textContent = 'Loading G1 robot model...';
notifyParent('Loading robot model...');
progress.style.width = '30%';
await this.loadModel();
status.textContent = 'Loading neural network policy...';
notifyParent('Loading policy...');
progress.style.width = '50%';
const policyUrl = this.policyKind === 'policy_154'
? `${DATASET_BASE_URL}/${category}/${clipId}/policy_154/${clipId}_policy.onnx`
: `media/${category}/${clipId}/policy/${clipId}_policy.onnx`;
await this.loadPolicy(policyUrl);
status.textContent = 'Loading reference motion...';
notifyParent('Loading motion...');
progress.style.width = '70%';
const motionUrl = this.policyKind === 'policy_154'
? `${DATASET_BASE_URL}/${category}/${clipId}/training/${clipId}.npz`
: `media/${category}/${clipId}/training/${clipId}.npz`;
await this.loadMotion(motionUrl);
status.textContent = 'Building 3D scene...';
notifyParent('Building scene...');
progress.style.width = '85%';
this.buildRenderer();
this.bodies = buildSceneFromModel(this.mujoco, this.model, this.data);
// Add body groups to scene
for (const b of Object.values(this.bodies)) {
this.scene.add(b);
}
// Anchor body: torso_link = MuJoCo body 16 (0=world, 1=pelvis, ..., 16=torso)
this.anchorBodyIdx = MUJOCO_ANCHOR_BODY;
this.resetSimulation();
this.resetCamera();
this.syncBodies();
progress.style.width = '100%';
status.textContent = 'Ready!';
if (!embed) {
document.getElementById('clip-name').textContent =
clipId.replace(/^[BJMV]_/, '').replace(/([a-z])([A-Z])/g, '$1 $2');
document.getElementById('hud').style.display = 'flex';
document.getElementById('controls').style.display = 'flex';
document.getElementById('info-panel').style.display = 'block';
this.setupControls();
}
setTimeout(() => document.getElementById('overlay').classList.add('hidden'), 300);
// Notify parent frame that viewer is fully loaded and rendering
if (window.parent !== window) {
window.parent.postMessage({ type: 'viewer-ready', clipId }, '*');
}
this.animate();
} catch (err) {
console.error('Viewer init failed:', err);
errorMsg.textContent = `Error: ${err.message}`;
errorMsg.style.display = 'block';
status.textContent = 'Failed to load';
}
}
async loadModel() {
const mj = this.mujoco;
mj.FS.mkdir('/model');
mj.FS.mkdir('/model/meshes');
const xmlText = await fetch('model/g1_viewer.xml').then(r => r.text());
mj.FS.writeFile('/model/g1_viewer.xml', xmlText);
// Parse XML to find mesh files, fetch all in parallel
const parser = new DOMParser();
const doc = parser.parseFromString(xmlText, 'text/xml');
const meshFiles = Array.from(doc.querySelectorAll('mesh[file]')).map(m => m.getAttribute('file'));
await Promise.all(meshFiles.map(async (filename) => {
const resp = await fetch(`model/meshes/${filename}`);
if (!resp.ok) throw new Error(`Missing mesh: ${filename}`);
const data = new Uint8Array(await resp.arrayBuffer());
mj.FS.writeFile(`/model/meshes/${filename}`, data);
}));
this.model = mj.MjModel.loadFromXML('/model/g1_viewer.xml');
this.data = new mj.MjData(this.model);
}
async loadPolicy(url) {
this.policySession = await ort.InferenceSession.create(url, {
executionProviders: ['wasm']
});
}
async loadMotion(url) {
const arrays = await loadNPZ(url);
this.motionData = {
jointPos: arrays.joint_pos,
jointVel: arrays.joint_vel,
bodyPosW: arrays.body_pos_w,
bodyQuatW: arrays.body_quat_w
};
this.nFrames = this.motionData.jointPos.shape[0];
if (arrays.fps) {
const f = Number(arrays.fps.data[0]);
if (f > 0 && f < 1000) this.motionFps = f;
}
}
buildRenderer() {
const canvas = document.getElementById('viewer-canvas');
this.renderer = new THREE.WebGLRenderer({ canvas, antialias: !this.embed });
this.renderer.setSize(window.innerWidth, window.innerHeight);
this.renderer.setPixelRatio(this.embed ? 1 : Math.min(window.devicePixelRatio, 2));
this.renderer.shadowMap.enabled = !this.embed;
if (!this.embed) this.renderer.shadowMap.type = THREE.PCFSoftShadowMap;
this.renderer.toneMapping = THREE.ACESFilmicToneMapping;
this.scene = new THREE.Scene();
this.scene.background = new THREE.Color(0x07070b);
this.scene.fog = new THREE.Fog(0x07070b, 8, 25);
this.scene.add(new THREE.AmbientLight(0xffffff, 1.0));
const dir1 = new THREE.DirectionalLight(0xffffff, 1.5);
dir1.position.set(3, 5, 3);
dir1.castShadow = true;
dir1.shadow.mapSize.set(2048, 2048);
const sc = dir1.shadow.camera;
sc.near = 0.5; sc.far = 20; sc.left = sc.bottom = -5; sc.right = sc.top = 5;
this.scene.add(dir1);
this.scene.add(new THREE.DirectionalLight(0xffffff, 0.8).translateX(-2).translateY(3));
this.camera = new THREE.PerspectiveCamera(50, window.innerWidth / window.innerHeight, 0.1, 100);
this.resetCamera();
this.controls = new OrbitControls(this.camera, canvas);
this.controls.target.set(0, 0.8, 0); // overridden by resetCamera() after sim init
this.controls.enableDamping = true;
this.controls.dampingFactor = 0.05;
this.controls.minDistance = 0.5;
this.controls.maxDistance = 10;
if (this.embed) {
this.controls.autoRotate = true;
this.controls.autoRotateSpeed = 1.5;
this.controls.enableZoom = false;
this.controls.enablePan = false;
}
this.controls.update();
window.addEventListener('resize', () => {
this.camera.aspect = window.innerWidth / window.innerHeight;
this.camera.updateProjectionMatrix();
this.renderer.setSize(window.innerWidth, window.innerHeight);
});
}
resetCamera() {
// Compute camera position in front of robot based on its facing direction
const dist = this.embed ? 2.1 : 2.8;
const height = this.embed ? 1.0 : 1.2;
let cx = 0, cy = height, cz = dist; // default: in front along +Z (Three.js)
let tx = 0, ty = 0.8, tz = 0;
if (this.data) {
// Robot pelvis position (MuJoCo Z-up → Three.js Y-up)
const px = this.data.qpos[0];
const pz_mj = this.data.qpos[2]; // MuJoCo Z → Three.js Y
const py_mj = this.data.qpos[1]; // MuJoCo Y → Three.js -Z
tx = px; ty = pz_mj * 0.9; tz = -py_mj;
// Extract robot forward direction from pelvis quaternion (MuJoCo wxyz)
// Robot forward is +X in MuJoCo body frame
const qw = this.data.qpos[3], qx = this.data.qpos[4];
const qy = this.data.qpos[5], qz = this.data.qpos[6];
// Rotate [1,0,0] by quaternion: forward in MuJoCo world frame
const fwdX = 1 - 2*(qy*qy + qz*qz);
const fwdY = 2*(qx*qy + qw*qz);
// Map MuJoCo (fwdX, fwdY) → Three.js (fwdX, -fwdY) on the XZ plane
cx = tx + fwdX * dist;
cz = tz + (-fwdY) * dist;
cy = ty + height;
}
this.camera.position.set(cx, cy, cz);
this.camera.lookAt(tx, ty, tz);
if (this.controls) {
this.controls.target.set(tx, ty, tz);
this.controls.update();
}
}
resetSimulation() {
// Reset data
this.mujoco.mj_resetData(this.model, this.data);
// Initialize from motion frame 0 pose (matches training initialization)
if (this.motionData) {
const md = this.motionData;
const nBod = md.bodyPosW.shape[1];
// Root position from motion frame 0, pelvis = body 0
this.data.qpos[0] = Number(md.bodyPosW.data[0 * nBod * 3 + 0 * 3 + 0]);
this.data.qpos[1] = Number(md.bodyPosW.data[0 * nBod * 3 + 0 * 3 + 1]);
this.data.qpos[2] = Number(md.bodyPosW.data[0 * nBod * 3 + 0 * 3 + 2]);
// Root quaternion from motion frame 0, pelvis = body 0 (w,x,y,z)
this.data.qpos[3] = Number(md.bodyQuatW.data[0 * nBod * 4 + 0 * 4 + 0]);
this.data.qpos[4] = Number(md.bodyQuatW.data[0 * nBod * 4 + 0 * 4 + 1]);
this.data.qpos[5] = Number(md.bodyQuatW.data[0 * nBod * 4 + 0 * 4 + 2]);
this.data.qpos[6] = Number(md.bodyQuatW.data[0 * nBod * 4 + 0 * 4 + 3]);
// Joint positions from motion frame 0
const jpCols = md.jointPos.shape[1] || N_JOINTS;
for (let i = 0; i < N_JOINTS; i++) {
this.data.qpos[7 + i] = Number(md.jointPos.data[0 * jpCols + i]);
this.data.ctrl[i] = Number(md.jointPos.data[0 * jpCols + i]);
}
} else {
// Fallback to default keyframe if no motion data loaded yet
this.data.qpos[0] = 0; this.data.qpos[1] = 0; this.data.qpos[2] = 0.76;
this.data.qpos[3] = 1; this.data.qpos[4] = 0; this.data.qpos[5] = 0; this.data.qpos[6] = 0;
for (let i = 0; i < N_JOINTS; i++) {
this.data.qpos[7 + i] = DEFAULT_JOINT_POS[i];
this.data.ctrl[i] = DEFAULT_JOINT_POS[i];
}
}
this.mujoco.mj_forward(this.model, this.data);
this.simTime = 0;
this.lastAction.fill(0);
}
getCurrentFrame() {
return Math.floor(this.simTime * this.motionFps) % this.nFrames;
}
constructObservation() {
const obs = new Float32Array(this.obsDim);
const frame = this.getCurrentFrame();
const md = this.motionData;
let idx = 0;
// Command: ref joint_pos (29) + joint_vel (29) = 58
const jpCols = md.jointPos.shape[1] || N_JOINTS;
const jvCols = md.jointVel.shape[1] || N_JOINTS;
for (let i = 0; i < N_JOINTS; i++) obs[idx++] = Number(md.jointPos.data[frame * jpCols + i]);
for (let i = 0; i < N_JOINTS; i++) obs[idx++] = Number(md.jointVel.data[frame * jvCols + i]);
const nBod = md.bodyPosW.shape[1];
const ab = this.anchorBodyIdx;
const robQuat = [this.data.xquat[ab*4], this.data.xquat[ab*4+1], this.data.xquat[ab*4+2], this.data.xquat[ab*4+3]]; // wxyz
if (this.obsDim === 160) {
// Motion anchor pos in body frame (3). The 154D robot-compatible policy omits this.
const rpOff = frame * nBod * 3 + MOTION_ANCHOR_IDX * 3;
const refPos = [Number(md.bodyPosW.data[rpOff]), Number(md.bodyPosW.data[rpOff+1]), Number(md.bodyPosW.data[rpOff+2])];
const robPos = [this.data.xpos[ab*3], this.data.xpos[ab*3+1], this.data.xpos[ab*3+2]];
const dp = [refPos[0]-robPos[0], refPos[1]-robPos[1], refPos[2]-robPos[2]];
const posB = quatRotateInv(robQuat, dp);
obs[idx++] = posB[0]; obs[idx++] = posB[1]; obs[idx++] = posB[2];
}
// Motion anchor ori in body frame (6) — first 2 columns of rotation matrix
const rqOff = frame * nBod * 4 + MOTION_ANCHOR_IDX * 4;
const refQuat = [
Number(md.bodyQuatW.data[rqOff]), Number(md.bodyQuatW.data[rqOff+1]),
Number(md.bodyQuatW.data[rqOff+2]), Number(md.bodyQuatW.data[rqOff+3])
];
const oriQ = quatMul(quatInv(robQuat), refQuat);
// Convert quaternion to rotation matrix, take first 2 columns (6 values)
const mat = quatToMat(oriQ);
obs[idx++] = mat[0]; obs[idx++] = mat[1]; obs[idx++] = mat[2]; // col 0
obs[idx++] = mat[3]; obs[idx++] = mat[4]; obs[idx++] = mat[5]; // col 1
const wx = this.data.qvel[3], wy = this.data.qvel[4], wz = this.data.qvel[5];
if (this.obsDim === 160) {
// Base linear velocity in body frame (3). The 154D robot-compatible policy omits this.
// Velocimeter = R^T * v_world + omega_body × r_imu
const rootQuat = [this.data.qpos[3], this.data.qpos[4], this.data.qpos[5], this.data.qpos[6]];
const linVelW = [this.data.qvel[0], this.data.qvel[1], this.data.qvel[2]];
const linVelB = quatRotateInv(rootQuat, linVelW);
// IMU offset in pelvis frame: (0.04525, 0, -0.08339)
const rx = 0.04525, ry = 0, rz = -0.08339;
obs[idx++] = linVelB[0] + (wy * rz - wz * ry);
obs[idx++] = linVelB[1] + (wz * rx - wx * rz);
obs[idx++] = linVelB[2] + (wx * ry - wy * rx);
}
// Base angular velocity (3) — MuJoCo free joint qvel[3:6] is already in body frame (matches gyro)
obs[idx++] = wx; obs[idx++] = wy; obs[idx++] = wz;
// Joint positions relative to defaults (29)
for (let i = 0; i < N_JOINTS; i++) obs[idx++] = this.data.qpos[7 + i] - DEFAULT_JOINT_POS[i];
// Joint velocities (29)
for (let i = 0; i < N_JOINTS; i++) obs[idx++] = this.data.qvel[6 + i];
// Last actions (29)
for (let i = 0; i < N_JOINTS; i++) obs[idx++] = this.lastAction[i];
return obs;
}
async step() {
if (this.paused) return;
const obs = this.constructObservation();
const inputName = this.policySession.inputNames[0];
const outputName = this.policySession.outputNames[0];
const tensor = new ort.Tensor('float32', obs, [1, this.obsDim]);
const feeds = { [inputName]: tensor };
if (this.policySession.inputNames.includes('time_step')) {
feeds.time_step = new ort.Tensor('float32', new Float32Array([this.getCurrentFrame()]), [1, 1]);
}
const results = await this.policySession.run(feeds);
const actions = results[outputName].data;
// Apply: ctrl = default_pos + action * scale
for (let i = 0; i < N_JOINTS; i++) {
const a = Math.max(-10, Math.min(10, actions[i]));
this.data.ctrl[i] = DEFAULT_JOINT_POS[i] + a * ACTION_SCALES[i];
this.lastAction[i] = a;
}
// Step physics: 4 substeps at 0.005s = 0.02s per policy step (matching training)
const substeps = 4;
for (let s = 0; s < substeps; s++) {
this.mujoco.mj_step(this.model, this.data);
}
this.simTime += CONTROL_DT;
}
syncBodies() {
for (let b = 0; b < this.model.nbody; b++) {
if (!(b in this.bodies)) continue;
getPosition(this.data.xpos, b, this.bodies[b].position);
getQuaternion(this.data.xquat, b, this.bodies[b].quaternion);
}
}
animate() {
const loop = async () => {
requestAnimationFrame(loop);
try { await this.step(); } catch (e) { console.error('Step:', e); }
this.syncBodies();
// Track robot pelvis (body 1) with camera
if (this.bodies[1]) {
const p = this.bodies[1].position;
this.controls.target.lerp(new THREE.Vector3(p.x, p.y, p.z), 0.05);
}
this.controls.update();
this.renderer.render(this.scene, this.camera);
// FPS
this.fpsCounter.frames++;
const now = performance.now();
if (now - this.fpsCounter.lastTime > 1000) {
this.fpsCounter.value = Math.round(this.fpsCounter.frames * 1000 / (now - this.fpsCounter.lastTime));
this.fpsCounter.frames = 0;
this.fpsCounter.lastTime = now;
document.getElementById('info-fps').textContent = `${this.fpsCounter.value} FPS`;
}
document.getElementById('info-frame').textContent =
`Frame: ${this.getCurrentFrame()} / ${this.nFrames}`;
};
loop();
}
setupControls() {
const pauseBtn = document.getElementById('btn-pause');
const resetBtn = document.getElementById('btn-reset');
const cameraBtn = document.getElementById('btn-camera');
pauseBtn.addEventListener('click', () => {
this.paused = !this.paused;
pauseBtn.textContent = this.paused ? 'Play' : 'Pause';
pauseBtn.classList.toggle('active', this.paused);
});
resetBtn.addEventListener('click', () => {
this.resetSimulation();
this.syncBodies();
});
cameraBtn.addEventListener('click', () => {
this.resetCamera();
this.controls.target.set(0, 0.8, 0);
this.controls.update();
});
document.addEventListener('keydown', (e) => {
if (e.code === 'Space') { e.preventDefault(); pauseBtn.click(); }
});
}
}
// ============================================================
// Entry point
// ============================================================
document.addEventListener('DOMContentLoaded', () => {
const params = new URLSearchParams(window.location.search);
const clip = params.get('clip');
const category = params.get('category');
const embed = params.get('embed') === '1';
const policy = params.get('policy') || 'policy';
if (!clip || !category) {
document.getElementById('load-status').textContent = 'Missing clip or category parameter';
return;
}
new G1PolicyViewer().init(clip, category, embed, policy);
});