experiments
Experiment configuration and runner APIs.
1"""Experiment configuration and runner APIs.""" 2 3from experiments.config import ExperimentConfig 4from experiments.configs import get_config, list_configs 5from experiments.defaults import default_policy, default_theta0 6from experiments.reporters import ( 7 ConsoleReporter, 8 FileStepLogger, 9 JsonReporter, 10 PlotReporter, 11 ReporterStack, 12 RunContext, 13 StepReporter, 14 create_run_context, 15) 16from experiments.results import EstimatorResult, ExperimentResult, OptimizationTrace 17from experiments.run import run_experiment 18from experiments.sweep_utils import ( 19 apply_config_overrides, 20 expand_override_grid, 21 generate_sweep_runs, 22 make_sweep_name, 23 run_preset_sweep, 24) 25 26__all__ = [ 27 "ExperimentConfig", 28 "get_config", 29 "list_configs", 30 "default_theta0", 31 "default_policy", 32 "EstimatorResult", 33 "ExperimentResult", 34 "OptimizationTrace", 35 "ConsoleReporter", 36 "FileStepLogger", 37 "JsonReporter", 38 "PlotReporter", 39 "ReporterStack", 40 "RunContext", 41 "StepReporter", 42 "create_run_context", 43 "run_experiment", 44 "expand_override_grid", 45 "apply_config_overrides", 46 "make_sweep_name", 47 "generate_sweep_runs", 48 "run_preset_sweep", 49]
49@dataclass(frozen=True) 50class ExperimentConfig: 51 """Frozen configuration for a single experiment run with validation.""" 52 53 state_dim: int 54 n_samples: int 55 step_rule: str 56 objective: Objective 57 theta0: np.ndarray 58 perturbation_space: Literal["theta", "u"] 59 batch_size: int | None = None 60 seed: int = 7 61 t_steps: int = 100 62 step_size: float = 0.01 63 grad_norm_tol: Optional[float] = None 64 ftol: Optional[float] = None 65 sigma: float = 0.1 66 n_grad_samples: int = 64 67 verbose: bool = False 68 plot: bool = True 69 plot_dir: str = "plots" 70 enabled_estimators: tuple[str, ...] = ("first_order", "gauss_stein") 71 wandb_enabled: bool = False 72 wandb_project: str | None = None 73 wandb_entity: str | None = None 74 wandb_group: str | None = None 75 wandb_job_type: str = "experiment" 76 wandb_tags: tuple[str, ...] = () 77 wandb_mode: Literal["online", "offline", "disabled"] = "online" 78 wandb_log_plots: bool = True 79 wandb_estimator_allowlist: tuple[str, ...] | None = None 80 correctness: CorrectnessSpec = field(default_factory=CorrectnessSpec) 81 x_fixed: np.ndarray | None = None # real data rows; replaces sample_states when set 82 83 def __post_init__(self) -> None: 84 estimator_aliases = { 85 "finite-difference": "finite_difference", 86 "stein-difference": "stein_difference", 87 } 88 enabled_estimators = tuple(estimator_aliases.get(name, name) for name in self.enabled_estimators) 89 object.__setattr__(self, "enabled_estimators", enabled_estimators) 90 if not enabled_estimators: 91 raise ValueError("enabled_estimators must include at least one estimator.") 92 if len(set(enabled_estimators)) != len(enabled_estimators): 93 raise ValueError("enabled_estimators must not contain duplicates.") 94 allowed_estimators = { 95 "first_order", 96 "finite_difference", 97 "gauss_stein", 98 "spsa", 99 "stein_difference", 100 } 101 unknown = [name for name in enabled_estimators if name not in allowed_estimators] 102 if unknown: 103 allowed = ", ".join(sorted(allowed_estimators)) 104 unknown_list = ", ".join(unknown) 105 raise ValueError(f"Unknown estimators: {unknown_list}. Allowed: {allowed}.") 106 107 if self.perturbation_space not in {"theta", "u"}: 108 raise ValueError("perturbation_space must be 'theta' or 'u'.") 109 110 wandb_tags = tuple(self.wandb_tags) 111 object.__setattr__(self, "wandb_tags", wandb_tags) 112 if self.wandb_mode not in {"online", "offline", "disabled"}: 113 raise ValueError("wandb_mode must be 'online', 'offline', or 'disabled'.") 114 if self.wandb_enabled and self.wandb_mode == "disabled": 115 raise ValueError("wandb_mode='disabled' is incompatible with wandb_enabled=True.") 116 if self.wandb_estimator_allowlist is not None: 117 wandb_allowlist = tuple(estimator_aliases.get(name, name) for name in self.wandb_estimator_allowlist) 118 object.__setattr__(self, "wandb_estimator_allowlist", wandb_allowlist) 119 if len(set(wandb_allowlist)) != len(wandb_allowlist): 120 raise ValueError("wandb_estimator_allowlist must not contain duplicates.") 121 unknown_wandb = [name for name in wandb_allowlist if name not in allowed_estimators] 122 if unknown_wandb: 123 allowed = ", ".join(sorted(allowed_estimators)) 124 unknown_list = ", ".join(unknown_wandb) 125 raise ValueError( 126 f"Unknown wandb estimators: {unknown_list}. Allowed: {allowed}." 127 ) 128 129 if self.perturbation_space == "u": 130 policy = getattr(self.objective, "policy", None) 131 if policy is None or not callable(getattr(policy, "value", None)) or not callable(getattr(policy, "grad", None)): 132 raise ValueError( 133 "perturbation_space='u' requires objective.policy with value() and grad()." 134 ) 135 136 if self.state_dim <= 0: 137 raise ValueError("state_dim must be positive.") 138 if self.n_samples <= 0: 139 raise ValueError("n_samples must be positive.") 140 141 theta0_arr = np.asarray(self.theta0, dtype=float) 142 if theta0_arr.ndim != 1 or theta0_arr.size < 1: 143 raise ValueError("theta0 must be a 1D array with at least one element.") 144 object.__setattr__(self, "theta0", theta0_arr) 145 146 if self.x_fixed is not None: 147 x_fixed_arr = np.asarray(self.x_fixed, dtype=float) 148 if x_fixed_arr.ndim != 2: 149 raise ValueError("x_fixed must be a 2D array of shape (n_rows, state_dim).") 150 if x_fixed_arr.shape[1] != self.state_dim: 151 raise ValueError( 152 f"x_fixed has {x_fixed_arr.shape[1]} columns but state_dim={self.state_dim}." 153 ) 154 object.__setattr__(self, "x_fixed", x_fixed_arr) 155 156 if self.batch_size is not None: 157 if self.batch_size <= 0: 158 raise ValueError("batch_size must be positive when provided.") 159 if self.batch_size > self.n_samples: 160 raise ValueError("batch_size must be <= n_samples when provided.") 161 162 if self.step_rule not in STEP_RULES: 163 allowed = ", ".join(sorted(STEP_RULES)) 164 raise ValueError(f"step_rule must be one of {allowed}.") 165 if self.step_size <= 0.0: 166 raise ValueError("step_size must be positive.") 167 if self.grad_norm_tol is not None and self.grad_norm_tol <= 0.0: 168 raise ValueError("grad_norm_tol must be positive when provided.") 169 if self.ftol is not None and self.ftol <= 0.0: 170 raise ValueError("ftol must be positive when provided.") 171 if self.n_grad_samples <= 0: 172 raise ValueError("n_grad_samples must be positive.") 173 174 objective = self.objective 175 value_fn = getattr(objective, "value", None) 176 grad_fn = getattr(objective, "grad", None) 177 if value_fn is None or not callable(value_fn): 178 raise ValueError("objective must implement value(theta, x_batch).") 179 if grad_fn is None or not callable(grad_fn): 180 raise ValueError("objective must implement grad(theta, x_batch).") 181 182 policy = getattr(objective, "policy", None) 183 if policy is not None: 184 policy_value = getattr(policy, "value", None) 185 policy_grad = getattr(policy, "grad", None) 186 if not callable(policy_value) or not callable(policy_grad): 187 raise ValueError("policy must implement value(theta, x_batch) and grad(theta, x_batch).") 188 # Probe with a single-sample batch 189 x_probe = np.zeros((1, self.state_dim), dtype=float) 190 u_probe_arr = np.asarray(policy_value(theta0_arr, x_probe), dtype=float) 191 if not bool(np.isfinite(u_probe_arr).all()): 192 raise ValueError("policy.value(theta0, x_batch) must be finite.") 193 grad_probe = np.asarray(policy_grad(theta0_arr, x_probe), dtype=float) 194 if grad_probe.ndim != 2 or grad_probe.shape[1] != theta0_arr.size: 195 raise ValueError("policy.grad(theta0, x_batch) must return (n_samples, theta_dim).") 196 197 def to_dict(self) -> dict[str, Any]: 198 """Serialize config to dictionary for JSON output.""" 199 return { 200 "state_dim": int(self.state_dim), 201 "n_samples": int(self.n_samples), 202 "batch_size": int(self.batch_size) if self.batch_size is not None else None, 203 "step_rule": self.step_rule, 204 "seed": int(self.seed), 205 "t_steps": int(self.t_steps), 206 "step_size": float(self.step_size), 207 "grad_norm_tol": float(self.grad_norm_tol) 208 if self.grad_norm_tol is not None 209 else None, 210 "ftol": float(self.ftol) if self.ftol is not None else None, 211 "sigma": float(self.sigma), 212 "n_grad_samples": int(self.n_grad_samples), 213 "verbose": bool(self.verbose), 214 "plot": bool(self.plot), 215 "plot_dir": self.plot_dir, 216 "enabled_estimators": list(self.enabled_estimators), 217 "perturbation_space": self.perturbation_space, 218 "theta0": _as_list(self.theta0), 219 "objective": _objective_to_dict(self.objective), 220 "wandb": { 221 "enabled": bool(self.wandb_enabled), 222 "project": self.wandb_project, 223 "entity": self.wandb_entity, 224 "group": self.wandb_group, 225 "job_type": self.wandb_job_type, 226 "tags": list(self.wandb_tags), 227 "mode": self.wandb_mode, 228 "log_plots": bool(self.wandb_log_plots), 229 "estimator_allowlist": list(self.wandb_estimator_allowlist) 230 if self.wandb_estimator_allowlist is not None 231 else None, 232 }, 233 "correctness": _correctness_to_dict(self.correctness), 234 "x_fixed_shape": list(self.x_fixed.shape) if self.x_fixed is not None else None, 235 }
Frozen configuration for a single experiment run with validation.
197 def to_dict(self) -> dict[str, Any]: 198 """Serialize config to dictionary for JSON output.""" 199 return { 200 "state_dim": int(self.state_dim), 201 "n_samples": int(self.n_samples), 202 "batch_size": int(self.batch_size) if self.batch_size is not None else None, 203 "step_rule": self.step_rule, 204 "seed": int(self.seed), 205 "t_steps": int(self.t_steps), 206 "step_size": float(self.step_size), 207 "grad_norm_tol": float(self.grad_norm_tol) 208 if self.grad_norm_tol is not None 209 else None, 210 "ftol": float(self.ftol) if self.ftol is not None else None, 211 "sigma": float(self.sigma), 212 "n_grad_samples": int(self.n_grad_samples), 213 "verbose": bool(self.verbose), 214 "plot": bool(self.plot), 215 "plot_dir": self.plot_dir, 216 "enabled_estimators": list(self.enabled_estimators), 217 "perturbation_space": self.perturbation_space, 218 "theta0": _as_list(self.theta0), 219 "objective": _objective_to_dict(self.objective), 220 "wandb": { 221 "enabled": bool(self.wandb_enabled), 222 "project": self.wandb_project, 223 "entity": self.wandb_entity, 224 "group": self.wandb_group, 225 "job_type": self.wandb_job_type, 226 "tags": list(self.wandb_tags), 227 "mode": self.wandb_mode, 228 "log_plots": bool(self.wandb_log_plots), 229 "estimator_allowlist": list(self.wandb_estimator_allowlist) 230 if self.wandb_estimator_allowlist is not None 231 else None, 232 }, 233 "correctness": _correctness_to_dict(self.correctness), 234 "x_fixed_shape": list(self.x_fixed.shape) if self.x_fixed is not None else None, 235 }
Serialize config to dictionary for JSON output.
25def get_config(name: str) -> ExperimentConfig: 26 try: 27 module_name = _CONFIG_MODULES[name] 28 except KeyError as exc: 29 available = ", ".join(sorted(_CONFIG_MODULES.keys())) 30 raise ValueError(f"Unknown experiment config '{name}'. Available: {available}.") from exc 31 32 if name not in _CONFIG_CACHE: 33 module = import_module(module_name) 34 _CONFIG_CACHE[name] = module.CONFIG 35 return _CONFIG_CACHE[name]
11def default_theta0(state_dim: int) -> np.ndarray: 12 """Return default initial theta for a policy with given state dimension.""" 13 return np.asarray([0.1] + [0.01] * state_dim, dtype=float)
Return default initial theta for a policy with given state dimension.
16def default_policy(state_dim: int = 1) -> SoftmaxPolicy: 17 """Return default softmax policy.""" 18 return SoftmaxPolicy()
Return default softmax policy.
31@dataclass(frozen=True) 32class EstimatorResult: 33 """Final result for one estimator: theta, mean action, objective value, and wall time.""" 34 35 theta: np.ndarray 36 u: float 37 value: float 38 time: float
Final result for one estimator: theta, mean action, objective value, and wall time.
41@dataclass(frozen=True) 42class ExperimentResult: 43 """Full experiment result: config, samples, traces, and final values per estimator.""" 44 45 config: ExperimentConfig 46 x_samples: np.ndarray # Shape (n_samples, state_dim) 47 initial_value: float 48 results: Mapping[str, EstimatorResult] 49 traces: Mapping[str, OptimizationTrace] 50 u_star: Optional[float] = None 51 value_at_u_star: Optional[float] = None
Full experiment result: config, samples, traces, and final values per estimator.
14@dataclass(frozen=True) 15class OptimizationTrace: 16 """Per-step trace: u values, objective values, gradient norms, and theta history.""" 17 18 steps: Sequence[int] 19 u_values: Sequence[float] 20 objective_values: Sequence[float] 21 u_grad_estimates: Sequence[float] 22 u_true_gradients: Optional[Sequence[float]] = None 23 theta_grad_norms: Optional[Sequence[float]] = None 24 true_theta_grad_norms: Optional[Sequence[float]] = None 25 step_sizes: Optional[Sequence[float]] = None 26 theta_values: Optional[Sequence[np.ndarray]] = None 27 optimizer_status: Optional[int] = None 28 optimizer_message: Optional[str] = None
Per-step trace: u values, objective values, gradient norms, and theta history.
117class ConsoleReporter: 118 """Reporter that prints to terminal. Verbose mode controls per-step output.""" 119 120 def __init__(self, verbose: bool = False) -> None: 121 self._verbose = verbose 122 123 def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None: 124 print(f"\n=== Running experiment: {run_context.experiment_name} ===") 125 126 def on_end(self, run_context: RunContext, result: ExperimentResult) -> None: 127 log_summary(result) 128 129 def log_step( 130 self, 131 method: str, 132 step: int, 133 u: float, 134 value: float, 135 grad_norm: float | None = None, 136 step_size: float | None = None, 137 ) -> None: 138 if self._verbose: 139 log_step(method, step, u, value, grad_norm, step_size)
Reporter that prints to terminal. Verbose mode controls per-step output.
343class FileStepLogger: 344 """Writes per-step metrics to a CSV file in the run directory.""" 345 346 def __init__(self) -> None: 347 self._file: IO[str] | None = None 348 self._path: Path | None = None 349 350 def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None: 351 self._path = run_context.run_dir / "steps.csv" 352 self._file = self._path.open("w", encoding="utf-8") 353 self._file.write("method,step,u,value,grad_norm,step_size\n") # type: ignore[union-attr] 354 355 def on_end(self, run_context: RunContext, result: ExperimentResult) -> None: 356 if self._file is not None: 357 self._file.close() 358 self._file = None 359 360 def log_step( 361 self, 362 method: str, 363 step: int, 364 u: float, 365 value: float, 366 grad_norm: float | None = None, 367 step_size: float | None = None, 368 ) -> None: 369 if self._file is None: 370 return 371 grad_str = f"{grad_norm:.6f}" if grad_norm is not None else "" 372 step_str = f"{step_size:.6f}" if step_size is not None else "" 373 self._file.write(f"{method},{step},{u:.6f},{value:.6f},{grad_str},{step_str}\n")
Writes per-step metrics to a CSV file in the run directory.
360 def log_step( 361 self, 362 method: str, 363 step: int, 364 u: float, 365 value: float, 366 grad_norm: float | None = None, 367 step_size: float | None = None, 368 ) -> None: 369 if self._file is None: 370 return 371 grad_str = f"{grad_norm:.6f}" if grad_norm is not None else "" 372 step_str = f"{step_size:.6f}" if step_size is not None else "" 373 self._file.write(f"{method},{step},{u:.6f},{value:.6f},{grad_str},{step_str}\n")
142class JsonReporter: 143 def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None: 144 return None 145 146 def on_end(self, run_context: RunContext, result: ExperimentResult) -> None: 147 payload = _build_summary_payload(run_context, result) 148 summary_path = run_context.run_dir / "summary.json" 149 with summary_path.open("w", encoding="utf-8") as handle: 150 json.dump(payload, handle, indent=2, sort_keys=True, ensure_ascii=True)
146 def on_end(self, run_context: RunContext, result: ExperimentResult) -> None: 147 payload = _build_summary_payload(run_context, result) 148 summary_path = run_context.run_dir / "summary.json" 149 with summary_path.open("w", encoding="utf-8") as handle: 150 json.dump(payload, handle, indent=2, sort_keys=True, ensure_ascii=True)
269class PlotReporter: 270 def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None: 271 return None 272 273 def on_end(self, run_context: RunContext, result: ExperimentResult) -> None: 274 config = result.config 275 if not config.plot or not result.traces: 276 return 277 run_context.plots_dir.mkdir(parents=True, exist_ok=True) 278 plot_dir = str(run_context.plots_dir) 279 objective = config.objective 280 action_objective = getattr(objective, "action_objective", None) 281 traces = result.traces 282 u_star_plot = _u_star_for_plot(action_objective, result.u_star) 283 plot_loss_curves( 284 traces, 285 plot_dir, 286 u_star=u_star_plot, 287 ) 288 plot_gradient_norms(traces, plot_dir) 289 if action_objective is not None: 290 plot_objective_u_slice( 291 result.x_samples, 292 action_objective, 293 traces, 294 plot_dir, 295 u_star=u_star_plot, 296 ) 297 if config.step_rule == STEP_RULE_ARMIJO: 298 plot_step_sizes(traces, plot_dir) 299 if config.theta0.size >= 2: 300 axis_indices = (0, 1) 301 axis_labels = None 302 theta_path_points = [config.theta0] 303 for trace in traces.values(): 304 if trace.theta_values: 305 theta_path_points.extend(trace.theta_values) 306 if config.theta0.size > 2 and theta_path_points: 307 axis_indices = select_theta_axes_max_variance(theta_path_points) 308 axis_labels = ( 309 f"theta[{axis_indices[0]}] (max-var axis)", 310 f"theta[{axis_indices[1]}] (max-var axis)", 311 ) 312 ordered_results = [ 313 (name, result.results[name]) 314 for name in config.enabled_estimators 315 if name in result.results 316 ] 317 theta_refs = [config.theta0] 318 theta_points = [(config.theta0, "initial", "#636363", "o")] 319 for name, estimator_result in ordered_results: 320 theta_refs.append(estimator_result.theta) 321 style = ESTIMATOR_STYLES[name] 322 theta_points.append( 323 ( 324 estimator_result.theta, 325 style["label"], 326 style["color"], 327 style["marker"], 328 ) 329 ) 330 plot_theta_objective_contours( 331 result.x_samples, 332 objective, 333 config.theta0, 334 plot_dir, 335 axis_indices=axis_indices, 336 axis_labels=axis_labels, 337 theta_refs=theta_refs, 338 theta_points=theta_points, 339 traces=traces, 340 )
273 def on_end(self, run_context: RunContext, result: ExperimentResult) -> None: 274 config = result.config 275 if not config.plot or not result.traces: 276 return 277 run_context.plots_dir.mkdir(parents=True, exist_ok=True) 278 plot_dir = str(run_context.plots_dir) 279 objective = config.objective 280 action_objective = getattr(objective, "action_objective", None) 281 traces = result.traces 282 u_star_plot = _u_star_for_plot(action_objective, result.u_star) 283 plot_loss_curves( 284 traces, 285 plot_dir, 286 u_star=u_star_plot, 287 ) 288 plot_gradient_norms(traces, plot_dir) 289 if action_objective is not None: 290 plot_objective_u_slice( 291 result.x_samples, 292 action_objective, 293 traces, 294 plot_dir, 295 u_star=u_star_plot, 296 ) 297 if config.step_rule == STEP_RULE_ARMIJO: 298 plot_step_sizes(traces, plot_dir) 299 if config.theta0.size >= 2: 300 axis_indices = (0, 1) 301 axis_labels = None 302 theta_path_points = [config.theta0] 303 for trace in traces.values(): 304 if trace.theta_values: 305 theta_path_points.extend(trace.theta_values) 306 if config.theta0.size > 2 and theta_path_points: 307 axis_indices = select_theta_axes_max_variance(theta_path_points) 308 axis_labels = ( 309 f"theta[{axis_indices[0]}] (max-var axis)", 310 f"theta[{axis_indices[1]}] (max-var axis)", 311 ) 312 ordered_results = [ 313 (name, result.results[name]) 314 for name in config.enabled_estimators 315 if name in result.results 316 ] 317 theta_refs = [config.theta0] 318 theta_points = [(config.theta0, "initial", "#636363", "o")] 319 for name, estimator_result in ordered_results: 320 theta_refs.append(estimator_result.theta) 321 style = ESTIMATOR_STYLES[name] 322 theta_points.append( 323 ( 324 estimator_result.theta, 325 style["label"], 326 style["color"], 327 style["marker"], 328 ) 329 ) 330 plot_theta_objective_contours( 331 result.x_samples, 332 objective, 333 config.theta0, 334 plot_dir, 335 axis_indices=axis_indices, 336 axis_labels=axis_labels, 337 theta_refs=theta_refs, 338 theta_points=theta_points, 339 traces=traces, 340 )
91class ReporterStack: 92 def __init__(self, reporters: Sequence[Reporter]) -> None: 93 self._reporters = list(reporters) 94 95 def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None: 96 for reporter in self._reporters: 97 reporter.on_start(run_context, config) 98 99 def on_end(self, run_context: RunContext, result: ExperimentResult) -> None: 100 for reporter in self._reporters: 101 reporter.on_end(run_context, result) 102 103 def log_step( 104 self, 105 method: str, 106 step: int, 107 u: float, 108 value: float, 109 grad_norm: float | None = None, 110 step_size: float | None = None, 111 ) -> None: 112 for reporter in self._reporters: 113 if isinstance(reporter, StepReporter): 114 reporter.log_step(method, step, u, value, grad_norm, step_size)
103 def log_step( 104 self, 105 method: str, 106 step: int, 107 u: float, 108 value: float, 109 grad_norm: float | None = None, 110 step_size: float | None = None, 111 ) -> None: 112 for reporter in self._reporters: 113 if isinstance(reporter, StepReporter): 114 reporter.log_step(method, step, u, value, grad_norm, step_size)
31@dataclass(frozen=True) 32class RunContext: 33 experiment_name: str 34 run_id: str 35 run_dir: Path 36 plots_dir: Path 37 started_at: datetime
66@runtime_checkable 67class StepReporter(Protocol): 68 """Protocol for per-step metric logging during optimization.""" 69 70 def log_step( 71 self, 72 method: str, 73 step: int, 74 u: float, 75 value: float, 76 grad_norm: float | None = None, 77 step_size: float | None = None, 78 ) -> None: 79 ...
Protocol for per-step metric logging during optimization.
1866def _no_init_or_replace_init(self, *args, **kwargs): 1867 cls = type(self) 1868 1869 if cls._is_protocol: 1870 raise TypeError('Protocols cannot be instantiated') 1871 1872 # Already using a custom `__init__`. No need to calculate correct 1873 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1874 if cls.__init__ is not _no_init_or_replace_init: 1875 return 1876 1877 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1878 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1879 # searches for a proper new `__init__` in the MRO. The new `__init__` 1880 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1881 # instantiation of the protocol subclass will thus use the new 1882 # `__init__` and no longer call `_no_init_or_replace_init`. 1883 for base in cls.__mro__: 1884 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1885 if init is not _no_init_or_replace_init: 1886 cls.__init__ = init 1887 break 1888 else: 1889 # should not happen 1890 cls.__init__ = object.__init__ 1891 1892 cls.__init__(self, *args, **kwargs)
46def create_run_context( 47 experiment_name: str, 48 runs_root: str = "outputs", 49 started_at: datetime | None = None, 50) -> RunContext: 51 timestamp = started_at or datetime.now() 52 run_id = timestamp.strftime("%Y%m%d_%H%M%S") 53 safe_name = _sanitize_name(experiment_name) 54 run_dir = Path(runs_root) / safe_name / run_id 55 run_dir.mkdir(parents=True, exist_ok=True) 56 plots_dir = run_dir / "plots" 57 return RunContext( 58 experiment_name=experiment_name, 59 run_id=run_id, 60 run_dir=run_dir, 61 plots_dir=plots_dir, 62 started_at=timestamp, 63 )
25def run_experiment( 26 config: ExperimentConfig, 27 step_reporter: StepReporter | None = None, 28) -> ExperimentResult: 29 """Run optimization with all enabled estimators; returns traces and final values.""" 30 objective = config.objective 31 enabled_estimators = tuple(config.enabled_estimators) 32 33 rng = default_rng(config.seed) 34 if config.x_fixed is not None: 35 x_samples = np.asarray(config.x_fixed, dtype=float) 36 else: 37 x_samples = sample_states(rng, config.n_samples, config.state_dim) 38 true_grad_theta_fn = resolve_true_grad_theta_fn(objective, config.correctness) 39 40 theta_initial = np.asarray(config.theta0, dtype=float) 41 initial_value = float(objective.value(theta_initial, x_samples)) 42 43 # Get optimal u if available 44 u_star = optimal_u(objective) 45 46 # Compute value at u* if available 47 value_at_u_star = None 48 if u_star is not None: 49 try: 50 value_at_u_star = _action_value_at_u(objective, x_samples, u_star) 51 except ValueError: 52 pass 53 54 # Get policy from objective for mean_action computation 55 policy = getattr(objective, "policy", None) 56 57 results: dict[str, EstimatorResult] = {} 58 traces = {} 59 60 if "first_order" in enabled_estimators: 61 start_first = time.perf_counter() 62 theta_first, trace_first = run_first_order( 63 theta_initial, 64 x_samples, 65 objective, 66 rng, 67 config.t_steps, 68 config.step_rule, 69 config.step_size, 70 config.n_grad_samples, 71 config.sigma, 72 config.batch_size, 73 perturbation_space=config.perturbation_space, 74 true_grad_theta_fn=true_grad_theta_fn, 75 grad_norm_tol=config.grad_norm_tol, 76 ftol=config.ftol, 77 step_reporter=step_reporter, 78 ) 79 time_first = time.perf_counter() - start_first 80 u_first = _mean_action(policy, theta_first, x_samples) if policy is not None else float("nan") 81 value_first = float(objective.value(theta_first, x_samples)) 82 results["first_order"] = EstimatorResult(theta=theta_first, u=u_first, value=value_first, time=time_first) 83 traces["first_order"] = trace_first 84 85 if "finite_difference" in enabled_estimators: 86 start_fd = time.perf_counter() 87 theta_fd, trace_fd = run_finite_difference( 88 theta_initial, 89 x_samples, 90 objective, 91 rng, 92 config.t_steps, 93 config.step_rule, 94 config.step_size, 95 config.n_grad_samples, 96 config.sigma, 97 config.batch_size, 98 perturbation_space=config.perturbation_space, 99 true_grad_theta_fn=true_grad_theta_fn, 100 grad_norm_tol=config.grad_norm_tol, 101 ftol=config.ftol, 102 step_reporter=step_reporter, 103 ) 104 time_fd = time.perf_counter() - start_fd 105 u_fd = _mean_action(policy, theta_fd, x_samples) if policy is not None else float("nan") 106 value_fd = float(objective.value(theta_fd, x_samples)) 107 results["finite_difference"] = EstimatorResult(theta=theta_fd, u=u_fd, value=value_fd, time=time_fd) 108 traces["finite_difference"] = trace_fd 109 110 if "gauss_stein" in enabled_estimators: 111 start_zero = time.perf_counter() 112 theta_zero, trace_zero = run_gauss_stein( 113 theta_initial, 114 x_samples, 115 objective, 116 rng, 117 config.t_steps, 118 config.step_rule, 119 config.step_size, 120 config.n_grad_samples, 121 config.sigma, 122 config.batch_size, 123 perturbation_space=config.perturbation_space, 124 true_grad_theta_fn=true_grad_theta_fn, 125 grad_norm_tol=config.grad_norm_tol, 126 ftol=config.ftol, 127 step_reporter=step_reporter, 128 ) 129 time_zero = time.perf_counter() - start_zero 130 u_zero = _mean_action(policy, theta_zero, x_samples) if policy is not None else float("nan") 131 value_zero = float(objective.value(theta_zero, x_samples)) 132 results["gauss_stein"] = EstimatorResult(theta=theta_zero, u=u_zero, value=value_zero, time=time_zero) 133 traces["gauss_stein"] = trace_zero 134 135 if "spsa" in enabled_estimators: 136 start_spsa = time.perf_counter() 137 theta_spsa, trace_spsa = run_spsa( 138 theta_initial, 139 x_samples, 140 objective, 141 rng, 142 config.t_steps, 143 config.step_rule, 144 config.step_size, 145 config.n_grad_samples, 146 config.sigma, 147 config.batch_size, 148 perturbation_space=config.perturbation_space, 149 true_grad_theta_fn=true_grad_theta_fn, 150 grad_norm_tol=config.grad_norm_tol, 151 ftol=config.ftol, 152 step_reporter=step_reporter, 153 ) 154 time_spsa = time.perf_counter() - start_spsa 155 u_spsa = _mean_action(policy, theta_spsa, x_samples) if policy is not None else float("nan") 156 value_spsa = float(objective.value(theta_spsa, x_samples)) 157 results["spsa"] = EstimatorResult(theta=theta_spsa, u=u_spsa, value=value_spsa, time=time_spsa) 158 traces["spsa"] = trace_spsa 159 160 if "stein_difference" in enabled_estimators: 161 start_stein = time.perf_counter() 162 theta_stein, trace_stein = run_stein_difference( 163 theta_initial, 164 x_samples, 165 objective, 166 rng, 167 config.t_steps, 168 config.step_rule, 169 config.step_size, 170 config.n_grad_samples, 171 config.sigma, 172 config.batch_size, 173 perturbation_space=config.perturbation_space, 174 true_grad_theta_fn=true_grad_theta_fn, 175 grad_norm_tol=config.grad_norm_tol, 176 ftol=config.ftol, 177 step_reporter=step_reporter, 178 ) 179 time_stein = time.perf_counter() - start_stein 180 u_stein = _mean_action(policy, theta_stein, x_samples) if policy is not None else float("nan") 181 value_stein = float(objective.value(theta_stein, x_samples)) 182 results["stein_difference"] = EstimatorResult( 183 theta=theta_stein, 184 u=u_stein, 185 value=value_stein, 186 time=time_stein, 187 ) 188 traces["stein_difference"] = trace_stein 189 190 return ExperimentResult( 191 config=config, 192 x_samples=x_samples, 193 initial_value=initial_value, 194 results=results, 195 traces=traces, 196 u_star=u_star, 197 value_at_u_star=value_at_u_star, 198 )
Run optimization with all enabled estimators; returns traces and final values.
26def expand_override_grid(grid: Mapping[str, Sequence[Any]]) -> list[dict[str, Any]]: 27 """Build cartesian-product override dictionaries from a field-value grid.""" 28 if not grid: 29 return [{}] 30 keys = list(grid.keys()) 31 value_lists = [list(grid[key]) for key in keys] 32 return [dict(zip(keys, combo)) for combo in product(*value_lists)]
Build cartesian-product override dictionaries from a field-value grid.
35def apply_config_overrides(config: ExperimentConfig, overrides: Mapping[str, Any]) -> ExperimentConfig: 36 """Return a config copy with top-level ExperimentConfig fields overridden.""" 37 valid_fields = {field.name for field in fields(ExperimentConfig)} 38 unknown = sorted(key for key in overrides.keys() if key not in valid_fields) 39 if unknown: 40 unknown_text = ", ".join(unknown) 41 raise ValueError(f"Unknown config override fields: {unknown_text}.") 42 return replace(config, **dict(overrides))
Return a config copy with top-level ExperimentConfig fields overridden.
45def make_sweep_name(base_name: str, index: int, overrides: Mapping[str, Any]) -> str: 46 """Build a readable, deterministic run name for one sweep variant.""" 47 if not overrides: 48 return f"{base_name}__sweep_{index:03d}" 49 parts = [f"{key}-{_stringify_override_value(overrides[key])}" for key in sorted(overrides.keys())] 50 suffix = "__".join(parts) 51 return f"{base_name}__sweep_{index:03d}__{suffix}"
Build a readable, deterministic run name for one sweep variant.
72def generate_sweep_runs( 73 *, 74 base_preset: str, 75 override_grid: Mapping[str, Sequence[Any]] | None = None, 76 override_list: Sequence[Mapping[str, Any]] | None = None, 77 display_keys: Sequence[str] | None = None, 78) -> list[tuple[str, ExperimentConfig, dict[str, Any]]]: 79 """Generate named configs by applying overrides to a base preset.""" 80 if override_grid is not None and override_list is not None: 81 raise ValueError("Specify either override_grid or override_list, not both.") 82 83 if override_list is not None: 84 overrides = [dict(item) for item in override_list] if override_list else [{}] 85 elif override_grid is not None: 86 overrides = expand_override_grid(override_grid) 87 else: 88 overrides = [{}] 89 90 base_config = get_config(base_preset) 91 runs: list[tuple[str, ExperimentConfig, dict[str, Any]]] = [] 92 for index, override in enumerate(overrides, start=1): 93 config = apply_config_overrides(base_config, override) 94 run_name = make_display_name( 95 base_preset, 96 index=index, 97 overrides=override, 98 display_keys=display_keys, 99 ) 100 runs.append((run_name, config, override)) 101 return runs
Generate named configs by applying overrides to a base preset.
104def run_preset_sweep( 105 *, 106 base_preset: str, 107 override_grid: Mapping[str, Sequence[Any]] | None = None, 108 override_list: Sequence[Mapping[str, Any]] | None = None, 109 runs_root: str = "outputs", 110 project_name: str | None = None, 111 display_keys: Sequence[str] | None = None, 112) -> list[tuple[str, ExperimentResult]]: 113 """Execute a preset sweep and return `(run_name, result)` pairs.""" 114 sweep_runs = generate_sweep_runs( 115 base_preset=base_preset, 116 override_grid=override_grid, 117 override_list=override_list, 118 display_keys=display_keys, 119 ) 120 results: list[tuple[str, ExperimentResult]] = [] 121 runs_root_path = _project_runs_root(runs_root, project_name) 122 123 for run_name, config, _ in sweep_runs: 124 run_context = create_run_context(run_name, runs_root=runs_root_path) 125 reporter_list = [ 126 ConsoleReporter(verbose=config.verbose), 127 FileStepLogger(), 128 JsonReporter(), 129 PlotReporter(), 130 ] 131 if config.wandb_enabled: 132 reporter_list.append(WandbReporter()) 133 reporters = ReporterStack(reporter_list) 134 reporters.on_start(run_context, config) 135 result = run_experiment(config, step_reporter=reporters) 136 reporters.on_end(run_context, result) 137 results.append((run_name, result)) 138 139 return results
Execute a preset sweep and return (run_name, result) pairs.