diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 5f1b895d7..87eefd5d7 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -27,6 +27,9 @@ from codeflash.cli_cmds.console import console, logger from codeflash.cli_cmds.extension import install_vscode_extension +# Import Java init module +from codeflash.cli_cmds.init_java import init_java_project + # Import JS/TS init module from codeflash.cli_cmds.init_javascript import ( ProjectLanguage, @@ -35,9 +38,6 @@ get_js_dependency_installation_commands, init_js_project, ) - -# Import Java init module -from codeflash.cli_cmds.init_java import init_java_project from codeflash.code_utils.code_utils import validate_relative_directory_path from codeflash.code_utils.compat import LF from codeflash.code_utils.config_parser import parse_config_file @@ -1674,9 +1674,7 @@ def _customize_java_workflow_content(optimize_yml_content: str, git_root: Path, # Install dependencies install_deps_cmd = get_java_dependency_installation_commands(build_tool) - optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd) - - return optimize_yml_content + return optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd) def get_formatter_cmds(formatter: str) -> list[str]: diff --git a/codeflash/cli_cmds/init_java.py b/codeflash/cli_cmds/init_java.py index 73822e626..5be5b19a9 100644 --- a/codeflash/cli_cmds/init_java.py +++ b/codeflash/cli_cmds/init_java.py @@ -165,9 +165,7 @@ def init_java_project() -> None: lang_panel = Panel( Text( - "Java project detected!\n\nI'll help you set up Codeflash for your project.", - style="cyan", - justify="center", + "Java project detected!\n\nI'll help you set up Codeflash for your project.", style="cyan", justify="center" ), title="Java Setup", border_style="bright_red", @@ -205,7 +203,9 @@ def init_java_project() -> None: completion_message = "Codeflash is now set up for your Java project!\n\nYou can now run any of these commands:" if did_add_new_key: - completion_message += "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!" + completion_message += ( + "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!" + ) if os.name == "nt": reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}" else: @@ -234,9 +234,7 @@ def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]: codeflash_config_path = project_root / "codeflash.toml" if codeflash_config_path.exists(): return Confirm.ask( - "A Codeflash config already exists. Do you want to re-configure it?", - default=False, - show_default=True, + "A Codeflash config already exists. Do you want to re-configure it?", default=False, show_default=True ), None return True, None @@ -285,14 +283,10 @@ def collect_java_setup_info() -> JavaSetupInfo: if Confirm.ask("Would you like to change any of these settings?", default=False): # Source root override - module_root_override = _prompt_directory_override( - "source", detected_source_root, curdir - ) + module_root_override = _prompt_directory_override("source", detected_source_root, curdir) # Test root override - test_root_override = _prompt_directory_override( - "test", detected_test_root, curdir - ) + test_root_override = _prompt_directory_override("test", detected_test_root, curdir) # Formatter override formatter_questions = [ @@ -300,7 +294,7 @@ def collect_java_setup_info() -> JavaSetupInfo: "formatter", message="Which code formatter do you use?", choices=[ - (f"keep detected (google-java-format)", "keep"), + ("keep detected (google-java-format)", "keep"), ("google-java-format", "google-java-format"), ("spotless", "spotless"), ("other", "other"), @@ -345,7 +339,7 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st subdirs = [d.name for d in curdir.iterdir() if d.is_dir() and not d.name.startswith(".")] subdirs = [d for d in subdirs if d not in ("target", "build", ".git", ".idea", detected)] - options = [keep_detected_option] + subdirs[:5] + [custom_dir_option] + options = [keep_detected_option, *subdirs[:5], custom_dir_option] questions = [ inquirer.List( @@ -364,10 +358,9 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st answer = answers[f"{dir_type}_root"] if answer == keep_detected_option: return None - elif answer == custom_dir_option: + if answer == custom_dir_option: return _prompt_custom_directory(dir_type) - else: - return answer + return answer def _prompt_custom_directory(dir_type: str) -> str: @@ -441,7 +434,7 @@ def get_java_formatter_cmd(formatter: str, build_tool: JavaBuildTool) -> list[st if formatter == "spotless": if build_tool == JavaBuildTool.MAVEN: return ["mvn spotless:apply -DspotlessFiles=$file"] - elif build_tool == JavaBuildTool.GRADLE: + if build_tool == JavaBuildTool.GRADLE: return ["./gradlew spotlessApply"] return ["spotless $file"] if formatter == "other": diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index bb28fe66b..e75d4e125 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -711,18 +711,12 @@ def _add_java_class_members( if not new_fields and not new_methods: return original_source - logger.debug( - f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}" - ) + logger.debug(f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}") # Import the insertion function from replacement module from codeflash.languages.java.replacement import _insert_class_members - result = _insert_class_members( - original_source, class_name, new_fields, new_methods, analyzer - ) - - return result + return _insert_class_members(original_source, class_name, new_fields, new_methods, analyzer) except Exception as e: logger.debug(f"Error adding Java class members: {e}") @@ -959,12 +953,14 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin for file_path_str, code in file_to_code_context.items(): if file_path_str: # Extract filename without creating Path object repeatedly - if file_path_str.endswith(target_filename) and (len(file_path_str) == len(target_filename) or file_path_str[-len(target_filename)-1] in ('/', '\\')): + if file_path_str.endswith(target_filename) and ( + len(file_path_str) == len(target_filename) + or file_path_str[-len(target_filename) - 1] in ("/", "\\") + ): module_optimized_code = code logger.debug(f"Matched {file_path_str} to {relative_path} by filename") break - if module_optimized_code is None: # Also try matching if there's only one code file, but ONLY for non-Python # languages where path matching is less strict. For Python, we require diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 76cb041a1..a0f212e8d 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -721,9 +721,7 @@ def inject_profiling_into_existing_test( if is_java(): from codeflash.languages.java.instrumentation import instrument_existing_test - return instrument_existing_test( - test_path, call_positions, function_to_optimize, tests_project_root, mode.value - ) + return instrument_existing_test(test_path, call_positions, function_to_optimize, tests_project_root, mode.value) if function_to_optimize.is_async: return inject_async_profiling_into_existing_test( diff --git a/codeflash/languages/__init__.py b/codeflash/languages/__init__.py index ffbd9d97f..416849243 100644 --- a/codeflash/languages/__init__.py +++ b/codeflash/languages/__init__.py @@ -36,15 +36,15 @@ reset_current_language, set_current_language, ) + +# Java language support +# Importing the module triggers registration via @register_language decorator +from codeflash.languages.java.support import JavaSupport # noqa: F401 from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401 # Import language support modules to trigger auto-registration # This ensures all supported languages are available when this package is imported from codeflash.languages.python import PythonSupport # noqa: F401 - -# Java language support -# Importing the module triggers registration via @register_language decorator -from codeflash.languages.java.support import JavaSupport # noqa: F401 from codeflash.languages.registry import ( detect_project_language, get_language_support, diff --git a/codeflash/languages/java/build_tools.py b/codeflash/languages/java/build_tools.py index 200555488..5fb962db6 100644 --- a/codeflash/languages/java/build_tools.py +++ b/codeflash/languages/java/build_tools.py @@ -13,7 +13,10 @@ import xml.etree.ElementTree as ET from dataclasses import dataclass from enum import Enum -from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path logger = logging.getLogger(__name__) @@ -29,6 +32,7 @@ def _safe_parse_xml(file_path: Path) -> ET.ElementTree: Raises: ET.ParseError: If XML parsing fails. + """ # Read file content and parse as string to avoid file-based attacks # This prevents XXE attacks by not allowing external entity resolution @@ -38,9 +42,7 @@ def _safe_parse_xml(file_path: Path) -> ET.ElementTree: root = ET.fromstring(content) # Create ElementTree from root - tree = ET.ElementTree(root) - - return tree + return ET.ElementTree(root) class BuildTool(Enum): @@ -390,13 +392,7 @@ def run_maven_tests( try: result = subprocess.run( - cmd, - check=False, - cwd=project_root, - env=run_env, - capture_output=True, - text=True, - timeout=timeout, + cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout ) # Parse test results from Surefire reports @@ -416,7 +412,7 @@ def run_maven_tests( ) except subprocess.TimeoutExpired: - logger.error("Maven test execution timed out after %d seconds", timeout) + logger.exception("Maven test execution timed out after %d seconds", timeout) return MavenTestResult( success=False, tests_run=0, @@ -496,10 +492,7 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]: def compile_maven_project( - project_root: Path, - include_tests: bool = True, - env: dict[str, str] | None = None, - timeout: int = 300, + project_root: Path, include_tests: bool = True, env: dict[str, str] | None = None, timeout: int = 300 ) -> tuple[bool, str, str]: """Compile a Maven project. @@ -533,13 +526,7 @@ def compile_maven_project( try: result = subprocess.run( - cmd, - check=False, - cwd=project_root, - env=run_env, - capture_output=True, - text=True, - timeout=timeout, + cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout ) return result.returncode == 0, result.stdout, result.stderr @@ -581,14 +568,7 @@ def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path) -> boo ] try: - result = subprocess.run( - cmd, - check=False, - cwd=project_root, - capture_output=True, - text=True, - timeout=60, - ) + result = subprocess.run(cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=60) if result.returncode == 0: logger.info("Successfully installed codeflash-runtime to local Maven repository") @@ -664,7 +644,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool: return True except ET.ParseError as e: - logger.error("Failed to parse pom.xml: %s", e) + logger.exception("Failed to parse pom.xml: %s", e) return False except Exception as e: logger.exception("Failed to add dependency to pom.xml: %s", e) @@ -751,11 +731,11 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool: # JaCoCo plugin XML to insert (indented for typical pom.xml format) # Note: For multi-module projects where tests are in a separate module, # we configure the report to look in multiple directories for classes - jacoco_plugin = """ + jacoco_plugin = f""" org.jacoco jacoco-maven-plugin - {version} + {JACOCO_PLUGIN_VERSION} prepare-agent @@ -777,7 +757,7 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool: - """.format(version=JACOCO_PLUGIN_VERSION) + """ # Find the main section (not inside ) # We need to find a that appears after or before @@ -786,7 +766,6 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool: profiles_end = content.find("") # Find all tags - import re # Find the main build section - it's the one NOT inside profiles # Strategy: Look for that comes after or before (or no profiles) @@ -816,7 +795,7 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool: if build_start != -1 and build_end != -1: # Found main build section, find plugins within it - build_section = content[build_start:build_end + len("")] + build_section = content[build_start : build_end + len("")] plugins_start_in_build = build_section.find("") plugins_end_in_build = build_section.rfind("") diff --git a/codeflash/languages/java/comparator.py b/codeflash/languages/java/comparator.py index c30bd2446..75fa7f51f 100644 --- a/codeflash/languages/java/comparator.py +++ b/codeflash/languages/java/comparator.py @@ -47,7 +47,16 @@ def _find_comparator_jar(project_root: Path | None = None) -> Path | None: return jar_path # Check local Maven repository - m2_jar = Path.home() / ".m2" / "repository" / "com" / "codeflash" / "codeflash-runtime" / "1.0.0" / "codeflash-runtime-1.0.0.jar" + m2_jar = ( + Path.home() + / ".m2" + / "repository" + / "com" + / "codeflash" + / "codeflash-runtime" + / "1.0.0" + / "codeflash-runtime-1.0.0.jar" + ) if m2_jar.exists(): return m2_jar @@ -113,8 +122,7 @@ def compare_test_results( jar_path = comparator_jar or _find_comparator_jar(project_root) if not jar_path or not jar_path.exists(): logger.error( - "codeflash-runtime JAR not found. " - "Please ensure the codeflash-runtime is installed in your project." + "codeflash-runtime JAR not found. Please ensure the codeflash-runtime is installed in your project." ) return False, [] @@ -155,10 +163,10 @@ def compare_test_results( comparison = json.loads(result.stdout) except json.JSONDecodeError as e: - logger.error(f"Failed to parse Java comparator output: {e}") - logger.error(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}") + logger.exception(f"Failed to parse Java comparator output: {e}") + logger.exception(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}") if result.stderr: - logger.error(f"stderr: {result.stderr[:500]}") + logger.exception(f"stderr: {result.stderr[:500]}") return False, [] # Check for errors in the JSON response @@ -178,9 +186,7 @@ def compare_test_results( for diff in comparison.get("diffs", []): scope_str = diff.get("scope", "return_value") scope = TestDiffScope.RETURN_VALUE - if scope_str == "exception": - scope = TestDiffScope.DID_PASS - elif scope_str == "missing": + if scope_str in {"exception", "missing"}: scope = TestDiffScope.DID_PASS # Build test identifier @@ -220,20 +226,17 @@ def compare_test_results( return equivalent, test_diffs except subprocess.TimeoutExpired: - logger.error("Java comparator timed out") + logger.exception("Java comparator timed out") return False, [] except FileNotFoundError: - logger.error("Java not found. Please install Java to compare test results.") + logger.exception("Java not found. Please install Java to compare test results.") return False, [] except Exception as e: - logger.error(f"Error running Java comparator: {e}") + logger.exception(f"Error running Java comparator: {e}") return False, [] -def compare_invocations_directly( - original_results: dict, - candidate_results: dict, -) -> tuple[bool, list]: +def compare_invocations_directly(original_results: dict, candidate_results: dict) -> tuple[bool, list]: """Compare test invocations directly from Python dictionaries. This is a fallback when the Java comparator is not available. diff --git a/codeflash/languages/java/config.py b/codeflash/languages/java/config.py index 4d99c6b10..408dcecaf 100644 --- a/codeflash/languages/java/config.py +++ b/codeflash/languages/java/config.py @@ -10,7 +10,6 @@ import logging import xml.etree.ElementTree as ET from dataclasses import dataclass, field -from pathlib import Path from typing import TYPE_CHECKING from codeflash.languages.java.build_tools import ( @@ -22,7 +21,7 @@ ) if TYPE_CHECKING: - pass + from pathlib import Path logger = logging.getLogger(__name__) @@ -80,9 +79,7 @@ def detect_java_project(project_root: Path) -> JavaProjectConfig | None: project_info = get_project_info(project_root) # Detect test framework - test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework( - project_root, build_tool - ) + test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework(project_root, build_tool) # Detect other dependencies has_mockito, has_assertj = _detect_test_dependencies(project_root, build_tool) @@ -120,9 +117,7 @@ def detect_java_project(project_root: Path) -> JavaProjectConfig | None: ) -def _detect_test_framework( - project_root: Path, build_tool: BuildTool -) -> tuple[str, bool, bool, bool]: +def _detect_test_framework(project_root: Path, build_tool: BuildTool) -> tuple[str, bool, bool, bool]: """Detect which test framework the project uses. Args: @@ -210,9 +205,7 @@ def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]: elif tag == "groupId": group_id = child.text - if group_id == "org.junit.jupiter" or ( - artifact_id and "junit-jupiter" in artifact_id - ): + if group_id == "org.junit.jupiter" or (artifact_id and "junit-jupiter" in artifact_id): has_junit5 = True elif group_id == "junit" and artifact_id == "junit": has_junit4 = True @@ -253,9 +246,7 @@ def _detect_test_deps_from_gradle(project_root: Path) -> tuple[bool, bool, bool] return has_junit5, has_junit4, has_testng -def _detect_test_dependencies( - project_root: Path, build_tool: BuildTool -) -> tuple[bool, bool]: +def _detect_test_dependencies(project_root: Path, build_tool: BuildTool) -> tuple[bool, bool]: """Detect additional test dependencies (Mockito, AssertJ). Returns: @@ -289,9 +280,7 @@ def _detect_test_dependencies( return has_mockito, has_assertj -def _get_compiler_settings( - project_root: Path, build_tool: BuildTool -) -> tuple[str | None, str | None]: +def _get_compiler_settings(project_root: Path, build_tool: BuildTool) -> tuple[str | None, str | None]: """Get compiler source and target settings. Returns: @@ -392,11 +381,7 @@ def is_java_project(project_root: Path) -> bool: return True # Check for Java source files - for pattern in ["src/**/*.java", "*.java"]: - if list(project_root.glob(pattern)): - return True - - return False + return any(list(project_root.glob(pattern)) for pattern in ["src/**/*.java", "*.java"]) def get_test_file_pattern(config: JavaProjectConfig) -> str: diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index 2ccfd34bf..a2c7f7c0e 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -8,26 +8,27 @@ from __future__ import annotations import logging -from pathlib import Path from typing import TYPE_CHECKING -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import CodeContext, HelperFunction, Language from codeflash.languages.java.discovery import discover_functions_from_source -from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files -from codeflash.languages.java.parser import JavaAnalyzer, JavaClassNode, get_java_analyzer +from codeflash.languages.java.import_resolver import find_helper_files +from codeflash.languages.java.parser import get_java_analyzer if TYPE_CHECKING: + from pathlib import Path + from tree_sitter import Node + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.java.parser import JavaAnalyzer + logger = logging.getLogger(__name__) class InvalidJavaSyntaxError(Exception): """Raised when extracted Java code is not syntactically valid.""" - pass - def extract_code_context( function: FunctionToOptimize, @@ -67,12 +68,8 @@ def extract_code_context( try: source = function.file_path.read_text(encoding="utf-8") except Exception as e: - logger.error("Failed to read %s: %s", function.file_path, e) - return CodeContext( - target_code="", - target_file=function.file_path, - language=Language.JAVA, - ) + logger.exception("Failed to read %s: %s", function.file_path, e) + return CodeContext(target_code="", target_file=function.file_path, language=Language.JAVA) # Extract target function code target_code = extract_function_source(source, function) @@ -94,9 +91,7 @@ def extract_code_context( import_statements = [_import_to_statement(imp) for imp in imports] # Extract helper functions - helper_functions = find_helper_functions( - function, project_root, max_helper_depth, analyzer - ) + helper_functions = find_helper_functions(function, project_root, max_helper_depth, analyzer) # Extract read-only context only if fields are NOT already in the skeleton # Avoid duplication between target_code and read_only_context @@ -107,9 +102,8 @@ def extract_code_context( # Validate syntax - extracted code must always be valid Java if validate_syntax and target_code: if not analyzer.validate_syntax(target_code): - raise InvalidJavaSyntaxError( - f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}" - ) + msg = f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}" + raise InvalidJavaSyntaxError(msg) return CodeContext( target_code=target_code, @@ -156,7 +150,7 @@ def __init__( enum_constants: str, type_indent: str, type_kind: str, # "class", "interface", or "enum" - outer_type_skeleton: "TypeSkeleton | None" = None, + outer_type_skeleton: TypeSkeleton | None = None, ) -> None: self.type_declaration = type_declaration self.type_javadoc = type_javadoc @@ -173,10 +167,7 @@ def __init__( def _extract_type_skeleton( - source: str, - type_name: str, - target_method_name: str, - analyzer: JavaAnalyzer, + source: str, type_name: str, target_method_name: str, analyzer: JavaAnalyzer ) -> TypeSkeleton | None: """Extract the type skeleton (class, interface, or enum) for wrapping a method. @@ -254,11 +245,7 @@ def _find_type_node(node: Node, type_name: str, source_bytes: bytes) -> tuple[No Tuple of (node, type_kind) where type_kind is "class", "interface", or "enum". """ - type_declarations = { - "class_declaration": "class", - "interface_declaration": "interface", - "enum_declaration": "enum", - } + type_declarations = {"class_declaration": "class", "interface_declaration": "interface", "enum_declaration": "enum"} if node.type in type_declarations: name_node = node.child_by_field_name("name") @@ -283,11 +270,7 @@ def _find_class_node(node: Node, class_name: str, source_bytes: bytes) -> Node | def _get_outer_type_skeleton( - inner_type_node: Node, - source_bytes: bytes, - lines: list[str], - target_method_name: str, - analyzer: JavaAnalyzer, + inner_type_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str, analyzer: JavaAnalyzer ) -> TypeSkeleton | None: """Get the outer type skeleton if this is an inner type. @@ -356,11 +339,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s parts: list[str] = [] # Determine which body node type to look for - body_types = { - "class": "class_body", - "interface": "interface_body", - "enum": "enum_body", - } + body_types = {"class": "class_body", "interface": "interface_body", "enum": "enum_body"} body_type = body_types.get(type_kind, "class_body") for child in type_node.children: @@ -374,7 +353,8 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s # Keep old function name for backwards compatibility -_extract_class_declaration = lambda node, source_bytes: _extract_type_declaration(node, source_bytes, "class") +def _extract_class_declaration(node, source_bytes): + return _extract_type_declaration(node, source_bytes, "class") def _find_javadoc(node: Node, source_bytes: bytes) -> str | None: @@ -390,11 +370,7 @@ def _find_javadoc(node: Node, source_bytes: bytes) -> str | None: def _extract_type_body_context( - body_node: Node, - source_bytes: bytes, - lines: list[str], - target_method_name: str, - type_kind: str, + body_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str, type_kind: str ) -> tuple[str, str, str]: """Extract fields, constructors, and enum constants from a type body. @@ -473,15 +449,10 @@ def _extract_type_body_context( # Keep old function name for backwards compatibility def _extract_class_body_context( - body_node: Node, - source_bytes: bytes, - lines: list[str], - target_method_name: str, + body_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str ) -> tuple[str, str]: """Extract fields and constructors from a class body.""" - fields, constructors, _ = _extract_type_body_context( - body_node, source_bytes, lines, target_method_name, "class" - ) + fields, constructors, _ = _extract_type_body_context(body_node, source_bytes, lines, target_method_name, "class") return (fields, constructors) @@ -584,10 +555,7 @@ def extract_function_source(source: str, function: FunctionToOptimize) -> str: def find_helper_functions( - function: FunctionToOptimize, - project_root: Path, - max_depth: int = 2, - analyzer: JavaAnalyzer | None = None, + function: FunctionToOptimize, project_root: Path, max_depth: int = 2, analyzer: JavaAnalyzer | None = None ) -> list[HelperFunction]: """Find helper functions that the target function depends on. @@ -606,11 +574,9 @@ def find_helper_functions( visited_functions: set[str] = set() # Find helper files through imports - helper_files = find_helper_files( - function.file_path, project_root, max_depth, analyzer - ) + helper_files = find_helper_files(function.file_path, project_root, max_depth, analyzer) - for file_path, class_names in helper_files.items(): + for file_path in helper_files: try: source = file_path.read_text(encoding="utf-8") file_functions = discover_functions_from_source(source, file_path, analyzer=analyzer) @@ -648,10 +614,7 @@ def find_helper_functions( return helpers -def _find_same_class_helpers( - function: FunctionToOptimize, - analyzer: JavaAnalyzer, -) -> list[HelperFunction]: +def _find_same_class_helpers(function: FunctionToOptimize, analyzer: JavaAnalyzer) -> list[HelperFunction]: """Find helper methods in the same class as the target function. Args: @@ -694,9 +657,7 @@ def _find_same_class_helpers( and method.class_name == function.class_name and method.name in called_methods ): - func_source = source_bytes[ - method.node.start_byte : method.node.end_byte - ].decode("utf8") + func_source = source_bytes[method.node.start_byte : method.node.end_byte].decode("utf8") helpers.append( HelperFunction( @@ -715,11 +676,7 @@ def _find_same_class_helpers( return helpers -def extract_read_only_context( - source: str, - function: FunctionToOptimize, - analyzer: JavaAnalyzer, -) -> str: +def extract_read_only_context(source: str, function: FunctionToOptimize, analyzer: JavaAnalyzer) -> str: """Extract read-only context (fields, constants, inner classes). This extracts class-level context that the function might depend on @@ -767,11 +724,7 @@ def _import_to_statement(import_info) -> str: return f"{prefix}{import_info.import_path}{suffix};" -def extract_class_context( - file_path: Path, - class_name: str, - analyzer: JavaAnalyzer | None = None, -) -> str: +def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyzer | None = None) -> str: """Extract the full context of a class. Args: @@ -813,5 +766,5 @@ def extract_class_context( return package_stmt + "\n".join(import_statements) + "\n\n" + class_source except Exception as e: - logger.error("Failed to extract class context: %s", e) + logger.exception("Failed to extract class context: %s", e) return "" diff --git a/codeflash/languages/java/discovery.py b/codeflash/languages/java/discovery.py index 902feca67..2d8f0b3ea 100644 --- a/codeflash/languages/java/discovery.py +++ b/codeflash/languages/java/discovery.py @@ -12,19 +12,17 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import FunctionFilterCriteria -from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer +from codeflash.languages.java.parser import get_java_analyzer from codeflash.models.function_types import FunctionParent if TYPE_CHECKING: - pass + from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode logger = logging.getLogger(__name__) def discover_functions( - file_path: Path, - filter_criteria: FunctionFilterCriteria | None = None, - analyzer: JavaAnalyzer | None = None, + file_path: Path, filter_criteria: FunctionFilterCriteria | None = None, analyzer: JavaAnalyzer | None = None ) -> list[FunctionToOptimize]: """Find all optimizable functions/methods in a Java file. @@ -115,10 +113,7 @@ def discover_functions_from_source( def _should_include_method( - method: JavaMethodNode, - criteria: FunctionFilterCriteria, - source: str, - analyzer: JavaAnalyzer, + method: JavaMethodNode, criteria: FunctionFilterCriteria, source: str, analyzer: JavaAnalyzer ) -> bool: """Check if a method should be included based on filter criteria. @@ -176,10 +171,7 @@ def _should_include_method( return True -def discover_test_methods( - file_path: Path, - analyzer: JavaAnalyzer | None = None, -) -> list[FunctionToOptimize]: +def discover_test_methods(file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[FunctionToOptimize]: """Find all JUnit test methods in a Java test file. Looks for methods annotated with @Test, @ParameterizedTest, @RepeatedTest, etc. @@ -232,7 +224,7 @@ def _walk_tree_for_test_methods( for child in node.children: if child.type == "modifiers": for mod_child in child.children: - if mod_child.type == "marker_annotation" or mod_child.type == "annotation": + if mod_child.type in {"marker_annotation", "annotation"}: annotation_text = analyzer.get_node_text(mod_child, source_bytes) # Check for JUnit 5 test annotations if any( @@ -278,10 +270,7 @@ def _walk_tree_for_test_methods( def get_method_by_name( - file_path: Path, - method_name: str, - class_name: str | None = None, - analyzer: JavaAnalyzer | None = None, + file_path: Path, method_name: str, class_name: str | None = None, analyzer: JavaAnalyzer | None = None ) -> FunctionToOptimize | None: """Find a specific method by name in a Java file. @@ -306,9 +295,7 @@ def get_method_by_name( def get_class_methods( - file_path: Path, - class_name: str, - analyzer: JavaAnalyzer | None = None, + file_path: Path, class_name: str, analyzer: JavaAnalyzer | None = None ) -> list[FunctionToOptimize]: """Get all methods in a specific class. diff --git a/codeflash/languages/java/formatter.py b/codeflash/languages/java/formatter.py index a9ccd2d8d..2bb228ca2 100644 --- a/codeflash/languages/java/formatter.py +++ b/codeflash/languages/java/formatter.py @@ -6,16 +6,13 @@ from __future__ import annotations +import contextlib import logging import os import shutil import subprocess import tempfile from pathlib import Path -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - pass logger = logging.getLogger(__name__) @@ -29,7 +26,7 @@ class JavaFormatter: # Version of google-java-format to use GOOGLE_JAVA_FORMAT_VERSION = "1.19.2" - def __init__(self, project_root: Path | None = None): + def __init__(self, project_root: Path | None = None) -> None: """Initialize the Java formatter. Args: @@ -107,21 +104,13 @@ def _format_with_google_java_format(self, source: str) -> str | None: try: # Write source to temp file - with tempfile.NamedTemporaryFile( - mode="w", suffix=".java", delete=False, encoding="utf-8" - ) as tmp: + with tempfile.NamedTemporaryFile(mode="w", suffix=".java", delete=False, encoding="utf-8") as tmp: tmp.write(source) tmp_path = tmp.name try: result = subprocess.run( - [ - self._java_executable, - "-jar", - str(jar_path), - "--replace", - tmp_path, - ], + [self._java_executable, "-jar", str(jar_path), "--replace", tmp_path], check=False, capture_output=True, text=True, @@ -133,16 +122,12 @@ def _format_with_google_java_format(self, source: str) -> str | None: with open(tmp_path, encoding="utf-8") as f: return f.read() else: - logger.debug( - "google-java-format failed: %s", result.stderr or result.stdout - ) + logger.debug("google-java-format failed: %s", result.stderr or result.stdout) finally: # Clean up temp file - try: + with contextlib.suppress(OSError): os.unlink(tmp_path) - except OSError: - pass except subprocess.TimeoutExpired: logger.warning("google-java-format timed out") @@ -169,9 +154,7 @@ def _get_google_java_format_jar(self) -> Path | None: if self.project_root else None, # In user's home directory - Path.home() - / ".codeflash" - / f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar", + Path.home() / ".codeflash" / f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar", # In system temp Path(tempfile.gettempdir()) / "codeflash" @@ -186,8 +169,7 @@ def _get_google_java_format_jar(self) -> Path | None: # Don't auto-download to avoid surprises # Users can manually download the JAR logger.debug( - "google-java-format JAR not found. " - "Download from https://github.com/google/google-java-format/releases" + "google-java-format JAR not found. Download from https://github.com/google/google-java-format/releases" ) return None @@ -239,7 +221,7 @@ def download_google_java_format(self, target_dir: Path | None = None) -> Path | logger.info("Downloaded google-java-format to %s", jar_path) return jar_path except Exception as e: - logger.error("Failed to download google-java-format: %s", e) + logger.exception("Failed to download google-java-format: %s", e) return None diff --git a/codeflash/languages/java/import_resolver.py b/codeflash/languages/java/import_resolver.py index 5ab8800ed..766434a94 100644 --- a/codeflash/languages/java/import_resolver.py +++ b/codeflash/languages/java/import_resolver.py @@ -8,14 +8,15 @@ import logging from dataclasses import dataclass -from pathlib import Path from typing import TYPE_CHECKING from codeflash.languages.java.build_tools import find_source_root, find_test_root, get_project_info -from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, get_java_analyzer +from codeflash.languages.java.parser import get_java_analyzer if TYPE_CHECKING: - pass + from pathlib import Path + + from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo logger = logging.getLogger(__name__) @@ -35,18 +36,7 @@ class JavaImportResolver: """Resolves Java imports to file paths within a project.""" # Standard Java packages that are always external - STANDARD_PACKAGES = frozenset( - [ - "java", - "javax", - "sun", - "com.sun", - "jdk", - "org.w3c", - "org.xml", - "org.ietf", - ] - ) + STANDARD_PACKAGES = frozenset(["java", "javax", "sun", "com.sun", "jdk", "org.w3c", "org.xml", "org.ietf"]) # Common third-party package prefixes COMMON_EXTERNAL_PREFIXES = frozenset( @@ -66,7 +56,7 @@ class JavaImportResolver: ] ) - def __init__(self, project_root: Path): + def __init__(self, project_root: Path) -> None: """Initialize the import resolver. Args: @@ -156,10 +146,7 @@ def resolve_imports(self, imports: list[JavaImportInfo]) -> list[ResolvedImport] def _is_standard_library(self, import_path: str) -> bool: """Check if an import is from the Java standard library.""" - for prefix in self.STANDARD_PACKAGES: - if import_path.startswith(prefix + ".") or import_path == prefix: - return True - return False + return any(import_path.startswith(prefix + ".") or import_path == prefix for prefix in self.STANDARD_PACKAGES) def _is_external_library(self, import_path: str) -> bool: """Check if an import is from a known external library.""" @@ -249,9 +236,7 @@ def find_class_file(self, class_name: str, package_hint: str | None = None) -> P return None - def get_imports_from_file( - self, file_path: Path, analyzer: JavaAnalyzer | None = None - ) -> list[ResolvedImport]: + def get_imports_from_file(self, file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[ResolvedImport]: """Get and resolve all imports from a Java file. Args: @@ -272,9 +257,7 @@ def get_imports_from_file( logger.warning("Failed to get imports from %s: %s", file_path, e) return [] - def get_project_imports( - self, file_path: Path, analyzer: JavaAnalyzer | None = None - ) -> list[ResolvedImport]: + def get_project_imports(self, file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[ResolvedImport]: """Get only the imports that resolve to files within the project. Args: @@ -308,10 +291,7 @@ def resolve_imports_for_file( def find_helper_files( - file_path: Path, - project_root: Path, - max_depth: int = 2, - analyzer: JavaAnalyzer | None = None, + file_path: Path, project_root: Path, max_depth: int = 2, analyzer: JavaAnalyzer | None = None ) -> dict[Path, list[str]]: """Find helper files imported by a Java file, recursively. diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index b42ea7d94..78cd77d3a 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -17,16 +17,16 @@ import logging import re from functools import lru_cache -from pathlib import Path from typing import TYPE_CHECKING -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.languages.java.parser import JavaAnalyzer - if TYPE_CHECKING: from collections.abc import Sequence + from pathlib import Path from typing import Any + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.java.parser import JavaAnalyzer + logger = logging.getLogger(__name__) @@ -36,7 +36,8 @@ def _get_function_name(func: Any) -> str: return func.function_name if hasattr(func, "name"): return func.name - raise AttributeError(f"Cannot get function name from {type(func)}") + msg = f"Cannot get function name from {type(func)}" + raise AttributeError(msg) def _get_qualified_name(func: Any) -> str: @@ -135,7 +136,7 @@ def instrument_existing_test( try: source = test_path.read_text(encoding="utf-8") except Exception as e: - logger.error("Failed to read test file %s: %s", test_path, e) + logger.exception("Failed to read test file %s: %s", test_path, e) return False, f"Failed to read test file: {e}" func_name = _get_function_name(function_to_optimize) @@ -227,7 +228,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) result.append(imp) imports_added = True continue - if stripped.startswith("public class") or stripped.startswith("class"): + if stripped.startswith(("public class", "class")): # No imports found, add before class for imp in import_statements: result.append(imp) @@ -244,7 +245,6 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) i = 0 iteration_counter = 0 - # Pre-compile the regex pattern once method_call_pattern = _get_method_call_pattern(func_name) @@ -291,11 +291,10 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) while i < len(lines) and brace_depth > 0: body_line = lines[i] # Count braces more efficiently using string methods - open_count = body_line.count('{') - close_count = body_line.count('}') + open_count = body_line.count("{") + close_count = body_line.count("}") brace_depth += open_count - close_count - if brace_depth > 0: body_lines.append(body_line) i += 1 @@ -581,7 +580,7 @@ def create_benchmark_test( method_id = _get_qualified_name(target_function) class_name = getattr(target_function, "class_name", None) or "Target" - benchmark_code = f""" + return f""" import org.junit.jupiter.api.Test; import org.junit.jupiter.api.DisplayName; @@ -615,7 +614,6 @@ def create_benchmark_test( }} }} """ - return benchmark_code def remove_instrumentation(source: str) -> str: @@ -713,7 +711,7 @@ def _add_import(source: str, import_statement: str) -> str: # Find the last import or package statement for i, line in enumerate(lines): stripped = line.strip() - if stripped.startswith("import ") or stripped.startswith("package "): + if stripped.startswith(("import ", "package ")): insert_idx = i + 1 elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"): # First non-import, non-comment line @@ -725,13 +723,11 @@ def _add_import(source: str, import_statement: str) -> str: return "".join(lines) - @lru_cache(maxsize=128) def _get_method_call_pattern(func_name: str): """Cache compiled regex patterns for method call matching.""" return re.compile( - rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", - re.MULTILINE + rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE ) @@ -739,6 +735,5 @@ def _get_method_call_pattern(func_name: str): def _get_method_call_pattern(func_name: str): """Cache compiled regex patterns for method call matching.""" return re.compile( - rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", - re.MULTILINE + rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE ) diff --git a/codeflash/languages/java/parser.py b/codeflash/languages/java/parser.py index bdffac44e..72a530179 100644 --- a/codeflash/languages/java/parser.py +++ b/codeflash/languages/java/parser.py @@ -13,8 +13,6 @@ from tree_sitter import Language, Parser if TYPE_CHECKING: - from pathlib import Path - from tree_sitter import Node, Tree logger = logging.getLogger(__name__) @@ -222,9 +220,7 @@ def _walk_tree_for_methods( current_class=new_class if node.type in type_declarations else current_class, ) - def _extract_method_info( - self, node: Node, source_bytes: bytes, current_class: str | None - ) -> JavaMethodNode | None: + def _extract_method_info(self, node: Node, source_bytes: bytes, current_class: str | None) -> JavaMethodNode | None: """Extract method information from a method_declaration node.""" name = "" is_static = False @@ -347,9 +343,7 @@ def _walk_tree_for_classes( for child in node.children: self._walk_tree_for_classes(child, source_bytes, classes, is_inner) - def _extract_class_info( - self, node: Node, source_bytes: bytes, is_inner: bool - ) -> JavaClassNode | None: + def _extract_class_info(self, node: Node, source_bytes: bytes, is_inner: bool) -> JavaClassNode | None: """Extract class information from a class_declaration node.""" name = "" is_public = False diff --git a/codeflash/languages/java/replacement.py b/codeflash/languages/java/replacement.py index 75a9a78e7..92ddd44e2 100644 --- a/codeflash/languages/java/replacement.py +++ b/codeflash/languages/java/replacement.py @@ -18,10 +18,10 @@ from typing import TYPE_CHECKING from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer +from codeflash.languages.java.parser import get_java_analyzer if TYPE_CHECKING: - pass + from codeflash.languages.java.parser import JavaAnalyzer logger = logging.getLogger(__name__) @@ -35,11 +35,7 @@ class ParsedOptimization: new_helper_methods: list[str] # Source text of new helper methods to add -def _parse_optimization_source( - new_source: str, - target_method_name: str, - analyzer: JavaAnalyzer, -) -> ParsedOptimization: +def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization: """Parse optimization source to extract method and additional class members. The new_source may contain: @@ -96,18 +92,12 @@ def _parse_optimization_source( new_fields.append(field.source_text) return ParsedOptimization( - target_method_source=target_method_source, - new_fields=new_fields, - new_helper_methods=new_helper_methods, + target_method_source=target_method_source, new_fields=new_fields, new_helper_methods=new_helper_methods ) def _insert_class_members( - source: str, - class_name: str, - fields: list[str], - methods: list[str], - analyzer: JavaAnalyzer, + source: str, class_name: str, fields: list[str], methods: list[str], analyzer: JavaAnalyzer ) -> str: """Insert new class members (fields and methods) into a class. @@ -212,10 +202,7 @@ def _insert_class_members( def replace_function( - source: str, - function: FunctionToOptimize, - new_source: str, - analyzer: JavaAnalyzer | None = None, + source: str, function: FunctionToOptimize, new_source: str, analyzer: JavaAnalyzer | None = None ) -> str: """Replace a function in source code with new implementation. @@ -257,9 +244,9 @@ def replace_function( # Find all methods matching the name (there may be overloads) matching_methods = [ - m for m in methods - if m.name == func_name - and (function.class_name is None or m.class_name == function.class_name) + m + for m in methods + if m.name == func_name and (function.class_name is None or m.class_name == function.class_name) ] if len(matching_methods) == 1: @@ -296,10 +283,7 @@ def replace_function( break if not target_method: # Fallback: use the first match - logger.warning( - "Multiple overloads of %s found but no line match, using first match", - func_name, - ) + logger.warning("Multiple overloads of %s found but no line match, using first match", func_name) target_method = matching_methods[0] target_overload_index = 0 @@ -342,18 +326,16 @@ def replace_function( len(new_helpers_to_add), class_name, ) - source = _insert_class_members( - source, class_name, new_fields_to_add, new_helpers_to_add, analyzer - ) + source = _insert_class_members(source, class_name, new_fields_to_add, new_helpers_to_add, analyzer) # Re-find the target method after modifications # Line numbers have shifted, but the relative order of overloads is preserved # Use the target_overload_index we saved earlier methods = analyzer.find_methods(source) matching_methods = [ - m for m in methods - if m.name == func_name - and (function.class_name is None or m.class_name == function.class_name) + m + for m in methods + if m.name == func_name and (function.class_name is None or m.class_name == function.class_name) ] if matching_methods and target_overload_index < len(matching_methods): @@ -398,9 +380,7 @@ def replace_function( before = lines[: start_line - 1] # Lines before the method after = lines[end_line:] # Lines after the method - result = "".join(before) + indented_new_source + "".join(after) - - return result + return "".join(before) + indented_new_source + "".join(after) def _get_indentation(line: str) -> str: @@ -460,10 +440,7 @@ def _apply_indentation(lines: list[str], base_indent: str) -> str: def replace_method_body( - source: str, - function: FunctionToOptimize, - new_body: str, - analyzer: JavaAnalyzer | None = None, + source: str, function: FunctionToOptimize, new_body: str, analyzer: JavaAnalyzer | None = None ) -> str: """Replace just the body of a method, preserving signature. @@ -600,11 +577,7 @@ def insert_method( return (before + separator.encode("utf8") + indented_method.encode("utf8") + after).decode("utf8") -def remove_method( - source: str, - function: FunctionToOptimize, - analyzer: JavaAnalyzer | None = None, -) -> str: +def remove_method(source: str, function: FunctionToOptimize, analyzer: JavaAnalyzer | None = None) -> str: """Remove a method from source code. Args: @@ -648,9 +621,7 @@ def remove_method( def remove_test_functions( - test_source: str, - functions_to_remove: list[str], - analyzer: JavaAnalyzer | None = None, + test_source: str, functions_to_remove: list[str], analyzer: JavaAnalyzer | None = None ) -> str: """Remove specific test functions from test source code. @@ -669,9 +640,7 @@ def remove_test_functions( methods = analyzer.find_methods(test_source) # Sort by start line in reverse order (remove from end first) - methods_to_remove = [ - m for m in methods if m.name in functions_to_remove - ] + methods_to_remove = [m for m in methods if m.name in functions_to_remove] methods_to_remove.sort(key=lambda m: m.start_line, reverse=True) result = test_source @@ -728,9 +697,7 @@ def add_runtime_comments( if original_ns > 0: speedup = ((original_ns - optimized_ns) / original_ns) * 100 - summary_lines.append( - f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)" - ) + summary_lines.append(f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)") # Insert after imports lines = test_source.splitlines(keepends=True) diff --git a/codeflash/languages/java/support.py b/codeflash/languages/java/support.py index 6fb015cd2..ed1bb339c 100644 --- a/codeflash/languages/java/support.py +++ b/codeflash/languages/java/support.py @@ -7,20 +7,9 @@ from __future__ import annotations import logging -from pathlib import Path from typing import TYPE_CHECKING, Any -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.languages.base import ( - CodeContext, - FunctionFilterCriteria, - HelperFunction, - Language, - LanguageSupport, - TestInfo, - TestResult, -) -from codeflash.languages.registry import register_language +from codeflash.languages.base import Language, LanguageSupport from codeflash.languages.java.build_tools import find_test_root from codeflash.languages.java.comparator import compare_test_results as _compare_test_results from codeflash.languages.java.config import detect_java_project @@ -33,11 +22,7 @@ instrument_for_benchmarking, ) from codeflash.languages.java.parser import get_java_analyzer -from codeflash.languages.java.replacement import ( - add_runtime_comments, - remove_test_functions, - replace_function, -) +from codeflash.languages.java.replacement import add_runtime_comments, remove_test_functions, replace_function from codeflash.languages.java.test_discovery import discover_tests from codeflash.languages.java.test_runner import ( parse_test_results, @@ -45,9 +30,14 @@ run_benchmarking_tests, run_tests, ) +from codeflash.languages.registry import register_language if TYPE_CHECKING: from collections.abc import Sequence + from pathlib import Path + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, TestInfo, TestResult logger = logging.getLogger(__name__) @@ -112,23 +102,17 @@ def discover_tests( # === Code Analysis === - def extract_code_context( - self, function: FunctionToOptimize, project_root: Path, module_root: Path - ) -> CodeContext: + def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext: """Extract function code and its dependencies.""" return extract_code_context(function, project_root, module_root, analyzer=self._analyzer) - def find_helper_functions( - self, function: FunctionToOptimize, project_root: Path - ) -> list[HelperFunction]: + def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]: """Find helper functions called by the target function.""" return find_helper_functions(function, project_root, analyzer=self._analyzer) # === Code Transformation === - def replace_function( - self, source: str, function: FunctionToOptimize, new_source: str - ) -> str: + def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str: """Replace a function in source code with new implementation.""" return replace_function(source, function, new_source, self._analyzer) @@ -140,11 +124,7 @@ def format_code(self, source: str, file_path: Path | None = None) -> str: # === Test Execution === def run_tests( - self, - test_files: Sequence[Path], - cwd: Path, - env: dict[str, str], - timeout: int, + self, test_files: Sequence[Path], cwd: Path, env: dict[str, str], timeout: int ) -> tuple[list[TestResult], Path]: """Run tests and return results.""" return run_tests(list(test_files), cwd, env, timeout) @@ -155,15 +135,11 @@ def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResu # === Instrumentation === - def instrument_for_behavior( - self, source: str, functions: Sequence[FunctionToOptimize] - ) -> str: + def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str: """Add behavior instrumentation to capture inputs/outputs.""" return instrument_for_behavior(source, functions, self._analyzer) - def instrument_for_benchmarking( - self, test_source: str, target_function: FunctionToOptimize - ) -> str: + def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str: """Add timing instrumentation to test code.""" return instrument_for_benchmarking(test_source, target_function, self._analyzer) @@ -180,32 +156,22 @@ def normalize_code(self, source: str) -> str: # === Test Editing === def add_runtime_comments( - self, - test_source: str, - original_runtimes: dict[str, int], - optimized_runtimes: dict[str, int], + self, test_source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int] ) -> str: """Add runtime performance comments to test source code.""" return add_runtime_comments(test_source, original_runtimes, optimized_runtimes, self._analyzer) - def remove_test_functions( - self, test_source: str, functions_to_remove: list[str] - ) -> str: + def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str: """Remove specific test functions from test source code.""" return remove_test_functions(test_source, functions_to_remove, self._analyzer) # === Test Result Comparison === def compare_test_results( - self, - original_results_path: Path, - candidate_results_path: Path, - project_root: Path | None = None, + self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None ) -> tuple[bool, list]: """Compare test results between original and candidate code.""" - return _compare_test_results( - original_results_path, candidate_results_path, project_root=project_root - ) + return _compare_test_results(original_results_path, candidate_results_path, project_root=project_root) # === Configuration === @@ -308,12 +274,7 @@ def instrument_existing_test( ) -> tuple[bool, str | None]: """Inject profiling code into an existing test file.""" return instrument_existing_test( - test_path, - call_positions, - function_to_optimize, - tests_project_root, - mode, - self._analyzer, + test_path, call_positions, function_to_optimize, tests_project_root, mode, self._analyzer ) def instrument_source_for_line_profiler( @@ -339,15 +300,7 @@ def run_behavioral_tests( candidate_index: int = 0, ) -> tuple[Path, Any, Path | None, Path | None]: """Run behavioral tests for Java.""" - return run_behavioral_tests( - test_paths, - test_env, - cwd, - timeout, - project_root, - enable_coverage, - candidate_index, - ) + return run_behavioral_tests(test_paths, test_env, cwd, timeout, project_root, enable_coverage, candidate_index) def run_benchmarking_tests( self, diff --git a/codeflash/languages/java/test_discovery.py b/codeflash/languages/java/test_discovery.py index aef25a8cb..67c11316b 100644 --- a/codeflash/languages/java/test_discovery.py +++ b/codeflash/languages/java/test_discovery.py @@ -7,27 +7,26 @@ from __future__ import annotations import logging -import re from collections import defaultdict -from pathlib import Path from typing import TYPE_CHECKING -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import TestInfo from codeflash.languages.java.config import detect_java_project from codeflash.languages.java.discovery import discover_test_methods -from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer +from codeflash.languages.java.parser import get_java_analyzer if TYPE_CHECKING: from collections.abc import Sequence + from pathlib import Path + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.languages.java.parser import JavaAnalyzer logger = logging.getLogger(__name__) def discover_tests( - test_root: Path, - source_functions: Sequence[FunctionToOptimize], - analyzer: JavaAnalyzer | None = None, + test_root: Path, source_functions: Sequence[FunctionToOptimize], analyzer: JavaAnalyzer | None = None ) -> dict[str, list[TestInfo]]: """Map source functions to their tests via static analysis. @@ -56,9 +55,7 @@ def discover_tests( # Find all test files (various naming conventions) test_files = ( - list(test_root.rglob("*Test.java")) - + list(test_root.rglob("*Tests.java")) - + list(test_root.rglob("Test*.java")) + list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java")) ) # Result map @@ -71,16 +68,12 @@ def discover_tests( for test_method in test_methods: # Find which source functions this test might exercise - matched_functions = _match_test_to_functions( - test_method, source, function_map, analyzer - ) + matched_functions = _match_test_to_functions(test_method, source, function_map, analyzer) for func_name in matched_functions: result[func_name].append( TestInfo( - test_name=test_method.function_name, - test_file=test_file, - test_class=test_method.class_name, + test_name=test_method.function_name, test_file=test_file, test_class=test_method.class_name ) ) @@ -114,7 +107,7 @@ def _match_test_to_functions( # e.g., testAdd -> add, testCalculatorAdd -> Calculator.add test_name_lower = test_method.function_name.lower() - for func_name, func_info in function_map.items(): + for func_info in function_map.values(): if func_info.function_name.lower() in test_name_lower: matched.append(func_info.qualified_name) @@ -125,11 +118,7 @@ def _match_test_to_functions( # Find method calls within the test method's line range method_calls = _find_method_calls_in_range( - tree.root_node, - source_bytes, - test_method.starting_line, - test_method.ending_line, - analyzer, + tree.root_node, source_bytes, test_method.starting_line, test_method.ending_line, analyzer ) for call_name in method_calls: @@ -151,7 +140,7 @@ def _match_test_to_functions( source_class_name = source_class_name[4:] # Look for functions in the matching class - for func_name, func_info in function_map.items(): + for func_info in function_map.values(): if func_info.class_name == source_class_name: if func_info.qualified_name not in matched: matched.append(func_info.qualified_name) @@ -161,7 +150,7 @@ def _match_test_to_functions( # This handles cases like TestQueryBlob importing Buffer and calling Buffer methods imported_classes = _extract_imports(tree.root_node, source_bytes, analyzer) - for func_name, func_info in function_map.items(): + for func_info in function_map.values(): if func_info.qualified_name in matched: continue @@ -172,11 +161,7 @@ def _match_test_to_functions( return matched -def _extract_imports( - node, - source_bytes: bytes, - analyzer: JavaAnalyzer, -) -> set[str]: +def _extract_imports(node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[str]: """Extract imported class names from a Java file. Args: @@ -224,7 +209,7 @@ def visit(n): # Regular import: extract class name from scoped_identifier for child in n.children: - if child.type == "scoped_identifier" or child.type == "identifier": + if child.type in {"scoped_identifier", "identifier"}: import_path = analyzer.get_node_text(child, source_bytes) # Extract just the class name (last part) # e.g., "com.example.Buffer" -> "Buffer" @@ -244,11 +229,7 @@ def visit(n): def _find_method_calls_in_range( - node, - source_bytes: bytes, - start_line: int, - end_line: int, - analyzer: JavaAnalyzer, + node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer ) -> list[str]: """Find method calls within a line range. @@ -278,17 +259,13 @@ def _find_method_calls_in_range( calls.append(analyzer.get_node_text(name_node, source_bytes)) for child in node.children: - calls.extend( - _find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer) - ) + calls.extend(_find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer)) return calls def find_tests_for_function( - function: FunctionToOptimize, - test_root: Path, - analyzer: JavaAnalyzer | None = None, + function: FunctionToOptimize, test_root: Path, analyzer: JavaAnalyzer | None = None ) -> list[TestInfo]: """Find tests that exercise a specific function. @@ -305,10 +282,7 @@ def find_tests_for_function( return result.get(function.qualified_name, []) -def get_test_class_for_source_class( - source_class_name: str, - test_root: Path, -) -> Path | None: +def get_test_class_for_source_class(source_class_name: str, test_root: Path) -> Path | None: """Find the test class file for a source class. Args: @@ -320,11 +294,7 @@ def get_test_class_for_source_class( """ # Try common naming patterns - patterns = [ - f"{source_class_name}Test.java", - f"Test{source_class_name}.java", - f"{source_class_name}Tests.java", - ] + patterns = [f"{source_class_name}Test.java", f"Test{source_class_name}.java", f"{source_class_name}Tests.java"] for pattern in patterns: matches = list(test_root.rglob(pattern)) @@ -334,10 +304,7 @@ def get_test_class_for_source_class( return None -def discover_all_tests( - test_root: Path, - analyzer: JavaAnalyzer | None = None, -) -> list[FunctionToOptimize]: +def discover_all_tests(test_root: Path, analyzer: JavaAnalyzer | None = None) -> list[FunctionToOptimize]: """Discover all test methods in a test directory. Args: @@ -353,9 +320,7 @@ def discover_all_tests( # Find all test files (various naming conventions) test_files = ( - list(test_root.rglob("*Test.java")) - + list(test_root.rglob("*Tests.java")) - + list(test_root.rglob("Test*.java")) + list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java")) ) for test_file in test_files: @@ -391,24 +356,18 @@ def is_test_file(file_path: Path) -> bool: name = file_path.name # Check naming patterns - if name.endswith("Test.java") or name.endswith("Tests.java"): + if name.endswith(("Test.java", "Tests.java")): return True if name.startswith("Test") and name.endswith(".java"): return True # Check if it's in a test directory path_parts = file_path.parts - for part in path_parts: - if part in ("test", "tests", "src/test"): - return True - - return False + return any(part in ("test", "tests", "src/test") for part in path_parts) def get_test_methods_for_class( - test_file: Path, - test_class_name: str | None = None, - analyzer: JavaAnalyzer | None = None, + test_file: Path, test_class_name: str | None = None, analyzer: JavaAnalyzer | None = None ) -> list[FunctionToOptimize]: """Get all test methods in a specific test class. @@ -430,8 +389,7 @@ def get_test_methods_for_class( def build_test_mapping_for_project( - project_root: Path, - analyzer: JavaAnalyzer | None = None, + project_root: Path, analyzer: JavaAnalyzer | None = None ) -> dict[str, list[TestInfo]]: """Build a complete test mapping for a project. diff --git a/codeflash/languages/java/test_runner.py b/codeflash/languages/java/test_runner.py index 0455782e7..b5e0618a8 100644 --- a/codeflash/languages/java/test_runner.py +++ b/codeflash/languages/java/test_runner.py @@ -31,7 +31,7 @@ # Regex pattern for valid Java class names (package.ClassName format) # Allows: letters, digits, underscores, dots, and dollar signs (inner classes) -_VALID_JAVA_CLASS_NAME = re.compile(r'^[a-zA-Z_$][a-zA-Z0-9_$.]*$') +_VALID_JAVA_CLASS_NAME = re.compile(r"^[a-zA-Z_$][a-zA-Z0-9_$.]*$") def _validate_java_class_name(class_name: str) -> bool: @@ -44,6 +44,7 @@ def _validate_java_class_name(class_name: str) -> bool: Returns: True if valid, False otherwise. + """ return bool(_VALID_JAVA_CLASS_NAME.match(class_name)) @@ -62,19 +63,21 @@ def _validate_test_filter(test_filter: str) -> str: Raises: ValueError: If the test filter contains invalid characters. + """ # Split by comma for multiple test patterns - patterns = [p.strip() for p in test_filter.split(',')] + patterns = [p.strip() for p in test_filter.split(",")] for pattern in patterns: # Remove wildcards for validation (they're allowed in test filters) - name_to_validate = pattern.replace('*', 'A') # Replace * with a valid char + name_to_validate = pattern.replace("*", "A") # Replace * with a valid char if not _validate_java_class_name(name_to_validate): - raise ValueError( + msg = ( f"Invalid test class name or pattern: '{pattern}'. " f"Test names must follow Java identifier rules (letters, digits, underscores, dots, dollar signs)." ) + raise ValueError(msg) return test_filter @@ -134,6 +137,7 @@ def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path, # This is a multi-module project root # Extract modules from pom.xml import re + modules = re.findall(r"([^<]+)", content) # Check if test file is in one of the modules for test_path in test_file_paths: @@ -310,10 +314,7 @@ def run_behavioral_tests( def _compile_tests( - project_root: Path, - env: dict[str, str], - test_module: str | None = None, - timeout: int = 120, + project_root: Path, env: dict[str, str], test_module: str | None = None, timeout: int = 120 ) -> subprocess.CompletedProcess: """Compile test code using Maven (without running tests). @@ -330,12 +331,7 @@ def _compile_tests( mvn = find_maven_executable() if not mvn: logger.error("Maven not found") - return subprocess.CompletedProcess( - args=["mvn"], - returncode=-1, - stdout="", - stderr="Maven not found", - ) + return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found") cmd = [mvn, "test-compile", "-e"] # Show errors but not verbose output @@ -346,37 +342,20 @@ def _compile_tests( try: return subprocess.run( - cmd, - check=False, - cwd=project_root, - env=env, - capture_output=True, - text=True, - timeout=timeout, + cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout ) except subprocess.TimeoutExpired: - logger.error("Maven compilation timed out after %d seconds", timeout) + logger.exception("Maven compilation timed out after %d seconds", timeout) return subprocess.CompletedProcess( - args=cmd, - returncode=-2, - stdout="", - stderr=f"Compilation timed out after {timeout} seconds", + args=cmd, returncode=-2, stdout="", stderr=f"Compilation timed out after {timeout} seconds" ) except Exception as e: logger.exception("Maven compilation failed: %s", e) - return subprocess.CompletedProcess( - args=cmd, - returncode=-1, - stdout="", - stderr=str(e), - ) + return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e)) def _get_test_classpath( - project_root: Path, - env: dict[str, str], - test_module: str | None = None, - timeout: int = 60, + project_root: Path, env: dict[str, str], test_module: str | None = None, timeout: int = 60 ) -> str | None: """Get the test classpath from Maven. @@ -397,13 +376,7 @@ def _get_test_classpath( # Create temp file for classpath output cp_file = project_root / ".codeflash_classpath.txt" - cmd = [ - mvn, - "dependency:build-classpath", - "-DincludeScope=test", - f"-Dmdep.outputFile={cp_file}", - "-q", - ] + cmd = [mvn, "dependency:build-classpath", "-DincludeScope=test", f"-Dmdep.outputFile={cp_file}", "-q"] if test_module: cmd.extend(["-pl", test_module]) @@ -412,13 +385,7 @@ def _get_test_classpath( try: result = subprocess.run( - cmd, - check=False, - cwd=project_root, - env=env, - capture_output=True, - text=True, - timeout=timeout, + cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout ) if result.returncode != 0: @@ -450,7 +417,7 @@ def _get_test_classpath( return os.pathsep.join(cp_parts) except subprocess.TimeoutExpired: - logger.error("Getting classpath timed out") + logger.exception("Getting classpath timed out") return None except Exception as e: logger.exception("Failed to get classpath: %s", e) @@ -525,30 +492,16 @@ def _run_tests_direct( try: return subprocess.run( - cmd, - check=False, - cwd=working_dir, - env=env, - capture_output=True, - text=True, - timeout=timeout, + cmd, check=False, cwd=working_dir, env=env, capture_output=True, text=True, timeout=timeout ) except subprocess.TimeoutExpired: - logger.error("Direct test execution timed out after %d seconds", timeout) + logger.exception("Direct test execution timed out after %d seconds", timeout) return subprocess.CompletedProcess( - args=cmd, - returncode=-2, - stdout="", - stderr=f"Test execution timed out after {timeout} seconds", + args=cmd, returncode=-2, stdout="", stderr=f"Test execution timed out after {timeout} seconds" ) except Exception as e: logger.exception("Direct test execution failed: %s", e) - return subprocess.CompletedProcess( - args=cmd, - returncode=-1, - stdout="", - stderr=str(e), - ) + return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e)) def _get_test_class_names(test_paths: Any, mode: str = "performance") -> list[str]: @@ -603,10 +556,7 @@ def _get_empty_result(maven_root: Path, test_module: str | None) -> tuple[Path, result_xml_path = _get_combined_junit_xml(surefire_dir, -1) empty_result = subprocess.CompletedProcess( - args=["java", "-cp", "...", "ConsoleLauncher"], - returncode=-1, - stdout="", - stderr="No test classes found", + args=["java", "-cp", "...", "ConsoleLauncher"], returncode=-1, stdout="", stderr="No test classes found" ) return result_xml_path, empty_result @@ -665,12 +615,7 @@ def _run_benchmarking_tests_maven( run_env["CODEFLASH_INNER_ITERATIONS"] = str(inner_iterations) result = _run_maven_tests( - maven_root, - test_paths, - run_env, - timeout=per_loop_timeout, - mode="performance", - test_module=test_module, + maven_root, test_paths, run_env, timeout=per_loop_timeout, mode="performance", test_module=test_module ) last_result = result @@ -683,27 +628,20 @@ def _run_benchmarking_tests_maven( elapsed = time.time() - total_start_time if loop_idx >= min_loops and elapsed >= target_duration_seconds: - logger.debug( - "Stopping Maven benchmark after %d loops (%.2fs elapsed)", - loop_idx, - elapsed, - ) + logger.debug("Stopping Maven benchmark after %d loops (%.2fs elapsed)", loop_idx, elapsed) break # Check if we have timing markers even if some tests failed # We should continue looping if we're getting valid timing data if result.returncode != 0: import re + timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!") has_timing_markers = bool(timing_pattern.search(result.stdout or "")) if not has_timing_markers: logger.warning("Tests failed in Maven loop %d with no timing markers, stopping", loop_idx) break - else: - logger.debug( - "Some tests failed in Maven loop %d but timing markers present, continuing", - loop_idx, - ) + logger.debug("Some tests failed in Maven loop %d but timing markers present, continuing", loop_idx) combined_stdout = "\n".join(all_stdout) combined_stderr = "\n".join(all_stderr) @@ -801,8 +739,15 @@ def run_benchmarking_tests( # Fall back to Maven-based execution logger.warning("Falling back to Maven-based test execution") return _run_benchmarking_tests_maven( - test_paths, test_env, cwd, timeout, project_root, - min_loops, max_loops, target_duration_seconds, inner_iterations + test_paths, + test_env, + cwd, + timeout, + project_root, + min_loops, + max_loops, + target_duration_seconds, + inner_iterations, ) logger.debug("Compilation completed in %.2fs", compile_time) @@ -814,8 +759,15 @@ def run_benchmarking_tests( if not classpath: logger.warning("Failed to get classpath, falling back to Maven-based execution") return _run_benchmarking_tests_maven( - test_paths, test_env, cwd, timeout, project_root, - min_loops, max_loops, target_duration_seconds, inner_iterations + test_paths, + test_env, + cwd, + timeout, + project_root, + min_loops, + max_loops, + target_duration_seconds, + inner_iterations, ) # Step 3: Run tests multiple times directly via JVM @@ -853,12 +805,7 @@ def run_benchmarking_tests( # Run tests directly with XML report generation loop_start = time.time() result = _run_tests_direct( - classpath, - test_classes, - run_env, - working_dir, - timeout=per_loop_timeout, - reports_dir=reports_dir, + classpath, test_classes, run_env, working_dir, timeout=per_loop_timeout, reports_dir=reports_dir ) loop_time = time.time() - loop_start @@ -875,12 +822,7 @@ def run_benchmarking_tests( # Check if JUnit Console Launcher is not available (JUnit 4 projects) # Fall back to Maven-based execution in this case - if ( - loop_idx == 1 - and result.returncode != 0 - and result.stderr - and "ConsoleLauncher" in result.stderr - ): + if loop_idx == 1 and result.returncode != 0 and result.stderr and "ConsoleLauncher" in result.stderr: logger.debug("JUnit Console Launcher not available, falling back to Maven-based execution") return _run_benchmarking_tests_maven( test_paths, @@ -909,16 +851,13 @@ def run_benchmarking_tests( # Check if tests failed - continue looping if we have timing markers if result.returncode != 0: import re + timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!") has_timing_markers = bool(timing_pattern.search(result.stdout or "")) if not has_timing_markers: logger.warning("Tests failed in loop %d with no timing markers, stopping benchmark", loop_idx) break - else: - logger.debug( - "Some tests failed in loop %d but timing markers present, continuing", - loop_idx, - ) + logger.debug("Some tests failed in loop %d but timing markers present, continuing", loop_idx) # Create a combined result with all stdout combined_stdout = "\n".join(all_stdout) @@ -1075,12 +1014,7 @@ def _run_maven_tests( mvn = find_maven_executable() if not mvn: logger.error("Maven not found") - return subprocess.CompletedProcess( - args=["mvn"], - returncode=-1, - stdout="", - stderr="Maven not found", - ) + return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found") # Build test filter test_filter = _build_test_filter(test_paths, mode=mode) @@ -1110,33 +1044,18 @@ def _run_maven_tests( logger.debug("Running Maven command: %s in %s", " ".join(cmd), project_root) try: - result = subprocess.run( - cmd, - check=False, - cwd=project_root, - env=env, - capture_output=True, - text=True, - timeout=timeout, + return subprocess.run( + cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout ) - return result except subprocess.TimeoutExpired: - logger.error("Maven test execution timed out after %d seconds", timeout) + logger.exception("Maven test execution timed out after %d seconds", timeout) return subprocess.CompletedProcess( - args=cmd, - returncode=-2, - stdout="", - stderr=f"Test execution timed out after {timeout} seconds", + args=cmd, returncode=-2, stdout="", stderr=f"Test execution timed out after {timeout} seconds" ) except Exception as e: logger.exception("Maven test execution failed: %s", e) - return subprocess.CompletedProcess( - args=cmd, - returncode=-1, - stdout="", - stderr=str(e), - ) + return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e)) def _build_test_filter(test_paths: Any, mode: str = "behavior") -> str: @@ -1196,7 +1115,7 @@ def _path_to_class_name(path: Path) -> str | None: Fully qualified class name, or None if unable to determine. """ - if not path.suffix == ".java": + if path.suffix != ".java": return None # Try to extract package from path @@ -1219,7 +1138,7 @@ def _path_to_class_name(path: Path) -> str | None: break if java_idx is not None: - class_parts = parts[java_idx + 1:] + class_parts = parts[java_idx + 1 :] # Remove .java extension from last part class_parts[-1] = class_parts[-1].replace(".java", "") return ".".join(class_parts) @@ -1228,12 +1147,7 @@ def _path_to_class_name(path: Path) -> str | None: return path.stem -def run_tests( - test_files: list[Path], - cwd: Path, - env: dict[str, str], - timeout: int, -) -> tuple[list[TestResult], Path]: +def run_tests(test_files: list[Path], cwd: Path, env: dict[str, str], timeout: int) -> tuple[list[TestResult], Path]: """Run tests and return results. Args: @@ -1366,10 +1280,7 @@ def _parse_surefire_xml(xml_file: Path) -> list[TestResult]: return results -def get_test_run_command( - project_root: Path, - test_classes: list[str] | None = None, -) -> list[str]: +def get_test_run_command(project_root: Path, test_classes: list[str] | None = None) -> list[str]: """Get the command to run Java tests. Args: @@ -1389,10 +1300,8 @@ def get_test_run_command( validated_classes = [] for test_class in test_classes: if not _validate_java_class_name(test_class): - raise ValueError( - f"Invalid test class name: '{test_class}'. " - f"Test names must follow Java identifier rules." - ) + msg = f"Invalid test class name: '{test_class}'. Test names must follow Java identifier rules." + raise ValueError(msg) validated_classes.append(test_class) cmd.append(f"-Dtest={','.join(validated_classes)}") diff --git a/codeflash/languages/javascript/find_references.py b/codeflash/languages/javascript/find_references.py index 812f7c4a7..8fe144a06 100644 --- a/codeflash/languages/javascript/find_references.py +++ b/codeflash/languages/javascript/find_references.py @@ -213,7 +213,7 @@ def find_references( if import_info: context.visited_files.add(file_path) - import_name, original_import = import_info + import_name, _original_import = import_info file_refs = self._find_references_in_file( file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True ) diff --git a/codeflash/languages/javascript/module_system.py b/codeflash/languages/javascript/module_system.py index 4e4e3bb0c..dcd2d2fc7 100644 --- a/codeflash/languages/javascript/module_system.py +++ b/codeflash/languages/javascript/module_system.py @@ -373,9 +373,14 @@ def ensure_vitest_imports(code: str, test_framework: str) -> str: insert_index = 0 for i, line in enumerate(lines): stripped = line.strip() - if stripped and not stripped.startswith("//") and not stripped.startswith("/*") and not stripped.startswith("*"): + if ( + stripped + and not stripped.startswith("//") + and not stripped.startswith("/*") + and not stripped.startswith("*") + ): # Check if this line is an import/require - insert after imports - if stripped.startswith("import ") or stripped.startswith("const ") or stripped.startswith("let "): + if stripped.startswith(("import ", "const ", "let ")): continue insert_index = i break diff --git a/codeflash/models/models.py b/codeflash/models/models.py index d09654722..2a034afdf 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -325,9 +325,7 @@ def file_to_path(self) -> dict[str, str]: """ if "file_to_path" in self._cache: return self._cache["file_to_path"] - result = { - str(code_string.file_path): code_string.code for code_string in self.code_strings - } + result = {str(code_string.file_path): code_string.code for code_string in self.code_strings} self._cache["file_to_path"] = result return result diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 900d3ea8c..be69bd544 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -2,7 +2,6 @@ import ast import concurrent.futures -import logging import os import queue import random @@ -23,7 +22,7 @@ from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success from codeflash.benchmarking.utils import process_benchmark_data -from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar +from codeflash.cli_cmds.console import DEBUG_MODE, code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils from codeflash.code_utils.code_extractor import get_opt_review_metrics, is_numerical_code from codeflash.code_utils.code_replacer import ( @@ -146,9 +145,70 @@ from codeflash.verification.verification_utils import TestConfig +def log_code_after_replacement(file_path: Path, candidate_index: int) -> None: + """Log the full file content after code replacement in verbose mode.""" + if not DEBUG_MODE: + return + + try: + code = file_path.read_text(encoding="utf-8") + lang_map = {".java": "java", ".py": "python", ".js": "javascript", ".ts": "typescript"} + language = lang_map.get(file_path.suffix.lower(), "text") + + console.print( + Panel( + Syntax(code, language, line_numbers=True, theme="monokai", word_wrap=True), + title=f"[bold blue]Code After Replacement (Candidate {candidate_index})[/] [dim]({file_path.name})[/]", + border_style="blue", + ) + ) + except Exception as e: + logger.debug(f"Failed to log code after replacement: {e}") + + +def log_instrumented_test(test_source: str, test_name: str, test_type: str, language: str) -> None: + """Log instrumented test code in verbose mode.""" + if not DEBUG_MODE: + return + + display_source = test_source + if len(test_source) > 15000: + display_source = test_source[:15000] + "\n\n... [truncated] ..." + + console.print( + Panel( + Syntax(display_source, language, line_numbers=True, theme="monokai", word_wrap=True), + title=f"[bold magenta]Instrumented Test: {test_name}[/] [dim]({test_type})[/]", + border_style="magenta", + ) + ) + + +def log_test_run_output(stdout: str, stderr: str, test_type: str, returncode: int = 0) -> None: + """Log test run stdout/stderr in verbose mode.""" + if not DEBUG_MODE: + return + + max_len = 10000 + + if stdout and stdout.strip(): + display_stdout = stdout[:max_len] + ("...[truncated]" if len(stdout) > max_len else "") + console.print( + Panel( + display_stdout, + title=f"[bold green]{test_type} - stdout[/] [dim](exit: {returncode})[/]", + border_style="green" if returncode == 0 else "red", + ) + ) + + if stderr and stderr.strip(): + display_stderr = stderr[:max_len] + ("...[truncated]" if len(stderr) > max_len else "") + console.print(Panel(display_stderr, title=f"[bold yellow]{test_type} - stderr[/]", border_style="yellow")) + + def log_optimization_context(function_name: str, code_context: CodeOptimizationContext) -> None: """Log optimization context details when in verbose mode using Rich formatting.""" - if logger.getEffectiveLevel() > logging.DEBUG: + if not DEBUG_MODE: return console.rule() @@ -594,18 +654,32 @@ def generate_and_instrument_tests( generated_test.instrumented_perf_test_source = modified_perf_source used_behavior_paths.add(behavior_path) - logger.debug( - f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}" - ) + logger.debug(f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}") with behavior_path.open("w", encoding="utf8") as f: f.write(generated_test.instrumented_behavior_test_source) logger.debug(f"[PIPELINE] Wrote behavioral test to {behavior_path}") + # Verbose: Log instrumented behavior test + log_instrumented_test( + generated_test.instrumented_behavior_test_source, + behavior_path.name, + "Behavioral Test", + language=self.function_to_optimize.language, + ) + with perf_path.open("w", encoding="utf8") as f: f.write(generated_test.instrumented_perf_test_source) logger.debug(f"[PIPELINE] Wrote perf test to {perf_path}") + # Verbose: Log instrumented performance test + log_instrumented_test( + generated_test.instrumented_perf_test_source, + perf_path.name, + "Performance Test", + language=self.function_to_optimize.language, + ) + # File paths are expected to be absolute - resolved at their source (CLI, TestConfig, etc.) test_file_obj = TestFile( instrumented_behavior_file_path=generated_test.behavior_file_path, @@ -675,22 +749,24 @@ def _get_java_sources_root(self) -> Path: parts = tests_root.parts # Look for standard Java package prefixes that indicate the start of package structure - standard_package_prefixes = ('com', 'org', 'net', 'io', 'edu', 'gov') + standard_package_prefixes = ("com", "org", "net", "io", "edu", "gov") for i, part in enumerate(parts): if part in standard_package_prefixes: # Found start of package path, return everything before it if i > 0: java_sources_root = Path(*parts[:i]) - logger.debug(f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})") + logger.debug( + f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})" + ) return java_sources_root # If no standard package prefix found, check if there's a 'java' directory # (standard Maven structure: src/test/java) for i, part in enumerate(parts): - if part == 'java' and i > 0: + if part == "java" and i > 0: # Return up to and including 'java' - java_sources_root = Path(*parts[:i + 1]) + java_sources_root = Path(*parts[: i + 1]) logger.debug(f"[JAVA] Detected Maven-style Java sources root: {java_sources_root}") return java_sources_root @@ -721,16 +797,16 @@ def _fix_java_test_paths( import re # Extract package from behavior source - package_match = re.search(r'^\s*package\s+([\w.]+)\s*;', behavior_source, re.MULTILINE) + package_match = re.search(r"^\s*package\s+([\w.]+)\s*;", behavior_source, re.MULTILINE) package_name = package_match.group(1) if package_match else "" # Extract class name from behavior source # Use more specific pattern to avoid matching words like "command" or text in comments - class_match = re.search(r'^(?:public\s+)?class\s+(\w+)', behavior_source, re.MULTILINE) + class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", behavior_source, re.MULTILINE) behavior_class = class_match.group(1) if class_match else "GeneratedTest" # Extract class name from perf source - perf_class_match = re.search(r'^(?:public\s+)?class\s+(\w+)', perf_source, re.MULTILINE) + perf_class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", perf_source, re.MULTILINE) perf_class = perf_class_match.group(1) if perf_class_match else "GeneratedPerfTest" # Build paths with package structure @@ -767,22 +843,20 @@ def _fix_java_test_paths( perf_path = new_perf_path # Rename class in source code - replace the class declaration modified_behavior_source = re.sub( - rf'^((?:public\s+)?class\s+){re.escape(behavior_class)}(\b)', - rf'\g<1>{new_behavior_class}\g<2>', + rf"^((?:public\s+)?class\s+){re.escape(behavior_class)}(\b)", + rf"\g<1>{new_behavior_class}\g<2>", behavior_source, count=1, flags=re.MULTILINE, ) modified_perf_source = re.sub( - rf'^((?:public\s+)?class\s+){re.escape(perf_class)}(\b)', - rf'\g<1>{new_perf_class}\g<2>', + rf"^((?:public\s+)?class\s+){re.escape(perf_class)}(\b)", + rf"\g<1>{new_perf_class}\g<2>", perf_source, count=1, flags=re.MULTILINE, ) - logger.debug( - f"[JAVA] Renamed duplicate test class from {behavior_class} to {new_behavior_class}" - ) + logger.debug(f"[JAVA] Renamed duplicate test class from {behavior_class} to {new_behavior_class}") break index += 1 @@ -1199,6 +1273,9 @@ def process_single_candidate( logger.info("No functions were replaced in the optimized code. Skipping optimization candidate.") console.rule() return None + + # Verbose: Log code after replacement + log_code_after_replacement(self.function_to_optimize.file_path, candidate_index) except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: logger.error(e) self.write_code_and_helpers( @@ -1764,6 +1841,14 @@ def get_instrumented_path(original_path: str, suffix: str) -> Path: with new_behavioral_test_path.open("w", encoding="utf8") as _f: _f.write(injected_behavior_test) logger.debug(f"[PIPELINE] Wrote instrumented behavior test to {new_behavioral_test_path}") + + # Verbose: Log instrumented existing behavior test + log_instrumented_test( + injected_behavior_test, + new_behavioral_test_path.name, + "Existing Behavioral Test", + language=self.function_to_optimize.language, + ) else: msg = "injected_behavior_test is None" raise ValueError(msg) @@ -1773,6 +1858,14 @@ def get_instrumented_path(original_path: str, suffix: str) -> Path: _f.write(injected_perf_test) logger.debug(f"[PIPELINE] Wrote instrumented perf test to {new_perf_test_path}") + # Verbose: Log instrumented existing performance test + log_instrumented_test( + injected_perf_test, + new_perf_test_path.name, + "Existing Performance Test", + language=self.function_to_optimize.language, + ) + unique_instrumented_test_files.add(new_behavioral_test_path) unique_instrumented_test_files.add(new_perf_test_path) @@ -2239,7 +2332,7 @@ def process_review( formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds) generated_tests_str += f"```{code_lang}\n{formatted_generated_test}\n```\n\n" - existing_tests, replay_tests, concolic_tests = existing_tests_source_for( + existing_tests, replay_tests, _concolic_tests = existing_tests_source_for( self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), function_to_all_tests, test_cfg=self.test_cfg, @@ -2880,6 +2973,11 @@ def run_and_parse_tests( else: msg = f"Unexpected testing type: {testing_type}" raise ValueError(msg) + + # Verbose: Log test run output + log_test_run_output( + run_result.stdout, run_result.stderr, f"Test Run ({testing_type.name})", run_result.returncode + ) except subprocess.TimeoutExpired: logger.exception( f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error" diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 759e4ecb2..6e34648c3 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -512,8 +512,10 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes # Check if the file name matches the module path file_stem = test_file.instrumented_behavior_file_path.stem # The instrumented file has __perfinstrumented suffix - original_class = file_stem.replace("__perfinstrumented", "").replace("__perfonlyinstrumented", "") - if original_class == test_module_path or file_stem == test_module_path: + original_class = file_stem.replace("__perfinstrumented", "").replace( + "__perfonlyinstrumented", "" + ) + if test_module_path in (original_class, file_stem): test_file_path = test_file.instrumented_behavior_file_path break # Check original file path @@ -551,7 +553,9 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes # Default to GENERATED_REGRESSION for Jest/Java tests when test type can't be determined if test_type is None and (is_jest or is_java_test): test_type = TestType.GENERATED_REGRESSION - logger.debug(f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION ({'Jest' if is_jest else 'Java'})") + logger.debug( + f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION ({'Jest' if is_jest else 'Java'})" + ) elif test_type is None: # Skip results where test type cannot be determined logger.debug(f"Skipping result for {test_function_name}: could not determine test type") diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index 9766a3951..45b96ff51 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -146,7 +146,9 @@ def _detect_java_test_framework(self) -> str: pom_path = current / "pom.xml" if pom_path.exists(): parent_config = detect_java_project(current) - if parent_config and (parent_config.has_junit4 or parent_config.has_junit5 or parent_config.has_testng): + if parent_config and ( + parent_config.has_junit4 or parent_config.has_junit5 or parent_config.has_testng + ): return parent_config.test_framework current = current.parent diff --git a/codeflash/version.py b/codeflash/version.py index 6225467e3..67379ab0c 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.0" +__version__ = "0.20.0.post414.dev0+2ad731d3"