Spaces:
Runtime error
Runtime error
| import datetime | |
| import itertools | |
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| from collections import defaultdict | |
| from pathlib import Path | |
| import numpy as np | |
| import yaml | |
| def flatten_conf(conf, to={}, parents=[]): | |
| """ | |
| Flattens a configuration dict: nested dictionaries are flattened | |
| as key1.key2.key3 = value | |
| conf.yaml: | |
| ```yaml | |
| a: 1 | |
| b: | |
| c: 2 | |
| d: | |
| e: 3 | |
| g: | |
| sample: sequential | |
| from: [4, 5] | |
| ``` | |
| Is flattened to | |
| { | |
| "a": 1, | |
| "b.c": 2, | |
| "b.d.e": 3, | |
| "b.g": { | |
| "sample": "sequential", | |
| "from": [4, 5] | |
| } | |
| } | |
| Does not affect sampling dicts. | |
| Args: | |
| conf (dict): the configuration to flatten | |
| new (dict, optional): the target flatenned dict. Defaults to {}. | |
| parents (list, optional): a final value's list of parents. Defaults to []. | |
| """ | |
| for k, v in conf.items(): | |
| if isinstance(v, dict) and "sample" not in v: | |
| flatten_conf(v, to, parents + [k]) | |
| else: | |
| new_k = ".".join([str(p) for p in parents + [k]]) | |
| to[new_k] = v | |
| def env_to_path(path): | |
| """Transorms an environment variable mention in a json | |
| into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds | |
| Args: | |
| path (str): path potentially containing the env variable | |
| """ | |
| path_elements = path.split("/") | |
| new_path = [] | |
| for el in path_elements: | |
| if "$" in el: | |
| new_path.append(os.environ[el.replace("$", "")]) | |
| else: | |
| new_path.append(el) | |
| return "/".join(new_path) | |
| class C: | |
| HEADER = "\033[95m" | |
| OKBLUE = "\033[94m" | |
| OKGREEN = "\033[92m" | |
| WARNING = "\033[93m" | |
| FAIL = "\033[91m" | |
| ENDC = "\033[0m" | |
| BOLD = "\033[1m" | |
| UNDERLINE = "\033[4m" | |
| ITALIC = "\33[3m" | |
| BEIGE = "\33[36m" | |
| def escape_path(path): | |
| p = str(path) | |
| return p.replace(" ", "\ ").replace("(", "\(").replace(")", "\)") # noqa: W605 | |
| def warn(*args, **kwargs): | |
| print("{}{}{}".format(C.WARNING, " ".join(args), C.ENDC), **kwargs) | |
| def parse_jobID(command_output): | |
| """ | |
| get job id from successful sbatch command output like | |
| `Submitted batch job 599583` | |
| Args: | |
| command_output (str): sbatch command's output | |
| Returns: | |
| int: the slurm job's ID | |
| """ | |
| command_output = command_output.strip() | |
| if isinstance(command_output, str): | |
| if "Submitted batch job" in command_output: | |
| return int(command_output.split()[-1]) | |
| return -1 | |
| def now(): | |
| return str(datetime.datetime.now()).replace(" ", "_") | |
| def cols(): | |
| try: | |
| col = os.get_terminal_size().columns | |
| except Exception: | |
| col = 50 | |
| return col | |
| def print_box(txt): | |
| if not txt: | |
| txt = "{}{}ERROR ⇪{}".format(C.BOLD, C.FAIL, C.ENDC) | |
| lt = 7 | |
| else: | |
| lt = len(txt) | |
| nlt = lt + 12 | |
| txt = "|" + " " * 5 + txt + " " * 5 + "|" | |
| line = "-" * nlt | |
| empty = "|" + " " * (nlt - 2) + "|" | |
| print(line) | |
| print(empty) | |
| print(txt) | |
| print(empty) | |
| print(line) | |
| def print_header(idx): | |
| b = C.BOLD | |
| bl = C.OKBLUE | |
| e = C.ENDC | |
| char = "≡" | |
| c = cols() | |
| txt = " " * 20 | |
| txt += f"{b}{bl}Run {idx}{e}" | |
| txt += " " * 20 | |
| ln = len(txt) - len(b) - len(bl) - len(e) | |
| t = int(np.floor((c - ln) / 2)) | |
| tt = int(np.ceil((c - ln) / 2)) | |
| print(char * c) | |
| print(char * t + " " * ln + char * tt) | |
| print(char * t + txt + char * tt) | |
| print(char * t + " " * ln + char * tt) | |
| print(char * c) | |
| def print_footer(): | |
| c = cols() | |
| char = "﹎" | |
| print() | |
| print(char * (c // len(char))) | |
| print() | |
| print(" " * (c // 2) + "•" + " " * (c - c // 2 - 1)) | |
| print() | |
| def extend_summary(summary, tmp_train_args_dict, tmp_template_dict, exclude=[]): | |
| exclude = set(exclude) | |
| if summary is None: | |
| summary = defaultdict(list) | |
| for k, v in tmp_template_dict.items(): | |
| if k not in exclude: | |
| summary[k].append(v) | |
| for k, v in tmp_train_args_dict.items(): | |
| if k not in exclude: | |
| if isinstance(v, list): | |
| v = str(v) | |
| summary[k].append(v) | |
| return summary | |
| def search_summary_table(summary, summary_dir=None): | |
| # filter out constant values | |
| summary = {k: v for k, v in summary.items() if len(set(v)) > 1} | |
| # if everything is constant: no summary | |
| if not summary: | |
| return None, None | |
| # find number of searches | |
| n_searches = len(list(summary.values())[0]) | |
| # print section title | |
| print( | |
| "{}{}{}Varying values across {} experiments:{}\n".format( | |
| C.OKBLUE, | |
| C.BOLD, | |
| C.UNDERLINE, | |
| n_searches, | |
| C.ENDC, | |
| ) | |
| ) | |
| # first column holds the Exp. number | |
| first_col = { | |
| "len": 8, # length of a column, to split columns according to terminal width | |
| "str": ["| Exp. |", "|:----:|"] | |
| + [ | |
| "| {0:^{1}} |".format(i, 4) for i in range(n_searches) | |
| ], # list of values to print | |
| } | |
| print_columns = [[first_col]] | |
| file_columns = [first_col] | |
| for k in sorted(summary.keys()): | |
| v = summary[k] | |
| col_title = f" {k} |" | |
| col_blank_line = f":{'-' * len(k)}-|" | |
| col_values = [ | |
| " {0:{1}} |".format( | |
| crop_string( | |
| str(crop_float(v[idx], min([5, len(k) - 2]))), len(k) | |
| ), # crop floats and long strings | |
| len(k), | |
| ) | |
| for idx in range(len(v)) | |
| ] | |
| # create column object | |
| col = {"len": len(k) + 3, "str": [col_title, col_blank_line] + col_values} | |
| # if adding a new column would overflow the terminal and mess up printing, start | |
| # new set of columns | |
| if sum(c["len"] for c in print_columns[-1]) + col["len"] >= cols(): | |
| print_columns.append([first_col]) | |
| # store current column to latest group of columns | |
| print_columns[-1].append(col) | |
| file_columns.append(col) | |
| print_table = "" | |
| # print each column group individually | |
| for colgroup in print_columns: | |
| # print columns line by line | |
| for i in range(n_searches + 2): | |
| # get value of column for current line i | |
| for col in colgroup: | |
| print_table += col["str"][i] | |
| # next line for current columns | |
| print_table += "\n" | |
| # new lines for new column group | |
| print_table += "\n" | |
| file_table = "" | |
| for i in range(n_searches + 2): | |
| # get value of column for current line i | |
| for col in file_columns: | |
| file_table += col["str"][i] | |
| # next line for current columns | |
| file_table += "\n" | |
| summary_path = None | |
| if summary_dir is not None: | |
| summary_path = summary_dir / (now() + ".md") | |
| with summary_path.open("w") as f: | |
| f.write(file_table.strip()) | |
| return print_table, summary_path | |
| def clean_arg(v): | |
| """ | |
| chain cleaning function | |
| Args: | |
| v (any): arg to pass to train.py | |
| Returns: | |
| str: parsed value to string | |
| """ | |
| return stringify_list(crop_float(quote_string(resolve_env(v)))) | |
| def resolve_env(v): | |
| """ | |
| resolve env variables in paths | |
| Args: | |
| v (any): arg to pass to train.py | |
| Returns: | |
| str: try and resolve an env variable | |
| """ | |
| if isinstance(v, str): | |
| try: | |
| if "$" in v: | |
| if "/" in v: | |
| v = env_to_path(v) | |
| else: | |
| _v = os.environ.get(v) | |
| if _v is not None: | |
| v = _v | |
| except Exception: | |
| pass | |
| return v | |
| def stringify_list(v): | |
| """ | |
| Stringify list (with double quotes) so that it can be passed a an argument | |
| to train.py's hydra command-line parsing | |
| Args: | |
| v (any): value to clean | |
| Returns: | |
| any: type of v, str if v was a list | |
| """ | |
| if isinstance(v, list): | |
| return '"{}"'.format(str(v).replace('"', "'")) | |
| if isinstance(v, str): | |
| if v.startswith("[") and v.endswith("]"): | |
| return f'"{v}"' | |
| return v | |
| def quote_string(v): | |
| """ | |
| Add double quotes around string if it contains a " " or an = | |
| Args: | |
| v (any): value to clean | |
| Returns: | |
| any: type of v, quoted if v is a string with " " or = | |
| """ | |
| if isinstance(v, str): | |
| if " " in v or "=" in v: | |
| return f'"{v}"' | |
| return v | |
| def crop_float(v, k=5): | |
| """ | |
| If v is a float, crop precision to 5 digits and return v as a str | |
| Args: | |
| v (any): value to crop if float | |
| Returns: | |
| any: cropped float as str if v is a float, original v otherwise | |
| """ | |
| if isinstance(v, float): | |
| return "{0:.{1}g}".format(v, k) | |
| return v | |
| def compute_n_search(conf): | |
| """ | |
| Compute the number of searchs to do if using -1 as n_search and using | |
| cartesian or sequential search | |
| Args: | |
| conf (dict): experimental configuration | |
| Returns: | |
| int: size of the cartesian product or length of longest sequential field | |
| """ | |
| samples = defaultdict(list) | |
| for k, v in conf.items(): | |
| if not isinstance(v, dict) or "sample" not in v: | |
| continue | |
| samples[v["sample"]].append(v) | |
| totals = [] | |
| if "cartesian" in samples: | |
| total = 1 | |
| for s in samples["cartesian"]: | |
| total *= len(s["from"]) | |
| totals.append(total) | |
| if "sequential" in samples: | |
| total = max(map(len, [s["from"] for s in samples["sequential"]])) | |
| totals.append(total) | |
| if totals: | |
| return max(totals) | |
| raise ValueError( | |
| "Used n_search=-1 without any field being 'cartesian' or 'sequential'" | |
| ) | |
| def crop_string(s, k=10): | |
| if len(s) <= k: | |
| return s | |
| else: | |
| return s[: k - 2] + ".." | |
| def sample_param(sample_dict): | |
| """sample a value (hyperparameter) from the instruction in the | |
| sample dict: | |
| { | |
| "sample": "range | list", | |
| "from": [min, max, step] | [v0, v1, v2 etc.] | |
| } | |
| if range, as np.arange is used, "from" MUST be a list, but may contain | |
| only 1 (=min) or 2 (min and max) values, not necessarily 3 | |
| Args: | |
| sample_dict (dict): instructions to sample a value | |
| Returns: | |
| scalar: sampled value | |
| """ | |
| if not isinstance(sample_dict, dict) or "sample" not in sample_dict: | |
| return sample_dict | |
| if sample_dict["sample"] == "cartesian": | |
| assert isinstance( | |
| sample_dict["from"], list | |
| ), "{}'s `from` field MUST be a list, found {}".format( | |
| sample_dict["sample"], sample_dict["from"] | |
| ) | |
| return "__cartesian__" | |
| if sample_dict["sample"] == "sequential": | |
| assert isinstance( | |
| sample_dict["from"], list | |
| ), "{}'s `from` field MUST be a list, found {}".format( | |
| sample_dict["sample"], sample_dict["from"] | |
| ) | |
| return "__sequential__" | |
| if sample_dict["sample"] == "range": | |
| return np.random.choice(np.arange(*sample_dict["from"])) | |
| if sample_dict["sample"] == "list": | |
| return np.random.choice(sample_dict["from"]) | |
| if sample_dict["sample"] == "uniform": | |
| return np.random.uniform(*sample_dict["from"]) | |
| raise ValueError("Unknown sample type in dict " + str(sample_dict)) | |
| def sample_sequentials(sequential_keys, exp, idx): | |
| """ | |
| Samples sequentially from the "from" values specified in each key of the | |
| experimental configuration which have sample == "sequential" | |
| Unlike `cartesian` sampling, `sequential` sampling iterates *independently* | |
| over each keys | |
| Args: | |
| sequential_keys (list): keys to be sampled sequentially | |
| exp (dict): experimental config | |
| idx (int): index of the current sample | |
| Returns: | |
| conf: sampled dict | |
| """ | |
| conf = {} | |
| for k in sequential_keys: | |
| v = exp[k]["from"] | |
| conf[k] = v[idx % len(v)] | |
| return conf | |
| def sample_cartesians(cartesian_keys, exp, idx): | |
| """ | |
| Returns the `idx`th item in the cartesian product of all cartesian keys to | |
| be sampled. | |
| Args: | |
| cartesian_keys (list): keys in the experimental configuration that are to | |
| be used in the full cartesian product | |
| exp (dict): experimental configuration | |
| idx (int): index of the current sample | |
| Returns: | |
| dict: sampled point in the cartesian space (with keys = cartesian_keys) | |
| """ | |
| conf = {} | |
| cartesian_values = [exp[key]["from"] for key in cartesian_keys] | |
| product = list(itertools.product(*cartesian_values)) | |
| for k, v in zip(cartesian_keys, product[idx % len(product)]): | |
| conf[k] = v | |
| return conf | |
| def resolve(hp_conf, nb): | |
| """ | |
| Samples parameters parametrized in `exp`: should be a dict with | |
| values which fit `sample_params(dic)`'s API | |
| Args: | |
| exp (dict): experiment's parametrization | |
| nb (int): number of experiments to sample | |
| Returns: | |
| dict: sampled configuration | |
| """ | |
| if nb == -1: | |
| nb = compute_n_search(hp_conf) | |
| confs = [] | |
| for idx in range(nb): | |
| conf = {} | |
| cartesians = [] | |
| sequentials = [] | |
| for k, v in hp_conf.items(): | |
| candidate = sample_param(v) | |
| if candidate == "__cartesian__": | |
| cartesians.append(k) | |
| elif candidate == "__sequential__": | |
| sequentials.append(k) | |
| else: | |
| conf[k] = candidate | |
| if sequentials: | |
| conf.update(sample_sequentials(sequentials, hp_conf, idx)) | |
| if cartesians: | |
| conf.update(sample_cartesians(cartesians, hp_conf, idx)) | |
| confs.append(conf) | |
| return confs | |
| def get_template_params(template): | |
| """ | |
| extract args in template str as {arg} | |
| Args: | |
| template (str): sbatch template string | |
| Returns: | |
| list(str): Args required to format the template string | |
| """ | |
| return map( | |
| lambda s: s.replace("{", "").replace("}", ""), | |
| re.findall("\{.*?\}", template), # noqa: W605 | |
| ) | |
| def read_exp_conf(name): | |
| """ | |
| Read hp search configuration from shared/experiment/ | |
| specified with or without the .yaml extension | |
| Args: | |
| name (str): name of the template to find in shared/experiment/ | |
| Returns: | |
| Tuple(Path, dict): file path and loaded dict | |
| """ | |
| if ".yaml" not in name: | |
| name += ".yaml" | |
| paths = [] | |
| dirs = ["shared", "config"] | |
| for d in dirs: | |
| path = Path(__file__).parent / d / "experiment" / name | |
| if path.exists(): | |
| paths.append(path.resolve()) | |
| if len(paths) == 0: | |
| failed = [Path(__file__).parent / d / "experiment" for d in dirs] | |
| s = "Could not find search config {} in :\n".format(name) | |
| for fd in failed: | |
| s += str(fd) + "\nAvailable:\n" | |
| for ym in fd.glob("*.yaml"): | |
| s += " " + ym.name + "\n" | |
| raise ValueError(s) | |
| if len(paths) == 2: | |
| print( | |
| "Warning: found 2 relevant files for search config:\n{}".format( | |
| "\n".join(paths) | |
| ) | |
| ) | |
| print("Using {}".format(paths[-1])) | |
| with paths[-1].open("r") as f: | |
| conf = yaml.safe_load(f) | |
| flat_conf = {} | |
| flatten_conf(conf, to=flat_conf) | |
| return (paths[-1], flat_conf) | |
| def read_template(name): | |
| """ | |
| Read template from shared/template/ specified with or without the .sh extension | |
| Args: | |
| name (str): name of the template to find in shared/template/ | |
| Returns: | |
| str: file's content as 1 string | |
| """ | |
| if ".sh" not in name: | |
| name += ".sh" | |
| paths = [] | |
| dirs = ["shared", "config"] | |
| for d in dirs: | |
| path = Path(__file__).parent / d / "template" / name | |
| if path.exists(): | |
| paths.append(path) | |
| if len(paths) == 0: | |
| failed = [Path(__file__).parent / d / "template" for d in dirs] | |
| s = "Could not find template {} in :\n".format(name) | |
| for fd in failed: | |
| s += str(fd) + "\nAvailable:\n" | |
| for ym in fd.glob("*.sh"): | |
| s += " " + ym.name + "\n" | |
| raise ValueError(s) | |
| if len(paths) == 2: | |
| print("Warning: found 2 relevant template files:\n{}".format("\n".join(paths))) | |
| print("Using {}".format(paths[-1])) | |
| with paths[-1].open("r") as f: | |
| return f.read() | |
| def is_sampled(key, conf): | |
| """ | |
| Is a key sampled or constant? Returns true if conf is empty | |
| Args: | |
| key (str): key to check | |
| conf (dict): hyper parameter search configuration dict | |
| Returns: | |
| bool: key is sampled? | |
| """ | |
| return not conf or ( | |
| key in conf and isinstance(conf[key], dict) and "sample" in conf[key] | |
| ) | |
| if __name__ == "__main__": | |
| """ | |
| Notes: | |
| * Must provide template name as template=name | |
| * `name`.sh should be in shared/template/ | |
| """ | |
| # ------------------------------- | |
| # ----- Default Variables ----- | |
| # ------------------------------- | |
| args = sys.argv[1:] | |
| command_output = "" | |
| user = os.environ.get("USER") | |
| home = os.environ.get("HOME") | |
| exp_conf = {} | |
| dev = False | |
| escape = False | |
| verbose = False | |
| template_name = None | |
| hp_exp_name = None | |
| hp_search_nb = None | |
| exp_path = None | |
| resume = None | |
| force_sbatchs = False | |
| sbatch_base = Path(home) / "climategan_sbatchs" | |
| summary_dir = Path(home) / "climategan_exp_summaries" | |
| hp_search_private = set(["n_search", "template", "search", "summary_dir"]) | |
| sbatch_path = "hash" | |
| # -------------------------- | |
| # ----- Sanity Check ----- | |
| # -------------------------- | |
| for arg in args: | |
| if "=" not in arg or " = " in arg: | |
| raise ValueError( | |
| "Args should be passed as `key=value`. Received `{}`".format(arg) | |
| ) | |
| # -------------------------------- | |
| # ----- Parse Command Line ----- | |
| # -------------------------------- | |
| args_dict = {arg.split("=")[0]: arg.split("=")[1] for arg in args} | |
| assert "template" in args_dict, "Please specify template=xxx" | |
| template = read_template(args_dict["template"]) | |
| template_dict = {k: None for k in get_template_params(template)} | |
| train_args = [] | |
| for k, v in args_dict.items(): | |
| if k == "verbose": | |
| if v != "0": | |
| verbose = True | |
| elif k == "sbatch_path": | |
| sbatch_path = v | |
| elif k == "sbatch_base": | |
| sbatch_base = Path(v).resolve() | |
| elif k == "force_sbatchs": | |
| force_sbatchs = v.lower() == "true" | |
| elif k == "dev": | |
| if v.lower() != "false": | |
| dev = True | |
| elif k == "escape": | |
| if v.lower() != "false": | |
| escape = True | |
| elif k == "template": | |
| template_name = v | |
| elif k == "exp": | |
| hp_exp_name = v | |
| elif k == "n_search": | |
| hp_search_nb = int(v) | |
| elif k == "resume": | |
| resume = f'"{v}"' | |
| template_dict[k] = f'"{v}"' | |
| elif k == "summary_dir": | |
| if v.lower() == "none": | |
| summary_dir = None | |
| else: | |
| summary_dir = Path(v) | |
| elif k in template_dict: | |
| template_dict[k] = v | |
| else: | |
| train_args.append(f"{k}={v}") | |
| # ------------------------------------ | |
| # ----- Load Experiment Config ----- | |
| # ------------------------------------ | |
| if hp_exp_name is not None: | |
| exp_path, exp_conf = read_exp_conf(hp_exp_name) | |
| if "n_search" in exp_conf and hp_search_nb is None: | |
| hp_search_nb = exp_conf["n_search"] | |
| assert ( | |
| hp_search_nb is not None | |
| ), "n_search should be specified in a yaml file or from the command line" | |
| hps = resolve(exp_conf, hp_search_nb) | |
| else: | |
| hps = [None] | |
| # --------------------------------- | |
| # ----- Run All Experiments ----- | |
| # --------------------------------- | |
| if summary_dir is not None: | |
| summary_dir.mkdir(exist_ok=True, parents=True) | |
| summary = None | |
| for hp_idx, hp in enumerate(hps): | |
| # copy shared values | |
| tmp_template_dict = template_dict.copy() | |
| tmp_train_args = train_args.copy() | |
| tmp_train_args_dict = { | |
| arg.split("=")[0]: arg.split("=")[1] for arg in tmp_train_args | |
| } | |
| print_header(hp_idx) | |
| # override shared values with run-specific values for run hp_idx/n_search | |
| if hp is not None: | |
| for k, v in hp.items(): | |
| if k == "resume" and resume is None: | |
| resume = f'"{v}"' | |
| # hp-search params to ignore | |
| if k in hp_search_private: | |
| continue | |
| if k == "codeloc": | |
| v = escape_path(v) | |
| if k == "output": | |
| Path(v).parent.mkdir(parents=True, exist_ok=True) | |
| # override template params depending on exp config | |
| if k in tmp_template_dict: | |
| if template_dict[k] is None or is_sampled(k, exp_conf): | |
| tmp_template_dict[k] = v | |
| # store sampled / specified params in current tmp_train_args_dict | |
| else: | |
| if k in tmp_train_args_dict: | |
| if is_sampled(k, exp_conf): | |
| # warn if key was specified from the command line | |
| tv = tmp_train_args_dict[k] | |
| warn( | |
| "\nWarning: overriding sampled config-file arg", | |
| "{} to command-line value {}\n".format(k, tv), | |
| ) | |
| else: | |
| tmp_train_args_dict[k] = v | |
| # create sbatch file where required | |
| tmp_sbatch_path = None | |
| if sbatch_path == "hash": | |
| tmp_sbatch_name = "" if hp_exp_name is None else hp_exp_name[:14] + "_" | |
| tmp_sbatch_name += now() + ".sh" | |
| tmp_sbatch_path = sbatch_base / tmp_sbatch_name | |
| tmp_sbatch_path.parent.mkdir(parents=True, exist_ok=True) | |
| tmp_train_args_dict["sbatch_file"] = str(tmp_sbatch_path) | |
| tmp_train_args_dict["exp_file"] = str(exp_path) | |
| else: | |
| tmp_sbatch_path = Path(sbatch_path).resolve() | |
| summary = extend_summary( | |
| summary, tmp_train_args_dict, tmp_template_dict, exclude=["sbatch_file"] | |
| ) | |
| # format train.py's args and crop floats' precision to 5 digits | |
| tmp_template_dict["train_args"] = " ".join( | |
| sorted( | |
| [ | |
| "{}={}".format(k, clean_arg(v)) | |
| for k, v in tmp_train_args_dict.items() | |
| ] | |
| ) | |
| ) | |
| if "resume.py" in template and resume is None: | |
| raise ValueError("No `resume` value but using a resume.py template") | |
| # format template with clean dict (replace None with "") | |
| sbatch = template.format( | |
| **{ | |
| k: v if v is not None else "" | |
| for k, v in tmp_template_dict.items() | |
| if k in template_dict | |
| } | |
| ) | |
| # -------------------------------------- | |
| # ----- Execute `sbatch` Command ----- | |
| # -------------------------------------- | |
| if not dev or force_sbatchs: | |
| if tmp_sbatch_path.exists(): | |
| print(f"Warning: overwriting {sbatch_path}") | |
| # write sbatch file | |
| with open(tmp_sbatch_path, "w") as f: | |
| f.write(sbatch) | |
| if not dev: | |
| # escape special characters such as " " from sbatch_path's parent dir | |
| parent = str(tmp_sbatch_path.parent) | |
| if escape: | |
| parent = escape_path(parent) | |
| # create command to execute in a subprocess | |
| command = "sbatch {}".format(tmp_sbatch_path.name) | |
| # execute sbatch command & store output | |
| command_output = subprocess.run( | |
| command.split(), stdout=subprocess.PIPE, cwd=parent | |
| ) | |
| command_output = "\n" + command_output.stdout.decode("utf-8") + "\n" | |
| print(f"Running from {parent}:") | |
| print(f"$ {command}") | |
| # --------------------------------- | |
| # ----- Summarize Execution ----- | |
| # --------------------------------- | |
| if verbose: | |
| print(C.BEIGE + C.ITALIC, "\n" + sbatch + C.ENDC) | |
| if not dev: | |
| print_box(command_output.strip()) | |
| jobID = parse_jobID(command_output.strip()) | |
| summary["Slurm JOBID"].append(jobID) | |
| summary["Comet Link"].append(f"[{hp_idx}][{hp_idx}]") | |
| print( | |
| "{}{}Summary{} {}:".format( | |
| C.UNDERLINE, | |
| C.OKGREEN, | |
| C.ENDC, | |
| f"{C.WARNING}(DEV){C.ENDC}" if dev else "", | |
| ) | |
| ) | |
| print( | |
| " " | |
| + "\n ".join( | |
| "{:10}: {}".format(k, v) for k, v in tmp_template_dict.items() | |
| ) | |
| ) | |
| print_footer() | |
| print(f"\nRan a total of {len(hps)} jobs{' in dev mode.' if dev else '.'}\n") | |
| table, sum_path = search_summary_table(summary, summary_dir if not dev else None) | |
| if table is not None: | |
| print(table) | |
| print( | |
| "Add `[i]: https://...` at the end of a markdown document", | |
| "to fill in the comet links.\n", | |
| ) | |
| if summary_dir is None: | |
| print("Add summary_dir=path to store the printed markdown table ⇪") | |
| else: | |
| print("Saved table in", str(sum_path)) | |
| if not dev: | |
| print( | |
| "Cancel entire experiment? \n$ scancel", | |
| " ".join(map(str, summary["Slurm JOBID"])), | |
| ) | |