diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 9a70de5fd..4fdbd5291 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -118,10 +118,11 @@ def get_code_optimization_context( # Get FunctionSource representation of helpers of FTO fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}} + jedi_refs_cache: dict[Path, dict[str, list[Name]]] | None = None if call_graph is not None: helpers_of_fto_dict, helpers_of_fto_list = call_graph.get_callees(fto_input) else: - helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi( + helpers_of_fto_dict, helpers_of_fto_list, jedi_refs_cache = get_function_sources_from_jedi( fto_input, project_root_path, jedi_project=jedi_project ) @@ -141,8 +142,8 @@ def get_code_optimization_context( for qualified_names in helpers_of_fto_qualified_names_dict.values(): qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn}) - helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi( - helpers_of_fto_qualified_names_dict, project_root_path, jedi_project=jedi_project + helpers_of_helpers_dict, helpers_of_helpers_list, _ = get_function_sources_from_jedi( + helpers_of_fto_qualified_names_dict, project_root_path, jedi_project=jedi_project, refs_cache=jedi_refs_cache ) # Extract all code contexts in a single pass (one CST parse per file) @@ -312,12 +313,15 @@ def extract_all_contexts_from_files( logger.debug(f"Error while getting read-writable code: {e}") # READ_ONLY + fto_ro_code_result: str | None = None + fto_ro_pruned_code_str: str | None = None try: ro_pruned = parse_code_and_prune_cst( all_cleaned, CodeContextType.READ_ONLY, fto_names, hoh_names, remove_docstrings=False ) - if ro_pruned.code.strip(): - ro_code = add_needed_imports_from_module( + fto_ro_pruned_code_str = ro_pruned.code.strip() + if fto_ro_pruned_code_str: + fto_ro_code_result = add_needed_imports_from_module( src_module_code=original_module, dst_module_code=ro_pruned, src_path=file_path, @@ -326,7 +330,7 @@ def extract_all_contexts_from_files( helper_functions=all_helper_functions, gathered_imports=src_gathered, ) - ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path)) + ro.code_strings.append(CodeString(code=fto_ro_code_result, file_path=relative_path)) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") @@ -341,21 +345,25 @@ def extract_all_contexts_from_files( except ValueError as e: logger.debug(f"Error while getting hashing code: {e}") - # TESTGEN + # TESTGEN -- reuse RO result when pruned code is identical try: testgen_pruned = parse_code_and_prune_cst( all_cleaned, CodeContextType.TESTGEN, fto_names, hoh_names, remove_docstrings=False ) - if testgen_pruned.code.strip(): - testgen_code = add_needed_imports_from_module( - src_module_code=original_module, - dst_module_code=testgen_pruned, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=all_helper_functions, - gathered_imports=src_gathered, - ) + fto_testgen_pruned_code_str = testgen_pruned.code.strip() + if fto_testgen_pruned_code_str: + if fto_ro_code_result is not None and fto_testgen_pruned_code_str == fto_ro_pruned_code_str: + testgen_code = fto_ro_code_result + else: + testgen_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=testgen_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=all_helper_functions, + gathered_imports=src_gathered, + ) testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path)) except ValueError as e: logger.debug(f"Error while getting testgen code: {e}") @@ -404,12 +412,15 @@ def extract_all_contexts_from_files( src_gathered = gather_source_imports(original_module, file_path, project_root_path) # READ_ONLY + ro_code_result: str | None = None + ro_pruned_code_str: str | None = None try: ro_pruned = parse_code_and_prune_cst( cleaned, CodeContextType.READ_ONLY, set(), hoh_names, remove_docstrings=False ) - if ro_pruned.code.strip(): - ro_code = add_needed_imports_from_module( + ro_pruned_code_str = ro_pruned.code.strip() + if ro_pruned_code_str: + ro_code_result = add_needed_imports_from_module( src_module_code=original_module, dst_module_code=ro_pruned, src_path=file_path, @@ -418,7 +429,7 @@ def extract_all_contexts_from_files( helper_functions=helper_functions, gathered_imports=src_gathered, ) - ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path)) + ro.code_strings.append(CodeString(code=ro_code_result, file_path=relative_path)) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") @@ -433,21 +444,25 @@ def extract_all_contexts_from_files( except ValueError as e: logger.debug(f"Error while getting hashing code: {e}") - # TESTGEN + # TESTGEN -- reuse RO result when pruned code is identical (common for HoH-only files) try: testgen_pruned = parse_code_and_prune_cst( cleaned, CodeContextType.TESTGEN, set(), hoh_names, remove_docstrings=False ) - if testgen_pruned.code.strip(): - testgen_code = add_needed_imports_from_module( - src_module_code=original_module, - dst_module_code=testgen_pruned, - src_path=file_path, - dst_path=file_path, - project_root=project_root_path, - helper_functions=helper_functions, - gathered_imports=src_gathered, - ) + testgen_pruned_code_str = testgen_pruned.code.strip() + if testgen_pruned_code_str: + if ro_code_result is not None and testgen_pruned_code_str == ro_pruned_code_str: + testgen_code = ro_code_result + else: + testgen_code = add_needed_imports_from_module( + src_module_code=original_module, + dst_module_code=testgen_pruned, + src_path=file_path, + dst_path=file_path, + project_root=project_root_path, + helper_functions=helper_functions, + gathered_imports=src_gathered, + ) testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path)) except ValueError as e: logger.debug(f"Error while getting testgen code: {e}") @@ -546,33 +561,39 @@ def get_function_sources_from_jedi( project_root_path: Path, *, jedi_project: object | None = None, -) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]: + refs_cache: dict[Path, dict[str, list[Name]]] | None = None, +) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource], dict[Path, dict[str, list[Name]]]]: import jedi project = jedi_project if jedi_project is not None else get_jedi_project(str(project_root_path)) file_path_to_function_source = defaultdict(set) function_source_list: list[FunctionSource] = [] + new_refs_cache: dict[Path, dict[str, list[Name]]] = {} if refs_cache is None else dict(refs_cache) for file_path, qualified_function_names in file_path_to_qualified_function_names.items(): - script = jedi.Script(path=file_path, project=project) - file_refs = script.get_names(all_scopes=True, definitions=False, references=True) + if file_path in new_refs_cache: + refs_by_parent = new_refs_cache[file_path] + else: + script = jedi.Script(path=file_path, project=project) + file_refs = script.get_names(all_scopes=True, definitions=False, references=True) - # Pre-group references by their parent function's qualified name for O(1) lookup - refs_by_parent: dict[str, list[Name]] = defaultdict(list) - for ref in file_refs: - if not ref.full_name: - continue - try: - parent = ref.parent() - if parent is None or parent.type != "function": + # Pre-group references by their parent function's qualified name for O(1) lookup + refs_by_parent = defaultdict(list) + for ref in file_refs: + if not ref.full_name: continue - parent_qn = get_qualified_name(parent.module_name, parent.full_name) - # Exclude self-references (recursive calls) — the ref's own qualified name matches the parent - ref_qn = get_qualified_name(ref.module_name, ref.full_name) - if ref_qn == parent_qn: + try: + parent = ref.parent() + if parent is None or parent.type != "function": + continue + parent_qn = get_qualified_name(parent.module_name, parent.full_name) + # Exclude self-references (recursive calls) — the ref's own qualified name matches the parent + ref_qn = get_qualified_name(ref.module_name, ref.full_name) + if ref_qn == parent_qn: + continue + refs_by_parent[parent_qn].append(ref) + except (AttributeError, ValueError): continue - refs_by_parent[parent_qn].append(ref) - except (AttributeError, ValueError): - continue + new_refs_cache[file_path] = dict(refs_by_parent) for qualified_function_name in qualified_function_names: names = refs_by_parent.get(qualified_function_name, []) @@ -623,7 +644,7 @@ def get_function_sources_from_jedi( file_path_to_function_source[definition_path].add(function_source) function_source_list.append(function_source) - return file_path_to_function_source, function_source_list + return file_path_to_function_source, function_source_list, new_refs_cache def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.Module, dict[str, str]] | None: diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 9d937e55e..899ee438f 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -617,6 +617,39 @@ def gather_source_imports( return None +def _collect_dst_referenced_names(dst_code: str) -> tuple[set[str], bool]: + """Collect all names referenced in destination code for import pre-filtering. + + Uses ast (not libcst) for speed. Collects Name nodes and base names of Attribute chains, + plus names inside string annotations. + + Returns (names, has_imports) where has_imports indicates whether the dst has any + pre-existing import statements. + """ + try: + tree = ast.parse(dst_code) + except SyntaxError: + return set(), False + names: set[str] = set() + has_imports = False + for node in ast.walk(tree): + if isinstance(node, ast.Name): + names.add(node.id) + elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): + names.add(node.value.id) + elif isinstance(node, (ast.Import, ast.ImportFrom)): + has_imports = True + elif isinstance(node, ast.Constant) and isinstance(node.value, str): + try: + inner = ast.parse(node.value, mode="eval") + for inner_node in ast.walk(inner): + if isinstance(inner_node, ast.Name): + names.add(inner_node.id) + except SyntaxError: + pass + return names, has_imports + + def add_needed_imports_from_module( src_module_code: str | cst.Module, dst_module_code: str | cst.Module, @@ -667,13 +700,20 @@ def add_needed_imports_from_module( parsed_dst_module.visit(dotted_import_collector) + # Pre-filter: collect names referenced in destination code to avoid adding unused imports. + # This keeps the intermediate module small so RemoveImportsVisitor's scope analysis is cheap. + dst_code_str = parsed_dst_module.code if isinstance(parsed_dst_module, cst.Module) else dst_code_fallback + dst_referenced_names, dst_has_imports = _collect_dst_referenced_names(dst_code_str) + try: for mod in gatherer.module_imports: # Skip __future__ imports as they cannot be imported directly # __future__ imports should only be imported with specific objects i.e from __future__ import annotations if mod == "__future__": continue - if mod not in dotted_import_collector.imports: + # For `import foo.bar`, the bound name is `foo` + bound_name = mod.split(".")[0] + if bound_name in dst_referenced_names and mod not in dotted_import_collector.imports: AddImportsVisitor.add_needed_import(dst_context, mod) RemoveImportsVisitor.remove_unused_import(dst_context, mod) aliased_objects = set() @@ -699,13 +739,18 @@ def add_needed_imports_from_module( for symbol in resolved_symbols: if ( - f"{mod}.{symbol}" not in helper_functions_fqn + symbol in dst_referenced_names + and f"{mod}.{symbol}" not in helper_functions_fqn and f"{mod}.{symbol}" not in dotted_import_collector.imports ): AddImportsVisitor.add_needed_import(dst_context, mod, symbol) RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol) else: - if f"{mod}.{obj}" not in dotted_import_collector.imports: + # For `from foo import bar`, the bound name is `bar` + # Always include __future__ imports -- they affect parsing behavior, not naming + if ( + mod == "__future__" or obj in dst_referenced_names + ) and f"{mod}.{obj}" not in dotted_import_collector.imports: AddImportsVisitor.add_needed_import(dst_context, mod, obj) RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) except Exception as e: @@ -715,7 +760,8 @@ def add_needed_imports_from_module( for mod, asname in gatherer.module_aliases.items(): if not asname: continue - if f"{mod}.{asname}" not in dotted_import_collector.imports: + # For `import foo as bar`, the bound name is `bar` + if asname in dst_referenced_names and f"{mod}.{asname}" not in dotted_import_collector.imports: AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname) RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname) @@ -727,14 +773,22 @@ def add_needed_imports_from_module( if not alias_pair[0] or not alias_pair[1]: continue - if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports: + # For `from foo import bar as baz`, the bound name is `baz` + if ( + alias_pair[1] in dst_referenced_names + and f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports + ): AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) try: add_imports_visitor = AddImportsVisitor(dst_context) transformed_module = add_imports_visitor.transform_module(parsed_dst_module) - transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module) + # Skip RemoveImportsVisitor when the dst had no pre-existing imports. + # In that case, the only imports are those just added by AddImportsVisitor, + # which are already pre-filtered to names referenced in the dst code. + if dst_has_imports: + transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module) return transformed_module.code.lstrip("\n") except Exception as e: logger.exception(f"Error adding imports to destination module code: {e}") diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index 606292977..596073590 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -351,7 +351,7 @@ def find_helper_functions(self, function: FunctionToOptimize, project_root: Path from codeflash.languages.python.context.code_context_extractor import get_function_sources_from_jedi try: - _dict, sources = get_function_sources_from_jedi( + _dict, sources, _ = get_function_sources_from_jedi( {function.file_path: {function.qualified_name}}, project_root ) except Exception as e: diff --git a/docs/benchmark_comparison.png b/docs/benchmark_comparison.png new file mode 100644 index 000000000..0112c6986 Binary files /dev/null and b/docs/benchmark_comparison.png differ