Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion personal_python_ast_optimizer/parser/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from enum import Enum, EnumType
from types import EllipsisType

from personal_python_ast_optimizer.python_info import (
default_functions_safe_to_exclude_in_test_expr,
)


class TypeHintsToSkip(Enum):
NONE = 0
Expand Down Expand Up @@ -132,18 +136,20 @@ class OptimizationsConfig(_Config):
__slots__ = (
"vars_to_fold",
"enums_to_fold",
"functions_safe_to_exclude_in_test_expr",
"remove_unused_imports",
"fold_constants",
"assume_this_machine",
)

def __init__(
def __init__( # noqa: PLR0913
self,
vars_to_fold: dict[
str, str | bytes | bool | int | float | complex | None | EllipsisType
]
| None = None,
enums_to_fold: Iterable[EnumType] | None = None,
functions_safe_to_exclude_in_test_expr: set[str] | None = None,
fold_constants: bool = True,
remove_unused_imports: bool = True,
assume_this_machine: bool = False,
Expand All @@ -156,6 +162,10 @@ def __init__(
if enums_to_fold is None
else self._format_enums_to_fold_as_dict(enums_to_fold)
)
self.functions_safe_to_exclude_in_test_expr: set[str] = (
functions_safe_to_exclude_in_test_expr
or default_functions_safe_to_exclude_in_test_expr
)
self.remove_unused_imports: bool = remove_unused_imports
self.assume_this_machine: bool = assume_this_machine
self.fold_constants: bool = fold_constants
Expand Down
76 changes: 50 additions & 26 deletions personal_python_ast_optimizer/parser/skipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
is_return_none,
remove_duplicate_slots,
skip_base_classes,
skip_dangling_expressions,
skip_decorators,
)
from personal_python_ast_optimizer.python_info import (
default_functions_safe_to_exclude_in_test_expr,
)


class _NodeContext(Enum):
Expand Down Expand Up @@ -103,11 +105,18 @@ def generic_visit(self, node: ast.AST) -> ast.AST:
for value in old_value:
if isinstance(value, ast.AST):
value = self.visit(value) # noqa: PLW2901
if value is None:

if value is None or (
self.token_types_config.skip_dangling_expressions
and isinstance(value, ast.Expr)
and isinstance(value.value, ast.Constant)
):
continue

if not isinstance(value, ast.AST):
new_values.extend(value)
continue

new_values.append(value)

if (
Expand Down Expand Up @@ -154,9 +163,6 @@ def _combine_imports(body: list) -> None:
body[:] = new_body

def visit_Module(self, node: ast.Module) -> ast.AST:
if self.token_types_config.skip_dangling_expressions:
skip_dangling_expressions(node)

self.generic_visit(node)

if self._simplified_named_tuple:
Expand Down Expand Up @@ -189,9 +195,6 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST | None:
if self._use_version_optimization((3, 0)):
skip_base_classes(node, ["object"])

if self.token_types_config.skip_dangling_expressions:
skip_dangling_expressions(node)

skip_base_classes(node, self.tokens_config.classes_to_skip)
skip_decorators(node, self.tokens_config.decorators_to_skip)

Expand Down Expand Up @@ -272,9 +275,6 @@ def _handle_function_node(
if self.token_types_config.skip_type_hints:
node.returns = None

if self.token_types_config.skip_dangling_expressions:
skip_dangling_expressions(node)

skip_decorators(node, self.tokens_config.decorators_to_skip)

if node.body:
Expand All @@ -298,26 +298,26 @@ def _should_skip_function(
def visit_Try(self, node: ast.Try) -> ast.AST | list[ast.stmt] | None:
parsed_node = self.generic_visit(node)

if isinstance(
parsed_node, (ast.Try, ast.TryStar)
) and self._is_useless_try_node(parsed_node):
if isinstance(parsed_node, (ast.Try, ast.TryStar)) and self._body_is_only_pass(
parsed_node.body
):
return parsed_node.finalbody or None

return parsed_node

def visit_TryStar(self, node: ast.TryStar) -> ast.AST | list[ast.stmt] | None:
parsed_node = self.generic_visit(node)

if isinstance(
parsed_node, (ast.Try, ast.TryStar)
) and self._is_useless_try_node(parsed_node):
if isinstance(parsed_node, (ast.Try, ast.TryStar)) and self._body_is_only_pass(
parsed_node.body
):
return parsed_node.finalbody or None

return parsed_node

@staticmethod
def _is_useless_try_node(node: ast.Try | ast.TryStar) -> bool:
return all(isinstance(n, ast.Pass) for n in node.body)
def _body_is_only_pass(node_body: list[ast.stmt]) -> bool:
return all(isinstance(n, ast.Pass) for n in node_body)

def visit_Attribute(self, node: ast.Attribute) -> ast.AST | None:
if isinstance(node.value, ast.Name):
Expand Down Expand Up @@ -497,13 +497,19 @@ def visit_Dict(self, node: ast.Dict) -> ast.AST:
def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt] | None:
parsed_node: ast.AST = self.generic_visit(node)

if isinstance(parsed_node, ast.If) and isinstance(
parsed_node.test, ast.Constant
):
if_body: list[ast.stmt] = (
parsed_node.body if parsed_node.test.value else parsed_node.orelse
)
return if_body or None
if isinstance(parsed_node, ast.If):
if isinstance(parsed_node.test, ast.Constant):
if_body: list[ast.stmt] = (
parsed_node.body if parsed_node.test.value else parsed_node.orelse
)
return if_body or None

if not parsed_node.orelse and self._body_is_only_pass(parsed_node.body):
call_finder = _DanglingExprCallFinder(
self.optimizations_config.functions_safe_to_exclude_in_test_expr
)
call_finder.visit(parsed_node.test)
return [ast.Expr(expr) for expr in call_finder.calls]

return parsed_node

Expand Down Expand Up @@ -832,3 +838,21 @@ def visit_Continue(self, node: ast.Continue) -> ast.Continue:

def visit_Constant(self, node: ast.Constant) -> ast.Constant:
return node


class _DanglingExprCallFinder(ast.NodeTransformer):
"""Finds all calls in a given dangling expression
except for a subset of builtin functions that have
no side effects."""

__slots__ = ("calls", "excludes")

def __init__(self, excludes: set[str]) -> None:
self.calls: list[ast.Call] = []
self.excludes: set[str] = excludes

def visit_Call(self, node: ast.Call) -> ast.Call:
if get_node_name(node) not in default_functions_safe_to_exclude_in_test_expr:
self.calls.append(node)

return node
13 changes: 0 additions & 13 deletions personal_python_ast_optimizer/parser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,6 @@ def is_return_none(node: ast.Return) -> bool:
return isinstance(node.value, ast.Constant) and node.value.value is None


def skip_dangling_expressions(
node: ast.Module | ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef,
) -> None:
"""Removes constant dangling expression like doc strings"""
node.body = [
element
for element in node.body
if not (
isinstance(element, ast.Expr) and isinstance(element.value, ast.Constant)
)
]


def skip_base_classes(
node: ast.ClassDef, classes_to_ignore: Iterable[str] | TokensToSkip
) -> None:
Expand Down
12 changes: 12 additions & 0 deletions personal_python_ast_optimizer/python_info.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
"""Various tokens in Python that the ast module writes"""

# Functions that have no side effects and thus are safe to remove
# if a test expression is found to be useless. For example:
# if "str(a) == 'a':pass" will be turned into just "str(a) == 'a'"
# but if its known str has no side effects then it can be fully removed
default_functions_safe_to_exclude_in_test_expr: set[str] = {
"int",
"str",
"isinstance",
"getattr",
"hasattr",
}

comparison_and_conjunctions: list[str] = [
" if ",
" else ",
Expand Down
31 changes: 27 additions & 4 deletions tests/parser/test_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
_if_cases = [
BeforeAndAfter(
"""
if a() == b:pass
if a() == b:eggs()
else:pass
""",
"if a()==b:pass",
"if a()==b:eggs()",
),
BeforeAndAfter(
"""
if a == b:pass
if a == b:eggs()
else:print()""",
"""
if a==b:pass
if a==b:eggs()
else:print()
""".strip(),
),
Expand Down Expand Up @@ -56,6 +56,29 @@
else:bar()""",
"foo()",
),
BeforeAndAfter(
"if test():pass\nelse:foo()",
"if test():pass\nelse:foo()",
),
BeforeAndAfter(
"if test():pass\nelse:pass",
"test()",
),
BeforeAndAfter(
"if str(a) == 'a':pass",
"",
),
BeforeAndAfter(
"if a < 3:pass",
"",
),
BeforeAndAfter(
"""
try:foo()
except:raise OSError
if test():pass""",
"try:foo()\nexcept:raise OSError\ntest()",
),
]


Expand Down
8 changes: 4 additions & 4 deletions tests/parser/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def test_one_line_if():
"""
'a' if 'True' == b else 'b'
'a' if b == 'True' else 'b'
'a' if 1==1 else 'b'
'a' if 1==2 else 'b'
a='a' if 1==1 else 'b'
b='a' if 1==2 else 'b'
""",
"""
'a'if'True'==b else'b'
'a'if b=='True'else'b'
'a'
'b'
a='a'
b='b'
""".strip(),
)
run_minifier_and_assert_correct(before_and_after)
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
6.1.0
6.1.1