Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.extension import install_vscode_extension

# Import Java init module
from codeflash.cli_cmds.init_java import init_java_project

# Import JS/TS init module
from codeflash.cli_cmds.init_javascript import (
ProjectLanguage,
Expand All @@ -35,9 +38,6 @@
get_js_dependency_installation_commands,
init_js_project,
)

# Import Java init module
from codeflash.cli_cmds.init_java import init_java_project
from codeflash.code_utils.code_utils import validate_relative_directory_path
from codeflash.code_utils.compat import LF
from codeflash.code_utils.config_parser import parse_config_file
Expand Down Expand Up @@ -1674,9 +1674,7 @@ def _customize_java_workflow_content(optimize_yml_content: str, git_root: Path,

# Install dependencies
install_deps_cmd = get_java_dependency_installation_commands(build_tool)
optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)

return optimize_yml_content
return optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)


def get_formatter_cmds(formatter: str) -> list[str]:
Expand Down
31 changes: 12 additions & 19 deletions codeflash/cli_cmds/init_java.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ def init_java_project() -> None:

lang_panel = Panel(
Text(
"Java project detected!\n\nI'll help you set up Codeflash for your project.",
style="cyan",
justify="center",
"Java project detected!\n\nI'll help you set up Codeflash for your project.", style="cyan", justify="center"
),
title="Java Setup",
border_style="bright_red",
Expand Down Expand Up @@ -205,7 +203,9 @@ def init_java_project() -> None:
completion_message = "Codeflash is now set up for your Java project!\n\nYou can now run any of these commands:"

if did_add_new_key:
completion_message += "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
completion_message += (
"\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
)
if os.name == "nt":
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
else:
Expand Down Expand Up @@ -234,9 +234,7 @@ def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]:
codeflash_config_path = project_root / "codeflash.toml"
if codeflash_config_path.exists():
return Confirm.ask(
"A Codeflash config already exists. Do you want to re-configure it?",
default=False,
show_default=True,
"A Codeflash config already exists. Do you want to re-configure it?", default=False, show_default=True
), None

return True, None
Expand Down Expand Up @@ -285,22 +283,18 @@ def collect_java_setup_info() -> JavaSetupInfo:

if Confirm.ask("Would you like to change any of these settings?", default=False):
# Source root override
module_root_override = _prompt_directory_override(
"source", detected_source_root, curdir
)
module_root_override = _prompt_directory_override("source", detected_source_root, curdir)

# Test root override
test_root_override = _prompt_directory_override(
"test", detected_test_root, curdir
)
test_root_override = _prompt_directory_override("test", detected_test_root, curdir)

# Formatter override
formatter_questions = [
inquirer.List(
"formatter",
message="Which code formatter do you use?",
choices=[
(f"keep detected (google-java-format)", "keep"),
("keep detected (google-java-format)", "keep"),
("google-java-format", "google-java-format"),
("spotless", "spotless"),
("other", "other"),
Expand Down Expand Up @@ -345,7 +339,7 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st
subdirs = [d.name for d in curdir.iterdir() if d.is_dir() and not d.name.startswith(".")]
subdirs = [d for d in subdirs if d not in ("target", "build", ".git", ".idea", detected)]

options = [keep_detected_option] + subdirs[:5] + [custom_dir_option]
options = [keep_detected_option, *subdirs[:5], custom_dir_option]

questions = [
inquirer.List(
Expand All @@ -364,10 +358,9 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st
answer = answers[f"{dir_type}_root"]
if answer == keep_detected_option:
return None
elif answer == custom_dir_option:
if answer == custom_dir_option:
return _prompt_custom_directory(dir_type)
else:
return answer
return answer


def _prompt_custom_directory(dir_type: str) -> str:
Expand Down Expand Up @@ -441,7 +434,7 @@ def get_java_formatter_cmd(formatter: str, build_tool: JavaBuildTool) -> list[st
if formatter == "spotless":
if build_tool == JavaBuildTool.MAVEN:
return ["mvn spotless:apply -DspotlessFiles=$file"]
elif build_tool == JavaBuildTool.GRADLE:
if build_tool == JavaBuildTool.GRADLE:
return ["./gradlew spotlessApply"]
return ["spotless $file"]
if formatter == "other":
Expand Down
16 changes: 6 additions & 10 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,18 +711,12 @@ def _add_java_class_members(
if not new_fields and not new_methods:
return original_source

logger.debug(
f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}"
)
logger.debug(f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}")

# Import the insertion function from replacement module
from codeflash.languages.java.replacement import _insert_class_members

result = _insert_class_members(
original_source, class_name, new_fields, new_methods, analyzer
)

return result
return _insert_class_members(original_source, class_name, new_fields, new_methods, analyzer)

except Exception as e:
logger.debug(f"Error adding Java class members: {e}")
Expand Down Expand Up @@ -959,12 +953,14 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
for file_path_str, code in file_to_code_context.items():
if file_path_str:
# Extract filename without creating Path object repeatedly
if file_path_str.endswith(target_filename) and (len(file_path_str) == len(target_filename) or file_path_str[-len(target_filename)-1] in ('/', '\\')):
if file_path_str.endswith(target_filename) and (
len(file_path_str) == len(target_filename)
or file_path_str[-len(target_filename) - 1] in ("/", "\\")
):
module_optimized_code = code
logger.debug(f"Matched {file_path_str} to {relative_path} by filename")
break


if module_optimized_code is None:
# Also try matching if there's only one code file, but ONLY for non-Python
# languages where path matching is less strict. For Python, we require
Expand Down
4 changes: 1 addition & 3 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,9 +721,7 @@ def inject_profiling_into_existing_test(
if is_java():
from codeflash.languages.java.instrumentation import instrument_existing_test

return instrument_existing_test(
test_path, call_positions, function_to_optimize, tests_project_root, mode.value
)
return instrument_existing_test(test_path, call_positions, function_to_optimize, tests_project_root, mode.value)

if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
Expand Down
8 changes: 4 additions & 4 deletions codeflash/languages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@
reset_current_language,
set_current_language,
)

# Java language support
# Importing the module triggers registration via @register_language decorator
from codeflash.languages.java.support import JavaSupport # noqa: F401
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401

# Import language support modules to trigger auto-registration
# This ensures all supported languages are available when this package is imported
from codeflash.languages.python import PythonSupport # noqa: F401

# Java language support
# Importing the module triggers registration via @register_language decorator
from codeflash.languages.java.support import JavaSupport # noqa: F401
from codeflash.languages.registry import (
detect_project_language,
get_language_support,
Expand Down
53 changes: 16 additions & 37 deletions codeflash/languages/java/build_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pathlib import Path

logger = logging.getLogger(__name__)

Expand All @@ -29,6 +32,7 @@ def _safe_parse_xml(file_path: Path) -> ET.ElementTree:

Raises:
ET.ParseError: If XML parsing fails.

"""
# Read file content and parse as string to avoid file-based attacks
# This prevents XXE attacks by not allowing external entity resolution
Expand All @@ -38,9 +42,7 @@ def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
root = ET.fromstring(content)

# Create ElementTree from root
tree = ET.ElementTree(root)

return tree
return ET.ElementTree(root)


class BuildTool(Enum):
Expand Down Expand Up @@ -390,13 +392,7 @@ def run_maven_tests(

try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
env=run_env,
capture_output=True,
text=True,
timeout=timeout,
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
)

# Parse test results from Surefire reports
Expand All @@ -416,7 +412,7 @@ def run_maven_tests(
)

except subprocess.TimeoutExpired:
logger.error("Maven test execution timed out after %d seconds", timeout)
logger.exception("Maven test execution timed out after %d seconds", timeout)
return MavenTestResult(
success=False,
tests_run=0,
Expand Down Expand Up @@ -496,10 +492,7 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:


def compile_maven_project(
project_root: Path,
include_tests: bool = True,
env: dict[str, str] | None = None,
timeout: int = 300,
project_root: Path, include_tests: bool = True, env: dict[str, str] | None = None, timeout: int = 300
) -> tuple[bool, str, str]:
"""Compile a Maven project.

Expand Down Expand Up @@ -533,13 +526,7 @@ def compile_maven_project(

try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
env=run_env,
capture_output=True,
text=True,
timeout=timeout,
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
)

return result.returncode == 0, result.stdout, result.stderr
Expand Down Expand Up @@ -581,14 +568,7 @@ def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path) -> boo
]

try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=60,
)
result = subprocess.run(cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=60)

if result.returncode == 0:
logger.info("Successfully installed codeflash-runtime to local Maven repository")
Expand Down Expand Up @@ -664,7 +644,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
return True

except ET.ParseError as e:
logger.error("Failed to parse pom.xml: %s", e)
logger.exception("Failed to parse pom.xml: %s", e)
return False
except Exception as e:
logger.exception("Failed to add dependency to pom.xml: %s", e)
Expand Down Expand Up @@ -751,11 +731,11 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
# JaCoCo plugin XML to insert (indented for typical pom.xml format)
# Note: For multi-module projects where tests are in a separate module,
# we configure the report to look in multiple directories for classes
jacoco_plugin = """
jacoco_plugin = f"""
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>{version}</version>
<version>{JACOCO_PLUGIN_VERSION}</version>
<executions>
<execution>
<id>prepare-agent</id>
Expand All @@ -777,7 +757,7 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
</configuration>
</execution>
</executions>
</plugin>""".format(version=JACOCO_PLUGIN_VERSION)
</plugin>"""

# Find the main <build> section (not inside <profiles>)
# We need to find a <build> that appears after </profiles> or before <profiles>
Expand All @@ -786,7 +766,6 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
profiles_end = content.find("</profiles>")

# Find all <build> tags
import re

# Find the main build section - it's the one NOT inside profiles
# Strategy: Look for <build> that comes after </profiles> or before <profiles> (or no profiles)
Expand Down Expand Up @@ -816,7 +795,7 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:

if build_start != -1 and build_end != -1:
# Found main build section, find plugins within it
build_section = content[build_start:build_end + len("</build>")]
build_section = content[build_start : build_end + len("</build>")]
plugins_start_in_build = build_section.find("<plugins>")
plugins_end_in_build = build_section.rfind("</plugins>")

Expand Down
Loading
Loading