diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py
index 5f1b895d7..87eefd5d7 100644
--- a/codeflash/cli_cmds/cmd_init.py
+++ b/codeflash/cli_cmds/cmd_init.py
@@ -27,6 +27,9 @@
from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.extension import install_vscode_extension
+# Import Java init module
+from codeflash.cli_cmds.init_java import init_java_project
+
# Import JS/TS init module
from codeflash.cli_cmds.init_javascript import (
ProjectLanguage,
@@ -35,9 +38,6 @@
get_js_dependency_installation_commands,
init_js_project,
)
-
-# Import Java init module
-from codeflash.cli_cmds.init_java import init_java_project
from codeflash.code_utils.code_utils import validate_relative_directory_path
from codeflash.code_utils.compat import LF
from codeflash.code_utils.config_parser import parse_config_file
@@ -1674,9 +1674,7 @@ def _customize_java_workflow_content(optimize_yml_content: str, git_root: Path,
# Install dependencies
install_deps_cmd = get_java_dependency_installation_commands(build_tool)
- optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
-
- return optimize_yml_content
+ return optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
def get_formatter_cmds(formatter: str) -> list[str]:
diff --git a/codeflash/cli_cmds/init_java.py b/codeflash/cli_cmds/init_java.py
index 73822e626..5be5b19a9 100644
--- a/codeflash/cli_cmds/init_java.py
+++ b/codeflash/cli_cmds/init_java.py
@@ -165,9 +165,7 @@ def init_java_project() -> None:
lang_panel = Panel(
Text(
- "Java project detected!\n\nI'll help you set up Codeflash for your project.",
- style="cyan",
- justify="center",
+ "Java project detected!\n\nI'll help you set up Codeflash for your project.", style="cyan", justify="center"
),
title="Java Setup",
border_style="bright_red",
@@ -205,7 +203,9 @@ def init_java_project() -> None:
completion_message = "Codeflash is now set up for your Java project!\n\nYou can now run any of these commands:"
if did_add_new_key:
- completion_message += "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
+ completion_message += (
+ "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
+ )
if os.name == "nt":
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
else:
@@ -234,9 +234,7 @@ def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]:
codeflash_config_path = project_root / "codeflash.toml"
if codeflash_config_path.exists():
return Confirm.ask(
- "A Codeflash config already exists. Do you want to re-configure it?",
- default=False,
- show_default=True,
+ "A Codeflash config already exists. Do you want to re-configure it?", default=False, show_default=True
), None
return True, None
@@ -285,14 +283,10 @@ def collect_java_setup_info() -> JavaSetupInfo:
if Confirm.ask("Would you like to change any of these settings?", default=False):
# Source root override
- module_root_override = _prompt_directory_override(
- "source", detected_source_root, curdir
- )
+ module_root_override = _prompt_directory_override("source", detected_source_root, curdir)
# Test root override
- test_root_override = _prompt_directory_override(
- "test", detected_test_root, curdir
- )
+ test_root_override = _prompt_directory_override("test", detected_test_root, curdir)
# Formatter override
formatter_questions = [
@@ -300,7 +294,7 @@ def collect_java_setup_info() -> JavaSetupInfo:
"formatter",
message="Which code formatter do you use?",
choices=[
- (f"keep detected (google-java-format)", "keep"),
+ ("keep detected (google-java-format)", "keep"),
("google-java-format", "google-java-format"),
("spotless", "spotless"),
("other", "other"),
@@ -345,7 +339,7 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st
subdirs = [d.name for d in curdir.iterdir() if d.is_dir() and not d.name.startswith(".")]
subdirs = [d for d in subdirs if d not in ("target", "build", ".git", ".idea", detected)]
- options = [keep_detected_option] + subdirs[:5] + [custom_dir_option]
+ options = [keep_detected_option, *subdirs[:5], custom_dir_option]
questions = [
inquirer.List(
@@ -364,10 +358,9 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st
answer = answers[f"{dir_type}_root"]
if answer == keep_detected_option:
return None
- elif answer == custom_dir_option:
+ if answer == custom_dir_option:
return _prompt_custom_directory(dir_type)
- else:
- return answer
+ return answer
def _prompt_custom_directory(dir_type: str) -> str:
@@ -441,7 +434,7 @@ def get_java_formatter_cmd(formatter: str, build_tool: JavaBuildTool) -> list[st
if formatter == "spotless":
if build_tool == JavaBuildTool.MAVEN:
return ["mvn spotless:apply -DspotlessFiles=$file"]
- elif build_tool == JavaBuildTool.GRADLE:
+ if build_tool == JavaBuildTool.GRADLE:
return ["./gradlew spotlessApply"]
return ["spotless $file"]
if formatter == "other":
diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py
index bb28fe66b..e75d4e125 100644
--- a/codeflash/code_utils/code_replacer.py
+++ b/codeflash/code_utils/code_replacer.py
@@ -711,18 +711,12 @@ def _add_java_class_members(
if not new_fields and not new_methods:
return original_source
- logger.debug(
- f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}"
- )
+ logger.debug(f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}")
# Import the insertion function from replacement module
from codeflash.languages.java.replacement import _insert_class_members
- result = _insert_class_members(
- original_source, class_name, new_fields, new_methods, analyzer
- )
-
- return result
+ return _insert_class_members(original_source, class_name, new_fields, new_methods, analyzer)
except Exception as e:
logger.debug(f"Error adding Java class members: {e}")
@@ -959,12 +953,14 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
for file_path_str, code in file_to_code_context.items():
if file_path_str:
# Extract filename without creating Path object repeatedly
- if file_path_str.endswith(target_filename) and (len(file_path_str) == len(target_filename) or file_path_str[-len(target_filename)-1] in ('/', '\\')):
+ if file_path_str.endswith(target_filename) and (
+ len(file_path_str) == len(target_filename)
+ or file_path_str[-len(target_filename) - 1] in ("/", "\\")
+ ):
module_optimized_code = code
logger.debug(f"Matched {file_path_str} to {relative_path} by filename")
break
-
if module_optimized_code is None:
# Also try matching if there's only one code file, but ONLY for non-Python
# languages where path matching is less strict. For Python, we require
diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py
index 76cb041a1..a0f212e8d 100644
--- a/codeflash/code_utils/instrument_existing_tests.py
+++ b/codeflash/code_utils/instrument_existing_tests.py
@@ -721,9 +721,7 @@ def inject_profiling_into_existing_test(
if is_java():
from codeflash.languages.java.instrumentation import instrument_existing_test
- return instrument_existing_test(
- test_path, call_positions, function_to_optimize, tests_project_root, mode.value
- )
+ return instrument_existing_test(test_path, call_positions, function_to_optimize, tests_project_root, mode.value)
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
diff --git a/codeflash/languages/__init__.py b/codeflash/languages/__init__.py
index ffbd9d97f..416849243 100644
--- a/codeflash/languages/__init__.py
+++ b/codeflash/languages/__init__.py
@@ -36,15 +36,15 @@
reset_current_language,
set_current_language,
)
+
+# Java language support
+# Importing the module triggers registration via @register_language decorator
+from codeflash.languages.java.support import JavaSupport # noqa: F401
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
# Import language support modules to trigger auto-registration
# This ensures all supported languages are available when this package is imported
from codeflash.languages.python import PythonSupport # noqa: F401
-
-# Java language support
-# Importing the module triggers registration via @register_language decorator
-from codeflash.languages.java.support import JavaSupport # noqa: F401
from codeflash.languages.registry import (
detect_project_language,
get_language_support,
diff --git a/codeflash/languages/java/build_tools.py b/codeflash/languages/java/build_tools.py
index 200555488..5fb962db6 100644
--- a/codeflash/languages/java/build_tools.py
+++ b/codeflash/languages/java/build_tools.py
@@ -13,7 +13,10 @@
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from enum import Enum
-from pathlib import Path
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from pathlib import Path
logger = logging.getLogger(__name__)
@@ -29,6 +32,7 @@ def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
Raises:
ET.ParseError: If XML parsing fails.
+
"""
# Read file content and parse as string to avoid file-based attacks
# This prevents XXE attacks by not allowing external entity resolution
@@ -38,9 +42,7 @@ def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
root = ET.fromstring(content)
# Create ElementTree from root
- tree = ET.ElementTree(root)
-
- return tree
+ return ET.ElementTree(root)
class BuildTool(Enum):
@@ -390,13 +392,7 @@ def run_maven_tests(
try:
result = subprocess.run(
- cmd,
- check=False,
- cwd=project_root,
- env=run_env,
- capture_output=True,
- text=True,
- timeout=timeout,
+ cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
)
# Parse test results from Surefire reports
@@ -416,7 +412,7 @@ def run_maven_tests(
)
except subprocess.TimeoutExpired:
- logger.error("Maven test execution timed out after %d seconds", timeout)
+ logger.exception("Maven test execution timed out after %d seconds", timeout)
return MavenTestResult(
success=False,
tests_run=0,
@@ -496,10 +492,7 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
def compile_maven_project(
- project_root: Path,
- include_tests: bool = True,
- env: dict[str, str] | None = None,
- timeout: int = 300,
+ project_root: Path, include_tests: bool = True, env: dict[str, str] | None = None, timeout: int = 300
) -> tuple[bool, str, str]:
"""Compile a Maven project.
@@ -533,13 +526,7 @@ def compile_maven_project(
try:
result = subprocess.run(
- cmd,
- check=False,
- cwd=project_root,
- env=run_env,
- capture_output=True,
- text=True,
- timeout=timeout,
+ cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
)
return result.returncode == 0, result.stdout, result.stderr
@@ -581,14 +568,7 @@ def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path) -> boo
]
try:
- result = subprocess.run(
- cmd,
- check=False,
- cwd=project_root,
- capture_output=True,
- text=True,
- timeout=60,
- )
+ result = subprocess.run(cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=60)
if result.returncode == 0:
logger.info("Successfully installed codeflash-runtime to local Maven repository")
@@ -664,7 +644,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
return True
except ET.ParseError as e:
- logger.error("Failed to parse pom.xml: %s", e)
+ logger.exception("Failed to parse pom.xml: %s", e)
return False
except Exception as e:
logger.exception("Failed to add dependency to pom.xml: %s", e)
@@ -751,11 +731,11 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
# JaCoCo plugin XML to insert (indented for typical pom.xml format)
# Note: For multi-module projects where tests are in a separate module,
# we configure the report to look in multiple directories for classes
- jacoco_plugin = """
+ jacoco_plugin = f"""
org.jacoco
jacoco-maven-plugin
- {version}
+ {JACOCO_PLUGIN_VERSION}
prepare-agent
@@ -777,7 +757,7 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
- """.format(version=JACOCO_PLUGIN_VERSION)
+ """
# Find the main section (not inside )
# We need to find a that appears after or before
@@ -786,7 +766,6 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
profiles_end = content.find("")
# Find all tags
- import re
# Find the main build section - it's the one NOT inside profiles
# Strategy: Look for that comes after or before (or no profiles)
@@ -816,7 +795,7 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
if build_start != -1 and build_end != -1:
# Found main build section, find plugins within it
- build_section = content[build_start:build_end + len("")]
+ build_section = content[build_start : build_end + len("")]
plugins_start_in_build = build_section.find("")
plugins_end_in_build = build_section.rfind("")
diff --git a/codeflash/languages/java/comparator.py b/codeflash/languages/java/comparator.py
index c30bd2446..75fa7f51f 100644
--- a/codeflash/languages/java/comparator.py
+++ b/codeflash/languages/java/comparator.py
@@ -47,7 +47,16 @@ def _find_comparator_jar(project_root: Path | None = None) -> Path | None:
return jar_path
# Check local Maven repository
- m2_jar = Path.home() / ".m2" / "repository" / "com" / "codeflash" / "codeflash-runtime" / "1.0.0" / "codeflash-runtime-1.0.0.jar"
+ m2_jar = (
+ Path.home()
+ / ".m2"
+ / "repository"
+ / "com"
+ / "codeflash"
+ / "codeflash-runtime"
+ / "1.0.0"
+ / "codeflash-runtime-1.0.0.jar"
+ )
if m2_jar.exists():
return m2_jar
@@ -113,8 +122,7 @@ def compare_test_results(
jar_path = comparator_jar or _find_comparator_jar(project_root)
if not jar_path or not jar_path.exists():
logger.error(
- "codeflash-runtime JAR not found. "
- "Please ensure the codeflash-runtime is installed in your project."
+ "codeflash-runtime JAR not found. Please ensure the codeflash-runtime is installed in your project."
)
return False, []
@@ -155,10 +163,10 @@ def compare_test_results(
comparison = json.loads(result.stdout)
except json.JSONDecodeError as e:
- logger.error(f"Failed to parse Java comparator output: {e}")
- logger.error(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
+ logger.exception(f"Failed to parse Java comparator output: {e}")
+ logger.exception(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
if result.stderr:
- logger.error(f"stderr: {result.stderr[:500]}")
+ logger.exception(f"stderr: {result.stderr[:500]}")
return False, []
# Check for errors in the JSON response
@@ -178,9 +186,7 @@ def compare_test_results(
for diff in comparison.get("diffs", []):
scope_str = diff.get("scope", "return_value")
scope = TestDiffScope.RETURN_VALUE
- if scope_str == "exception":
- scope = TestDiffScope.DID_PASS
- elif scope_str == "missing":
+ if scope_str in {"exception", "missing"}:
scope = TestDiffScope.DID_PASS
# Build test identifier
@@ -220,20 +226,17 @@ def compare_test_results(
return equivalent, test_diffs
except subprocess.TimeoutExpired:
- logger.error("Java comparator timed out")
+ logger.exception("Java comparator timed out")
return False, []
except FileNotFoundError:
- logger.error("Java not found. Please install Java to compare test results.")
+ logger.exception("Java not found. Please install Java to compare test results.")
return False, []
except Exception as e:
- logger.error(f"Error running Java comparator: {e}")
+ logger.exception(f"Error running Java comparator: {e}")
return False, []
-def compare_invocations_directly(
- original_results: dict,
- candidate_results: dict,
-) -> tuple[bool, list]:
+def compare_invocations_directly(original_results: dict, candidate_results: dict) -> tuple[bool, list]:
"""Compare test invocations directly from Python dictionaries.
This is a fallback when the Java comparator is not available.
diff --git a/codeflash/languages/java/config.py b/codeflash/languages/java/config.py
index 4d99c6b10..408dcecaf 100644
--- a/codeflash/languages/java/config.py
+++ b/codeflash/languages/java/config.py
@@ -10,7 +10,6 @@
import logging
import xml.etree.ElementTree as ET
from dataclasses import dataclass, field
-from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.java.build_tools import (
@@ -22,7 +21,7 @@
)
if TYPE_CHECKING:
- pass
+ from pathlib import Path
logger = logging.getLogger(__name__)
@@ -80,9 +79,7 @@ def detect_java_project(project_root: Path) -> JavaProjectConfig | None:
project_info = get_project_info(project_root)
# Detect test framework
- test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework(
- project_root, build_tool
- )
+ test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework(project_root, build_tool)
# Detect other dependencies
has_mockito, has_assertj = _detect_test_dependencies(project_root, build_tool)
@@ -120,9 +117,7 @@ def detect_java_project(project_root: Path) -> JavaProjectConfig | None:
)
-def _detect_test_framework(
- project_root: Path, build_tool: BuildTool
-) -> tuple[str, bool, bool, bool]:
+def _detect_test_framework(project_root: Path, build_tool: BuildTool) -> tuple[str, bool, bool, bool]:
"""Detect which test framework the project uses.
Args:
@@ -210,9 +205,7 @@ def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]:
elif tag == "groupId":
group_id = child.text
- if group_id == "org.junit.jupiter" or (
- artifact_id and "junit-jupiter" in artifact_id
- ):
+ if group_id == "org.junit.jupiter" or (artifact_id and "junit-jupiter" in artifact_id):
has_junit5 = True
elif group_id == "junit" and artifact_id == "junit":
has_junit4 = True
@@ -253,9 +246,7 @@ def _detect_test_deps_from_gradle(project_root: Path) -> tuple[bool, bool, bool]
return has_junit5, has_junit4, has_testng
-def _detect_test_dependencies(
- project_root: Path, build_tool: BuildTool
-) -> tuple[bool, bool]:
+def _detect_test_dependencies(project_root: Path, build_tool: BuildTool) -> tuple[bool, bool]:
"""Detect additional test dependencies (Mockito, AssertJ).
Returns:
@@ -289,9 +280,7 @@ def _detect_test_dependencies(
return has_mockito, has_assertj
-def _get_compiler_settings(
- project_root: Path, build_tool: BuildTool
-) -> tuple[str | None, str | None]:
+def _get_compiler_settings(project_root: Path, build_tool: BuildTool) -> tuple[str | None, str | None]:
"""Get compiler source and target settings.
Returns:
@@ -392,11 +381,7 @@ def is_java_project(project_root: Path) -> bool:
return True
# Check for Java source files
- for pattern in ["src/**/*.java", "*.java"]:
- if list(project_root.glob(pattern)):
- return True
-
- return False
+ return any(list(project_root.glob(pattern)) for pattern in ["src/**/*.java", "*.java"])
def get_test_file_pattern(config: JavaProjectConfig) -> str:
diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py
index 2ccfd34bf..a2c7f7c0e 100644
--- a/codeflash/languages/java/context.py
+++ b/codeflash/languages/java/context.py
@@ -8,26 +8,27 @@
from __future__ import annotations
import logging
-from pathlib import Path
from typing import TYPE_CHECKING
-from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import CodeContext, HelperFunction, Language
from codeflash.languages.java.discovery import discover_functions_from_source
-from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
-from codeflash.languages.java.parser import JavaAnalyzer, JavaClassNode, get_java_analyzer
+from codeflash.languages.java.import_resolver import find_helper_files
+from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
+ from pathlib import Path
+
from tree_sitter import Node
+ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
+ from codeflash.languages.java.parser import JavaAnalyzer
+
logger = logging.getLogger(__name__)
class InvalidJavaSyntaxError(Exception):
"""Raised when extracted Java code is not syntactically valid."""
- pass
-
def extract_code_context(
function: FunctionToOptimize,
@@ -67,12 +68,8 @@ def extract_code_context(
try:
source = function.file_path.read_text(encoding="utf-8")
except Exception as e:
- logger.error("Failed to read %s: %s", function.file_path, e)
- return CodeContext(
- target_code="",
- target_file=function.file_path,
- language=Language.JAVA,
- )
+ logger.exception("Failed to read %s: %s", function.file_path, e)
+ return CodeContext(target_code="", target_file=function.file_path, language=Language.JAVA)
# Extract target function code
target_code = extract_function_source(source, function)
@@ -94,9 +91,7 @@ def extract_code_context(
import_statements = [_import_to_statement(imp) for imp in imports]
# Extract helper functions
- helper_functions = find_helper_functions(
- function, project_root, max_helper_depth, analyzer
- )
+ helper_functions = find_helper_functions(function, project_root, max_helper_depth, analyzer)
# Extract read-only context only if fields are NOT already in the skeleton
# Avoid duplication between target_code and read_only_context
@@ -107,9 +102,8 @@ def extract_code_context(
# Validate syntax - extracted code must always be valid Java
if validate_syntax and target_code:
if not analyzer.validate_syntax(target_code):
- raise InvalidJavaSyntaxError(
- f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}"
- )
+ msg = f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}"
+ raise InvalidJavaSyntaxError(msg)
return CodeContext(
target_code=target_code,
@@ -156,7 +150,7 @@ def __init__(
enum_constants: str,
type_indent: str,
type_kind: str, # "class", "interface", or "enum"
- outer_type_skeleton: "TypeSkeleton | None" = None,
+ outer_type_skeleton: TypeSkeleton | None = None,
) -> None:
self.type_declaration = type_declaration
self.type_javadoc = type_javadoc
@@ -173,10 +167,7 @@ def __init__(
def _extract_type_skeleton(
- source: str,
- type_name: str,
- target_method_name: str,
- analyzer: JavaAnalyzer,
+ source: str, type_name: str, target_method_name: str, analyzer: JavaAnalyzer
) -> TypeSkeleton | None:
"""Extract the type skeleton (class, interface, or enum) for wrapping a method.
@@ -254,11 +245,7 @@ def _find_type_node(node: Node, type_name: str, source_bytes: bytes) -> tuple[No
Tuple of (node, type_kind) where type_kind is "class", "interface", or "enum".
"""
- type_declarations = {
- "class_declaration": "class",
- "interface_declaration": "interface",
- "enum_declaration": "enum",
- }
+ type_declarations = {"class_declaration": "class", "interface_declaration": "interface", "enum_declaration": "enum"}
if node.type in type_declarations:
name_node = node.child_by_field_name("name")
@@ -283,11 +270,7 @@ def _find_class_node(node: Node, class_name: str, source_bytes: bytes) -> Node |
def _get_outer_type_skeleton(
- inner_type_node: Node,
- source_bytes: bytes,
- lines: list[str],
- target_method_name: str,
- analyzer: JavaAnalyzer,
+ inner_type_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str, analyzer: JavaAnalyzer
) -> TypeSkeleton | None:
"""Get the outer type skeleton if this is an inner type.
@@ -356,11 +339,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
parts: list[str] = []
# Determine which body node type to look for
- body_types = {
- "class": "class_body",
- "interface": "interface_body",
- "enum": "enum_body",
- }
+ body_types = {"class": "class_body", "interface": "interface_body", "enum": "enum_body"}
body_type = body_types.get(type_kind, "class_body")
for child in type_node.children:
@@ -374,7 +353,8 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
# Keep old function name for backwards compatibility
-_extract_class_declaration = lambda node, source_bytes: _extract_type_declaration(node, source_bytes, "class")
+def _extract_class_declaration(node, source_bytes):
+ return _extract_type_declaration(node, source_bytes, "class")
def _find_javadoc(node: Node, source_bytes: bytes) -> str | None:
@@ -390,11 +370,7 @@ def _find_javadoc(node: Node, source_bytes: bytes) -> str | None:
def _extract_type_body_context(
- body_node: Node,
- source_bytes: bytes,
- lines: list[str],
- target_method_name: str,
- type_kind: str,
+ body_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str, type_kind: str
) -> tuple[str, str, str]:
"""Extract fields, constructors, and enum constants from a type body.
@@ -473,15 +449,10 @@ def _extract_type_body_context(
# Keep old function name for backwards compatibility
def _extract_class_body_context(
- body_node: Node,
- source_bytes: bytes,
- lines: list[str],
- target_method_name: str,
+ body_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str
) -> tuple[str, str]:
"""Extract fields and constructors from a class body."""
- fields, constructors, _ = _extract_type_body_context(
- body_node, source_bytes, lines, target_method_name, "class"
- )
+ fields, constructors, _ = _extract_type_body_context(body_node, source_bytes, lines, target_method_name, "class")
return (fields, constructors)
@@ -584,10 +555,7 @@ def extract_function_source(source: str, function: FunctionToOptimize) -> str:
def find_helper_functions(
- function: FunctionToOptimize,
- project_root: Path,
- max_depth: int = 2,
- analyzer: JavaAnalyzer | None = None,
+ function: FunctionToOptimize, project_root: Path, max_depth: int = 2, analyzer: JavaAnalyzer | None = None
) -> list[HelperFunction]:
"""Find helper functions that the target function depends on.
@@ -606,11 +574,9 @@ def find_helper_functions(
visited_functions: set[str] = set()
# Find helper files through imports
- helper_files = find_helper_files(
- function.file_path, project_root, max_depth, analyzer
- )
+ helper_files = find_helper_files(function.file_path, project_root, max_depth, analyzer)
- for file_path, class_names in helper_files.items():
+ for file_path in helper_files:
try:
source = file_path.read_text(encoding="utf-8")
file_functions = discover_functions_from_source(source, file_path, analyzer=analyzer)
@@ -648,10 +614,7 @@ def find_helper_functions(
return helpers
-def _find_same_class_helpers(
- function: FunctionToOptimize,
- analyzer: JavaAnalyzer,
-) -> list[HelperFunction]:
+def _find_same_class_helpers(function: FunctionToOptimize, analyzer: JavaAnalyzer) -> list[HelperFunction]:
"""Find helper methods in the same class as the target function.
Args:
@@ -694,9 +657,7 @@ def _find_same_class_helpers(
and method.class_name == function.class_name
and method.name in called_methods
):
- func_source = source_bytes[
- method.node.start_byte : method.node.end_byte
- ].decode("utf8")
+ func_source = source_bytes[method.node.start_byte : method.node.end_byte].decode("utf8")
helpers.append(
HelperFunction(
@@ -715,11 +676,7 @@ def _find_same_class_helpers(
return helpers
-def extract_read_only_context(
- source: str,
- function: FunctionToOptimize,
- analyzer: JavaAnalyzer,
-) -> str:
+def extract_read_only_context(source: str, function: FunctionToOptimize, analyzer: JavaAnalyzer) -> str:
"""Extract read-only context (fields, constants, inner classes).
This extracts class-level context that the function might depend on
@@ -767,11 +724,7 @@ def _import_to_statement(import_info) -> str:
return f"{prefix}{import_info.import_path}{suffix};"
-def extract_class_context(
- file_path: Path,
- class_name: str,
- analyzer: JavaAnalyzer | None = None,
-) -> str:
+def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyzer | None = None) -> str:
"""Extract the full context of a class.
Args:
@@ -813,5 +766,5 @@ def extract_class_context(
return package_stmt + "\n".join(import_statements) + "\n\n" + class_source
except Exception as e:
- logger.error("Failed to extract class context: %s", e)
+ logger.exception("Failed to extract class context: %s", e)
return ""
diff --git a/codeflash/languages/java/discovery.py b/codeflash/languages/java/discovery.py
index 902feca67..2d8f0b3ea 100644
--- a/codeflash/languages/java/discovery.py
+++ b/codeflash/languages/java/discovery.py
@@ -12,19 +12,17 @@
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import FunctionFilterCriteria
-from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
+from codeflash.languages.java.parser import get_java_analyzer
from codeflash.models.function_types import FunctionParent
if TYPE_CHECKING:
- pass
+ from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode
logger = logging.getLogger(__name__)
def discover_functions(
- file_path: Path,
- filter_criteria: FunctionFilterCriteria | None = None,
- analyzer: JavaAnalyzer | None = None,
+ file_path: Path, filter_criteria: FunctionFilterCriteria | None = None, analyzer: JavaAnalyzer | None = None
) -> list[FunctionToOptimize]:
"""Find all optimizable functions/methods in a Java file.
@@ -115,10 +113,7 @@ def discover_functions_from_source(
def _should_include_method(
- method: JavaMethodNode,
- criteria: FunctionFilterCriteria,
- source: str,
- analyzer: JavaAnalyzer,
+ method: JavaMethodNode, criteria: FunctionFilterCriteria, source: str, analyzer: JavaAnalyzer
) -> bool:
"""Check if a method should be included based on filter criteria.
@@ -176,10 +171,7 @@ def _should_include_method(
return True
-def discover_test_methods(
- file_path: Path,
- analyzer: JavaAnalyzer | None = None,
-) -> list[FunctionToOptimize]:
+def discover_test_methods(file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[FunctionToOptimize]:
"""Find all JUnit test methods in a Java test file.
Looks for methods annotated with @Test, @ParameterizedTest, @RepeatedTest, etc.
@@ -232,7 +224,7 @@ def _walk_tree_for_test_methods(
for child in node.children:
if child.type == "modifiers":
for mod_child in child.children:
- if mod_child.type == "marker_annotation" or mod_child.type == "annotation":
+ if mod_child.type in {"marker_annotation", "annotation"}:
annotation_text = analyzer.get_node_text(mod_child, source_bytes)
# Check for JUnit 5 test annotations
if any(
@@ -278,10 +270,7 @@ def _walk_tree_for_test_methods(
def get_method_by_name(
- file_path: Path,
- method_name: str,
- class_name: str | None = None,
- analyzer: JavaAnalyzer | None = None,
+ file_path: Path, method_name: str, class_name: str | None = None, analyzer: JavaAnalyzer | None = None
) -> FunctionToOptimize | None:
"""Find a specific method by name in a Java file.
@@ -306,9 +295,7 @@ def get_method_by_name(
def get_class_methods(
- file_path: Path,
- class_name: str,
- analyzer: JavaAnalyzer | None = None,
+ file_path: Path, class_name: str, analyzer: JavaAnalyzer | None = None
) -> list[FunctionToOptimize]:
"""Get all methods in a specific class.
diff --git a/codeflash/languages/java/formatter.py b/codeflash/languages/java/formatter.py
index a9ccd2d8d..2bb228ca2 100644
--- a/codeflash/languages/java/formatter.py
+++ b/codeflash/languages/java/formatter.py
@@ -6,16 +6,13 @@
from __future__ import annotations
+import contextlib
import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- pass
logger = logging.getLogger(__name__)
@@ -29,7 +26,7 @@ class JavaFormatter:
# Version of google-java-format to use
GOOGLE_JAVA_FORMAT_VERSION = "1.19.2"
- def __init__(self, project_root: Path | None = None):
+ def __init__(self, project_root: Path | None = None) -> None:
"""Initialize the Java formatter.
Args:
@@ -107,21 +104,13 @@ def _format_with_google_java_format(self, source: str) -> str | None:
try:
# Write source to temp file
- with tempfile.NamedTemporaryFile(
- mode="w", suffix=".java", delete=False, encoding="utf-8"
- ) as tmp:
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".java", delete=False, encoding="utf-8") as tmp:
tmp.write(source)
tmp_path = tmp.name
try:
result = subprocess.run(
- [
- self._java_executable,
- "-jar",
- str(jar_path),
- "--replace",
- tmp_path,
- ],
+ [self._java_executable, "-jar", str(jar_path), "--replace", tmp_path],
check=False,
capture_output=True,
text=True,
@@ -133,16 +122,12 @@ def _format_with_google_java_format(self, source: str) -> str | None:
with open(tmp_path, encoding="utf-8") as f:
return f.read()
else:
- logger.debug(
- "google-java-format failed: %s", result.stderr or result.stdout
- )
+ logger.debug("google-java-format failed: %s", result.stderr or result.stdout)
finally:
# Clean up temp file
- try:
+ with contextlib.suppress(OSError):
os.unlink(tmp_path)
- except OSError:
- pass
except subprocess.TimeoutExpired:
logger.warning("google-java-format timed out")
@@ -169,9 +154,7 @@ def _get_google_java_format_jar(self) -> Path | None:
if self.project_root
else None,
# In user's home directory
- Path.home()
- / ".codeflash"
- / f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
+ Path.home() / ".codeflash" / f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
# In system temp
Path(tempfile.gettempdir())
/ "codeflash"
@@ -186,8 +169,7 @@ def _get_google_java_format_jar(self) -> Path | None:
# Don't auto-download to avoid surprises
# Users can manually download the JAR
logger.debug(
- "google-java-format JAR not found. "
- "Download from https://github.com/google/google-java-format/releases"
+ "google-java-format JAR not found. Download from https://github.com/google/google-java-format/releases"
)
return None
@@ -239,7 +221,7 @@ def download_google_java_format(self, target_dir: Path | None = None) -> Path |
logger.info("Downloaded google-java-format to %s", jar_path)
return jar_path
except Exception as e:
- logger.error("Failed to download google-java-format: %s", e)
+ logger.exception("Failed to download google-java-format: %s", e)
return None
diff --git a/codeflash/languages/java/import_resolver.py b/codeflash/languages/java/import_resolver.py
index 5ab8800ed..766434a94 100644
--- a/codeflash/languages/java/import_resolver.py
+++ b/codeflash/languages/java/import_resolver.py
@@ -8,14 +8,15 @@
import logging
from dataclasses import dataclass
-from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.java.build_tools import find_source_root, find_test_root, get_project_info
-from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, get_java_analyzer
+from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
- pass
+ from pathlib import Path
+
+ from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo
logger = logging.getLogger(__name__)
@@ -35,18 +36,7 @@ class JavaImportResolver:
"""Resolves Java imports to file paths within a project."""
# Standard Java packages that are always external
- STANDARD_PACKAGES = frozenset(
- [
- "java",
- "javax",
- "sun",
- "com.sun",
- "jdk",
- "org.w3c",
- "org.xml",
- "org.ietf",
- ]
- )
+ STANDARD_PACKAGES = frozenset(["java", "javax", "sun", "com.sun", "jdk", "org.w3c", "org.xml", "org.ietf"])
# Common third-party package prefixes
COMMON_EXTERNAL_PREFIXES = frozenset(
@@ -66,7 +56,7 @@ class JavaImportResolver:
]
)
- def __init__(self, project_root: Path):
+ def __init__(self, project_root: Path) -> None:
"""Initialize the import resolver.
Args:
@@ -156,10 +146,7 @@ def resolve_imports(self, imports: list[JavaImportInfo]) -> list[ResolvedImport]
def _is_standard_library(self, import_path: str) -> bool:
"""Check if an import is from the Java standard library."""
- for prefix in self.STANDARD_PACKAGES:
- if import_path.startswith(prefix + ".") or import_path == prefix:
- return True
- return False
+ return any(import_path.startswith(prefix + ".") or import_path == prefix for prefix in self.STANDARD_PACKAGES)
def _is_external_library(self, import_path: str) -> bool:
"""Check if an import is from a known external library."""
@@ -249,9 +236,7 @@ def find_class_file(self, class_name: str, package_hint: str | None = None) -> P
return None
- def get_imports_from_file(
- self, file_path: Path, analyzer: JavaAnalyzer | None = None
- ) -> list[ResolvedImport]:
+ def get_imports_from_file(self, file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[ResolvedImport]:
"""Get and resolve all imports from a Java file.
Args:
@@ -272,9 +257,7 @@ def get_imports_from_file(
logger.warning("Failed to get imports from %s: %s", file_path, e)
return []
- def get_project_imports(
- self, file_path: Path, analyzer: JavaAnalyzer | None = None
- ) -> list[ResolvedImport]:
+ def get_project_imports(self, file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[ResolvedImport]:
"""Get only the imports that resolve to files within the project.
Args:
@@ -308,10 +291,7 @@ def resolve_imports_for_file(
def find_helper_files(
- file_path: Path,
- project_root: Path,
- max_depth: int = 2,
- analyzer: JavaAnalyzer | None = None,
+ file_path: Path, project_root: Path, max_depth: int = 2, analyzer: JavaAnalyzer | None = None
) -> dict[Path, list[str]]:
"""Find helper files imported by a Java file, recursively.
diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py
index b42ea7d94..78cd77d3a 100644
--- a/codeflash/languages/java/instrumentation.py
+++ b/codeflash/languages/java/instrumentation.py
@@ -17,16 +17,16 @@
import logging
import re
from functools import lru_cache
-from pathlib import Path
from typing import TYPE_CHECKING
-from codeflash.discovery.functions_to_optimize import FunctionToOptimize
-from codeflash.languages.java.parser import JavaAnalyzer
-
if TYPE_CHECKING:
from collections.abc import Sequence
+ from pathlib import Path
from typing import Any
+ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
+ from codeflash.languages.java.parser import JavaAnalyzer
+
logger = logging.getLogger(__name__)
@@ -36,7 +36,8 @@ def _get_function_name(func: Any) -> str:
return func.function_name
if hasattr(func, "name"):
return func.name
- raise AttributeError(f"Cannot get function name from {type(func)}")
+ msg = f"Cannot get function name from {type(func)}"
+ raise AttributeError(msg)
def _get_qualified_name(func: Any) -> str:
@@ -135,7 +136,7 @@ def instrument_existing_test(
try:
source = test_path.read_text(encoding="utf-8")
except Exception as e:
- logger.error("Failed to read test file %s: %s", test_path, e)
+ logger.exception("Failed to read test file %s: %s", test_path, e)
return False, f"Failed to read test file: {e}"
func_name = _get_function_name(function_to_optimize)
@@ -227,7 +228,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
result.append(imp)
imports_added = True
continue
- if stripped.startswith("public class") or stripped.startswith("class"):
+ if stripped.startswith(("public class", "class")):
# No imports found, add before class
for imp in import_statements:
result.append(imp)
@@ -244,7 +245,6 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
i = 0
iteration_counter = 0
-
# Pre-compile the regex pattern once
method_call_pattern = _get_method_call_pattern(func_name)
@@ -291,11 +291,10 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
while i < len(lines) and brace_depth > 0:
body_line = lines[i]
# Count braces more efficiently using string methods
- open_count = body_line.count('{')
- close_count = body_line.count('}')
+ open_count = body_line.count("{")
+ close_count = body_line.count("}")
brace_depth += open_count - close_count
-
if brace_depth > 0:
body_lines.append(body_line)
i += 1
@@ -581,7 +580,7 @@ def create_benchmark_test(
method_id = _get_qualified_name(target_function)
class_name = getattr(target_function, "class_name", None) or "Target"
- benchmark_code = f"""
+ return f"""
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
@@ -615,7 +614,6 @@ def create_benchmark_test(
}}
}}
"""
- return benchmark_code
def remove_instrumentation(source: str) -> str:
@@ -713,7 +711,7 @@ def _add_import(source: str, import_statement: str) -> str:
# Find the last import or package statement
for i, line in enumerate(lines):
stripped = line.strip()
- if stripped.startswith("import ") or stripped.startswith("package "):
+ if stripped.startswith(("import ", "package ")):
insert_idx = i + 1
elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"):
# First non-import, non-comment line
@@ -725,13 +723,11 @@ def _add_import(source: str, import_statement: str) -> str:
return "".join(lines)
-
@lru_cache(maxsize=128)
def _get_method_call_pattern(func_name: str):
"""Cache compiled regex patterns for method call matching."""
return re.compile(
- rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)",
- re.MULTILINE
+ rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
)
@@ -739,6 +735,5 @@ def _get_method_call_pattern(func_name: str):
def _get_method_call_pattern(func_name: str):
"""Cache compiled regex patterns for method call matching."""
return re.compile(
- rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)",
- re.MULTILINE
+ rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
)
diff --git a/codeflash/languages/java/parser.py b/codeflash/languages/java/parser.py
index bdffac44e..72a530179 100644
--- a/codeflash/languages/java/parser.py
+++ b/codeflash/languages/java/parser.py
@@ -13,8 +13,6 @@
from tree_sitter import Language, Parser
if TYPE_CHECKING:
- from pathlib import Path
-
from tree_sitter import Node, Tree
logger = logging.getLogger(__name__)
@@ -222,9 +220,7 @@ def _walk_tree_for_methods(
current_class=new_class if node.type in type_declarations else current_class,
)
- def _extract_method_info(
- self, node: Node, source_bytes: bytes, current_class: str | None
- ) -> JavaMethodNode | None:
+ def _extract_method_info(self, node: Node, source_bytes: bytes, current_class: str | None) -> JavaMethodNode | None:
"""Extract method information from a method_declaration node."""
name = ""
is_static = False
@@ -347,9 +343,7 @@ def _walk_tree_for_classes(
for child in node.children:
self._walk_tree_for_classes(child, source_bytes, classes, is_inner)
- def _extract_class_info(
- self, node: Node, source_bytes: bytes, is_inner: bool
- ) -> JavaClassNode | None:
+ def _extract_class_info(self, node: Node, source_bytes: bytes, is_inner: bool) -> JavaClassNode | None:
"""Extract class information from a class_declaration node."""
name = ""
is_public = False
diff --git a/codeflash/languages/java/replacement.py b/codeflash/languages/java/replacement.py
index 75a9a78e7..92ddd44e2 100644
--- a/codeflash/languages/java/replacement.py
+++ b/codeflash/languages/java/replacement.py
@@ -18,10 +18,10 @@
from typing import TYPE_CHECKING
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
-from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
+from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
- pass
+ from codeflash.languages.java.parser import JavaAnalyzer
logger = logging.getLogger(__name__)
@@ -35,11 +35,7 @@ class ParsedOptimization:
new_helper_methods: list[str] # Source text of new helper methods to add
-def _parse_optimization_source(
- new_source: str,
- target_method_name: str,
- analyzer: JavaAnalyzer,
-) -> ParsedOptimization:
+def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization:
"""Parse optimization source to extract method and additional class members.
The new_source may contain:
@@ -96,18 +92,12 @@ def _parse_optimization_source(
new_fields.append(field.source_text)
return ParsedOptimization(
- target_method_source=target_method_source,
- new_fields=new_fields,
- new_helper_methods=new_helper_methods,
+ target_method_source=target_method_source, new_fields=new_fields, new_helper_methods=new_helper_methods
)
def _insert_class_members(
- source: str,
- class_name: str,
- fields: list[str],
- methods: list[str],
- analyzer: JavaAnalyzer,
+ source: str, class_name: str, fields: list[str], methods: list[str], analyzer: JavaAnalyzer
) -> str:
"""Insert new class members (fields and methods) into a class.
@@ -212,10 +202,7 @@ def _insert_class_members(
def replace_function(
- source: str,
- function: FunctionToOptimize,
- new_source: str,
- analyzer: JavaAnalyzer | None = None,
+ source: str, function: FunctionToOptimize, new_source: str, analyzer: JavaAnalyzer | None = None
) -> str:
"""Replace a function in source code with new implementation.
@@ -257,9 +244,9 @@ def replace_function(
# Find all methods matching the name (there may be overloads)
matching_methods = [
- m for m in methods
- if m.name == func_name
- and (function.class_name is None or m.class_name == function.class_name)
+ m
+ for m in methods
+ if m.name == func_name and (function.class_name is None or m.class_name == function.class_name)
]
if len(matching_methods) == 1:
@@ -296,10 +283,7 @@ def replace_function(
break
if not target_method:
# Fallback: use the first match
- logger.warning(
- "Multiple overloads of %s found but no line match, using first match",
- func_name,
- )
+ logger.warning("Multiple overloads of %s found but no line match, using first match", func_name)
target_method = matching_methods[0]
target_overload_index = 0
@@ -342,18 +326,16 @@ def replace_function(
len(new_helpers_to_add),
class_name,
)
- source = _insert_class_members(
- source, class_name, new_fields_to_add, new_helpers_to_add, analyzer
- )
+ source = _insert_class_members(source, class_name, new_fields_to_add, new_helpers_to_add, analyzer)
# Re-find the target method after modifications
# Line numbers have shifted, but the relative order of overloads is preserved
# Use the target_overload_index we saved earlier
methods = analyzer.find_methods(source)
matching_methods = [
- m for m in methods
- if m.name == func_name
- and (function.class_name is None or m.class_name == function.class_name)
+ m
+ for m in methods
+ if m.name == func_name and (function.class_name is None or m.class_name == function.class_name)
]
if matching_methods and target_overload_index < len(matching_methods):
@@ -398,9 +380,7 @@ def replace_function(
before = lines[: start_line - 1] # Lines before the method
after = lines[end_line:] # Lines after the method
- result = "".join(before) + indented_new_source + "".join(after)
-
- return result
+ return "".join(before) + indented_new_source + "".join(after)
def _get_indentation(line: str) -> str:
@@ -460,10 +440,7 @@ def _apply_indentation(lines: list[str], base_indent: str) -> str:
def replace_method_body(
- source: str,
- function: FunctionToOptimize,
- new_body: str,
- analyzer: JavaAnalyzer | None = None,
+ source: str, function: FunctionToOptimize, new_body: str, analyzer: JavaAnalyzer | None = None
) -> str:
"""Replace just the body of a method, preserving signature.
@@ -600,11 +577,7 @@ def insert_method(
return (before + separator.encode("utf8") + indented_method.encode("utf8") + after).decode("utf8")
-def remove_method(
- source: str,
- function: FunctionToOptimize,
- analyzer: JavaAnalyzer | None = None,
-) -> str:
+def remove_method(source: str, function: FunctionToOptimize, analyzer: JavaAnalyzer | None = None) -> str:
"""Remove a method from source code.
Args:
@@ -648,9 +621,7 @@ def remove_method(
def remove_test_functions(
- test_source: str,
- functions_to_remove: list[str],
- analyzer: JavaAnalyzer | None = None,
+ test_source: str, functions_to_remove: list[str], analyzer: JavaAnalyzer | None = None
) -> str:
"""Remove specific test functions from test source code.
@@ -669,9 +640,7 @@ def remove_test_functions(
methods = analyzer.find_methods(test_source)
# Sort by start line in reverse order (remove from end first)
- methods_to_remove = [
- m for m in methods if m.name in functions_to_remove
- ]
+ methods_to_remove = [m for m in methods if m.name in functions_to_remove]
methods_to_remove.sort(key=lambda m: m.start_line, reverse=True)
result = test_source
@@ -728,9 +697,7 @@ def add_runtime_comments(
if original_ns > 0:
speedup = ((original_ns - optimized_ns) / original_ns) * 100
- summary_lines.append(
- f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)"
- )
+ summary_lines.append(f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)")
# Insert after imports
lines = test_source.splitlines(keepends=True)
diff --git a/codeflash/languages/java/support.py b/codeflash/languages/java/support.py
index 6fb015cd2..ed1bb339c 100644
--- a/codeflash/languages/java/support.py
+++ b/codeflash/languages/java/support.py
@@ -7,20 +7,9 @@
from __future__ import annotations
import logging
-from pathlib import Path
from typing import TYPE_CHECKING, Any
-from codeflash.discovery.functions_to_optimize import FunctionToOptimize
-from codeflash.languages.base import (
- CodeContext,
- FunctionFilterCriteria,
- HelperFunction,
- Language,
- LanguageSupport,
- TestInfo,
- TestResult,
-)
-from codeflash.languages.registry import register_language
+from codeflash.languages.base import Language, LanguageSupport
from codeflash.languages.java.build_tools import find_test_root
from codeflash.languages.java.comparator import compare_test_results as _compare_test_results
from codeflash.languages.java.config import detect_java_project
@@ -33,11 +22,7 @@
instrument_for_benchmarking,
)
from codeflash.languages.java.parser import get_java_analyzer
-from codeflash.languages.java.replacement import (
- add_runtime_comments,
- remove_test_functions,
- replace_function,
-)
+from codeflash.languages.java.replacement import add_runtime_comments, remove_test_functions, replace_function
from codeflash.languages.java.test_discovery import discover_tests
from codeflash.languages.java.test_runner import (
parse_test_results,
@@ -45,9 +30,14 @@
run_benchmarking_tests,
run_tests,
)
+from codeflash.languages.registry import register_language
if TYPE_CHECKING:
from collections.abc import Sequence
+ from pathlib import Path
+
+ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
+ from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, TestInfo, TestResult
logger = logging.getLogger(__name__)
@@ -112,23 +102,17 @@ def discover_tests(
# === Code Analysis ===
- def extract_code_context(
- self, function: FunctionToOptimize, project_root: Path, module_root: Path
- ) -> CodeContext:
+ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
"""Extract function code and its dependencies."""
return extract_code_context(function, project_root, module_root, analyzer=self._analyzer)
- def find_helper_functions(
- self, function: FunctionToOptimize, project_root: Path
- ) -> list[HelperFunction]:
+ def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
"""Find helper functions called by the target function."""
return find_helper_functions(function, project_root, analyzer=self._analyzer)
# === Code Transformation ===
- def replace_function(
- self, source: str, function: FunctionToOptimize, new_source: str
- ) -> str:
+ def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
"""Replace a function in source code with new implementation."""
return replace_function(source, function, new_source, self._analyzer)
@@ -140,11 +124,7 @@ def format_code(self, source: str, file_path: Path | None = None) -> str:
# === Test Execution ===
def run_tests(
- self,
- test_files: Sequence[Path],
- cwd: Path,
- env: dict[str, str],
- timeout: int,
+ self, test_files: Sequence[Path], cwd: Path, env: dict[str, str], timeout: int
) -> tuple[list[TestResult], Path]:
"""Run tests and return results."""
return run_tests(list(test_files), cwd, env, timeout)
@@ -155,15 +135,11 @@ def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResu
# === Instrumentation ===
- def instrument_for_behavior(
- self, source: str, functions: Sequence[FunctionToOptimize]
- ) -> str:
+ def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str:
"""Add behavior instrumentation to capture inputs/outputs."""
return instrument_for_behavior(source, functions, self._analyzer)
- def instrument_for_benchmarking(
- self, test_source: str, target_function: FunctionToOptimize
- ) -> str:
+ def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
"""Add timing instrumentation to test code."""
return instrument_for_benchmarking(test_source, target_function, self._analyzer)
@@ -180,32 +156,22 @@ def normalize_code(self, source: str) -> str:
# === Test Editing ===
def add_runtime_comments(
- self,
- test_source: str,
- original_runtimes: dict[str, int],
- optimized_runtimes: dict[str, int],
+ self, test_source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
) -> str:
"""Add runtime performance comments to test source code."""
return add_runtime_comments(test_source, original_runtimes, optimized_runtimes, self._analyzer)
- def remove_test_functions(
- self, test_source: str, functions_to_remove: list[str]
- ) -> str:
+ def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str:
"""Remove specific test functions from test source code."""
return remove_test_functions(test_source, functions_to_remove, self._analyzer)
# === Test Result Comparison ===
def compare_test_results(
- self,
- original_results_path: Path,
- candidate_results_path: Path,
- project_root: Path | None = None,
+ self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
) -> tuple[bool, list]:
"""Compare test results between original and candidate code."""
- return _compare_test_results(
- original_results_path, candidate_results_path, project_root=project_root
- )
+ return _compare_test_results(original_results_path, candidate_results_path, project_root=project_root)
# === Configuration ===
@@ -308,12 +274,7 @@ def instrument_existing_test(
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file."""
return instrument_existing_test(
- test_path,
- call_positions,
- function_to_optimize,
- tests_project_root,
- mode,
- self._analyzer,
+ test_path, call_positions, function_to_optimize, tests_project_root, mode, self._analyzer
)
def instrument_source_for_line_profiler(
@@ -339,15 +300,7 @@ def run_behavioral_tests(
candidate_index: int = 0,
) -> tuple[Path, Any, Path | None, Path | None]:
"""Run behavioral tests for Java."""
- return run_behavioral_tests(
- test_paths,
- test_env,
- cwd,
- timeout,
- project_root,
- enable_coverage,
- candidate_index,
- )
+ return run_behavioral_tests(test_paths, test_env, cwd, timeout, project_root, enable_coverage, candidate_index)
def run_benchmarking_tests(
self,
diff --git a/codeflash/languages/java/test_discovery.py b/codeflash/languages/java/test_discovery.py
index aef25a8cb..67c11316b 100644
--- a/codeflash/languages/java/test_discovery.py
+++ b/codeflash/languages/java/test_discovery.py
@@ -7,27 +7,26 @@
from __future__ import annotations
import logging
-import re
from collections import defaultdict
-from pathlib import Path
from typing import TYPE_CHECKING
-from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import TestInfo
from codeflash.languages.java.config import detect_java_project
from codeflash.languages.java.discovery import discover_test_methods
-from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
+from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
from collections.abc import Sequence
+ from pathlib import Path
+
+ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
+ from codeflash.languages.java.parser import JavaAnalyzer
logger = logging.getLogger(__name__)
def discover_tests(
- test_root: Path,
- source_functions: Sequence[FunctionToOptimize],
- analyzer: JavaAnalyzer | None = None,
+ test_root: Path, source_functions: Sequence[FunctionToOptimize], analyzer: JavaAnalyzer | None = None
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests via static analysis.
@@ -56,9 +55,7 @@ def discover_tests(
# Find all test files (various naming conventions)
test_files = (
- list(test_root.rglob("*Test.java"))
- + list(test_root.rglob("*Tests.java"))
- + list(test_root.rglob("Test*.java"))
+ list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java"))
)
# Result map
@@ -71,16 +68,12 @@ def discover_tests(
for test_method in test_methods:
# Find which source functions this test might exercise
- matched_functions = _match_test_to_functions(
- test_method, source, function_map, analyzer
- )
+ matched_functions = _match_test_to_functions(test_method, source, function_map, analyzer)
for func_name in matched_functions:
result[func_name].append(
TestInfo(
- test_name=test_method.function_name,
- test_file=test_file,
- test_class=test_method.class_name,
+ test_name=test_method.function_name, test_file=test_file, test_class=test_method.class_name
)
)
@@ -114,7 +107,7 @@ def _match_test_to_functions(
# e.g., testAdd -> add, testCalculatorAdd -> Calculator.add
test_name_lower = test_method.function_name.lower()
- for func_name, func_info in function_map.items():
+ for func_info in function_map.values():
if func_info.function_name.lower() in test_name_lower:
matched.append(func_info.qualified_name)
@@ -125,11 +118,7 @@ def _match_test_to_functions(
# Find method calls within the test method's line range
method_calls = _find_method_calls_in_range(
- tree.root_node,
- source_bytes,
- test_method.starting_line,
- test_method.ending_line,
- analyzer,
+ tree.root_node, source_bytes, test_method.starting_line, test_method.ending_line, analyzer
)
for call_name in method_calls:
@@ -151,7 +140,7 @@ def _match_test_to_functions(
source_class_name = source_class_name[4:]
# Look for functions in the matching class
- for func_name, func_info in function_map.items():
+ for func_info in function_map.values():
if func_info.class_name == source_class_name:
if func_info.qualified_name not in matched:
matched.append(func_info.qualified_name)
@@ -161,7 +150,7 @@ def _match_test_to_functions(
# This handles cases like TestQueryBlob importing Buffer and calling Buffer methods
imported_classes = _extract_imports(tree.root_node, source_bytes, analyzer)
- for func_name, func_info in function_map.items():
+ for func_info in function_map.values():
if func_info.qualified_name in matched:
continue
@@ -172,11 +161,7 @@ def _match_test_to_functions(
return matched
-def _extract_imports(
- node,
- source_bytes: bytes,
- analyzer: JavaAnalyzer,
-) -> set[str]:
+def _extract_imports(node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[str]:
"""Extract imported class names from a Java file.
Args:
@@ -224,7 +209,7 @@ def visit(n):
# Regular import: extract class name from scoped_identifier
for child in n.children:
- if child.type == "scoped_identifier" or child.type == "identifier":
+ if child.type in {"scoped_identifier", "identifier"}:
import_path = analyzer.get_node_text(child, source_bytes)
# Extract just the class name (last part)
# e.g., "com.example.Buffer" -> "Buffer"
@@ -244,11 +229,7 @@ def visit(n):
def _find_method_calls_in_range(
- node,
- source_bytes: bytes,
- start_line: int,
- end_line: int,
- analyzer: JavaAnalyzer,
+ node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer
) -> list[str]:
"""Find method calls within a line range.
@@ -278,17 +259,13 @@ def _find_method_calls_in_range(
calls.append(analyzer.get_node_text(name_node, source_bytes))
for child in node.children:
- calls.extend(
- _find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer)
- )
+ calls.extend(_find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer))
return calls
def find_tests_for_function(
- function: FunctionToOptimize,
- test_root: Path,
- analyzer: JavaAnalyzer | None = None,
+ function: FunctionToOptimize, test_root: Path, analyzer: JavaAnalyzer | None = None
) -> list[TestInfo]:
"""Find tests that exercise a specific function.
@@ -305,10 +282,7 @@ def find_tests_for_function(
return result.get(function.qualified_name, [])
-def get_test_class_for_source_class(
- source_class_name: str,
- test_root: Path,
-) -> Path | None:
+def get_test_class_for_source_class(source_class_name: str, test_root: Path) -> Path | None:
"""Find the test class file for a source class.
Args:
@@ -320,11 +294,7 @@ def get_test_class_for_source_class(
"""
# Try common naming patterns
- patterns = [
- f"{source_class_name}Test.java",
- f"Test{source_class_name}.java",
- f"{source_class_name}Tests.java",
- ]
+ patterns = [f"{source_class_name}Test.java", f"Test{source_class_name}.java", f"{source_class_name}Tests.java"]
for pattern in patterns:
matches = list(test_root.rglob(pattern))
@@ -334,10 +304,7 @@ def get_test_class_for_source_class(
return None
-def discover_all_tests(
- test_root: Path,
- analyzer: JavaAnalyzer | None = None,
-) -> list[FunctionToOptimize]:
+def discover_all_tests(test_root: Path, analyzer: JavaAnalyzer | None = None) -> list[FunctionToOptimize]:
"""Discover all test methods in a test directory.
Args:
@@ -353,9 +320,7 @@ def discover_all_tests(
# Find all test files (various naming conventions)
test_files = (
- list(test_root.rglob("*Test.java"))
- + list(test_root.rglob("*Tests.java"))
- + list(test_root.rglob("Test*.java"))
+ list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java"))
)
for test_file in test_files:
@@ -391,24 +356,18 @@ def is_test_file(file_path: Path) -> bool:
name = file_path.name
# Check naming patterns
- if name.endswith("Test.java") or name.endswith("Tests.java"):
+ if name.endswith(("Test.java", "Tests.java")):
return True
if name.startswith("Test") and name.endswith(".java"):
return True
# Check if it's in a test directory
path_parts = file_path.parts
- for part in path_parts:
- if part in ("test", "tests", "src/test"):
- return True
-
- return False
+ return any(part in ("test", "tests", "src/test") for part in path_parts)
def get_test_methods_for_class(
- test_file: Path,
- test_class_name: str | None = None,
- analyzer: JavaAnalyzer | None = None,
+ test_file: Path, test_class_name: str | None = None, analyzer: JavaAnalyzer | None = None
) -> list[FunctionToOptimize]:
"""Get all test methods in a specific test class.
@@ -430,8 +389,7 @@ def get_test_methods_for_class(
def build_test_mapping_for_project(
- project_root: Path,
- analyzer: JavaAnalyzer | None = None,
+ project_root: Path, analyzer: JavaAnalyzer | None = None
) -> dict[str, list[TestInfo]]:
"""Build a complete test mapping for a project.
diff --git a/codeflash/languages/java/test_runner.py b/codeflash/languages/java/test_runner.py
index 0455782e7..b5e0618a8 100644
--- a/codeflash/languages/java/test_runner.py
+++ b/codeflash/languages/java/test_runner.py
@@ -31,7 +31,7 @@
# Regex pattern for valid Java class names (package.ClassName format)
# Allows: letters, digits, underscores, dots, and dollar signs (inner classes)
-_VALID_JAVA_CLASS_NAME = re.compile(r'^[a-zA-Z_$][a-zA-Z0-9_$.]*$')
+_VALID_JAVA_CLASS_NAME = re.compile(r"^[a-zA-Z_$][a-zA-Z0-9_$.]*$")
def _validate_java_class_name(class_name: str) -> bool:
@@ -44,6 +44,7 @@ def _validate_java_class_name(class_name: str) -> bool:
Returns:
True if valid, False otherwise.
+
"""
return bool(_VALID_JAVA_CLASS_NAME.match(class_name))
@@ -62,19 +63,21 @@ def _validate_test_filter(test_filter: str) -> str:
Raises:
ValueError: If the test filter contains invalid characters.
+
"""
# Split by comma for multiple test patterns
- patterns = [p.strip() for p in test_filter.split(',')]
+ patterns = [p.strip() for p in test_filter.split(",")]
for pattern in patterns:
# Remove wildcards for validation (they're allowed in test filters)
- name_to_validate = pattern.replace('*', 'A') # Replace * with a valid char
+ name_to_validate = pattern.replace("*", "A") # Replace * with a valid char
if not _validate_java_class_name(name_to_validate):
- raise ValueError(
+ msg = (
f"Invalid test class name or pattern: '{pattern}'. "
f"Test names must follow Java identifier rules (letters, digits, underscores, dots, dollar signs)."
)
+ raise ValueError(msg)
return test_filter
@@ -134,6 +137,7 @@ def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path,
# This is a multi-module project root
# Extract modules from pom.xml
import re
+
modules = re.findall(r"([^<]+)", content)
# Check if test file is in one of the modules
for test_path in test_file_paths:
@@ -310,10 +314,7 @@ def run_behavioral_tests(
def _compile_tests(
- project_root: Path,
- env: dict[str, str],
- test_module: str | None = None,
- timeout: int = 120,
+ project_root: Path, env: dict[str, str], test_module: str | None = None, timeout: int = 120
) -> subprocess.CompletedProcess:
"""Compile test code using Maven (without running tests).
@@ -330,12 +331,7 @@ def _compile_tests(
mvn = find_maven_executable()
if not mvn:
logger.error("Maven not found")
- return subprocess.CompletedProcess(
- args=["mvn"],
- returncode=-1,
- stdout="",
- stderr="Maven not found",
- )
+ return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
cmd = [mvn, "test-compile", "-e"] # Show errors but not verbose output
@@ -346,37 +342,20 @@ def _compile_tests(
try:
return subprocess.run(
- cmd,
- check=False,
- cwd=project_root,
- env=env,
- capture_output=True,
- text=True,
- timeout=timeout,
+ cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout
)
except subprocess.TimeoutExpired:
- logger.error("Maven compilation timed out after %d seconds", timeout)
+ logger.exception("Maven compilation timed out after %d seconds", timeout)
return subprocess.CompletedProcess(
- args=cmd,
- returncode=-2,
- stdout="",
- stderr=f"Compilation timed out after {timeout} seconds",
+ args=cmd, returncode=-2, stdout="", stderr=f"Compilation timed out after {timeout} seconds"
)
except Exception as e:
logger.exception("Maven compilation failed: %s", e)
- return subprocess.CompletedProcess(
- args=cmd,
- returncode=-1,
- stdout="",
- stderr=str(e),
- )
+ return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def _get_test_classpath(
- project_root: Path,
- env: dict[str, str],
- test_module: str | None = None,
- timeout: int = 60,
+ project_root: Path, env: dict[str, str], test_module: str | None = None, timeout: int = 60
) -> str | None:
"""Get the test classpath from Maven.
@@ -397,13 +376,7 @@ def _get_test_classpath(
# Create temp file for classpath output
cp_file = project_root / ".codeflash_classpath.txt"
- cmd = [
- mvn,
- "dependency:build-classpath",
- "-DincludeScope=test",
- f"-Dmdep.outputFile={cp_file}",
- "-q",
- ]
+ cmd = [mvn, "dependency:build-classpath", "-DincludeScope=test", f"-Dmdep.outputFile={cp_file}", "-q"]
if test_module:
cmd.extend(["-pl", test_module])
@@ -412,13 +385,7 @@ def _get_test_classpath(
try:
result = subprocess.run(
- cmd,
- check=False,
- cwd=project_root,
- env=env,
- capture_output=True,
- text=True,
- timeout=timeout,
+ cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout
)
if result.returncode != 0:
@@ -450,7 +417,7 @@ def _get_test_classpath(
return os.pathsep.join(cp_parts)
except subprocess.TimeoutExpired:
- logger.error("Getting classpath timed out")
+ logger.exception("Getting classpath timed out")
return None
except Exception as e:
logger.exception("Failed to get classpath: %s", e)
@@ -525,30 +492,16 @@ def _run_tests_direct(
try:
return subprocess.run(
- cmd,
- check=False,
- cwd=working_dir,
- env=env,
- capture_output=True,
- text=True,
- timeout=timeout,
+ cmd, check=False, cwd=working_dir, env=env, capture_output=True, text=True, timeout=timeout
)
except subprocess.TimeoutExpired:
- logger.error("Direct test execution timed out after %d seconds", timeout)
+ logger.exception("Direct test execution timed out after %d seconds", timeout)
return subprocess.CompletedProcess(
- args=cmd,
- returncode=-2,
- stdout="",
- stderr=f"Test execution timed out after {timeout} seconds",
+ args=cmd, returncode=-2, stdout="", stderr=f"Test execution timed out after {timeout} seconds"
)
except Exception as e:
logger.exception("Direct test execution failed: %s", e)
- return subprocess.CompletedProcess(
- args=cmd,
- returncode=-1,
- stdout="",
- stderr=str(e),
- )
+ return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def _get_test_class_names(test_paths: Any, mode: str = "performance") -> list[str]:
@@ -603,10 +556,7 @@ def _get_empty_result(maven_root: Path, test_module: str | None) -> tuple[Path,
result_xml_path = _get_combined_junit_xml(surefire_dir, -1)
empty_result = subprocess.CompletedProcess(
- args=["java", "-cp", "...", "ConsoleLauncher"],
- returncode=-1,
- stdout="",
- stderr="No test classes found",
+ args=["java", "-cp", "...", "ConsoleLauncher"], returncode=-1, stdout="", stderr="No test classes found"
)
return result_xml_path, empty_result
@@ -665,12 +615,7 @@ def _run_benchmarking_tests_maven(
run_env["CODEFLASH_INNER_ITERATIONS"] = str(inner_iterations)
result = _run_maven_tests(
- maven_root,
- test_paths,
- run_env,
- timeout=per_loop_timeout,
- mode="performance",
- test_module=test_module,
+ maven_root, test_paths, run_env, timeout=per_loop_timeout, mode="performance", test_module=test_module
)
last_result = result
@@ -683,27 +628,20 @@ def _run_benchmarking_tests_maven(
elapsed = time.time() - total_start_time
if loop_idx >= min_loops and elapsed >= target_duration_seconds:
- logger.debug(
- "Stopping Maven benchmark after %d loops (%.2fs elapsed)",
- loop_idx,
- elapsed,
- )
+ logger.debug("Stopping Maven benchmark after %d loops (%.2fs elapsed)", loop_idx, elapsed)
break
# Check if we have timing markers even if some tests failed
# We should continue looping if we're getting valid timing data
if result.returncode != 0:
import re
+
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
if not has_timing_markers:
logger.warning("Tests failed in Maven loop %d with no timing markers, stopping", loop_idx)
break
- else:
- logger.debug(
- "Some tests failed in Maven loop %d but timing markers present, continuing",
- loop_idx,
- )
+ logger.debug("Some tests failed in Maven loop %d but timing markers present, continuing", loop_idx)
combined_stdout = "\n".join(all_stdout)
combined_stderr = "\n".join(all_stderr)
@@ -801,8 +739,15 @@ def run_benchmarking_tests(
# Fall back to Maven-based execution
logger.warning("Falling back to Maven-based test execution")
return _run_benchmarking_tests_maven(
- test_paths, test_env, cwd, timeout, project_root,
- min_loops, max_loops, target_duration_seconds, inner_iterations
+ test_paths,
+ test_env,
+ cwd,
+ timeout,
+ project_root,
+ min_loops,
+ max_loops,
+ target_duration_seconds,
+ inner_iterations,
)
logger.debug("Compilation completed in %.2fs", compile_time)
@@ -814,8 +759,15 @@ def run_benchmarking_tests(
if not classpath:
logger.warning("Failed to get classpath, falling back to Maven-based execution")
return _run_benchmarking_tests_maven(
- test_paths, test_env, cwd, timeout, project_root,
- min_loops, max_loops, target_duration_seconds, inner_iterations
+ test_paths,
+ test_env,
+ cwd,
+ timeout,
+ project_root,
+ min_loops,
+ max_loops,
+ target_duration_seconds,
+ inner_iterations,
)
# Step 3: Run tests multiple times directly via JVM
@@ -853,12 +805,7 @@ def run_benchmarking_tests(
# Run tests directly with XML report generation
loop_start = time.time()
result = _run_tests_direct(
- classpath,
- test_classes,
- run_env,
- working_dir,
- timeout=per_loop_timeout,
- reports_dir=reports_dir,
+ classpath, test_classes, run_env, working_dir, timeout=per_loop_timeout, reports_dir=reports_dir
)
loop_time = time.time() - loop_start
@@ -875,12 +822,7 @@ def run_benchmarking_tests(
# Check if JUnit Console Launcher is not available (JUnit 4 projects)
# Fall back to Maven-based execution in this case
- if (
- loop_idx == 1
- and result.returncode != 0
- and result.stderr
- and "ConsoleLauncher" in result.stderr
- ):
+ if loop_idx == 1 and result.returncode != 0 and result.stderr and "ConsoleLauncher" in result.stderr:
logger.debug("JUnit Console Launcher not available, falling back to Maven-based execution")
return _run_benchmarking_tests_maven(
test_paths,
@@ -909,16 +851,13 @@ def run_benchmarking_tests(
# Check if tests failed - continue looping if we have timing markers
if result.returncode != 0:
import re
+
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
if not has_timing_markers:
logger.warning("Tests failed in loop %d with no timing markers, stopping benchmark", loop_idx)
break
- else:
- logger.debug(
- "Some tests failed in loop %d but timing markers present, continuing",
- loop_idx,
- )
+ logger.debug("Some tests failed in loop %d but timing markers present, continuing", loop_idx)
# Create a combined result with all stdout
combined_stdout = "\n".join(all_stdout)
@@ -1075,12 +1014,7 @@ def _run_maven_tests(
mvn = find_maven_executable()
if not mvn:
logger.error("Maven not found")
- return subprocess.CompletedProcess(
- args=["mvn"],
- returncode=-1,
- stdout="",
- stderr="Maven not found",
- )
+ return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
# Build test filter
test_filter = _build_test_filter(test_paths, mode=mode)
@@ -1110,33 +1044,18 @@ def _run_maven_tests(
logger.debug("Running Maven command: %s in %s", " ".join(cmd), project_root)
try:
- result = subprocess.run(
- cmd,
- check=False,
- cwd=project_root,
- env=env,
- capture_output=True,
- text=True,
- timeout=timeout,
+ return subprocess.run(
+ cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout
)
- return result
except subprocess.TimeoutExpired:
- logger.error("Maven test execution timed out after %d seconds", timeout)
+ logger.exception("Maven test execution timed out after %d seconds", timeout)
return subprocess.CompletedProcess(
- args=cmd,
- returncode=-2,
- stdout="",
- stderr=f"Test execution timed out after {timeout} seconds",
+ args=cmd, returncode=-2, stdout="", stderr=f"Test execution timed out after {timeout} seconds"
)
except Exception as e:
logger.exception("Maven test execution failed: %s", e)
- return subprocess.CompletedProcess(
- args=cmd,
- returncode=-1,
- stdout="",
- stderr=str(e),
- )
+ return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def _build_test_filter(test_paths: Any, mode: str = "behavior") -> str:
@@ -1196,7 +1115,7 @@ def _path_to_class_name(path: Path) -> str | None:
Fully qualified class name, or None if unable to determine.
"""
- if not path.suffix == ".java":
+ if path.suffix != ".java":
return None
# Try to extract package from path
@@ -1219,7 +1138,7 @@ def _path_to_class_name(path: Path) -> str | None:
break
if java_idx is not None:
- class_parts = parts[java_idx + 1:]
+ class_parts = parts[java_idx + 1 :]
# Remove .java extension from last part
class_parts[-1] = class_parts[-1].replace(".java", "")
return ".".join(class_parts)
@@ -1228,12 +1147,7 @@ def _path_to_class_name(path: Path) -> str | None:
return path.stem
-def run_tests(
- test_files: list[Path],
- cwd: Path,
- env: dict[str, str],
- timeout: int,
-) -> tuple[list[TestResult], Path]:
+def run_tests(test_files: list[Path], cwd: Path, env: dict[str, str], timeout: int) -> tuple[list[TestResult], Path]:
"""Run tests and return results.
Args:
@@ -1366,10 +1280,7 @@ def _parse_surefire_xml(xml_file: Path) -> list[TestResult]:
return results
-def get_test_run_command(
- project_root: Path,
- test_classes: list[str] | None = None,
-) -> list[str]:
+def get_test_run_command(project_root: Path, test_classes: list[str] | None = None) -> list[str]:
"""Get the command to run Java tests.
Args:
@@ -1389,10 +1300,8 @@ def get_test_run_command(
validated_classes = []
for test_class in test_classes:
if not _validate_java_class_name(test_class):
- raise ValueError(
- f"Invalid test class name: '{test_class}'. "
- f"Test names must follow Java identifier rules."
- )
+ msg = f"Invalid test class name: '{test_class}'. Test names must follow Java identifier rules."
+ raise ValueError(msg)
validated_classes.append(test_class)
cmd.append(f"-Dtest={','.join(validated_classes)}")
diff --git a/codeflash/languages/javascript/find_references.py b/codeflash/languages/javascript/find_references.py
index 812f7c4a7..8fe144a06 100644
--- a/codeflash/languages/javascript/find_references.py
+++ b/codeflash/languages/javascript/find_references.py
@@ -213,7 +213,7 @@ def find_references(
if import_info:
context.visited_files.add(file_path)
- import_name, original_import = import_info
+ import_name, _original_import = import_info
file_refs = self._find_references_in_file(
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
)
diff --git a/codeflash/languages/javascript/module_system.py b/codeflash/languages/javascript/module_system.py
index 4e4e3bb0c..dcd2d2fc7 100644
--- a/codeflash/languages/javascript/module_system.py
+++ b/codeflash/languages/javascript/module_system.py
@@ -373,9 +373,14 @@ def ensure_vitest_imports(code: str, test_framework: str) -> str:
insert_index = 0
for i, line in enumerate(lines):
stripped = line.strip()
- if stripped and not stripped.startswith("//") and not stripped.startswith("/*") and not stripped.startswith("*"):
+ if (
+ stripped
+ and not stripped.startswith("//")
+ and not stripped.startswith("/*")
+ and not stripped.startswith("*")
+ ):
# Check if this line is an import/require - insert after imports
- if stripped.startswith("import ") or stripped.startswith("const ") or stripped.startswith("let "):
+ if stripped.startswith(("import ", "const ", "let ")):
continue
insert_index = i
break
diff --git a/codeflash/models/models.py b/codeflash/models/models.py
index d09654722..2a034afdf 100644
--- a/codeflash/models/models.py
+++ b/codeflash/models/models.py
@@ -325,9 +325,7 @@ def file_to_path(self) -> dict[str, str]:
"""
if "file_to_path" in self._cache:
return self._cache["file_to_path"]
- result = {
- str(code_string.file_path): code_string.code for code_string in self.code_strings
- }
+ result = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
self._cache["file_to_path"] = result
return result
diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py
index 900d3ea8c..be69bd544 100644
--- a/codeflash/optimization/function_optimizer.py
+++ b/codeflash/optimization/function_optimizer.py
@@ -2,7 +2,6 @@
import ast
import concurrent.futures
-import logging
import os
import queue
import random
@@ -23,7 +22,7 @@
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success
from codeflash.benchmarking.utils import process_benchmark_data
-from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar
+from codeflash.cli_cmds.console import DEBUG_MODE, code_print, console, logger, lsp_log, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import get_opt_review_metrics, is_numerical_code
from codeflash.code_utils.code_replacer import (
@@ -146,9 +145,70 @@
from codeflash.verification.verification_utils import TestConfig
+def log_code_after_replacement(file_path: Path, candidate_index: int) -> None:
+ """Log the full file content after code replacement in verbose mode."""
+ if not DEBUG_MODE:
+ return
+
+ try:
+ code = file_path.read_text(encoding="utf-8")
+ lang_map = {".java": "java", ".py": "python", ".js": "javascript", ".ts": "typescript"}
+ language = lang_map.get(file_path.suffix.lower(), "text")
+
+ console.print(
+ Panel(
+ Syntax(code, language, line_numbers=True, theme="monokai", word_wrap=True),
+ title=f"[bold blue]Code After Replacement (Candidate {candidate_index})[/] [dim]({file_path.name})[/]",
+ border_style="blue",
+ )
+ )
+ except Exception as e:
+ logger.debug(f"Failed to log code after replacement: {e}")
+
+
+def log_instrumented_test(test_source: str, test_name: str, test_type: str, language: str) -> None:
+ """Log instrumented test code in verbose mode."""
+ if not DEBUG_MODE:
+ return
+
+ display_source = test_source
+ if len(test_source) > 15000:
+ display_source = test_source[:15000] + "\n\n... [truncated] ..."
+
+ console.print(
+ Panel(
+ Syntax(display_source, language, line_numbers=True, theme="monokai", word_wrap=True),
+ title=f"[bold magenta]Instrumented Test: {test_name}[/] [dim]({test_type})[/]",
+ border_style="magenta",
+ )
+ )
+
+
+def log_test_run_output(stdout: str, stderr: str, test_type: str, returncode: int = 0) -> None:
+ """Log test run stdout/stderr in verbose mode."""
+ if not DEBUG_MODE:
+ return
+
+ max_len = 10000
+
+ if stdout and stdout.strip():
+ display_stdout = stdout[:max_len] + ("...[truncated]" if len(stdout) > max_len else "")
+ console.print(
+ Panel(
+ display_stdout,
+ title=f"[bold green]{test_type} - stdout[/] [dim](exit: {returncode})[/]",
+ border_style="green" if returncode == 0 else "red",
+ )
+ )
+
+ if stderr and stderr.strip():
+ display_stderr = stderr[:max_len] + ("...[truncated]" if len(stderr) > max_len else "")
+ console.print(Panel(display_stderr, title=f"[bold yellow]{test_type} - stderr[/]", border_style="yellow"))
+
+
def log_optimization_context(function_name: str, code_context: CodeOptimizationContext) -> None:
"""Log optimization context details when in verbose mode using Rich formatting."""
- if logger.getEffectiveLevel() > logging.DEBUG:
+ if not DEBUG_MODE:
return
console.rule()
@@ -594,18 +654,32 @@ def generate_and_instrument_tests(
generated_test.instrumented_perf_test_source = modified_perf_source
used_behavior_paths.add(behavior_path)
- logger.debug(
- f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}"
- )
+ logger.debug(f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}")
with behavior_path.open("w", encoding="utf8") as f:
f.write(generated_test.instrumented_behavior_test_source)
logger.debug(f"[PIPELINE] Wrote behavioral test to {behavior_path}")
+ # Verbose: Log instrumented behavior test
+ log_instrumented_test(
+ generated_test.instrumented_behavior_test_source,
+ behavior_path.name,
+ "Behavioral Test",
+ language=self.function_to_optimize.language,
+ )
+
with perf_path.open("w", encoding="utf8") as f:
f.write(generated_test.instrumented_perf_test_source)
logger.debug(f"[PIPELINE] Wrote perf test to {perf_path}")
+ # Verbose: Log instrumented performance test
+ log_instrumented_test(
+ generated_test.instrumented_perf_test_source,
+ perf_path.name,
+ "Performance Test",
+ language=self.function_to_optimize.language,
+ )
+
# File paths are expected to be absolute - resolved at their source (CLI, TestConfig, etc.)
test_file_obj = TestFile(
instrumented_behavior_file_path=generated_test.behavior_file_path,
@@ -675,22 +749,24 @@ def _get_java_sources_root(self) -> Path:
parts = tests_root.parts
# Look for standard Java package prefixes that indicate the start of package structure
- standard_package_prefixes = ('com', 'org', 'net', 'io', 'edu', 'gov')
+ standard_package_prefixes = ("com", "org", "net", "io", "edu", "gov")
for i, part in enumerate(parts):
if part in standard_package_prefixes:
# Found start of package path, return everything before it
if i > 0:
java_sources_root = Path(*parts[:i])
- logger.debug(f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})")
+ logger.debug(
+ f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})"
+ )
return java_sources_root
# If no standard package prefix found, check if there's a 'java' directory
# (standard Maven structure: src/test/java)
for i, part in enumerate(parts):
- if part == 'java' and i > 0:
+ if part == "java" and i > 0:
# Return up to and including 'java'
- java_sources_root = Path(*parts[:i + 1])
+ java_sources_root = Path(*parts[: i + 1])
logger.debug(f"[JAVA] Detected Maven-style Java sources root: {java_sources_root}")
return java_sources_root
@@ -721,16 +797,16 @@ def _fix_java_test_paths(
import re
# Extract package from behavior source
- package_match = re.search(r'^\s*package\s+([\w.]+)\s*;', behavior_source, re.MULTILINE)
+ package_match = re.search(r"^\s*package\s+([\w.]+)\s*;", behavior_source, re.MULTILINE)
package_name = package_match.group(1) if package_match else ""
# Extract class name from behavior source
# Use more specific pattern to avoid matching words like "command" or text in comments
- class_match = re.search(r'^(?:public\s+)?class\s+(\w+)', behavior_source, re.MULTILINE)
+ class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", behavior_source, re.MULTILINE)
behavior_class = class_match.group(1) if class_match else "GeneratedTest"
# Extract class name from perf source
- perf_class_match = re.search(r'^(?:public\s+)?class\s+(\w+)', perf_source, re.MULTILINE)
+ perf_class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", perf_source, re.MULTILINE)
perf_class = perf_class_match.group(1) if perf_class_match else "GeneratedPerfTest"
# Build paths with package structure
@@ -767,22 +843,20 @@ def _fix_java_test_paths(
perf_path = new_perf_path
# Rename class in source code - replace the class declaration
modified_behavior_source = re.sub(
- rf'^((?:public\s+)?class\s+){re.escape(behavior_class)}(\b)',
- rf'\g<1>{new_behavior_class}\g<2>',
+ rf"^((?:public\s+)?class\s+){re.escape(behavior_class)}(\b)",
+ rf"\g<1>{new_behavior_class}\g<2>",
behavior_source,
count=1,
flags=re.MULTILINE,
)
modified_perf_source = re.sub(
- rf'^((?:public\s+)?class\s+){re.escape(perf_class)}(\b)',
- rf'\g<1>{new_perf_class}\g<2>',
+ rf"^((?:public\s+)?class\s+){re.escape(perf_class)}(\b)",
+ rf"\g<1>{new_perf_class}\g<2>",
perf_source,
count=1,
flags=re.MULTILINE,
)
- logger.debug(
- f"[JAVA] Renamed duplicate test class from {behavior_class} to {new_behavior_class}"
- )
+ logger.debug(f"[JAVA] Renamed duplicate test class from {behavior_class} to {new_behavior_class}")
break
index += 1
@@ -1199,6 +1273,9 @@ def process_single_candidate(
logger.info("No functions were replaced in the optimized code. Skipping optimization candidate.")
console.rule()
return None
+
+ # Verbose: Log code after replacement
+ log_code_after_replacement(self.function_to_optimize.file_path, candidate_index)
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
logger.error(e)
self.write_code_and_helpers(
@@ -1764,6 +1841,14 @@ def get_instrumented_path(original_path: str, suffix: str) -> Path:
with new_behavioral_test_path.open("w", encoding="utf8") as _f:
_f.write(injected_behavior_test)
logger.debug(f"[PIPELINE] Wrote instrumented behavior test to {new_behavioral_test_path}")
+
+ # Verbose: Log instrumented existing behavior test
+ log_instrumented_test(
+ injected_behavior_test,
+ new_behavioral_test_path.name,
+ "Existing Behavioral Test",
+ language=self.function_to_optimize.language,
+ )
else:
msg = "injected_behavior_test is None"
raise ValueError(msg)
@@ -1773,6 +1858,14 @@ def get_instrumented_path(original_path: str, suffix: str) -> Path:
_f.write(injected_perf_test)
logger.debug(f"[PIPELINE] Wrote instrumented perf test to {new_perf_test_path}")
+ # Verbose: Log instrumented existing performance test
+ log_instrumented_test(
+ injected_perf_test,
+ new_perf_test_path.name,
+ "Existing Performance Test",
+ language=self.function_to_optimize.language,
+ )
+
unique_instrumented_test_files.add(new_behavioral_test_path)
unique_instrumented_test_files.add(new_perf_test_path)
@@ -2239,7 +2332,7 @@ def process_review(
formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds)
generated_tests_str += f"```{code_lang}\n{formatted_generated_test}\n```\n\n"
- existing_tests, replay_tests, concolic_tests = existing_tests_source_for(
+ existing_tests, replay_tests, _concolic_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
function_to_all_tests,
test_cfg=self.test_cfg,
@@ -2880,6 +2973,11 @@ def run_and_parse_tests(
else:
msg = f"Unexpected testing type: {testing_type}"
raise ValueError(msg)
+
+ # Verbose: Log test run output
+ log_test_run_output(
+ run_result.stdout, run_result.stderr, f"Test Run ({testing_type.name})", run_result.returncode
+ )
except subprocess.TimeoutExpired:
logger.exception(
f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error"
diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py
index 759e4ecb2..6e34648c3 100644
--- a/codeflash/verification/parse_test_output.py
+++ b/codeflash/verification/parse_test_output.py
@@ -512,8 +512,10 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
# Check if the file name matches the module path
file_stem = test_file.instrumented_behavior_file_path.stem
# The instrumented file has __perfinstrumented suffix
- original_class = file_stem.replace("__perfinstrumented", "").replace("__perfonlyinstrumented", "")
- if original_class == test_module_path or file_stem == test_module_path:
+ original_class = file_stem.replace("__perfinstrumented", "").replace(
+ "__perfonlyinstrumented", ""
+ )
+ if test_module_path in (original_class, file_stem):
test_file_path = test_file.instrumented_behavior_file_path
break
# Check original file path
@@ -551,7 +553,9 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
# Default to GENERATED_REGRESSION for Jest/Java tests when test type can't be determined
if test_type is None and (is_jest or is_java_test):
test_type = TestType.GENERATED_REGRESSION
- logger.debug(f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION ({'Jest' if is_jest else 'Java'})")
+ logger.debug(
+ f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION ({'Jest' if is_jest else 'Java'})"
+ )
elif test_type is None:
# Skip results where test type cannot be determined
logger.debug(f"Skipping result for {test_function_name}: could not determine test type")
diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py
index 9766a3951..45b96ff51 100644
--- a/codeflash/verification/verification_utils.py
+++ b/codeflash/verification/verification_utils.py
@@ -146,7 +146,9 @@ def _detect_java_test_framework(self) -> str:
pom_path = current / "pom.xml"
if pom_path.exists():
parent_config = detect_java_project(current)
- if parent_config and (parent_config.has_junit4 or parent_config.has_junit5 or parent_config.has_testng):
+ if parent_config and (
+ parent_config.has_junit4 or parent_config.has_junit5 or parent_config.has_testng
+ ):
return parent_config.test_framework
current = current.parent
diff --git a/codeflash/version.py b/codeflash/version.py
index 6225467e3..67379ab0c 100644
--- a/codeflash/version.py
+++ b/codeflash/version.py
@@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
-__version__ = "0.20.0"
+__version__ = "0.20.0.post414.dev0+2ad731d3"