Skip to content

Commit 8ae2a19

Browse files
committed
Remove dict.keys() when unnecessary
1 parent 63da6d1 commit 8ae2a19

File tree

28 files changed

+53
-64
lines changed

28 files changed

+53
-64
lines changed

pytensor/compile/function/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def checkSV(sv_ori, sv_rpl):
659659
exist_svs = [i.variable for i in maker.inputs]
660660

661661
# Check if given ShareVariables exist
662-
for sv in swap.keys():
662+
for sv in swap:
663663
if sv not in exist_svs:
664664
raise ValueError(f"SharedVariable: {sv.name} not found")
665665

@@ -711,9 +711,9 @@ def checkSV(sv_ori, sv_rpl):
711711
# it is well tested, we don't share the part of the storage_map.
712712
if share_memory:
713713
i_o_vars = maker.fgraph.inputs + maker.fgraph.outputs
714-
for key in storage_map.keys():
714+
for key, val in storage_map.items():
715715
if key not in i_o_vars:
716-
new_storage_map[memo[key]] = storage_map[key]
716+
new_storage_map[memo[key]] = val
717717

718718
if not name and self.name:
719719
name = self.name + " copy"
@@ -1446,7 +1446,7 @@ def prepare_fgraph(
14461446
if not hasattr(mode.linker, "accept"):
14471447
raise ValueError(
14481448
"'linker' parameter of FunctionMaker should be "
1449-
f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers.keys())}"
1449+
f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers)}"
14501450
)
14511451

14521452
def __init__(

pytensor/compile/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1446,7 +1446,7 @@ def summary(self, file=sys.stderr, n_ops_to_print=20, n_apply_to_print=20):
14461446
file=file,
14471447
)
14481448
if config.profiling__debugprint:
1449-
fcts = {fgraph for (fgraph, n) in self.apply_time.keys()}
1449+
fcts = {fgraph for (fgraph, n) in self.apply_time}
14501450
pytensor.printing.debugprint(fcts, print_type=True)
14511451
if self.variable_shape or self.variable_strides:
14521452
self.summary_memory(file, n_apply_to_print)

pytensor/configdefaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,7 @@ def add_caching_dir_configvars():
13181318
_compiledir_format_dict["short_platform"] = short_platform()
13191319
# Allow to have easily one compiledir per device.
13201320
_compiledir_format_dict["device"] = config.device
1321-
compiledir_format_keys = ", ".join(sorted(_compiledir_format_dict.keys()))
1321+
compiledir_format_keys = ", ".join(sorted(_compiledir_format_dict))
13221322
_default_compiledir_format = (
13231323
"compiledir_%(short_platform)s-%(processor)s-"
13241324
"%(python_version)s-%(python_bitwidth)s"

pytensor/configparser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def change_flags(self, *args, **kwargs) -> _ChangeFlagsDecorator:
214214
return _ChangeFlagsDecorator(*args, _root=self, **kwargs)
215215

216216
def warn_unused_flags(self):
217-
for key in self._flags_dict.keys():
217+
for key in self._flags_dict:
218218
warnings.warn(f"PyTensor does not recognise this flag: {key}")
219219

220220

pytensor/gradient.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def grad(
500500
if cost is not None:
501501
outputs.append(cost)
502502
if known_grads is not None:
503-
outputs.extend(list(known_grads.keys()))
503+
outputs.extend(list(known_grads))
504504

505505
var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, _wrt, consider_constant)
506506

@@ -966,7 +966,7 @@ def visit(var):
966966
visit(elem)
967967

968968
# Remove variables that don't have wrt as a true ancestor
969-
orig_vars = list(var_to_app_to_idx.keys())
969+
orig_vars = list(var_to_app_to_idx)
970970
for var in orig_vars:
971971
if var not in visited:
972972
del var_to_app_to_idx[var]

pytensor/graph/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]:
631631
if not hasattr(self, "_fn_cache"):
632632
self._fn_cache: dict = dict()
633633

634-
inputs = tuple(sorted(parsed_inputs_to_values.keys(), key=id))
634+
inputs = tuple(sorted(parsed_inputs_to_values, key=id))
635635
cache_key = (inputs, tuple(kwargs.items()))
636636
try:
637637
fn = self._fn_cache[cache_key]

pytensor/graph/destroyhandler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def recursive_destroys_finder(protected_var):
406406
# If True means that the apply node, destroys the protected_var.
407407
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
408408
return True
409-
for var_idx in app.op.view_map.keys():
409+
for var_idx in app.op.view_map:
410410
if idx in app.op.view_map[var_idx]:
411411
# We need to recursively check the destroy_map of all the
412412
# outputs that we have a view_map on.

pytensor/graph/rewriting/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from collections.abc import Iterable as IterableType
1616
from functools import _compose_mro, partial, reduce # type: ignore
1717
from itertools import chain
18-
from typing import TYPE_CHECKING, Literal, cast
18+
from typing import TYPE_CHECKING, Literal
1919

2020
import pytensor
2121
from pytensor.configdefaults import config
@@ -1924,9 +1924,9 @@ def process_node(
19241924
remove: list[Variable] = []
19251925
if isinstance(replacements, dict):
19261926
if "remove" in replacements:
1927-
remove = list(cast(Sequence[Variable], replacements.pop("remove")))
1928-
old_vars = list(cast(Sequence[Variable], replacements.keys()))
1929-
replacements = list(cast(Sequence[Variable], replacements.values()))
1927+
remove = list(replacements.pop("remove"))
1928+
old_vars = list(replacements)
1929+
replacements = list(replacements.values())
19301930
elif not isinstance(replacements, tuple | list):
19311931
raise TypeError(
19321932
f"Node rewriter {node_rewriter} gave wrong type of replacement. "

pytensor/graph/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class MissingInputError(Exception):
168168
def __init__(self, *args, **kwargs):
169169
if kwargs:
170170
# The call to list is needed for Python 3
171-
assert list(kwargs.keys()) == ["variable"]
171+
assert list(kwargs) == ["variable"]
172172
error_msg = get_variable_trace_string(kwargs["variable"])
173173
if error_msg:
174174
args = (*args, error_msg)

pytensor/link/c/params_type.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,7 @@ def __init__(self, params_type, **kwargs):
264264
def __repr__(self):
265265
return "Params({})".format(
266266
", ".join(
267-
[
268-
(f"{k}:{type(self[k]).__name__}:{self[k]}")
269-
for k in sorted(self.keys())
270-
]
267+
[(f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)]
271268
)
272269
)
273270

@@ -365,7 +362,7 @@ def __init__(self, **kwargs):
365362
)
366363

367364
self.length = len(kwargs)
368-
self.fields = tuple(sorted(kwargs.keys()))
365+
self.fields = tuple(sorted(kwargs))
369366
self.types = tuple(kwargs[field] for field in self.fields)
370367
self.name = self.generate_struct_name()
371368

0 commit comments

Comments
 (0)