Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def optimize_code(
is_async: bool = False,
n_candidates: int = 5,
is_numerical_code: bool | None = None,
rerun_trace_id: str | None = None,
) -> list[OptimizedCandidate]:
"""Optimize the given code for performance by making a request to the Django endpoint.

Expand Down Expand Up @@ -194,6 +195,7 @@ def optimize_code(
"call_sequence": self.get_next_sequence(),
"n_candidates": n_candidates,
"is_numerical_code": is_numerical_code,
"rerun_trace_id": rerun_trace_id,
}

self.add_language_metadata(payload, language_version, module_system)
Expand Down Expand Up @@ -234,6 +236,7 @@ def optimize_python_code_line_profiler(
is_numerical_code: bool | None = None,
language: str = "python",
language_version: str | None = None,
rerun_trace_id: str | None = None,
) -> list[OptimizedCandidate]:
"""Optimize code for performance using line profiler results.

Expand Down Expand Up @@ -272,6 +275,7 @@ def optimize_python_code_line_profiler(
"codeflash_version": codeflash_version,
"call_sequence": self.get_next_sequence(),
"is_numerical_code": is_numerical_code,
"rerun_trace_id": rerun_trace_id,
}

try:
Expand Down Expand Up @@ -318,7 +322,9 @@ def adaptive_optimize(self, request: AIServiceAdaptiveOptimizeRequest) -> Optimi
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
return None

def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
def optimize_code_refinement(
self, request: list[AIServiceRefinerRequest], rerun_trace_id: str | None = None
) -> list[OptimizedCandidate]:
"""Refine optimization candidates for improved performance.

Supports Python, JavaScript, and TypeScript code refinement with optional
Expand Down Expand Up @@ -349,6 +355,7 @@ def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> li
"call_sequence": self.get_next_sequence(),
# Multi-language support
"language": opt.language,
"rerun_trace_id": rerun_trace_id,
}

self.add_language_metadata(item, opt.language_version)
Expand All @@ -375,7 +382,9 @@ def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> li
console.rule()
return []

def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None:
def code_repair(
self, request: AIServiceCodeRepairRequest, rerun_trace_id: str | None = None
) -> OptimizedCandidate | None:
console.rule()
try:
payload = {
Expand All @@ -385,6 +394,7 @@ def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate
"trace_id": request.trace_id,
"test_diffs": request.test_diffs,
"language": request.language,
"rerun_trace_id": rerun_trace_id,
}
response = self.make_ai_service_request("/code_repair", payload=payload, timeout=self.timeout)
except (requests.exceptions.RequestException, TypeError) as e:
Expand Down Expand Up @@ -607,6 +617,7 @@ def generate_regression_tests(
language_version: str | None = None,
module_system: str | None = None,
is_numerical_code: bool | None = None,
rerun_trace_id: str | None = None,
) -> tuple[str, str, str, str | None] | None:
"""Generate regression tests for the given function by making a request to the Django endpoint.

Expand Down Expand Up @@ -655,6 +666,7 @@ def generate_regression_tests(
"is_numerical_code": is_numerical_code,
"class_name": function_to_optimize.class_name,
"qualified_name": function_to_optimize.qualified_name,
"rerun_trace_id": rerun_trace_id,
}

self.add_language_metadata(payload, language_version, module_system)
Expand Down Expand Up @@ -913,6 +925,34 @@ def generate_workflow_steps(
logger.debug("[aiservice.py:generate_workflow_steps] Could not parse error response")
return None

def prescreen_functions(
self, functions: list[dict[str, str]], trace_id: str | None = None
) -> dict[str, dict[str, Any]] | None:
"""Pre-screen functions for optimization potential using LLM assessment.

Args:
functions: List of dicts with keys: qualified_name, source_code, language
trace_id: Optional trace ID for logging

Returns:
Dict mapping qualified_name to {score: int, optimizable: bool, reason: str},
or None if the call fails.

"""
payload: dict[str, Any] = {"functions": functions, "trace_id": trace_id}
try:
response = self.make_ai_service_request("/prescreen", payload=payload, timeout=30)
except requests.exceptions.RequestException as e:
logger.debug(f"Prescreening request failed: {e}")
return None

if response.status_code == 200:
logger.info(f"loading|Pre-screening {len(functions)} function(s) for optimization potential")
result: dict[str, dict[str, Any]] = response.json().get("functions", {})
return result
logger.debug(f"Prescreening returned status {response.status_code}")
return None


class LocalAiServiceClient(AiServiceClient):
"""Client for interacting with the local AI service."""
Expand Down
30 changes: 30 additions & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def process_pyproject_config(args: Namespace) -> Namespace:
"disable_imports_sorting",
"git_remote",
"override_fixtures",
"max_functions",
"max_time",
]
for key in supported_keys:
if key in pyproject_config and (
Expand Down Expand Up @@ -489,5 +491,33 @@ def _build_parser() -> ArgumentParser:
action="store_true",
help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.",
)
parser.add_argument(
"--max-functions",
type=int,
default=None,
help="Maximum number of functions to optimize per run. After ranking, only the top N functions are optimized. "
"Useful for controlling CI runtime.",
)
parser.add_argument(
"--max-time",
type=int,
default=None,
help="Maximum total optimization time in minutes. Stops optimizing new functions after this budget is exhausted. "
"Functions already in progress will complete.",
)
parser.add_argument(
"--prescreening",
action="store_true",
default=False,
help="Enable LLM-based pre-screening of functions before optimization. "
"Uses a fast LLM call to assess whether each function has meaningful optimization potential. "
"Filters out functions unlikely to benefit from optimization.",
)
parser.add_argument(
"--no-prescreening",
action="store_true",
default=False,
help="Disable LLM pre-screening even when it would be auto-enabled (e.g., in CI with --all).",
)

return parser
232 changes: 232 additions & 0 deletions codeflash/discovery/diff_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from __future__ import annotations

import re
from dataclasses import dataclass
from enum import Enum
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING

import git
from unidiff import PatchSet

from codeflash.cli_cmds.console import logger

if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize


class DiffCategory(str, Enum):
COSMETIC = "cosmetic"
TRIVIAL = "trivial"
MEANINGFUL = "meaningful"
MAJOR = "major"


@dataclass(frozen=True)
class FunctionDiffInfo:
category: DiffCategory
added_logic_lines: int
removed_logic_lines: int
total_changed_lines: int
is_comment_only: bool
is_whitespace_only: bool


# Patterns for comments across languages
_COMMENT_PATTERNS = [
re.compile(r"^\s*#"), # Python
re.compile(r"^\s*//"), # JS/TS/Java
re.compile(r"^\s*\*"), # Multiline comment body
re.compile(r"^\s*/\*"), # Multiline comment start
re.compile(r"^\s*\*/"), # Multiline comment end
re.compile(r'^\s*"""'), # Python docstring
re.compile(r"^\s*'''"), # Python docstring
]

# Patterns for import/require statements
_IMPORT_PATTERNS = [
re.compile(r"^\s*(import |from \S+ import )"), # Python
re.compile(r"^\s*(const|let|var)\s+.*=\s*require\("), # JS require
re.compile(r"^\s*import\s+"), # JS/TS/Java import
]


def _is_comment_line(line: str) -> bool:
return any(p.match(line) for p in _COMMENT_PATTERNS)


def _is_import_line(line: str) -> bool:
return any(p.match(line) for p in _IMPORT_PATTERNS)


def _is_logic_line(line: str) -> bool:
stripped = line.strip()
if not stripped:
return False
if _is_comment_line(line):
return False
# String-only lines (just a string literal)
if stripped.startswith(('"""', "'''", '"', "'")) and stripped.endswith(('"""', "'''", '"', "'")):
return False
return True


def classify_function_diff(func: FunctionToOptimize, repo_directory: Path | None = None) -> FunctionDiffInfo:
"""Classify the type of change made to a function based on git diff content."""
if func.starting_line is None or func.ending_line is None:
return FunctionDiffInfo(
category=DiffCategory.MEANINGFUL,
added_logic_lines=0,
removed_logic_lines=0,
total_changed_lines=0,
is_comment_only=False,
is_whitespace_only=False,
)

diff_lines = _get_function_diff_lines(func, repo_directory)
if not diff_lines:
return FunctionDiffInfo(
category=DiffCategory.COSMETIC,
added_logic_lines=0,
removed_logic_lines=0,
total_changed_lines=0,
is_comment_only=False,
is_whitespace_only=True,
)

added_lines = [line for line in diff_lines if line.startswith("+")]
removed_lines = [line for line in diff_lines if line.startswith("-")]
total_changed = len(added_lines) + len(removed_lines)

# Strip the +/- prefix for content analysis
added_content = [line[1:] for line in added_lines]
removed_content = [line[1:] for line in removed_lines]

# Check if all changes are whitespace-only
added_stripped = [line.strip() for line in added_content]
removed_stripped = [line.strip() for line in removed_content]
if all(not s for s in added_stripped) and all(not s for s in removed_stripped):
return FunctionDiffInfo(
category=DiffCategory.COSMETIC,
added_logic_lines=0,
removed_logic_lines=0,
total_changed_lines=total_changed,
is_comment_only=False,
is_whitespace_only=True,
)

# Check if all changes are comment-only
added_logic = [line for line in added_content if _is_logic_line(line)]
removed_logic = [line for line in removed_content if _is_logic_line(line)]

is_comment_only = len(added_logic) == 0 and len(removed_logic) == 0
if is_comment_only:
return FunctionDiffInfo(
category=DiffCategory.COSMETIC,
added_logic_lines=0,
removed_logic_lines=0,
total_changed_lines=total_changed,
is_comment_only=True,
is_whitespace_only=False,
)

# Classify by logic change magnitude
logic_change_count = len(added_logic) + len(removed_logic)

if logic_change_count <= 2:
category = DiffCategory.TRIVIAL
elif logic_change_count <= 10:
category = DiffCategory.MEANINGFUL
else:
category = DiffCategory.MAJOR

return FunctionDiffInfo(
category=category,
added_logic_lines=len(added_logic),
removed_logic_lines=len(removed_logic),
total_changed_lines=total_changed,
is_comment_only=False,
is_whitespace_only=False,
)


def _get_function_diff_lines(func: FunctionToOptimize, repo_directory: Path | None = None) -> list[str]:
"""Extract diff lines that fall within a function's line range."""
if repo_directory is None:
repo_directory = Path.cwd()

try:
repository = git.Repo(repo_directory, search_parent_directories=True)
except git.InvalidGitRepositoryError:
return []

commit = repository.head.commit
try:
uni_diff_text = repository.git.diff(
commit.hexsha + "^1", commit.hexsha, ignore_blank_lines=True, ignore_space_at_eol=True
)
except git.GitCommandError:
return []

patch_set = PatchSet(StringIO(uni_diff_text))
func_start = func.starting_line or 0
func_end = func.ending_line or 0
result: list[str] = []

for patched_file in patch_set:
file_path = Path(repository.working_dir) / patched_file.path
if file_path != func.file_path:
continue

for hunk in patched_file:
for line in hunk:
if line.is_added and line.target_line_no and func_start <= line.target_line_no <= func_end:
result.append(f"+{line.value}")
elif line.is_removed and line.source_line_no and func_start <= line.source_line_no <= func_end:
result.append(f"-{line.value}")

return result


def filter_cosmetic_diff_functions(
functions: dict[Path, list[FunctionToOptimize]], repo_directory: Path | None = None
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
"""Remove functions where the diff is purely cosmetic (comments/whitespace only)."""
filtered: dict[Path, list[FunctionToOptimize]] = {}
skipped_count = 0

for file_path, funcs in functions.items():
kept: list[FunctionToOptimize] = []
for func in funcs:
try:
diff_info = classify_function_diff(func, repo_directory)
except Exception:
kept.append(func)
continue

if diff_info.category == DiffCategory.COSMETIC:
skipped_count += 1
logger.debug(
f"Skipping {func.qualified_name} — diff is cosmetic "
f"({'comments only' if diff_info.is_comment_only else 'whitespace only'})"
)
else:
kept.append(func)

if kept:
filtered[file_path] = kept

if skipped_count > 0:
logger.info(f"Diff analysis: skipped {skipped_count} function(s) with cosmetic-only changes")

return filtered, skipped_count


def get_effort_for_diff(diff_info: FunctionDiffInfo) -> str | None:
"""Suggest an effort level based on diff category. Returns None to use default."""
if diff_info.category == DiffCategory.TRIVIAL:
return "low"
if diff_info.category == DiffCategory.MAJOR:
return "high"
return None
Loading
Loading