experiments

Experiment configuration and runner APIs.

 1"""Experiment configuration and runner APIs."""
 2
 3from experiments.config import ExperimentConfig
 4from experiments.configs import get_config, list_configs
 5from experiments.defaults import default_policy, default_theta0
 6from experiments.reporters import (
 7    ConsoleReporter,
 8    FileStepLogger,
 9    JsonReporter,
10    PlotReporter,
11    ReporterStack,
12    RunContext,
13    StepReporter,
14    create_run_context,
15)
16from experiments.results import EstimatorResult, ExperimentResult, OptimizationTrace
17from experiments.run import run_experiment
18from experiments.sweep_utils import (
19    apply_config_overrides,
20    expand_override_grid,
21    generate_sweep_runs,
22    make_sweep_name,
23    run_preset_sweep,
24)
25
26__all__ = [
27    "ExperimentConfig",
28    "get_config",
29    "list_configs",
30    "default_theta0",
31    "default_policy",
32    "EstimatorResult",
33    "ExperimentResult",
34    "OptimizationTrace",
35    "ConsoleReporter",
36    "FileStepLogger",
37    "JsonReporter",
38    "PlotReporter",
39    "ReporterStack",
40    "RunContext",
41    "StepReporter",
42    "create_run_context",
43    "run_experiment",
44    "expand_override_grid",
45    "apply_config_overrides",
46    "make_sweep_name",
47    "generate_sweep_runs",
48    "run_preset_sweep",
49]
@dataclass(frozen=True)
class ExperimentConfig:
 49@dataclass(frozen=True)
 50class ExperimentConfig:
 51    """Frozen configuration for a single experiment run with validation."""
 52
 53    state_dim: int
 54    n_samples: int
 55    step_rule: str
 56    objective: Objective
 57    theta0: np.ndarray
 58    perturbation_space: Literal["theta", "u"]
 59    batch_size: int | None = None
 60    seed: int = 7
 61    t_steps: int = 100
 62    step_size: float = 0.01
 63    grad_norm_tol: Optional[float] = None
 64    ftol: Optional[float] = None
 65    sigma: float = 0.1
 66    n_grad_samples: int = 64
 67    verbose: bool = False
 68    plot: bool = True
 69    plot_dir: str = "plots"
 70    enabled_estimators: tuple[str, ...] = ("first_order", "gauss_stein")
 71    wandb_enabled: bool = False
 72    wandb_project: str | None = None
 73    wandb_entity: str | None = None
 74    wandb_group: str | None = None
 75    wandb_job_type: str = "experiment"
 76    wandb_tags: tuple[str, ...] = ()
 77    wandb_mode: Literal["online", "offline", "disabled"] = "online"
 78    wandb_log_plots: bool = True
 79    wandb_estimator_allowlist: tuple[str, ...] | None = None
 80    correctness: CorrectnessSpec = field(default_factory=CorrectnessSpec)
 81    x_fixed: np.ndarray | None = None  # real data rows; replaces sample_states when set
 82
 83    def __post_init__(self) -> None:
 84        estimator_aliases = {
 85            "finite-difference": "finite_difference",
 86            "stein-difference": "stein_difference",
 87        }
 88        enabled_estimators = tuple(estimator_aliases.get(name, name) for name in self.enabled_estimators)
 89        object.__setattr__(self, "enabled_estimators", enabled_estimators)
 90        if not enabled_estimators:
 91            raise ValueError("enabled_estimators must include at least one estimator.")
 92        if len(set(enabled_estimators)) != len(enabled_estimators):
 93            raise ValueError("enabled_estimators must not contain duplicates.")
 94        allowed_estimators = {
 95            "first_order",
 96            "finite_difference",
 97            "gauss_stein",
 98            "spsa",
 99            "stein_difference",
100        }
101        unknown = [name for name in enabled_estimators if name not in allowed_estimators]
102        if unknown:
103            allowed = ", ".join(sorted(allowed_estimators))
104            unknown_list = ", ".join(unknown)
105            raise ValueError(f"Unknown estimators: {unknown_list}. Allowed: {allowed}.")
106
107        if self.perturbation_space not in {"theta", "u"}:
108            raise ValueError("perturbation_space must be 'theta' or 'u'.")
109
110        wandb_tags = tuple(self.wandb_tags)
111        object.__setattr__(self, "wandb_tags", wandb_tags)
112        if self.wandb_mode not in {"online", "offline", "disabled"}:
113            raise ValueError("wandb_mode must be 'online', 'offline', or 'disabled'.")
114        if self.wandb_enabled and self.wandb_mode == "disabled":
115            raise ValueError("wandb_mode='disabled' is incompatible with wandb_enabled=True.")
116        if self.wandb_estimator_allowlist is not None:
117            wandb_allowlist = tuple(estimator_aliases.get(name, name) for name in self.wandb_estimator_allowlist)
118            object.__setattr__(self, "wandb_estimator_allowlist", wandb_allowlist)
119            if len(set(wandb_allowlist)) != len(wandb_allowlist):
120                raise ValueError("wandb_estimator_allowlist must not contain duplicates.")
121            unknown_wandb = [name for name in wandb_allowlist if name not in allowed_estimators]
122            if unknown_wandb:
123                allowed = ", ".join(sorted(allowed_estimators))
124                unknown_list = ", ".join(unknown_wandb)
125                raise ValueError(
126                    f"Unknown wandb estimators: {unknown_list}. Allowed: {allowed}."
127                )
128
129        if self.perturbation_space == "u":
130            policy = getattr(self.objective, "policy", None)
131            if policy is None or not callable(getattr(policy, "value", None)) or not callable(getattr(policy, "grad", None)):
132                raise ValueError(
133                    "perturbation_space='u' requires objective.policy with value() and grad()."
134                )
135
136        if self.state_dim <= 0:
137            raise ValueError("state_dim must be positive.")
138        if self.n_samples <= 0:
139            raise ValueError("n_samples must be positive.")
140
141        theta0_arr = np.asarray(self.theta0, dtype=float)
142        if theta0_arr.ndim != 1 or theta0_arr.size < 1:
143            raise ValueError("theta0 must be a 1D array with at least one element.")
144        object.__setattr__(self, "theta0", theta0_arr)
145
146        if self.x_fixed is not None:
147            x_fixed_arr = np.asarray(self.x_fixed, dtype=float)
148            if x_fixed_arr.ndim != 2:
149                raise ValueError("x_fixed must be a 2D array of shape (n_rows, state_dim).")
150            if x_fixed_arr.shape[1] != self.state_dim:
151                raise ValueError(
152                    f"x_fixed has {x_fixed_arr.shape[1]} columns but state_dim={self.state_dim}."
153                )
154            object.__setattr__(self, "x_fixed", x_fixed_arr)
155
156        if self.batch_size is not None:
157            if self.batch_size <= 0:
158                raise ValueError("batch_size must be positive when provided.")
159            if self.batch_size > self.n_samples:
160                raise ValueError("batch_size must be <= n_samples when provided.")
161
162        if self.step_rule not in STEP_RULES:
163            allowed = ", ".join(sorted(STEP_RULES))
164            raise ValueError(f"step_rule must be one of {allowed}.")
165        if self.step_size <= 0.0:
166            raise ValueError("step_size must be positive.")
167        if self.grad_norm_tol is not None and self.grad_norm_tol <= 0.0:
168            raise ValueError("grad_norm_tol must be positive when provided.")
169        if self.ftol is not None and self.ftol <= 0.0:
170            raise ValueError("ftol must be positive when provided.")
171        if self.n_grad_samples <= 0:
172            raise ValueError("n_grad_samples must be positive.")
173
174        objective = self.objective
175        value_fn = getattr(objective, "value", None)
176        grad_fn = getattr(objective, "grad", None)
177        if value_fn is None or not callable(value_fn):
178            raise ValueError("objective must implement value(theta, x_batch).")
179        if grad_fn is None or not callable(grad_fn):
180            raise ValueError("objective must implement grad(theta, x_batch).")
181
182        policy = getattr(objective, "policy", None)
183        if policy is not None:
184            policy_value = getattr(policy, "value", None)
185            policy_grad = getattr(policy, "grad", None)
186            if not callable(policy_value) or not callable(policy_grad):
187                raise ValueError("policy must implement value(theta, x_batch) and grad(theta, x_batch).")
188            # Probe with a single-sample batch
189            x_probe = np.zeros((1, self.state_dim), dtype=float)
190            u_probe_arr = np.asarray(policy_value(theta0_arr, x_probe), dtype=float)
191            if not bool(np.isfinite(u_probe_arr).all()):
192                raise ValueError("policy.value(theta0, x_batch) must be finite.")
193            grad_probe = np.asarray(policy_grad(theta0_arr, x_probe), dtype=float)
194            if grad_probe.ndim != 2 or grad_probe.shape[1] != theta0_arr.size:
195                raise ValueError("policy.grad(theta0, x_batch) must return (n_samples, theta_dim).")
196
197    def to_dict(self) -> dict[str, Any]:
198        """Serialize config to dictionary for JSON output."""
199        return {
200            "state_dim": int(self.state_dim),
201            "n_samples": int(self.n_samples),
202            "batch_size": int(self.batch_size) if self.batch_size is not None else None,
203            "step_rule": self.step_rule,
204            "seed": int(self.seed),
205            "t_steps": int(self.t_steps),
206            "step_size": float(self.step_size),
207            "grad_norm_tol": float(self.grad_norm_tol)
208            if self.grad_norm_tol is not None
209            else None,
210            "ftol": float(self.ftol) if self.ftol is not None else None,
211            "sigma": float(self.sigma),
212            "n_grad_samples": int(self.n_grad_samples),
213            "verbose": bool(self.verbose),
214            "plot": bool(self.plot),
215            "plot_dir": self.plot_dir,
216            "enabled_estimators": list(self.enabled_estimators),
217            "perturbation_space": self.perturbation_space,
218            "theta0": _as_list(self.theta0),
219            "objective": _objective_to_dict(self.objective),
220            "wandb": {
221                "enabled": bool(self.wandb_enabled),
222                "project": self.wandb_project,
223                "entity": self.wandb_entity,
224                "group": self.wandb_group,
225                "job_type": self.wandb_job_type,
226                "tags": list(self.wandb_tags),
227                "mode": self.wandb_mode,
228                "log_plots": bool(self.wandb_log_plots),
229                "estimator_allowlist": list(self.wandb_estimator_allowlist)
230                if self.wandb_estimator_allowlist is not None
231                else None,
232            },
233            "correctness": _correctness_to_dict(self.correctness),
234            "x_fixed_shape": list(self.x_fixed.shape) if self.x_fixed is not None else None,
235        }

Frozen configuration for a single experiment run with validation.

ExperimentConfig( state_dim: int, n_samples: int, step_rule: str, objective: objective.Objective, theta0: numpy.ndarray, perturbation_space: Literal['theta', 'u'], batch_size: int | None = None, seed: int = 7, t_steps: int = 100, step_size: float = 0.01, grad_norm_tol: float | None = None, ftol: float | None = None, sigma: float = 0.1, n_grad_samples: int = 64, verbose: bool = False, plot: bool = True, plot_dir: str = 'plots', enabled_estimators: tuple[str, ...] = ('first_order', 'gauss_stein'), wandb_enabled: bool = False, wandb_project: str | None = None, wandb_entity: str | None = None, wandb_group: str | None = None, wandb_job_type: str = 'experiment', wandb_tags: tuple[str, ...] = (), wandb_mode: Literal['online', 'offline', 'disabled'] = 'online', wandb_log_plots: bool = True, wandb_estimator_allowlist: tuple[str, ...] | None = None, correctness: experiments.config.CorrectnessSpec = <factory>, x_fixed: numpy.ndarray | None = None)
state_dim: int
n_samples: int
step_rule: str
objective: objective.Objective
theta0: numpy.ndarray
perturbation_space: Literal['theta', 'u']
batch_size: int | None = None
seed: int = 7
t_steps: int = 100
step_size: float = 0.01
grad_norm_tol: float | None = None
ftol: float | None = None
sigma: float = 0.1
n_grad_samples: int = 64
verbose: bool = False
plot: bool = True
plot_dir: str = 'plots'
enabled_estimators: tuple[str, ...] = ('first_order', 'gauss_stein')
wandb_enabled: bool = False
wandb_project: str | None = None
wandb_entity: str | None = None
wandb_group: str | None = None
wandb_job_type: str = 'experiment'
wandb_tags: tuple[str, ...] = ()
wandb_mode: Literal['online', 'offline', 'disabled'] = 'online'
wandb_log_plots: bool = True
wandb_estimator_allowlist: tuple[str, ...] | None = None
correctness: experiments.config.CorrectnessSpec
x_fixed: numpy.ndarray | None = None
def to_dict(self) -> dict[str, typing.Any]:
197    def to_dict(self) -> dict[str, Any]:
198        """Serialize config to dictionary for JSON output."""
199        return {
200            "state_dim": int(self.state_dim),
201            "n_samples": int(self.n_samples),
202            "batch_size": int(self.batch_size) if self.batch_size is not None else None,
203            "step_rule": self.step_rule,
204            "seed": int(self.seed),
205            "t_steps": int(self.t_steps),
206            "step_size": float(self.step_size),
207            "grad_norm_tol": float(self.grad_norm_tol)
208            if self.grad_norm_tol is not None
209            else None,
210            "ftol": float(self.ftol) if self.ftol is not None else None,
211            "sigma": float(self.sigma),
212            "n_grad_samples": int(self.n_grad_samples),
213            "verbose": bool(self.verbose),
214            "plot": bool(self.plot),
215            "plot_dir": self.plot_dir,
216            "enabled_estimators": list(self.enabled_estimators),
217            "perturbation_space": self.perturbation_space,
218            "theta0": _as_list(self.theta0),
219            "objective": _objective_to_dict(self.objective),
220            "wandb": {
221                "enabled": bool(self.wandb_enabled),
222                "project": self.wandb_project,
223                "entity": self.wandb_entity,
224                "group": self.wandb_group,
225                "job_type": self.wandb_job_type,
226                "tags": list(self.wandb_tags),
227                "mode": self.wandb_mode,
228                "log_plots": bool(self.wandb_log_plots),
229                "estimator_allowlist": list(self.wandb_estimator_allowlist)
230                if self.wandb_estimator_allowlist is not None
231                else None,
232            },
233            "correctness": _correctness_to_dict(self.correctness),
234            "x_fixed_shape": list(self.x_fixed.shape) if self.x_fixed is not None else None,
235        }

Serialize config to dictionary for JSON output.

def get_config(name: str) -> ExperimentConfig:
25def get_config(name: str) -> ExperimentConfig:
26    try:
27        module_name = _CONFIG_MODULES[name]
28    except KeyError as exc:
29        available = ", ".join(sorted(_CONFIG_MODULES.keys()))
30        raise ValueError(f"Unknown experiment config '{name}'. Available: {available}.") from exc
31
32    if name not in _CONFIG_CACHE:
33        module = import_module(module_name)
34        _CONFIG_CACHE[name] = module.CONFIG
35    return _CONFIG_CACHE[name]
def list_configs() -> tuple[str, ...]:
21def list_configs() -> tuple[str, ...]:
22    return tuple(_CONFIG_MODULES.keys())
def default_theta0(state_dim: int) -> numpy.ndarray:
11def default_theta0(state_dim: int) -> np.ndarray:
12    """Return default initial theta for a policy with given state dimension."""
13    return np.asarray([0.1] + [0.01] * state_dim, dtype=float)

Return default initial theta for a policy with given state dimension.

def default_policy(state_dim: int = 1) -> objective.SoftmaxPolicy:
16def default_policy(state_dim: int = 1) -> SoftmaxPolicy:
17    """Return default softmax policy."""
18    return SoftmaxPolicy()

Return default softmax policy.

@dataclass(frozen=True)
class EstimatorResult:
31@dataclass(frozen=True)
32class EstimatorResult:
33    """Final result for one estimator: theta, mean action, objective value, and wall time."""
34
35    theta: np.ndarray
36    u: float
37    value: float
38    time: float

Final result for one estimator: theta, mean action, objective value, and wall time.

EstimatorResult(theta: numpy.ndarray, u: float, value: float, time: float)
theta: numpy.ndarray
u: float
value: float
time: float
@dataclass(frozen=True)
class ExperimentResult:
41@dataclass(frozen=True)
42class ExperimentResult:
43    """Full experiment result: config, samples, traces, and final values per estimator."""
44
45    config: ExperimentConfig
46    x_samples: np.ndarray  # Shape (n_samples, state_dim)
47    initial_value: float
48    results: Mapping[str, EstimatorResult]
49    traces: Mapping[str, OptimizationTrace]
50    u_star: Optional[float] = None
51    value_at_u_star: Optional[float] = None

Full experiment result: config, samples, traces, and final values per estimator.

ExperimentResult( config: ExperimentConfig, x_samples: numpy.ndarray, initial_value: float, results: Mapping[str, EstimatorResult], traces: Mapping[str, OptimizationTrace], u_star: float | None = None, value_at_u_star: float | None = None)
x_samples: numpy.ndarray
initial_value: float
results: Mapping[str, EstimatorResult]
traces: Mapping[str, OptimizationTrace]
u_star: float | None = None
value_at_u_star: float | None = None
@dataclass(frozen=True)
class OptimizationTrace:
14@dataclass(frozen=True)
15class OptimizationTrace:
16    """Per-step trace: u values, objective values, gradient norms, and theta history."""
17
18    steps: Sequence[int]
19    u_values: Sequence[float]
20    objective_values: Sequence[float]
21    u_grad_estimates: Sequence[float]
22    u_true_gradients: Optional[Sequence[float]] = None
23    theta_grad_norms: Optional[Sequence[float]] = None
24    true_theta_grad_norms: Optional[Sequence[float]] = None
25    step_sizes: Optional[Sequence[float]] = None
26    theta_values: Optional[Sequence[np.ndarray]] = None
27    optimizer_status: Optional[int] = None
28    optimizer_message: Optional[str] = None

Per-step trace: u values, objective values, gradient norms, and theta history.

OptimizationTrace( steps: Sequence[int], u_values: Sequence[float], objective_values: Sequence[float], u_grad_estimates: Sequence[float], u_true_gradients: Sequence[float] | None = None, theta_grad_norms: Sequence[float] | None = None, true_theta_grad_norms: Sequence[float] | None = None, step_sizes: Sequence[float] | None = None, theta_values: Sequence[numpy.ndarray] | None = None, optimizer_status: int | None = None, optimizer_message: str | None = None)
steps: Sequence[int]
u_values: Sequence[float]
objective_values: Sequence[float]
u_grad_estimates: Sequence[float]
u_true_gradients: Sequence[float] | None = None
theta_grad_norms: Sequence[float] | None = None
true_theta_grad_norms: Sequence[float] | None = None
step_sizes: Sequence[float] | None = None
theta_values: Sequence[numpy.ndarray] | None = None
optimizer_status: int | None = None
optimizer_message: str | None = None
class ConsoleReporter:
117class ConsoleReporter:
118    """Reporter that prints to terminal. Verbose mode controls per-step output."""
119
120    def __init__(self, verbose: bool = False) -> None:
121        self._verbose = verbose
122
123    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
124        print(f"\n=== Running experiment: {run_context.experiment_name} ===")
125
126    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
127        log_summary(result)
128
129    def log_step(
130        self,
131        method: str,
132        step: int,
133        u: float,
134        value: float,
135        grad_norm: float | None = None,
136        step_size: float | None = None,
137    ) -> None:
138        if self._verbose:
139            log_step(method, step, u, value, grad_norm, step_size)

Reporter that prints to terminal. Verbose mode controls per-step output.

ConsoleReporter(verbose: bool = False)
120    def __init__(self, verbose: bool = False) -> None:
121        self._verbose = verbose
def on_start( self, run_context: RunContext, config: ExperimentConfig) -> None:
123    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
124        print(f"\n=== Running experiment: {run_context.experiment_name} ===")
def on_end( self, run_context: RunContext, result: ExperimentResult) -> None:
126    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
127        log_summary(result)
def log_step( self, method: str, step: int, u: float, value: float, grad_norm: float | None = None, step_size: float | None = None) -> None:
129    def log_step(
130        self,
131        method: str,
132        step: int,
133        u: float,
134        value: float,
135        grad_norm: float | None = None,
136        step_size: float | None = None,
137    ) -> None:
138        if self._verbose:
139            log_step(method, step, u, value, grad_norm, step_size)
class FileStepLogger:
343class FileStepLogger:
344    """Writes per-step metrics to a CSV file in the run directory."""
345
346    def __init__(self) -> None:
347        self._file: IO[str] | None = None
348        self._path: Path | None = None
349
350    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
351        self._path = run_context.run_dir / "steps.csv"
352        self._file = self._path.open("w", encoding="utf-8")
353        self._file.write("method,step,u,value,grad_norm,step_size\n")  # type: ignore[union-attr]
354
355    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
356        if self._file is not None:
357            self._file.close()
358            self._file = None
359
360    def log_step(
361        self,
362        method: str,
363        step: int,
364        u: float,
365        value: float,
366        grad_norm: float | None = None,
367        step_size: float | None = None,
368    ) -> None:
369        if self._file is None:
370            return
371        grad_str = f"{grad_norm:.6f}" if grad_norm is not None else ""
372        step_str = f"{step_size:.6f}" if step_size is not None else ""
373        self._file.write(f"{method},{step},{u:.6f},{value:.6f},{grad_str},{step_str}\n")

Writes per-step metrics to a CSV file in the run directory.

def on_start( self, run_context: RunContext, config: ExperimentConfig) -> None:
350    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
351        self._path = run_context.run_dir / "steps.csv"
352        self._file = self._path.open("w", encoding="utf-8")
353        self._file.write("method,step,u,value,grad_norm,step_size\n")  # type: ignore[union-attr]
def on_end( self, run_context: RunContext, result: ExperimentResult) -> None:
355    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
356        if self._file is not None:
357            self._file.close()
358            self._file = None
def log_step( self, method: str, step: int, u: float, value: float, grad_norm: float | None = None, step_size: float | None = None) -> None:
360    def log_step(
361        self,
362        method: str,
363        step: int,
364        u: float,
365        value: float,
366        grad_norm: float | None = None,
367        step_size: float | None = None,
368    ) -> None:
369        if self._file is None:
370            return
371        grad_str = f"{grad_norm:.6f}" if grad_norm is not None else ""
372        step_str = f"{step_size:.6f}" if step_size is not None else ""
373        self._file.write(f"{method},{step},{u:.6f},{value:.6f},{grad_str},{step_str}\n")
class JsonReporter:
142class JsonReporter:
143    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
144        return None
145
146    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
147        payload = _build_summary_payload(run_context, result)
148        summary_path = run_context.run_dir / "summary.json"
149        with summary_path.open("w", encoding="utf-8") as handle:
150            json.dump(payload, handle, indent=2, sort_keys=True, ensure_ascii=True)
def on_start( self, run_context: RunContext, config: ExperimentConfig) -> None:
143    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
144        return None
def on_end( self, run_context: RunContext, result: ExperimentResult) -> None:
146    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
147        payload = _build_summary_payload(run_context, result)
148        summary_path = run_context.run_dir / "summary.json"
149        with summary_path.open("w", encoding="utf-8") as handle:
150            json.dump(payload, handle, indent=2, sort_keys=True, ensure_ascii=True)
class PlotReporter:
269class PlotReporter:
270    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
271        return None
272
273    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
274        config = result.config
275        if not config.plot or not result.traces:
276            return
277        run_context.plots_dir.mkdir(parents=True, exist_ok=True)
278        plot_dir = str(run_context.plots_dir)
279        objective = config.objective
280        action_objective = getattr(objective, "action_objective", None)
281        traces = result.traces
282        u_star_plot = _u_star_for_plot(action_objective, result.u_star)
283        plot_loss_curves(
284            traces,
285            plot_dir,
286            u_star=u_star_plot,
287        )
288        plot_gradient_norms(traces, plot_dir)
289        if action_objective is not None:
290            plot_objective_u_slice(
291                result.x_samples,
292                action_objective,
293                traces,
294                plot_dir,
295                u_star=u_star_plot,
296            )
297        if config.step_rule == STEP_RULE_ARMIJO:
298            plot_step_sizes(traces, plot_dir)
299        if config.theta0.size >= 2:
300            axis_indices = (0, 1)
301            axis_labels = None
302            theta_path_points = [config.theta0]
303            for trace in traces.values():
304                if trace.theta_values:
305                    theta_path_points.extend(trace.theta_values)
306            if config.theta0.size > 2 and theta_path_points:
307                axis_indices = select_theta_axes_max_variance(theta_path_points)
308                axis_labels = (
309                    f"theta[{axis_indices[0]}] (max-var axis)",
310                    f"theta[{axis_indices[1]}] (max-var axis)",
311                )
312            ordered_results = [
313                (name, result.results[name])
314                for name in config.enabled_estimators
315                if name in result.results
316            ]
317            theta_refs = [config.theta0]
318            theta_points = [(config.theta0, "initial", "#636363", "o")]
319            for name, estimator_result in ordered_results:
320                theta_refs.append(estimator_result.theta)
321                style = ESTIMATOR_STYLES[name]
322                theta_points.append(
323                    (
324                        estimator_result.theta,
325                        style["label"],
326                        style["color"],
327                        style["marker"],
328                    )
329                )
330            plot_theta_objective_contours(
331                result.x_samples,
332                objective,
333                config.theta0,
334                plot_dir,
335                axis_indices=axis_indices,
336                axis_labels=axis_labels,
337                theta_refs=theta_refs,
338                theta_points=theta_points,
339                traces=traces,
340            )
def on_start( self, run_context: RunContext, config: ExperimentConfig) -> None:
270    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
271        return None
def on_end( self, run_context: RunContext, result: ExperimentResult) -> None:
273    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
274        config = result.config
275        if not config.plot or not result.traces:
276            return
277        run_context.plots_dir.mkdir(parents=True, exist_ok=True)
278        plot_dir = str(run_context.plots_dir)
279        objective = config.objective
280        action_objective = getattr(objective, "action_objective", None)
281        traces = result.traces
282        u_star_plot = _u_star_for_plot(action_objective, result.u_star)
283        plot_loss_curves(
284            traces,
285            plot_dir,
286            u_star=u_star_plot,
287        )
288        plot_gradient_norms(traces, plot_dir)
289        if action_objective is not None:
290            plot_objective_u_slice(
291                result.x_samples,
292                action_objective,
293                traces,
294                plot_dir,
295                u_star=u_star_plot,
296            )
297        if config.step_rule == STEP_RULE_ARMIJO:
298            plot_step_sizes(traces, plot_dir)
299        if config.theta0.size >= 2:
300            axis_indices = (0, 1)
301            axis_labels = None
302            theta_path_points = [config.theta0]
303            for trace in traces.values():
304                if trace.theta_values:
305                    theta_path_points.extend(trace.theta_values)
306            if config.theta0.size > 2 and theta_path_points:
307                axis_indices = select_theta_axes_max_variance(theta_path_points)
308                axis_labels = (
309                    f"theta[{axis_indices[0]}] (max-var axis)",
310                    f"theta[{axis_indices[1]}] (max-var axis)",
311                )
312            ordered_results = [
313                (name, result.results[name])
314                for name in config.enabled_estimators
315                if name in result.results
316            ]
317            theta_refs = [config.theta0]
318            theta_points = [(config.theta0, "initial", "#636363", "o")]
319            for name, estimator_result in ordered_results:
320                theta_refs.append(estimator_result.theta)
321                style = ESTIMATOR_STYLES[name]
322                theta_points.append(
323                    (
324                        estimator_result.theta,
325                        style["label"],
326                        style["color"],
327                        style["marker"],
328                    )
329                )
330            plot_theta_objective_contours(
331                result.x_samples,
332                objective,
333                config.theta0,
334                plot_dir,
335                axis_indices=axis_indices,
336                axis_labels=axis_labels,
337                theta_refs=theta_refs,
338                theta_points=theta_points,
339                traces=traces,
340            )
class ReporterStack:
 91class ReporterStack:
 92    def __init__(self, reporters: Sequence[Reporter]) -> None:
 93        self._reporters = list(reporters)
 94
 95    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
 96        for reporter in self._reporters:
 97            reporter.on_start(run_context, config)
 98
 99    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
100        for reporter in self._reporters:
101            reporter.on_end(run_context, result)
102
103    def log_step(
104        self,
105        method: str,
106        step: int,
107        u: float,
108        value: float,
109        grad_norm: float | None = None,
110        step_size: float | None = None,
111    ) -> None:
112        for reporter in self._reporters:
113            if isinstance(reporter, StepReporter):
114                reporter.log_step(method, step, u, value, grad_norm, step_size)
ReporterStack(reporters: Sequence[experiments.reporters.Reporter])
92    def __init__(self, reporters: Sequence[Reporter]) -> None:
93        self._reporters = list(reporters)
def on_start( self, run_context: RunContext, config: ExperimentConfig) -> None:
95    def on_start(self, run_context: RunContext, config: ExperimentConfig) -> None:
96        for reporter in self._reporters:
97            reporter.on_start(run_context, config)
def on_end( self, run_context: RunContext, result: ExperimentResult) -> None:
 99    def on_end(self, run_context: RunContext, result: ExperimentResult) -> None:
100        for reporter in self._reporters:
101            reporter.on_end(run_context, result)
def log_step( self, method: str, step: int, u: float, value: float, grad_norm: float | None = None, step_size: float | None = None) -> None:
103    def log_step(
104        self,
105        method: str,
106        step: int,
107        u: float,
108        value: float,
109        grad_norm: float | None = None,
110        step_size: float | None = None,
111    ) -> None:
112        for reporter in self._reporters:
113            if isinstance(reporter, StepReporter):
114                reporter.log_step(method, step, u, value, grad_norm, step_size)
@dataclass(frozen=True)
class RunContext:
31@dataclass(frozen=True)
32class RunContext:
33    experiment_name: str
34    run_id: str
35    run_dir: Path
36    plots_dir: Path
37    started_at: datetime
RunContext( experiment_name: str, run_id: str, run_dir: pathlib.Path, plots_dir: pathlib.Path, started_at: datetime.datetime)
experiment_name: str
run_id: str
run_dir: pathlib.Path
plots_dir: pathlib.Path
started_at: datetime.datetime
@runtime_checkable
class StepReporter(typing.Protocol):
66@runtime_checkable
67class StepReporter(Protocol):
68    """Protocol for per-step metric logging during optimization."""
69
70    def log_step(
71        self,
72        method: str,
73        step: int,
74        u: float,
75        value: float,
76        grad_norm: float | None = None,
77        step_size: float | None = None,
78    ) -> None:
79        ...

Protocol for per-step metric logging during optimization.

StepReporter(*args, **kwargs)
1866def _no_init_or_replace_init(self, *args, **kwargs):
1867    cls = type(self)
1868
1869    if cls._is_protocol:
1870        raise TypeError('Protocols cannot be instantiated')
1871
1872    # Already using a custom `__init__`. No need to calculate correct
1873    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1874    if cls.__init__ is not _no_init_or_replace_init:
1875        return
1876
1877    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1878    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1879    # searches for a proper new `__init__` in the MRO. The new `__init__`
1880    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1881    # instantiation of the protocol subclass will thus use the new
1882    # `__init__` and no longer call `_no_init_or_replace_init`.
1883    for base in cls.__mro__:
1884        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1885        if init is not _no_init_or_replace_init:
1886            cls.__init__ = init
1887            break
1888    else:
1889        # should not happen
1890        cls.__init__ = object.__init__
1891
1892    cls.__init__(self, *args, **kwargs)
def log_step( self, method: str, step: int, u: float, value: float, grad_norm: float | None = None, step_size: float | None = None) -> None:
70    def log_step(
71        self,
72        method: str,
73        step: int,
74        u: float,
75        value: float,
76        grad_norm: float | None = None,
77        step_size: float | None = None,
78    ) -> None:
79        ...
def create_run_context( experiment_name: str, runs_root: str = 'outputs', started_at: datetime.datetime | None = None) -> RunContext:
46def create_run_context(
47    experiment_name: str,
48    runs_root: str = "outputs",
49    started_at: datetime | None = None,
50) -> RunContext:
51    timestamp = started_at or datetime.now()
52    run_id = timestamp.strftime("%Y%m%d_%H%M%S")
53    safe_name = _sanitize_name(experiment_name)
54    run_dir = Path(runs_root) / safe_name / run_id
55    run_dir.mkdir(parents=True, exist_ok=True)
56    plots_dir = run_dir / "plots"
57    return RunContext(
58        experiment_name=experiment_name,
59        run_id=run_id,
60        run_dir=run_dir,
61        plots_dir=plots_dir,
62        started_at=timestamp,
63    )
def run_experiment( config: ExperimentConfig, step_reporter: StepReporter | None = None) -> ExperimentResult:
 25def run_experiment(
 26    config: ExperimentConfig,
 27    step_reporter: StepReporter | None = None,
 28) -> ExperimentResult:
 29    """Run optimization with all enabled estimators; returns traces and final values."""
 30    objective = config.objective
 31    enabled_estimators = tuple(config.enabled_estimators)
 32
 33    rng = default_rng(config.seed)
 34    if config.x_fixed is not None:
 35        x_samples = np.asarray(config.x_fixed, dtype=float)
 36    else:
 37        x_samples = sample_states(rng, config.n_samples, config.state_dim)
 38    true_grad_theta_fn = resolve_true_grad_theta_fn(objective, config.correctness)
 39
 40    theta_initial = np.asarray(config.theta0, dtype=float)
 41    initial_value = float(objective.value(theta_initial, x_samples))
 42
 43    # Get optimal u if available
 44    u_star = optimal_u(objective)
 45
 46    # Compute value at u* if available
 47    value_at_u_star = None
 48    if u_star is not None:
 49        try:
 50            value_at_u_star = _action_value_at_u(objective, x_samples, u_star)
 51        except ValueError:
 52            pass
 53
 54    # Get policy from objective for mean_action computation
 55    policy = getattr(objective, "policy", None)
 56
 57    results: dict[str, EstimatorResult] = {}
 58    traces = {}
 59
 60    if "first_order" in enabled_estimators:
 61        start_first = time.perf_counter()
 62        theta_first, trace_first = run_first_order(
 63            theta_initial,
 64            x_samples,
 65            objective,
 66            rng,
 67            config.t_steps,
 68            config.step_rule,
 69            config.step_size,
 70            config.n_grad_samples,
 71            config.sigma,
 72            config.batch_size,
 73            perturbation_space=config.perturbation_space,
 74            true_grad_theta_fn=true_grad_theta_fn,
 75            grad_norm_tol=config.grad_norm_tol,
 76            ftol=config.ftol,
 77            step_reporter=step_reporter,
 78        )
 79        time_first = time.perf_counter() - start_first
 80        u_first = _mean_action(policy, theta_first, x_samples) if policy is not None else float("nan")
 81        value_first = float(objective.value(theta_first, x_samples))
 82        results["first_order"] = EstimatorResult(theta=theta_first, u=u_first, value=value_first, time=time_first)
 83        traces["first_order"] = trace_first
 84
 85    if "finite_difference" in enabled_estimators:
 86        start_fd = time.perf_counter()
 87        theta_fd, trace_fd = run_finite_difference(
 88            theta_initial,
 89            x_samples,
 90            objective,
 91            rng,
 92            config.t_steps,
 93            config.step_rule,
 94            config.step_size,
 95            config.n_grad_samples,
 96            config.sigma,
 97            config.batch_size,
 98            perturbation_space=config.perturbation_space,
 99            true_grad_theta_fn=true_grad_theta_fn,
100            grad_norm_tol=config.grad_norm_tol,
101            ftol=config.ftol,
102            step_reporter=step_reporter,
103        )
104        time_fd = time.perf_counter() - start_fd
105        u_fd = _mean_action(policy, theta_fd, x_samples) if policy is not None else float("nan")
106        value_fd = float(objective.value(theta_fd, x_samples))
107        results["finite_difference"] = EstimatorResult(theta=theta_fd, u=u_fd, value=value_fd, time=time_fd)
108        traces["finite_difference"] = trace_fd
109
110    if "gauss_stein" in enabled_estimators:
111        start_zero = time.perf_counter()
112        theta_zero, trace_zero = run_gauss_stein(
113            theta_initial,
114            x_samples,
115            objective,
116            rng,
117            config.t_steps,
118            config.step_rule,
119            config.step_size,
120            config.n_grad_samples,
121            config.sigma,
122            config.batch_size,
123            perturbation_space=config.perturbation_space,
124            true_grad_theta_fn=true_grad_theta_fn,
125            grad_norm_tol=config.grad_norm_tol,
126            ftol=config.ftol,
127            step_reporter=step_reporter,
128        )
129        time_zero = time.perf_counter() - start_zero
130        u_zero = _mean_action(policy, theta_zero, x_samples) if policy is not None else float("nan")
131        value_zero = float(objective.value(theta_zero, x_samples))
132        results["gauss_stein"] = EstimatorResult(theta=theta_zero, u=u_zero, value=value_zero, time=time_zero)
133        traces["gauss_stein"] = trace_zero
134
135    if "spsa" in enabled_estimators:
136        start_spsa = time.perf_counter()
137        theta_spsa, trace_spsa = run_spsa(
138            theta_initial,
139            x_samples,
140            objective,
141            rng,
142            config.t_steps,
143            config.step_rule,
144            config.step_size,
145            config.n_grad_samples,
146            config.sigma,
147            config.batch_size,
148            perturbation_space=config.perturbation_space,
149            true_grad_theta_fn=true_grad_theta_fn,
150            grad_norm_tol=config.grad_norm_tol,
151            ftol=config.ftol,
152            step_reporter=step_reporter,
153        )
154        time_spsa = time.perf_counter() - start_spsa
155        u_spsa = _mean_action(policy, theta_spsa, x_samples) if policy is not None else float("nan")
156        value_spsa = float(objective.value(theta_spsa, x_samples))
157        results["spsa"] = EstimatorResult(theta=theta_spsa, u=u_spsa, value=value_spsa, time=time_spsa)
158        traces["spsa"] = trace_spsa
159
160    if "stein_difference" in enabled_estimators:
161        start_stein = time.perf_counter()
162        theta_stein, trace_stein = run_stein_difference(
163            theta_initial,
164            x_samples,
165            objective,
166            rng,
167            config.t_steps,
168            config.step_rule,
169            config.step_size,
170            config.n_grad_samples,
171            config.sigma,
172            config.batch_size,
173            perturbation_space=config.perturbation_space,
174            true_grad_theta_fn=true_grad_theta_fn,
175            grad_norm_tol=config.grad_norm_tol,
176            ftol=config.ftol,
177            step_reporter=step_reporter,
178        )
179        time_stein = time.perf_counter() - start_stein
180        u_stein = _mean_action(policy, theta_stein, x_samples) if policy is not None else float("nan")
181        value_stein = float(objective.value(theta_stein, x_samples))
182        results["stein_difference"] = EstimatorResult(
183            theta=theta_stein,
184            u=u_stein,
185            value=value_stein,
186            time=time_stein,
187        )
188        traces["stein_difference"] = trace_stein
189
190    return ExperimentResult(
191        config=config,
192        x_samples=x_samples,
193        initial_value=initial_value,
194        results=results,
195        traces=traces,
196        u_star=u_star,
197        value_at_u_star=value_at_u_star,
198    )

Run optimization with all enabled estimators; returns traces and final values.

def expand_override_grid(grid: Mapping[str, Sequence[Any]]) -> list[dict[str, typing.Any]]:
26def expand_override_grid(grid: Mapping[str, Sequence[Any]]) -> list[dict[str, Any]]:
27    """Build cartesian-product override dictionaries from a field-value grid."""
28    if not grid:
29        return [{}]
30    keys = list(grid.keys())
31    value_lists = [list(grid[key]) for key in keys]
32    return [dict(zip(keys, combo)) for combo in product(*value_lists)]

Build cartesian-product override dictionaries from a field-value grid.

def apply_config_overrides( config: ExperimentConfig, overrides: Mapping[str, Any]) -> ExperimentConfig:
35def apply_config_overrides(config: ExperimentConfig, overrides: Mapping[str, Any]) -> ExperimentConfig:
36    """Return a config copy with top-level ExperimentConfig fields overridden."""
37    valid_fields = {field.name for field in fields(ExperimentConfig)}
38    unknown = sorted(key for key in overrides.keys() if key not in valid_fields)
39    if unknown:
40        unknown_text = ", ".join(unknown)
41        raise ValueError(f"Unknown config override fields: {unknown_text}.")
42    return replace(config, **dict(overrides))

Return a config copy with top-level ExperimentConfig fields overridden.

def make_sweep_name(base_name: str, index: int, overrides: Mapping[str, Any]) -> str:
45def make_sweep_name(base_name: str, index: int, overrides: Mapping[str, Any]) -> str:
46    """Build a readable, deterministic run name for one sweep variant."""
47    if not overrides:
48        return f"{base_name}__sweep_{index:03d}"
49    parts = [f"{key}-{_stringify_override_value(overrides[key])}" for key in sorted(overrides.keys())]
50    suffix = "__".join(parts)
51    return f"{base_name}__sweep_{index:03d}__{suffix}"

Build a readable, deterministic run name for one sweep variant.

def generate_sweep_runs( *, base_preset: str, override_grid: Mapping[str, Sequence[Any]] | None = None, override_list: Sequence[Mapping[str, Any]] | None = None, display_keys: Sequence[str] | None = None) -> list[tuple[str, ExperimentConfig, dict[str, typing.Any]]]:
 72def generate_sweep_runs(
 73    *,
 74    base_preset: str,
 75    override_grid: Mapping[str, Sequence[Any]] | None = None,
 76    override_list: Sequence[Mapping[str, Any]] | None = None,
 77    display_keys: Sequence[str] | None = None,
 78) -> list[tuple[str, ExperimentConfig, dict[str, Any]]]:
 79    """Generate named configs by applying overrides to a base preset."""
 80    if override_grid is not None and override_list is not None:
 81        raise ValueError("Specify either override_grid or override_list, not both.")
 82
 83    if override_list is not None:
 84        overrides = [dict(item) for item in override_list] if override_list else [{}]
 85    elif override_grid is not None:
 86        overrides = expand_override_grid(override_grid)
 87    else:
 88        overrides = [{}]
 89
 90    base_config = get_config(base_preset)
 91    runs: list[tuple[str, ExperimentConfig, dict[str, Any]]] = []
 92    for index, override in enumerate(overrides, start=1):
 93        config = apply_config_overrides(base_config, override)
 94        run_name = make_display_name(
 95            base_preset,
 96            index=index,
 97            overrides=override,
 98            display_keys=display_keys,
 99        )
100        runs.append((run_name, config, override))
101    return runs

Generate named configs by applying overrides to a base preset.

def run_preset_sweep( *, base_preset: str, override_grid: Mapping[str, Sequence[Any]] | None = None, override_list: Sequence[Mapping[str, Any]] | None = None, runs_root: str = 'outputs', project_name: str | None = None, display_keys: Sequence[str] | None = None) -> list[tuple[str, ExperimentResult]]:
104def run_preset_sweep(
105    *,
106    base_preset: str,
107    override_grid: Mapping[str, Sequence[Any]] | None = None,
108    override_list: Sequence[Mapping[str, Any]] | None = None,
109    runs_root: str = "outputs",
110    project_name: str | None = None,
111    display_keys: Sequence[str] | None = None,
112) -> list[tuple[str, ExperimentResult]]:
113    """Execute a preset sweep and return `(run_name, result)` pairs."""
114    sweep_runs = generate_sweep_runs(
115        base_preset=base_preset,
116        override_grid=override_grid,
117        override_list=override_list,
118        display_keys=display_keys,
119    )
120    results: list[tuple[str, ExperimentResult]] = []
121    runs_root_path = _project_runs_root(runs_root, project_name)
122
123    for run_name, config, _ in sweep_runs:
124        run_context = create_run_context(run_name, runs_root=runs_root_path)
125        reporter_list = [
126            ConsoleReporter(verbose=config.verbose),
127            FileStepLogger(),
128            JsonReporter(),
129            PlotReporter(),
130        ]
131        if config.wandb_enabled:
132            reporter_list.append(WandbReporter())
133        reporters = ReporterStack(reporter_list)
134        reporters.on_start(run_context, config)
135        result = run_experiment(config, step_reporter=reporters)
136        reporters.on_end(run_context, result)
137        results.append((run_name, result))
138
139    return results

Execute a preset sweep and return (run_name, result) pairs.