Skip to content

Commit

Permalink
Sampling: Copy over iterable overrides
Browse files Browse the repository at this point in the history
If an override was iterable, any modifications to the returned value
would alter the reference to the global storage dict.

Therefore, copy the structure if it's an iterable so any modification
won't alter the original override. Also apply this for the function
that checks for forced overrides.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed May 18, 2024
1 parent 0e9385e commit b9fd855
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 8 additions & 2 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pathlib
import yaml
from copy import deepcopy
from loguru import logger
from pydantic import AliasChoices, BaseModel, Field
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -376,14 +377,19 @@ def get_all_presets():
def get_default_sampler_value(key, fallback=None):
"""Gets an overridden default sampler value"""

return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback)
default_value = unwrap(
deepcopy(overrides_container.overrides.get(key, {}).get("override")),
fallback,
)

return default_value


def apply_forced_sampler_overrides(params: BaseSamplerRequest):
"""Forcefully applies overrides if specified by the user"""

for var, value in overrides_container.overrides.items():
override = value.get("override")
override = deepcopy(value.get("override"))
original_value = getattr(params, var, None)

# Force takes precedence over additive
Expand Down
2 changes: 1 addition & 1 deletion common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ def coalesce(*args):


def prune_dict(input_dict):
"""Trim out instances of None from a dictionary"""
"""Trim out instances of None from a dictionary."""

return {k: v for k, v in input_dict.items() if v is not None}

0 comments on commit b9fd855

Please sign in to comment.