Spaces:
Sleeping
Sleeping
Commit
·
ec9b09e
1
Parent(s):
16f5fe6
dataset v2 and pybullet
Browse files- app.py +1 -1
- src/backend.py +57 -73
app.py
CHANGED
|
@@ -63,7 +63,7 @@ pre, code {
|
|
| 63 |
|
| 64 |
|
| 65 |
REPO_ID = "open-rl-leaderboard/leaderboard"
|
| 66 |
-
RESULTS_REPO = "open-rl-leaderboard/
|
| 67 |
|
| 68 |
|
| 69 |
links_md = f"""
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
REPO_ID = "open-rl-leaderboard/leaderboard"
|
| 66 |
+
RESULTS_REPO = "open-rl-leaderboard/results_v2"
|
| 67 |
|
| 68 |
|
| 69 |
links_md = f"""
|
src/backend.py
CHANGED
|
@@ -2,11 +2,9 @@ import fnmatch
|
|
| 2 |
import importlib
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
-
import
|
| 6 |
import shutil
|
| 7 |
import sys
|
| 8 |
-
import tempfile
|
| 9 |
-
import time
|
| 10 |
import zipfile
|
| 11 |
from pathlib import Path
|
| 12 |
from typing import Optional
|
|
@@ -15,7 +13,8 @@ import numpy as np
|
|
| 15 |
import rl_zoo3.import_envs # noqa: F401 pylint: disable=unused-import
|
| 16 |
import torch as th
|
| 17 |
import yaml
|
| 18 |
-
from
|
|
|
|
| 19 |
from huggingface_hub.utils import EntryNotFoundError
|
| 20 |
from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId, load_from_hub
|
| 21 |
from requests.exceptions import HTTPError
|
|
@@ -118,6 +117,16 @@ ALL_ENV_IDS = [
|
|
| 118 |
"Reacher-v4",
|
| 119 |
"Swimmer-v4",
|
| 120 |
"Walker2d-v4",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
]
|
| 122 |
|
| 123 |
|
|
@@ -504,85 +513,59 @@ def evaluate(
|
|
| 504 |
logger = setup_logger(__name__)
|
| 505 |
|
| 506 |
API = HfApi(token=os.environ.get("TOKEN"))
|
| 507 |
-
RESULTS_REPO = "open-rl-leaderboard/
|
| 508 |
|
| 509 |
|
| 510 |
def _backend_routine():
|
| 511 |
# List only the text classification models
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
compatible_models.append((model.modelId, model.sha))
|
| 517 |
-
|
| 518 |
-
logger.info(f"Found {len(compatible_models)} compatible models")
|
| 519 |
-
|
| 520 |
# Get the results
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
for filename in filenames:
|
| 527 |
-
path = API.hf_hub_download(repo_id=RESULTS_REPO, filename=filename, repo_type="dataset")
|
| 528 |
-
with open(path) as fp:
|
| 529 |
-
report = json.load(fp)
|
| 530 |
-
evaluated_models.add((report["config"]["model_id"], report["config"]["model_sha"]))
|
| 531 |
-
|
| 532 |
-
# Find the models that are not associated with any results
|
| 533 |
-
pending_models = list(set(compatible_models) - evaluated_models)
|
| 534 |
logger.info(f"Found {len(pending_models)} pending models")
|
| 535 |
|
| 536 |
if len(pending_models) == 0:
|
| 537 |
return None
|
| 538 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
# Run an evaluation on the models
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
if evaluations is not None:
|
| 568 |
-
report["results"] = evaluations
|
| 569 |
-
report["status"] = "DONE"
|
| 570 |
-
else:
|
| 571 |
-
report["status"] = "FAILED"
|
| 572 |
-
|
| 573 |
-
# Update the results
|
| 574 |
-
dumped = json.dumps(report, indent=2)
|
| 575 |
-
path_in_repo = f"{model_id}/results_{sha}.json"
|
| 576 |
-
local_path = os.path.join(tmp_dir, path_in_repo)
|
| 577 |
-
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
| 578 |
-
with open(local_path, "w") as f:
|
| 579 |
-
f.write(dumped)
|
| 580 |
-
|
| 581 |
-
commits.append(CommitOperationAdd(path_in_repo=path_in_repo, path_or_fileobj=local_path))
|
| 582 |
-
|
| 583 |
-
API.create_commit(
|
| 584 |
-
repo_id=RESULTS_REPO, commit_message="Add evaluation results", operations=commits, repo_type="dataset"
|
| 585 |
-
)
|
| 586 |
|
| 587 |
|
| 588 |
def backend_routine():
|
|
@@ -593,4 +576,5 @@ def backend_routine():
|
|
| 593 |
|
| 594 |
|
| 595 |
if __name__ == "__main__":
|
| 596 |
-
|
|
|
|
|
|
| 2 |
import importlib
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
+
import random
|
| 6 |
import shutil
|
| 7 |
import sys
|
|
|
|
|
|
|
| 8 |
import zipfile
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Optional
|
|
|
|
| 13 |
import rl_zoo3.import_envs # noqa: F401 pylint: disable=unused-import
|
| 14 |
import torch as th
|
| 15 |
import yaml
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
from huggingface_hub import HfApi
|
| 18 |
from huggingface_hub.utils import EntryNotFoundError
|
| 19 |
from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId, load_from_hub
|
| 20 |
from requests.exceptions import HTTPError
|
|
|
|
| 117 |
"Reacher-v4",
|
| 118 |
"Swimmer-v4",
|
| 119 |
"Walker2d-v4",
|
| 120 |
+
# PyBullet
|
| 121 |
+
"AntBulletEnv-v0",
|
| 122 |
+
"HalfCheetahBulletEnv-v0",
|
| 123 |
+
"HopperBulletEnv-v0",
|
| 124 |
+
"HumanoidBulletEnv-v0",
|
| 125 |
+
"InvertedDoublePendulumBulletEnv-v0",
|
| 126 |
+
"InvertedPendulumSwingupBulletEnv-v0",
|
| 127 |
+
"MinitaurBulletEnv-v0",
|
| 128 |
+
"ReacherBulletEnv-v0",
|
| 129 |
+
"Walker2DBulletEnv-v0",
|
| 130 |
]
|
| 131 |
|
| 132 |
|
|
|
|
| 513 |
logger = setup_logger(__name__)
|
| 514 |
|
| 515 |
API = HfApi(token=os.environ.get("TOKEN"))
|
| 516 |
+
RESULTS_REPO = "open-rl-leaderboard/results_v2"
|
| 517 |
|
| 518 |
|
| 519 |
def _backend_routine():
|
| 520 |
# List only the text classification models
|
| 521 |
+
sb3_models = [
|
| 522 |
+
(model.modelId, model.sha) for model in API.list_models(filter=["reinforcement-learning", "stable-baselines3"])
|
| 523 |
+
]
|
| 524 |
+
logger.info(f"Found {len(sb3_models)} SB3 models")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
# Get the results
|
| 526 |
+
dataset = load_dataset(
|
| 527 |
+
RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks"
|
| 528 |
+
)
|
| 529 |
+
evaluated_models = [("/".join([x["user_id"], x["model_id"]]), x["sha"]) for x in dataset]
|
| 530 |
+
pending_models = list(set(sb3_models) - set(evaluated_models))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
logger.info(f"Found {len(pending_models)} pending models")
|
| 532 |
|
| 533 |
if len(pending_models) == 0:
|
| 534 |
return None
|
| 535 |
|
| 536 |
+
# Select a random model
|
| 537 |
+
repo_id, sha = random.choice(pending_models)
|
| 538 |
+
user_id, model_id = repo_id.split("/")
|
| 539 |
+
row = {"model_id": model_id, "user_id": user_id, "sha": sha}
|
| 540 |
+
|
| 541 |
# Run an evaluation on the models
|
| 542 |
+
model_info = API.model_info(repo_id, revision=sha)
|
| 543 |
+
|
| 544 |
+
# Extract the environment IDs from the tags (usually only one)
|
| 545 |
+
env_ids = pattern_match(model_info.tags, ALL_ENV_IDS)
|
| 546 |
+
if len(env_ids) > 0:
|
| 547 |
+
env = env_ids[0]
|
| 548 |
+
logger.info(f"Running evaluation on {user_id}/{model_id}")
|
| 549 |
+
algo = model_info.model_index[0]["name"].lower()
|
| 550 |
+
|
| 551 |
+
try:
|
| 552 |
+
episodic_returns = evaluate(user_id, model_id, env, "rl-trained-agents", algo, no_render=True, verbose=1)
|
| 553 |
+
row["status"] = "DONE"
|
| 554 |
+
row["env_id"] = env
|
| 555 |
+
row["episodic_returns"] = episodic_returns
|
| 556 |
+
except Exception as e:
|
| 557 |
+
logger.error(f"Error evaluating {model_id}: {e}")
|
| 558 |
+
row["status"] = "FAILED"
|
| 559 |
+
|
| 560 |
+
else:
|
| 561 |
+
logger.error(f"No environment found for {model_id}")
|
| 562 |
+
row["status"] = "FAILED"
|
| 563 |
+
|
| 564 |
+
dataset = load_dataset(
|
| 565 |
+
RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks"
|
| 566 |
+
) # Reload the dataset, in case it was updated
|
| 567 |
+
dataset = dataset.add_item(row)
|
| 568 |
+
dataset.push_to_hub(RESULTS_REPO, split="train")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
|
| 571 |
def backend_routine():
|
|
|
|
| 576 |
|
| 577 |
|
| 578 |
if __name__ == "__main__":
|
| 579 |
+
while True:
|
| 580 |
+
backend_routine()
|