Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit autoupdate #7295

Merged
merged 6 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
- --exclude=binder/
- --exclude=versioneer.py
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.1
rev: v0.4.5
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
Expand Down
8 changes: 4 additions & 4 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
for stats in sampler_vars:
for key, dtype in stats.items():
if dtypes.setdefault(key, dtype) != dtype:
raise ValueError("Sampler statistic %s appears with " "different types." % key)
raise ValueError(f"Sampler statistic {key} appears with different types.")

self.sampler_vars = sampler_vars

Expand Down Expand Up @@ -247,7 +247,7 @@

sampler_idxs = [i for i, s in enumerate(self.sampler_vars) if stat_name in s]
if not sampler_idxs:
raise KeyError("Unknown sampler stat %s" % stat_name)
raise KeyError(f"Unknown sampler stat {stat_name}")

Check warning on line 250 in pymc/backends/base.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/base.py#L250

Added line #L250 was not covered by tests

vals = np.stack(
[self._get_sampler_stats(stat_name, i, burn, thin) for i in sampler_idxs], axis=-1
Expand Down Expand Up @@ -388,7 +388,7 @@
return self.get_values(var, burn=burn, thin=thin)
if var in self.stat_names:
return self.get_sampler_stats(var, burn=burn, thin=thin)
raise KeyError("Unknown variable %s" % var)
raise KeyError(f"Unknown variable {var}")

Check warning on line 391 in pymc/backends/base.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/base.py#L391

Added line #L391 was not covered by tests

_attrs = {"_straces", "varnames", "chains", "stat_names", "_report"}

Expand Down Expand Up @@ -512,7 +512,7 @@
List or ndarray depending on parameters.
"""
if stat_name not in self.stat_names:
raise KeyError("Unknown sampler statistic %s" % stat_name)
raise KeyError(f"Unknown sampler statistic {stat_name}")

Check warning on line 515 in pymc/backends/base.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/base.py#L515

Added line #L515 was not covered by tests

if chains is None:
chains = self.chains
Expand Down
8 changes: 4 additions & 4 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,11 +1047,11 @@
tril_testval = None

c = pt.sqrt(
ChiSquared("%s_c" % name, nu - np.arange(2, 2 + n_diag), shape=n_diag, initval=diag_testval)
ChiSquared(f"{name}_c", nu - np.arange(2, 2 + n_diag), shape=n_diag, initval=diag_testval)
)
pm._log.info("Added new variable %s_c to model diagonal of Wishart." % name)
z = Normal("%s_z" % name, 0.0, 1.0, shape=n_tril, initval=tril_testval)
pm._log.info("Added new variable %s_z to model off-diagonals of Wishart." % name)
pm._log.info(f"Added new variable {name}_c to model diagonal of Wishart.")
z = Normal(f"{name}_z", 0.0, 1.0, shape=n_tril, initval=tril_testval)
pm._log.info(f"Added new variable {name}_z to model off-diagonals of Wishart.")

Check warning on line 1054 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1052-L1054

Added lines #L1052 - L1054 were not covered by tests
# Construct A matrix
A = pt.zeros(S.shape, dtype=np.float32)
A = pt.set_subtensor(A[diag_idx], c)
Expand Down
14 changes: 4 additions & 10 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,8 @@
if size is None:
broadcasted_shape = np.broadcast_shapes(*shapes)
if broadcasted_shape is None:
raise ValueError(
"Cannot broadcast provided shapes {} given size: {}".format(
", ".join([f"{s}" for s in shapes]), size
)
)
tmp = ", ".join([f"{s}" for s in shapes])
raise ValueError(f"Cannot broadcast provided shapes {tmp} given size: {size}")

Check warning on line 145 in pymc/distributions/shape_utils.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/shape_utils.py#L144-L145

Added lines #L144 - L145 were not covered by tests
return broadcasted_shape
shapes = [_check_shape_type(s) for s in shapes]
_size = to_tuple(size)
Expand All @@ -154,11 +151,8 @@
try:
broadcast_shape = np.broadcast_shapes(*sp_shapes)
except ValueError:
raise ValueError(
"Cannot broadcast provided shapes {} given size: {}".format(
", ".join([f"{s}" for s in shapes]), size
)
)
tmp = ", ".join([f"{s}" for s in shapes])
raise ValueError(f"Cannot broadcast provided shapes {tmp} given size: {size}")
broadcastable_shapes = []
for shape, sp_shape in zip(shapes, sp_shapes):
if _size == shape[: len(_size)]:
Expand Down
5 changes: 2 additions & 3 deletions pymc/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,8 @@ def getter(self):
value = getattr(self, name, None)
if value is None:
raise AttributeError(
"'{}' not set. Provide as argument "
"to condition, or call 'prior' "
"first".format(name.lstrip("_"))
f"'{name.lstrip('_')}' not set. Provide as argument "
"to condition, or call 'prior' first"
)
else:
return value
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init_subclass__(cls, **kwargs):
def __call__(cls, *args, **kwargs):
# We type hint Model here so type checkers understand that Model is a context manager.
# This metaclass is only used for Model, so this is safe to do. See #6809 for more info.
instance: "Model" = cls.__new__(cls, *args, **kwargs)
instance: Model = cls.__new__(cls, *args, **kwargs)
with instance: # appends context
instance.__init__(*args, **kwargs)
return instance
Expand Down
12 changes: 6 additions & 6 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,11 @@ def str_for_dist(
else r"\\operatorname{Unknown}"
)
if include_params:
params = ",~".join([d.strip("$") for d in dist_args])
if print_name:
return r"${} \sim {}({})$".format(
print_name, op_name, ",~".join([d.strip("$") for d in dist_args])
)
return rf"${print_name} \sim {op_name}({params})$"
else:
return r"${}({})$".format(op_name, ",~".join([d.strip("$") for d in dist_args]))
return rf"${op_name}({params})$"

else:
if print_name:
Expand All @@ -83,10 +82,11 @@ def str_for_dist(
dist.owner.op._print_name[0] if hasattr(dist.owner.op, "_print_name") else "Unknown"
)
if include_params:
params = ", ".join(dist_args)
if print_name:
return r"{} ~ {}({})".format(print_name, dist_name, ", ".join(dist_args))
return rf"{print_name} ~ {dist_name}({params})"
else:
return r"{}({})".format(dist_name, ", ".join(dist_args))
return rf"{dist_name}({params})"
else:
if print_name:
return rf"{print_name} ~ {dist_name}"
Expand Down
5 changes: 2 additions & 3 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@

data = {k: np.stack(v) for k, v in zip(names, values)}
if data is None:
raise AssertionError("No variables sampled: attempting to sample %s" % names)
raise AssertionError(f"No variables sampled: attempting to sample {names}")

Check warning on line 422 in pymc/sampling/forward.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/forward.py#L422

Added line #L422 was not covered by tests

prior: dict[str, np.ndarray] = {}
for var_name in vars_:
Expand Down Expand Up @@ -765,8 +765,7 @@
samples = len(_trace)
else:
raise TypeError(
"Do not know how to compute number of samples for trace argument of type %s"
% type(_trace)
f"Do not know how to compute number of samples for trace argument of type {type(_trace)}"
)

assert samples is not None
Expand Down
5 changes: 1 addition & 4 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,10 +697,7 @@ def joined_blas_limiter():
msg = "Tuning was enabled throughout the whole trace."
_log.warning(msg)
elif draws < 100:
msg = (
"Only %s samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
% draws
)
msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
_log.warning(msg)

auto_nuts_init = True
Expand Down
8 changes: 4 additions & 4 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
tb = traceback.format_exception(type(exc), exc, tb)
tb = "".join(tb)
self.exc = exc
self.tb = '\n"""\n%s"""' % tb
self.tb = f'\n"""\n{tb}"""'

def __reduce__(self):
return rebuild_exc, (self.exc, self.tb)
Expand Down Expand Up @@ -216,7 +216,7 @@
mp_ctx,
):
self.chain = chain
process_name = "worker_chain_%s" % chain
process_name = f"worker_chain_{chain}"
self._msg_pipe, remote_conn = multiprocessing.Pipe()

self._shared_point = {}
Expand All @@ -228,7 +228,7 @@
size *= int(dim)
size *= dtype.itemsize
if size != ctypes.c_size_t(size).value:
raise ValueError("Variable %s is too large" % name)
raise ValueError(f"Variable {name} is too large")

Check warning on line 231 in pymc/sampling/parallel.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/parallel.py#L231

Added line #L231 was not covered by tests

array = mp_ctx.RawArray("c", size)
self._shared_point[name] = (array, shape, dtype)
Expand Down Expand Up @@ -388,7 +388,7 @@
mp_ctx=None,
):
if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError("Number of seeds and start_points must be %s." % chains)
raise ValueError(f"Number of seeds and start_points must be {chains}.")

Check warning on line 391 in pymc/sampling/parallel.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/parallel.py#L391

Added line #L391 was not covered by tests

if mp_ctx is None or isinstance(mp_ctx, str):
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
def compute_state(self, q: RaveledVars, p: RaveledVars):
"""Compute Hamiltonian functions using a position and momentum."""
if q.data.dtype != self._dtype or p.data.dtype != self._dtype:
raise ValueError("Invalid dtype. Must be %s" % self._dtype)
raise ValueError(f"Invalid dtype. Must be {self._dtype}")

Check warning on line 54 in pymc/step_methods/hmc/integration.py

View check run for this annotation

Codecov / codecov/patch

pymc/step_methods/hmc/integration.py#L54

Added line #L54 was not covered by tests

logp, dlogp = self._logp_dlogp_func(q)

Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@
elif S.ndim == 2:
self.proposal_dist = MultivariateNormalProposal(S)
else:
raise ValueError("Invalid rank for variance: %s" % S.ndim)
raise ValueError(f"Invalid rank for variance: {S.ndim}")

Check warning on line 181 in pymc/step_methods/metropolis.py

View check run for this annotation

Codecov / codecov/patch

pymc/step_methods/metropolis.py#L181

Added line #L181 was not covered by tests

self.scaling = np.atleast_1d(scaling).astype("d")
self.tune = tune
Expand Down
4 changes: 2 additions & 2 deletions pymc/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def _maybe_score(self, score):
score = returns_loss
elif score and not returns_loss:
warnings.warn(
"method `fit` got `score == True` but %s "
"does not return loss. Ignoring `score` argument" % self.objective.op
f"method `fit` got `score == True` but {self.objective.op} "
"does not return loss. Ignoring `score` argument"
)
score = False
else:
Expand Down
12 changes: 6 additions & 6 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@
if fn_kwargs is None:
fn_kwargs = {}
if score and not self.op.returns_loss:
raise NotImplementedError("%s does not have loss" % self.op)
raise NotImplementedError(f"{self.op} does not have loss")
updates = self.updates(
obj_n_mc=obj_n_mc,
tf_n_mc=tf_n_mc,
Expand Down Expand Up @@ -416,7 +416,7 @@
if fn_kwargs is None:
fn_kwargs = {}
if not self.op.returns_loss:
raise NotImplementedError("%s does not have loss" % self.op)
raise NotImplementedError(f"{self.op} does not have loss")
if more_replacements is None:
more_replacements = {}
loss = self(sc_n_mc, more_replacements=more_replacements)
Expand Down Expand Up @@ -496,13 +496,13 @@
def __call__(self, f=None):
if self.has_test_function:
if f is None:
raise ParametrizationError("Operator %s requires TestFunction" % self)
raise ParametrizationError(f"Operator {self} requires TestFunction")

Check warning on line 499 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L499

Added line #L499 was not covered by tests
else:
if not isinstance(f, TestFunction):
f = TestFunction.from_function(f)
else:
if f is not None:
warnings.warn("TestFunction for %s is redundant and removed" % self, stacklevel=3)
warnings.warn(f"TestFunction for {self} is redundant and removed", stacklevel=3)

Check warning on line 505 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L505

Added line #L505 was not covered by tests
else:
pass
f = TestFunction()
Expand Down Expand Up @@ -555,7 +555,7 @@
@classmethod
def from_function(cls, f):
if not callable(f):
raise ParametrizationError("Need callable, got %r" % f)
raise ParametrizationError(f"Need callable, got {f!r}")

Check warning on line 558 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L558

Added line #L558 was not covered by tests
obj = TestFunction()
obj.__call__ = f
return obj
Expand Down Expand Up @@ -1512,7 +1512,7 @@
found.name = name + "_vi_random_slice"
break
else:
raise KeyError("%r not found" % name)
raise KeyError(f"{name!r} not found")

Check warning on line 1515 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L1515

Added line #L1515 was not covered by tests
return found

@node_property
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ addopts = ["--color=yes"]
[tool.ruff]
line-length = 100
target-version = "py310"
exclude = ["versioneer.py"]
extend-exclude = ["versioneer.py", "_version.py"]

[tool.ruff.lint]
select = ["D", "E", "F", "I", "UP", "W", "RUF"]
ignore-init-module-imports = true
ignore = [
"E501",
"F841", # Local variable name is assigned to but never used
Expand Down
2 changes: 1 addition & 1 deletion tests/variational/test_opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_logq_mini_2_sample_2_var(parametric_grouped_approxes, three_var_model):

def test_logq_globals(three_var_approx):
if not three_var_approx.has_logq:
pytest.skip("%s does not implement logq" % three_var_approx)
pytest.skip(f"{three_var_approx} does not implement logq")
approx = three_var_approx
logq, symbolic_logq = approx.set_size_and_deterministic(
[approx.logq, approx.symbolic_logq], 1, 0
Expand Down
Loading