diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 281dd2ec..00000000 --- a/mypy.ini +++ /dev/null @@ -1,3 +0,0 @@ -[mypy] -ignore_missing_imports = True -check_untyped_defs = True diff --git a/pyproject.toml b/pyproject.toml index 35b0361f..7eb8a42e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,3 +51,61 @@ features = ["dev"] [project.scripts] splat = "splat.__main__:splat_main" + + +[tool.mypy] +files = ["src"] +enable_error_code = [ + "truthy-bool", + "mutable-override", + "exhaustive-match", +] +show_column_numbers = true +show_error_codes = true +show_traceback = true +disallow_any_decorated = true +disallow_any_unimported = false # TODO: change true +ignore_missing_imports = true +local_partial_types = true +no_implicit_optional = true +#strict = true +warn_unreachable = true +check_untyped_defs = false # TODO: change true + +[tool.ruff] +#line-length = 79 +fix = true + +include = ["*.py", "*.pyi", "**/pyproject.toml"] + +[tool.ruff.lint] +extend-select = [ + #"A", # flake8-builtins + #"COM", # flake8-commas + "E", # Error + "FA", # flake8-future-annotations + #"I", # isort + "ICN", # flake8-import-conventions + "PYI", # flake8-pyi + "R", # Refactor + "RET", # flake8-return + "RUF", # Ruff-specific rules + #"SIM", # flake8-simplify + "SLOT", # flake8-slots + "TCH", # flake8-type-checking + "UP", # pyupgrade + "W", # Warning + "YTT", # flake8-2020 +] +extend-ignore = [ + "E501", # line-too-long + "PYI041", # redundant-numeric-union + "SIM117", # multiple-with-statements +] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = [ + "D100", # undocumented-public-module + "D103", # undocumented-public-function + "D107", # undocumented-public-init +] diff --git a/src/splat/__main__.py b/src/splat/__main__.py index 5959d38c..724d5604 100644 --- a/src/splat/__main__.py +++ b/src/splat/__main__.py @@ -5,7 +5,7 @@ import splat -def splat_main(): +def splat_main() -> None: parser = argparse.ArgumentParser( description="A binary splitting tool to assist with decompilation and modding projects", prog="splat", diff --git a/src/splat/disassembler/disassembler.py b/src/splat/disassembler/disassembler.py index e06cad15..7d5a1531 100644 --- a/src/splat/disassembler/disassembler.py +++ b/src/splat/disassembler/disassembler.py @@ -1,10 +1,13 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Set class Disassembler(ABC): + __slots__ = () + @abstractmethod - def configure(self): + def configure(self) -> None: raise NotImplementedError("configure") @abstractmethod @@ -12,5 +15,5 @@ def check_version(self, skip_version_check: bool, splat_version: str): raise NotImplementedError("check_version") @abstractmethod - def known_types(self) -> Set[str]: + def known_types(self) -> set[str]: raise NotImplementedError("known_types") diff --git a/src/splat/disassembler/disassembler_instance.py b/src/splat/disassembler/disassembler_instance.py index 1745a442..cb1f2009 100644 --- a/src/splat/disassembler/disassembler_instance.py +++ b/src/splat/disassembler/disassembler_instance.py @@ -27,5 +27,4 @@ def get_instance() -> Disassembler: global __initialized if not __initialized: raise Exception("Disassembler instance not initialized") - return None return __instance diff --git a/src/splat/disassembler/disassembler_section.py b/src/splat/disassembler/disassembler_section.py index 5493581d..4fc8b5af 100644 --- a/src/splat/disassembler/disassembler_section.py +++ b/src/splat/disassembler/disassembler_section.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Optional import spimdisasm @@ -7,88 +8,92 @@ class DisassemblerSection(ABC): + __slots__ = () + @abstractmethod - def disassemble(self): + def disassemble(self) -> str: raise NotImplementedError("disassemble") @abstractmethod - def analyze(self): + def analyze(self) -> None: raise NotImplementedError("analyze") @abstractmethod - def set_comment_offset(self, rom_start: int): + def set_comment_offset(self, rom_start: int) -> None: raise NotImplementedError("set_comment_offset") @abstractmethod def make_bss_section( self, - rom_start, - rom_end, - vram_start, - bss_end, - name, - segment_rom_start, - exclusive_ram_id, - ): + rom_start: int, + rom_end: int, + vram_start: int, + bss_end: int, + name: str, + segment_rom_start: int, + exclusive_ram_id: str | None, + ) -> None: raise NotImplementedError("make_bss_section") @abstractmethod def make_data_section( self, - rom_start, - rom_end, - vram_start, - name, - rom_bytes, - segment_rom_start, - exclusive_ram_id, - ): + rom_start: int, + rom_end: int, + vram_start: int, + name: str, + rom_bytes: bytes, + segment_rom_start: int, + exclusive_ram_id: str | None, + ) -> None: raise NotImplementedError("make_data_section") @abstractmethod - def get_section(self): + def get_section(self) -> spimdisasm.mips.sections.SectionBase | None: raise NotImplementedError("get_section") @abstractmethod def make_rodata_section( self, - rom_start, - rom_end, - vram_start, - name, - rom_bytes, - segment_rom_start, - exclusive_ram_id, - ): + rom_start: int, + rom_end: int, + vram_start: int, + name: str, + rom_bytes: bytes, + segment_rom_start: int, + exclusive_ram_id: str | None, + ) -> None: raise NotImplementedError("make_rodata_section") @abstractmethod def make_text_section( self, - rom_start, - rom_end, - vram_start, - name, - rom_bytes, - segment_rom_start, - exclusive_ram_id, - ): + rom_start: int, + rom_end: int, + vram_start: int, + name: str, + rom_bytes: bytes, + segment_rom_start: int, + exclusive_ram_id: str | None, + ) -> None: raise NotImplementedError("make_text_section") class SpimdisasmDisassemberSection(DisassemblerSection): - def __init__(self): - self.spim_section: Optional[spimdisasm.mips.sections.SectionBase] = None + __slots__ = ("spim_section",) + + def __init__(self) -> None: + self.spim_section: spimdisasm.mips.sections.SectionBase | None = None def disassemble(self) -> str: assert self.spim_section is not None return self.spim_section.disassemble() - def analyze(self): + def analyze(self) -> None: assert self.spim_section is not None self.spim_section.analyze() - def set_comment_offset(self, rom_start: int): + def set_comment_offset(self, rom_start: int) -> None: assert self.spim_section is not None self.spim_section.setCommentOffset(rom_start) @@ -100,8 +105,8 @@ def make_bss_section( bss_end: int, name: str, segment_rom_start: int, - exclusive_ram_id, - ): + exclusive_ram_id: str | None, + ) -> None: self.spim_section = spimdisasm.mips.sections.SectionBss( symbols.spim_context, rom_start, @@ -121,8 +126,8 @@ def make_data_section( name: str, rom_bytes: bytes, segment_rom_start: int, - exclusive_ram_id, - ): + exclusive_ram_id: str | None, + ) -> None: self.spim_section = spimdisasm.mips.sections.SectionData( symbols.spim_context, rom_start, @@ -134,7 +139,7 @@ def make_data_section( exclusive_ram_id, ) - def get_section(self) -> Optional[spimdisasm.mips.sections.SectionBase]: + def get_section(self) -> spimdisasm.mips.sections.SectionBase | None: return self.spim_section def make_rodata_section( @@ -145,8 +150,8 @@ def make_rodata_section( name: str, rom_bytes: bytes, segment_rom_start: int, - exclusive_ram_id, - ): + exclusive_ram_id: str | None, + ) -> None: self.spim_section = spimdisasm.mips.sections.SectionRodata( symbols.spim_context, rom_start, @@ -166,8 +171,8 @@ def make_text_section( name: str, rom_bytes: bytes, segment_rom_start: int, - exclusive_ram_id, - ): + exclusive_ram_id: str | None, + ) -> None: self.spim_section = spimdisasm.mips.sections.SectionText( symbols.spim_context, rom_start, @@ -187,8 +192,8 @@ def make_gcc_except_table_section( name: str, rom_bytes: bytes, segment_rom_start: int, - exclusive_ram_id, - ): + exclusive_ram_id: str | None, + ) -> None: self.spim_section = spimdisasm.mips.sections.SectionGccExceptTable( symbols.spim_context, rom_start, @@ -201,12 +206,11 @@ def make_gcc_except_table_section( ) -def make_disassembler_section() -> Optional[SpimdisasmDisassemberSection]: +def make_disassembler_section() -> SpimdisasmDisassemberSection | None: if options.opts.platform in ["n64", "psx", "ps2", "psp"]: return SpimdisasmDisassemberSection() raise NotImplementedError("No disassembler section for requested platform") - return None def make_text_section( @@ -216,7 +220,7 @@ def make_text_section( name: str, rom_bytes: bytes, segment_rom_start: int, - exclusive_ram_id, + exclusive_ram_id: str | None, ) -> DisassemblerSection: section = make_disassembler_section() assert section is not None @@ -239,7 +243,7 @@ def make_data_section( name: str, rom_bytes: bytes, segment_rom_start: int, - exclusive_ram_id, + exclusive_ram_id: str | None, ) -> DisassemblerSection: section = make_disassembler_section() assert section is not None @@ -262,7 +266,7 @@ def make_rodata_section( name: str, rom_bytes: bytes, segment_rom_start: int, - exclusive_ram_id, + exclusive_ram_id: str | None, ) -> DisassemblerSection: section = make_disassembler_section() assert section is not None @@ -285,7 +289,7 @@ def make_bss_section( bss_end: int, name: str, segment_rom_start: int, - exclusive_ram_id, + exclusive_ram_id: str | None, ) -> DisassemblerSection: section = make_disassembler_section() assert section is not None @@ -308,7 +312,7 @@ def make_gcc_except_table_section( name: str, rom_bytes: bytes, segment_rom_start: int, - exclusive_ram_id, + exclusive_ram_id: str | None, ) -> DisassemblerSection: section = make_disassembler_section() assert section is not None diff --git a/src/splat/disassembler/null_disassembler.py b/src/splat/disassembler/null_disassembler.py index 5436e3af..7fdbb7f5 100644 --- a/src/splat/disassembler/null_disassembler.py +++ b/src/splat/disassembler/null_disassembler.py @@ -1,13 +1,16 @@ +from __future__ import annotations + from . import disassembler -from typing import Set class NullDisassembler(disassembler.Disassembler): - def configure(self): + __slots__ = () + + def configure(self) -> None: pass def check_version(self, skip_version_check: bool, splat_version: str): pass - def known_types(self) -> Set[str]: + def known_types(self) -> set[str]: return set() diff --git a/src/splat/disassembler/spimdisasm_disassembler.py b/src/splat/disassembler/spimdisasm_disassembler.py index cf2406a3..966965bf 100644 --- a/src/splat/disassembler/spimdisasm_disassembler.py +++ b/src/splat/disassembler/spimdisasm_disassembler.py @@ -2,14 +2,13 @@ import spimdisasm import rabbitizer from ..util import log, compiler, options -from typing import Set class SpimdisasmDisassembler(disassembler.Disassembler): # This value should be kept in sync with the version listed on requirements.txt and pyproject.toml SPIMDISASM_MIN = (1, 40, 0) - def configure(self): + def configure(self) -> None: # Configure spimdisasm spimdisasm.common.GlobalConfig.PRODUCE_SYMBOLS_PLUS_OFFSET = True spimdisasm.common.GlobalConfig.TRUST_USER_FUNCTIONS = True @@ -137,5 +136,5 @@ def check_version(self, skip_version_check: bool, splat_version: str): f"splat {splat_version} (powered by spimdisasm {spimdisasm.__version__})" ) - def known_types(self) -> Set[str]: + def known_types(self) -> set[str]: return spimdisasm.common.gKnownTypes diff --git a/src/splat/scripts/capy.py b/src/splat/scripts/capy.py index 8bd48a95..a821bfd6 100644 --- a/src/splat/scripts/capy.py +++ b/src/splat/scripts/capy.py @@ -27,18 +27,18 @@ """ -def print_capybara(): +def print_capybara() -> None: print(capybara) -def process_arguments(args: argparse.Namespace): +def process_arguments(args: argparse.Namespace) -> None: print_capybara() script_description = "Capybara" -def add_subparser(subparser: argparse._SubParsersAction): +def add_subparser(subparser: argparse._SubParsersAction) -> None: parser = subparser.add_parser( "capy", help=script_description, description=script_description ) diff --git a/src/splat/scripts/create_config.py b/src/splat/scripts/create_config.py index eed507cd..809b51ad 100644 --- a/src/splat/scripts/create_config.py +++ b/src/splat/scripts/create_config.py @@ -1,11 +1,11 @@ #! /usr/bin/env python3 +from __future__ import annotations import argparse import hashlib from pathlib import Path import subprocess import sys -from typing import Optional from ..util.n64 import find_code_length, rominfo from ..util.psx import psxexeinfo @@ -13,7 +13,7 @@ from ..util import log, file_presets, conf -def main(file_path: Path, objcopy: Optional[str]): +def main(file_path: Path, objcopy: str | None): if not file_path.exists(): sys.exit(f"File {file_path} does not exist ({file_path.absolute()})") if file_path.is_dir(): @@ -195,7 +195,7 @@ def create_n64_config(rom_path: Path): # Write reloc_addrs.txt file reloc_addrs: list[str] = [] - addresses_info: list[tuple[Optional[rominfo.EntryAddressInfo], str]] = [ + addresses_info: list[tuple[rominfo.EntryAddressInfo | None, str]] = [ (rom.entrypoint_info.main_address, "main"), (rom.entrypoint_info.bss_start_address, "main_BSS_START"), (rom.entrypoint_info.bss_size, "main_BSS_SIZE"), @@ -374,7 +374,7 @@ def create_psx_config(exe_path: Path, exe_bytes: bytes): file_presets.write_all_files() -def do_elf(elf_path: Path, elf_bytes: bytes, objcopy: Optional[str]): +def do_elf(elf_path: Path, elf_bytes: bytes, objcopy: str | None): elf = ps2elfinfo.Ps2Elf.get_info(elf_path, elf_bytes) if elf is None: log.error(f"Unsupported elf file '{elf_path}'") @@ -531,7 +531,7 @@ def do_elf(elf_path: Path, elf_bytes: bytes, objcopy: Optional[str]): print("```") -def find_objcopy() -> str: +def find_objcopy() -> str: # noqa: RET503 # First we try to figure out if the user has objcopy on their pc, and under # which name. # We just try a bunch and hope for the best diff --git a/src/splat/scripts/split.py b/src/splat/scripts/split.py index a4fc643b..4e63d556 100644 --- a/src/splat/scripts/split.py +++ b/src/splat/scripts/split.py @@ -1,9 +1,11 @@ #! /usr/bin/env python3 +from __future__ import annotations + import argparse import hashlib import importlib -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, TYPE_CHECKING from pathlib import Path from collections import defaultdict, deque @@ -24,22 +26,25 @@ from ..segtypes.common.group import CommonSegGroup from ..util import conf, log, options, palettes, symbols, relocs +if TYPE_CHECKING: + from types import ModuleType + linker_writer: LinkerWriter -config: Dict[str, Any] +config: dict[str, Any] segment_roms: IntervalTree = IntervalTree() segment_rams: IntervalTree = IntervalTree() -def initialize_segments(config_segments: Union[dict, list]) -> List[Segment]: +def initialize_segments(config_segments: dict | list) -> list[Segment]: global segment_roms global segment_rams segment_roms = IntervalTree() segment_rams = IntervalTree() - segments_by_name: Dict[str, Segment] = {} - ret: List[Segment] = [] + segments_by_name: dict[str, Segment] = {} + ret: list[Segment] = [] # Cross segment pairing can be quite expensive, so we try to avoid it if the user haven't requested it. do_cross_segment_pairing = False @@ -55,11 +60,11 @@ def initialize_segments(config_segments: Union[dict, list]) -> List[Segment]: segment_class = Segment.get_class_for_type(seg_type) - this_start, is_auto_segment = Segment.parse_segment_start(seg_yaml) + this_start, _is_auto_segment = Segment.parse_segment_start(seg_yaml) j = i + 1 while j < len(config_segments): - next_start, next_is_auto_segment = Segment.parse_segment_start( + next_start, _next_is_auto_segment = Segment.parse_segment_start( config_segments[j] ) if next_start is not None: @@ -77,7 +82,11 @@ def initialize_segments(config_segments: Union[dict, list]) -> List[Segment]: next_start = last_rom_end segment: Segment = Segment.from_yaml( - segment_class, seg_yaml, this_start, next_start, None + segment_class, + seg_yaml, + this_start, + next_start, + None, ) if segment.require_unique_name: @@ -168,13 +177,13 @@ def initialize_segments(config_segments: Union[dict, list]) -> List[Segment]: return ret -def assign_symbols_to_segments(): +def assign_symbols_to_segments() -> None: for symbol in symbols.all_symbols: if symbol.segment: continue if symbol.rom: - cands: Set[Interval] = segment_roms[symbol.rom] + cands: set[Interval] = segment_roms[symbol.rom] if len(cands) > 1: log.error("multiple segments rom overlap symbol", symbol) elif len(cands) == 0: @@ -185,13 +194,13 @@ def assign_symbols_to_segments(): seg.add_symbol(symbol) else: cands = segment_rams[symbol.vram_start] - segs: List[Segment] = [cand.data for cand in cands] + segs: list[Segment] = [cand.data for cand in cands] for seg in segs: if not seg.get_exclusive_ram_id(): seg.add_symbol(symbol) -def brief_seg_name(seg: Segment, limit: int, ellipsis="…") -> str: +def brief_seg_name(seg: Segment, limit: int, ellipsis: str = "…") -> str: s = seg.name.strip() if len(s) > limit: return s[:limit].strip() + ellipsis @@ -200,10 +209,10 @@ def brief_seg_name(seg: Segment, limit: int, ellipsis="…") -> str: # Return a mapping of vram classes to segments that need to be part of their vram symbol's calculation def calc_segment_dependences( - all_segments: List[Segment], -) -> Dict[vram_classes.VramClass, List[Segment]]: + all_segments: list[Segment], +) -> dict[vram_classes.VramClass, list[Segment]]: # Map vram class names to segments that have that vram class - vram_class_to_segments: Dict[str, List[Segment]] = {} + vram_class_to_segments: dict[str, list[Segment]] = {} for seg in all_segments: if seg.vram_class is not None: if seg.vram_class.name not in vram_class_to_segments: @@ -211,7 +220,7 @@ def calc_segment_dependences( vram_class_to_segments[seg.vram_class.name].append(seg) # Map vram class names to segments that the vram class follows - vram_class_to_follows_segments: Dict[vram_classes.VramClass, List[Segment]] = {} + vram_class_to_follows_segments: dict[vram_classes.VramClass, list[Segment]] = {} for vram_class in vram_classes._vram_classes.values(): if vram_class.follows_classes: vram_class_to_follows_segments[vram_class] = [] @@ -225,16 +234,16 @@ def calc_segment_dependences( def sort_segments_by_vram_class_dependency( - all_segments: List[Segment], -) -> List[Segment]: + all_segments: list[Segment], +) -> list[Segment]: # map all "_VRAM_END" strings to segments - end_sym_to_seg: Dict[str, Segment] = {} + end_sym_to_seg: dict[str, Segment] = {} for seg in all_segments: end_sym_to_seg[get_segment_vram_end_symbol_name(seg)] = seg # build dependency graph: A -> B means "A must come before B" - graph: Dict[Segment, List[Segment]] = defaultdict(list) - indeg: Dict[Segment, int] = {seg: 0 for seg in all_segments} + graph: dict[Segment, list[Segment]] = defaultdict(list) + indeg: dict[Segment, int] = {seg: 0 for seg in all_segments} for seg in all_segments: sym = seg.vram_symbol @@ -248,7 +257,7 @@ def sort_segments_by_vram_class_dependency( # stable topo sort with queue seeded in original order q = deque([seg for seg in all_segments if indeg[seg] == 0]) - out: List[Segment] = [] + out: list[Segment] = [] while q: n = q.popleft() @@ -279,7 +288,7 @@ def read_target_binary() -> bytes: return rom_bytes -def initialize_platform(rom_bytes: bytes): +def initialize_platform(rom_bytes: bytes) -> ModuleType: platform_module = importlib.import_module( f"{__package_name__}.platforms.{options.opts.platform}" ) @@ -289,7 +298,7 @@ def initialize_platform(rom_bytes: bytes): return platform_module -def initialize_all_symbols(all_segments: List[Segment]): +def initialize_all_symbols(all_segments: list[Segment]) -> None: # Load and process symbols symbols.initialize(all_segments) relocs.initialize() @@ -303,12 +312,12 @@ def initialize_all_symbols(all_segments: List[Segment]): def do_scan( - all_segments: List[Segment], + all_segments: list[Segment], rom_bytes: bytes, stats: statistics.Statistics, cache: cache_handler.Cache, -): - processed_segments: List[Segment] = [] +) -> list[Segment]: + processed_segments: list[Segment] = [] scan_bar = progress_bar.get_progress_bar(all_segments) for segment in scan_bar: @@ -336,11 +345,11 @@ def do_scan( def do_split( - all_segments: List[Segment], + all_segments: list[Segment], rom_bytes: bytes, stats: statistics.Statistics, cache: cache_handler.Cache, -): +) -> None: split_bar = progress_bar.get_progress_bar(all_segments) for segment in split_bar: assert isinstance(segment, Segment) @@ -359,14 +368,14 @@ def do_split( segment.split(segment_bytes) -def write_linker_script(all_segments: List[Segment]) -> LinkerWriter: +def write_linker_script(all_segments: list[Segment]) -> LinkerWriter: if options.opts.ld_sort_segments_by_vram_class_dependency: all_segments = sort_segments_by_vram_class_dependency(all_segments) vram_class_dependencies = calc_segment_dependences(all_segments) vram_classes_to_search = set(vram_class_dependencies.keys()) - max_vram_end_insertion_points: Dict[Segment, List[Tuple[str, List[Segment]]]] = {} + max_vram_end_insertion_points: dict[Segment, list[tuple[str, list[Segment]]]] = {} for seg in reversed(all_segments): if seg.vram_class in vram_classes_to_search: assert seg.vram_class.vram_symbol is not None @@ -441,7 +450,7 @@ def write_linker_script(all_segments: List[Segment]) -> LinkerWriter: return linker_writer -def write_ld_dependencies(linker_writer: LinkerWriter): +def write_ld_dependencies(linker_writer: LinkerWriter) -> None: if options.opts.ld_dependencies: elf_path = options.opts.elf_path if elf_path is None: @@ -453,7 +462,7 @@ def write_ld_dependencies(linker_writer: LinkerWriter): ) -def write_elf_sections_file(all_segments: List[Segment]): +def write_elf_sections_file(all_segments: list[Segment]) -> None: # write elf_sections.txt - this only lists the generated sections in the elf, not subsections # that the elf combines into one section if options.opts.elf_section_list_path: @@ -465,40 +474,44 @@ def write_elf_sections_file(all_segments: List[Segment]): f.write(section_list) -def write_undefined_auto(to_write: List[symbols.Symbol], file_path: Path): +def write_undefined_auto(to_write: list[symbols.Symbol], file_path: Path) -> None: file_path.parent.mkdir(parents=True, exist_ok=True) with file_path.open("w", newline="\n") as f: for symbol in to_write: f.write(f"{symbol.name} = 0x{symbol.vram_start:X};\n") -def write_undefined_funcs_auto(): - if options.opts.create_undefined_funcs_auto: - to_write = [ - s - for s in symbols.all_symbols - if s.referenced and not s.defined and s.type == "func" - ] - to_write.sort(key=lambda x: x.vram_start) +def write_undefined_funcs_auto() -> None: + if not options.opts.create_undefined_funcs_auto: + return + + to_write = [ + s + for s in symbols.all_symbols + if s.referenced and not s.defined and s.type == "func" + ] + to_write.sort(key=lambda x: x.vram_start) + + write_undefined_auto(to_write, options.opts.undefined_funcs_auto_path) - write_undefined_auto(to_write, options.opts.undefined_funcs_auto_path) +def write_undefined_syms_auto() -> None: + if not options.opts.create_undefined_syms_auto: + return -def write_undefined_syms_auto(): - if options.opts.create_undefined_syms_auto: - to_write = [ - s - for s in symbols.all_symbols - if s.referenced - and not s.defined - and s.type not in {"func", "label", "jtbl_label"} - ] - to_write.sort(key=lambda x: x.vram_start) + to_write = [ + s + for s in symbols.all_symbols + if s.referenced + and not s.defined + and s.type not in {"func", "label", "jtbl_label"} + ] + to_write.sort(key=lambda x: x.vram_start) - write_undefined_auto(to_write, options.opts.undefined_syms_auto_path) + write_undefined_auto(to_write, options.opts.undefined_syms_auto_path) -def print_segment_warnings(all_segments: List[Segment]): +def print_segment_warnings(all_segments: list[Segment]) -> None: for segment in all_segments: if len(segment.warnings) > 0: log.write( @@ -539,15 +552,15 @@ def dump_symbols() -> None: def main( - config_path: List[Path], - modes: Optional[List[str]], + config_path: list[Path], + modes: list[str] | None, verbose: bool, use_cache: bool = True, skip_version_check: bool = False, stdout_only: bool = False, disassemble_all: bool = False, - make_full_disasm_for_code=False, -): + make_full_disasm_for_code: Any = False, +) -> None: if stdout_only: log.write("--stdout-only flag is deprecated", status="warn") progress_bar.out_file = sys.stdout @@ -620,7 +633,7 @@ def main( file_presets.write_all_files() -def add_arguments_to_parser(parser: argparse.ArgumentParser): +def add_arguments_to_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument( "config", help="path to a compatible config .yaml file", @@ -654,7 +667,7 @@ def add_arguments_to_parser(parser: argparse.ArgumentParser): ) -def process_arguments(args: argparse.Namespace): +def process_arguments(args: argparse.Namespace) -> None: main( args.config, args.modes, @@ -670,7 +683,7 @@ def process_arguments(args: argparse.Namespace): script_description = "Split a rom given a rom, a config, and output directory" -def add_subparser(subparser: argparse._SubParsersAction): +def add_subparser(subparser: argparse._SubParsersAction) -> None: parser = subparser.add_parser( "split", help=script_description, description=script_description ) diff --git a/src/splat/segtypes/common/asm.py b/src/splat/segtypes/common/asm.py index c1e5f9cc..3d784797 100644 --- a/src/splat/segtypes/common/asm.py +++ b/src/splat/segtypes/common/asm.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations from .codesubsegment import CommonSegCodeSubsegment @@ -9,7 +9,7 @@ class CommonSegAsm(CommonSegCodeSubsegment): def is_text() -> bool: return True - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "ax" def scan(self, rom_bytes: bytes): diff --git a/src/splat/segtypes/common/bin.py b/src/splat/segtypes/common/bin.py index 7c4b80d9..e363e577 100644 --- a/src/splat/segtypes/common/bin.py +++ b/src/splat/segtypes/common/bin.py @@ -1,10 +1,13 @@ -from pathlib import Path -from typing import Optional +from __future__ import annotations +from typing import TYPE_CHECKING from ...util import log, options from .segment import CommonSegment -from ..segment import SegmentType + +if TYPE_CHECKING: + from ..segment import SegmentType + from pathlib import Path class CommonSegBin(CommonSegment): @@ -12,7 +15,7 @@ class CommonSegBin(CommonSegment): def is_data() -> bool: return True - def out_path(self) -> Optional[Path]: + def out_path(self) -> Path | None: return options.opts.asset_path / self.dir / f"{self.name}.bin" def split(self, rom_bytes): diff --git a/src/splat/segtypes/common/bss.py b/src/splat/segtypes/common/bss.py index 5a9adbc9..9aae73b3 100644 --- a/src/splat/segtypes/common/bss.py +++ b/src/splat/segtypes/common/bss.py @@ -1,8 +1,9 @@ -from typing import Optional +from __future__ import annotations from ...util import options, symbols, log from .data import CommonSegData +from .group import CommonSegGroup from ...disassembler.disassembler_section import DisassemblerSection, make_bss_section @@ -13,7 +14,7 @@ class CommonSegBss(CommonSegData): def get_linker_section(self) -> str: return ".bss" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "wa" @staticmethod @@ -35,7 +36,7 @@ def configure_disassembler_section( pass - def disassemble_data(self, rom_bytes: bytes): + def disassemble_data(self, rom_bytes: bytes) -> None: if not options.opts.ld_bss_is_noload: super().disassemble_data(rom_bytes) return @@ -60,6 +61,7 @@ def disassemble_data(self, rom_bytes: bytes): f"Segment '{self.name}' (type '{self.type}') requires a vram address. Got '{self.vram_start}'" ) + assert isinstance(self.parent, CommonSegGroup) next_subsegment = self.parent.get_next_subsegment_for_ram( self.vram_start, self.index_within_group ) @@ -86,7 +88,10 @@ def disassemble_data(self, rom_bytes: bytes): self.spim_section.analyze() self.spim_section.set_comment_offset(self.rom_start) - for spim_sym in self.spim_section.get_section().symbolList: + section = self.spim_section.get_section() + assert section is not None + + for spim_sym in section.symbolList: symbols.create_symbol_from_spim_symbol( self.get_most_parent(), spim_sym.contextSym, force_in_segment=True ) diff --git a/src/splat/segtypes/common/c.py b/src/splat/segtypes/common/c.py index 77990eb4..bbc9a806 100644 --- a/src/splat/segtypes/common/c.py +++ b/src/splat/segtypes/common/c.py @@ -1,18 +1,23 @@ +from __future__ import annotations + import os import re from pathlib import Path -from typing import Optional, Set, List +from typing import TYPE_CHECKING import rabbitizer import spimdisasm from ...util import log, options, symbols from ...util.compiler import IDO -from ...util.symbols import Symbol from .codesubsegment import CommonSegCodeSubsegment from .rodata import CommonSegRodata +if TYPE_CHECKING: + from ...util.symbols import Symbol + from collections.abc import Generator + STRIP_C_COMMENTS_RE = re.compile( r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', @@ -27,37 +32,36 @@ class CommonSegC(CommonSegCodeSubsegment): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.defined_funcs: Set[str] = set() - self.global_asm_funcs: Set[str] = set() - self.global_asm_rodata_syms: Set[str] = set() + self.defined_funcs: set[str] = set() + self.global_asm_funcs: set[str] = set() + self.global_asm_rodata_syms: set[str] = set() self.file_extension = "c" self.use_gp_rel_macro = options.opts.use_gp_rel_macro_nonmatching @staticmethod - def strip_c_comments(text): - def replacer(match): + def strip_c_comments(text: str) -> str: + def replacer(match: re.Match[str]) -> str: s = match.group(0) if s.startswith("/"): return " " - else: - return s + return s return re.sub(STRIP_C_COMMENTS_RE, replacer, text) @staticmethod - def get_funcs_defined_in_c(c_file: Path) -> Set[str]: - with open(c_file, "r", encoding="utf-8") as f: + def get_funcs_defined_in_c(c_file: Path) -> set[str]: + with open(c_file, encoding="utf-8") as f: text = CommonSegC.strip_c_comments(f.read()) return set(m.group(1) for m in C_FUNC_RE.finditer(text)) @staticmethod - def find_all_instances(string: str, sub: str): + def find_all_instances(string: str, sub: str) -> Generator[int, None, None]: start = 0 while True: start = string.find(sub, start) @@ -67,7 +71,7 @@ def find_all_instances(string: str, sub: str): start += len(sub) @staticmethod - def get_close_parenthesis(string: str, pos: int): + def get_close_parenthesis(string: str, pos: int) -> int: paren_count = 0 while True: cur_char = string[pos] @@ -76,14 +80,15 @@ def get_close_parenthesis(string: str, pos: int): elif cur_char == ")": if paren_count == 0: return pos + 1 - else: - paren_count -= 1 + paren_count -= 1 pos += 1 - @staticmethod - def find_include_macro(text: str, macro_name: str): - for pos in CommonSegC.find_all_instances(text, f"{macro_name}("): - close_paren_pos = CommonSegC.get_close_parenthesis( + @classmethod + def find_include_macro( + cls, text: str, macro_name: str + ) -> Generator[str, None, None]: + for pos in cls.find_all_instances(text, f"{macro_name}("): + close_paren_pos = cls.get_close_parenthesis( text, pos + len(f"{macro_name}(") ) macro_contents = text[pos:close_paren_pos] @@ -95,43 +100,41 @@ def find_include_macro(text: str, macro_name: str): if len(macro_args) >= 2: yield macro_args[1].strip(" )") - @staticmethod - def find_include_asm(text: str): - return CommonSegC.find_include_macro(text, "INCLUDE_ASM") + @classmethod + def find_include_asm(cls, text: str) -> Generator[str, None, None]: + return cls.find_include_macro(text, "INCLUDE_ASM") - @staticmethod - def find_include_rodata(text: str): - return CommonSegC.find_include_macro(text, "INCLUDE_RODATA") + @classmethod + def find_include_rodata(cls, text: str) -> Generator[str, None, None]: + return cls.find_include_macro(text, "INCLUDE_RODATA") - @staticmethod - def get_global_asm_funcs(c_file: Path) -> Set[str]: + @classmethod + def get_global_asm_funcs(cls, c_file: Path) -> set[str]: with c_file.open(encoding="utf-8") as f: - text = CommonSegC.strip_c_comments(f.read()) + text = cls.strip_c_comments(f.read()) if options.opts.compiler == IDO: return set(m.group(2) for m in C_GLOBAL_ASM_IDO_RE.finditer(text)) - else: - return set(CommonSegC.find_include_asm(text)) + return set(cls.find_include_asm(text)) - @staticmethod - def get_global_asm_rodata_syms(c_file: Path) -> Set[str]: + @classmethod + def get_global_asm_rodata_syms(cls, c_file: Path) -> set[str]: with c_file.open(encoding="utf-8") as f: - text = CommonSegC.strip_c_comments(f.read()) + text = cls.strip_c_comments(f.read()) if options.opts.compiler == IDO: return set(m.group(2) for m in C_GLOBAL_ASM_IDO_RE.finditer(text)) - else: - return set(CommonSegC.find_include_rodata(text)) + return set(cls.find_include_rodata(text)) @staticmethod def is_text() -> bool: return True - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "ax" - def out_path(self) -> Optional[Path]: + def out_path(self) -> Path | None: return options.opts.src_path / self.dir / f"{self.name}.{self.file_extension}" - def scan(self, rom_bytes: bytes): + def scan(self, rom_bytes: bytes) -> None: if ( self.rom_start is not None and self.rom_end is not None @@ -150,7 +153,7 @@ def scan(self, rom_bytes: bytes): self.scan_code(rom_bytes) - def split(self, rom_bytes: bytes): + def split(self, rom_bytes: bytes) -> None: if self.is_auto_segment: if options.opts.make_full_disasm_for_code: self.split_as_asmtu_file(self.asm_out_path()) @@ -162,13 +165,22 @@ def split(self, rom_bytes: bytes): self.print_file_boundaries() - assert self.spim_section is not None and isinstance( - self.spim_section.get_section(), spimdisasm.mips.sections.SectionText - ), f"{self.name}, rom_start:{self.rom_start}, rom_end:{self.rom_end}" + assert self.spim_section is not None, ( + f"{self.name}, rom_start:{self.rom_start}, rom_end:{self.rom_end}" + ) + spim_section = self.spim_section.get_section() + + assert isinstance(spim_section, spimdisasm.mips.sections.SectionText), ( + f"{self.name}, rom_start:{self.rom_start}, rom_end:{self.rom_end}" + ) # We want to know if this C section has a corresponding rodata section so we can migrate its rodata rodata_section_type = "" - rodata_spim_segment: Optional[spimdisasm.mips.sections.SectionRodata] = None + rodata_spim_segment: ( + spimdisasm.mips.sections.SectionRodata + | spimdisasm.mips.sections.SectionBase + | None + ) = None if options.opts.migrate_rodata_to_functions: # We don't know if the rodata section is .rodata or .rdata, so we need to check both for sect in [".rodata", ".rdata"]: @@ -204,19 +216,23 @@ def split(self, rom_bytes: bytes): ) assert rodata_sibling.spim_section is not None, f"{rodata_sibling}" + rodata_spim_segment = rodata_sibling.spim_section.get_section() assert isinstance( - rodata_sibling.spim_section.get_section(), + rodata_spim_segment, spimdisasm.mips.sections.SectionRodata, ) - rodata_spim_segment = rodata_sibling.spim_section.get_section() # Stop searching break + assert rodata_spim_segment is None or isinstance( + rodata_spim_segment, spimdisasm.mips.sections.SectionRodata + ), rodata_spim_segment + # Precompute function-rodata pairings symbols_entries = ( spimdisasm.mips.FunctionRodataEntry.getAllEntriesFromSections( - self.spim_section.get_section(), rodata_spim_segment + spim_section, rodata_spim_segment ) ) @@ -281,16 +297,16 @@ def split(self, rom_bytes: bytes): ) if options.opts.make_full_disasm_for_code: + # TODO: Figure out why mypy thinks these attributes don't exist # Disable gpRelHack since this file is expected to be built with modern gas - section = self.spim_section.get_section() - old_value = section.getGpRelHack() - section.setGpRelHack(False) + old_value = spim_section.getGpRelHack() # type: ignore[attr-defined] + spim_section.setGpRelHack(False) # type: ignore[attr-defined] if options.opts.platform == "ps2": # Modern gas requires `$` on the special r5900 registers. from rabbitizer import TrinaryValue - for func in section.symbolList: + for func in spim_section.symbolList: assert isinstance(func, spimdisasm.mips.symbols.SymbolFunction) for inst in func.instructions: inst.flag_r5900UseDollar = TrinaryValue.TRUE @@ -298,14 +314,15 @@ def split(self, rom_bytes: bytes): self.split_as_asmtu_file(self.asm_out_path()) if options.opts.platform == "ps2": - for func in section.symbolList: + for func in spim_section.symbolList: assert isinstance(func, spimdisasm.mips.symbols.SymbolFunction) for inst in func.instructions: inst.flag_r5900UseDollar = TrinaryValue.FALSE - section.setGpRelHack(old_value) + # See comment above + spim_section.setGpRelHack(old_value) # type: ignore[attr-defined] - def get_c_preamble(self): + def get_c_preamble(self) -> list[str]: ret = [] preamble = options.opts.generated_c_preamble @@ -317,8 +334,8 @@ def get_c_preamble(self): def check_gaps_in_migrated_rodata( self, func: spimdisasm.mips.symbols.SymbolFunction, - rodata_list: List[spimdisasm.mips.symbols.SymbolBase], - ): + rodata_list: list[spimdisasm.mips.symbols.SymbolBase], + ) -> None: for index in range(len(rodata_list) - 1): rodata_sym = rodata_list[index] next_rodata_sym = rodata_list[index + 1] @@ -342,7 +359,7 @@ def create_c_asm_file( func_rodata_entry: spimdisasm.mips.FunctionRodataEntry, out_dir: Path, func_sym: Symbol, - ): + ) -> None: outpath = out_dir / self.name / f"{func_sym.filename}.s" # Skip extraction if the file exists and the symbol is marked as extract=false @@ -380,7 +397,7 @@ def create_unmigrated_rodata_file( spim_rodata_sym: spimdisasm.mips.symbols.SymbolBase, out_dir: Path, rodata_sym: Symbol, - ): + ) -> None: outpath = out_dir / self.name / f"{rodata_sym.filename}.s" # Skip extraction if the file exists and the symbol is marked as extract=false @@ -389,7 +406,7 @@ def create_unmigrated_rodata_file( outpath.parent.mkdir(parents=True, exist_ok=True) - with outpath.open("w", newline="\n") as f: + with outpath.open("w", encoding="utf-8", newline="\n") as f: preamble = options.opts.generated_s_preamble if preamble: f.write(preamble + "\n") @@ -425,7 +442,7 @@ def get_c_lines_for_function( sym: Symbol, spim_sym: spimdisasm.mips.symbols.SymbolFunction, asm_out_dir: Path, - ) -> List[str]: + ) -> list[str]: c_lines = [] # Terrible hack to "auto-decompile" empty functions @@ -444,7 +461,7 @@ def get_c_lines_for_function( c_lines.append("") return c_lines - def get_c_lines_for_rodata_sym(self, sym: Symbol, asm_out_dir: Path): + def get_c_lines_for_rodata_sym(self, sym: Symbol, asm_out_dir: Path) -> list[str]: c_lines = [self.get_c_line_include_macro(sym, asm_out_dir, "INCLUDE_RODATA")] c_lines.append("") return c_lines @@ -453,8 +470,8 @@ def create_c_file( self, asm_out_dir: Path, c_path: Path, - symbols_entries: List[spimdisasm.mips.FunctionRodataEntry], - ): + symbols_entries: list[spimdisasm.mips.FunctionRodataEntry], + ) -> None: c_lines = self.get_c_preamble() for entry in symbols_entries: @@ -489,8 +506,8 @@ def create_asm_dependencies_file( c_path: Path, asm_out_dir: Path, is_new_c_file: bool, - symbols_entries: List[spimdisasm.mips.FunctionRodataEntry], - ): + symbols_entries: list[spimdisasm.mips.FunctionRodataEntry], + ) -> None: if not options.opts.create_asm_dependencies: return if ( diff --git a/src/splat/segtypes/common/code.py b/src/splat/segtypes/common/code.py index 050c980f..08525db4 100644 --- a/src/splat/segtypes/common/code.py +++ b/src/splat/segtypes/common/code.py @@ -1,28 +1,33 @@ +from __future__ import annotations + from collections import OrderedDict -from typing import List, Optional, Type, Tuple +from typing import TYPE_CHECKING, cast from ...util import log, options, utils from .group import CommonSegGroup from ..segment import Segment, parse_segment_align +if TYPE_CHECKING: + from ...util.vram_classes import SerializedSegmentData + -def dotless_type(type: str) -> str: - return type[1:] if type[0] == "." else type +def dotless_type(type_: str) -> str: + return type_[1:] if type_[0] == "." else type_ # code group class CommonSegCode(CommonSegGroup): def __init__( self, - rom_start: Optional[int], - rom_end: Optional[int], + rom_start: int | None, + rom_end: int | None, type: str, name: str, - vram_start: Optional[int], - args: list, - yaml, - ): + vram_start: int | None, + args: list[str], + yaml: SerializedSegmentData | list[str], + ) -> None: self.bss_size: int = yaml.get("bss_size", 0) if isinstance(yaml, dict) else 0 super().__init__( @@ -46,22 +51,21 @@ def needs_symbols(self) -> bool: return True @property - def vram_end(self) -> Optional[int]: + def vram_end(self) -> int | None: if self.vram_start is not None and self.size is not None: return self.vram_start + self.size + self.bss_size - else: - return None + return None # Generates a placeholder segment for the auto_link_sections option def _generate_segment_from_all( self, rep_type: str, - replace_class: Type[Segment], + replace_class: type[Segment], base_name: str, base_seg: Segment, - rom_start: Optional[int] = None, - rom_end: Optional[int] = None, - vram_start: Optional[int] = None, + rom_start: int | None = None, + rom_end: int | None = None, + vram_start: int | None = None, ) -> Segment: rep: Segment = replace_class( rom_start=rom_start, @@ -70,7 +74,7 @@ def _generate_segment_from_all( name=base_name, vram_start=vram_start, args=[], - yaml={}, + yaml=cast("SerializedSegmentData", {}), ) rep.extract = False rep.given_subalign = self.given_subalign @@ -90,14 +94,14 @@ def _generate_segment_from_all( def _insert_all_auto_sections( self, - ret: List[Segment], + ret: list[Segment], base_segments: OrderedDict[str, Segment], readonly_before: bool, - ) -> List[Segment]: + ) -> list[Segment]: if len(options.opts.auto_link_sections) == 0: return ret - base_segments_list: List[Tuple[str, Segment]] = list(base_segments.items()) + base_segments_list: list[tuple[str, Segment]] = list(base_segments.items()) # Determine what will be the min insertion index last_inserted_index = len(base_segments_list) - 1 @@ -145,7 +149,9 @@ def _insert_all_auto_sections( return ret - def parse_subsegments(self, segment_yaml) -> List[Segment]: + def parse_subsegments( + self, segment_yaml: dict[str, list[SerializedSegmentData | list[str]]] + ) -> list[Segment]: if "subsegments" not in segment_yaml: if not self.parent: raise Exception( @@ -154,9 +160,9 @@ def parse_subsegments(self, segment_yaml) -> List[Segment]: return [] base_segments: OrderedDict[str, Segment] = OrderedDict() - ret: List[Segment] = [] - prev_start: Optional[int] = -1 - prev_vram: Optional[int] = -1 + ret: list[Segment] = [] + prev_start: int | None = -1 + prev_vram: int | None = -1 last_rom_end = None @@ -194,9 +200,9 @@ def parse_subsegments(self, segment_yaml) -> List[Segment]: # First, try to get the end address from the next segment's start address # Second, try to get the end address from the estimated size of this segment # Third, try to get the end address from the next segment with a start address - end: Optional[int] = None + end: int | None = None if i < len(segment_yaml["subsegments"]) - 1: - end, end_is_auto_segment = Segment.parse_segment_start( + end, _end_is_auto_segment = Segment.parse_segment_start( segment_yaml["subsegments"][i + 1] ) if start is not None and end is None: @@ -275,7 +281,7 @@ def parse_subsegments(self, segment_yaml) -> List[Segment]: return ret - def scan(self, rom_bytes): + def scan(self, rom_bytes: bytes) -> None: # Always scan code first for sub in self.subsegments: if sub.is_text() and sub.should_scan(): diff --git a/src/splat/segtypes/common/codesubsegment.py b/src/splat/segtypes/common/codesubsegment.py index 18a5ea74..7d578143 100644 --- a/src/splat/segtypes/common/codesubsegment.py +++ b/src/splat/segtypes/common/codesubsegment.py @@ -1,5 +1,5 @@ -from pathlib import Path -from typing import Optional, List +from __future__ import annotations + import spimdisasm import rabbitizer @@ -8,25 +8,46 @@ from .code import CommonSegCode -from ..segment import Segment, parse_segment_vram +from ..segment import Segment, parse_segment_vram, SerializedSegment from ...disassembler.disassembler_section import DisassemblerSection, make_text_section +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path # abstract class for c, asm, data, etc class CommonSegCodeSubsegment(Segment): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + rom_start: int | None, + rom_end: int | None, + type: str, + name: str, + vram_start: int | None, + args: list[str], + yaml: SerializedSegment, + ) -> None: + super().__init__( + rom_start=rom_start, + rom_end=rom_end, + type=type, + name=name, + vram_start=vram_start, + args=args, + yaml=yaml, + ) vram = parse_segment_vram(self.yaml) if vram is not None: self.vram_start = vram - self.str_encoding: Optional[str] = ( + self.str_encoding: str | None = ( self.yaml.get("str_encoding", None) if isinstance(self.yaml, dict) else None ) - self.spim_section: Optional[DisassemblerSection] = None + self.spim_section: DisassemblerSection | None = None self.instr_category = rabbitizer.InstrCategory.CPU if options.opts.platform == "ps2": self.instr_category = rabbitizer.InstrCategory.R5900 @@ -35,7 +56,7 @@ def __init__(self, *args, **kwargs): elif options.opts.platform == "psp": self.instr_category = rabbitizer.InstrCategory.R4000ALLEGREX - self.detect_redundant_function_end: Optional[bool] = ( + self.detect_redundant_function_end: bool | None = ( self.yaml.get("detect_redundant_function_end", None) if isinstance(self.yaml, dict) else None @@ -43,6 +64,7 @@ def __init__(self, *args, **kwargs): self.is_hasm = False self.use_gp_rel_macro = options.opts.use_gp_rel_macro + # self.parent: CommonSegCode @property def needs_symbols(self) -> bool: @@ -57,13 +79,15 @@ def configure_disassembler_section( "Allows to configure the section before running the analysis on it" section = disassembler_section.get_section() + assert section is not None + # TODO: Figure out why mypy thinks these attributes don't exist section.isHandwritten = self.is_hasm - section.instrCat = self.instr_category - section.detectRedundantFunctionEnd = self.detect_redundant_function_end - section.setGpRelHack(not self.use_gp_rel_macro) + section.instrCat = self.instr_category # type: ignore[attr-defined] + section.detectRedundantFunctionEnd = self.detect_redundant_function_end # type: ignore[attr-defined] + section.setGpRelHack(not self.use_gp_rel_macro) # type: ignore[attr-defined] - def scan_code(self, rom_bytes, is_hasm=False): + def scan_code(self, rom_bytes: bytes, is_hasm: bool = False) -> None: self.is_hasm = is_hasm if self.is_auto_segment: @@ -103,7 +127,11 @@ def scan_code(self, rom_bytes, is_hasm=False): self.spim_section.analyze() self.spim_section.set_comment_offset(self.rom_start) - for func in self.spim_section.get_section().symbolList: + section = self.spim_section.get_section() + + assert section is not None + + for func in section.symbolList: assert isinstance(func, spimdisasm.mips.symbols.SymbolFunction) self.process_insns(func) @@ -111,12 +139,11 @@ def scan_code(self, rom_bytes, is_hasm=False): def process_insns( self, func_spim: spimdisasm.mips.symbols.SymbolFunction, - ): + ) -> None: assert isinstance(self.parent, CommonSegCode) assert func_spim.vram is not None assert func_spim.vramEnd is not None assert self.spim_section is not None - self.parent: CommonSegCode = self.parent symbols.create_symbol_from_spim_symbol( self.get_most_parent(), func_spim.contextSym, force_in_segment=False @@ -124,9 +151,9 @@ def process_insns( # Gather symbols found by spimdisasm and create those symbols in splat's side for referenced_vram in func_spim.referencedVrams: - context_sym = self.spim_section.get_section().getSymbol( - referenced_vram, tryPlusOffset=False - ) + section = self.spim_section.get_section() + assert section is not None + context_sym = section.getSymbol(referenced_vram, tryPlusOffset=False) if context_sym is not None: symbols.create_symbol_from_spim_symbol( self.get_most_parent(), context_sym, force_in_segment=False @@ -150,28 +177,32 @@ def process_insns( if instr_offset in func_spim.instrAnalyzer.symbolInstrOffset: sym_address = func_spim.instrAnalyzer.symbolInstrOffset[instr_offset] - context_sym = self.spim_section.get_section().getSymbol(sym_address) + section = self.spim_section.get_section() + assert section is not None + context_sym = section.getSymbol(sym_address) if context_sym is not None: symbols.create_symbol_from_spim_symbol( self.get_most_parent(), context_sym, force_in_segment=False ) - def print_file_boundaries(self): + def print_file_boundaries(self) -> None: if not self.show_file_boundaries or not self.spim_section: return + assert isinstance(self.parent, CommonSegCode) assert isinstance(self.rom_start, int) - for in_file_offset in self.spim_section.get_section().fileBoundaries: + section = self.spim_section.get_section() + assert section is not None + + for in_file_offset in section.fileBoundaries: if not self.parent.reported_file_split: self.parent.reported_file_split = True # Look up for the last symbol in this boundary sym_addr = 0 - for sym in self.spim_section.get_section().symbolList: - symOffset = ( - sym.inFileOffset - self.spim_section.get_section().inFileOffset - ) + for sym in section.symbolList: + symOffset = sym.inFileOffset - section.inFileOffset if in_file_offset == symOffset: break sym_addr = sym.vram @@ -199,7 +230,7 @@ def should_split(self) -> bool: def should_self_split(self) -> bool: return self.should_split() - def get_asm_file_header(self) -> List[str]: + def get_asm_file_header(self) -> list[str]: ret = [] ret.append('.include "macro.inc"') @@ -217,7 +248,7 @@ def get_asm_file_header(self) -> List[str]: return ret - def get_asm_file_extra_directives(self) -> List[str]: + def get_asm_file_extra_directives(self) -> list[str]: ret = [] ret.append(".set noat") # allow manual use of $at @@ -231,10 +262,10 @@ def get_asm_file_extra_directives(self) -> List[str]: def asm_out_path(self) -> Path: return options.opts.asm_path / self.dir / f"{self.name}.s" - def out_path(self) -> Optional[Path]: + def out_path(self) -> Path | None: return self.asm_out_path() - def split_as_asm_file(self, out_path: Optional[Path]): + def split_as_asm_file(self, out_path: Path | None) -> None: if self.spim_section is None: return @@ -252,7 +283,7 @@ def split_as_asm_file(self, out_path: Optional[Path]): f.write(self.spim_section.disassemble()) # Same as above but write all sections from siblings - def split_as_asmtu_file(self, out_path: Path): + def split_as_asmtu_file(self, out_path: Path) -> None: out_path.parent.mkdir(parents=True, exist_ok=True) self.print_file_boundaries() diff --git a/src/splat/segtypes/common/data.py b/src/splat/segtypes/common/data.py index 19f7467f..0c58519d 100644 --- a/src/splat/segtypes/common/data.py +++ b/src/splat/segtypes/common/data.py @@ -1,5 +1,6 @@ -from pathlib import Path -from typing import Optional, List +from __future__ import annotations + +from typing import TYPE_CHECKING from ...util import options, symbols, log from .codesubsegment import CommonSegCodeSubsegment @@ -7,6 +8,10 @@ from ...disassembler.disassembler_section import DisassemblerSection, make_data_section +if TYPE_CHECKING: + from pathlib import Path + from ...segtypes.linker_entry import LinkerEntry + class CommonSegData(CommonSegCodeSubsegment, CommonSegGroup): @staticmethod @@ -20,28 +25,26 @@ def asm_out_path(self) -> Path: return options.opts.data_path / self.dir / f"{self.name}.{typ}.s" - def out_path(self) -> Optional[Path]: + def out_path(self) -> Path | None: if self.type.startswith("."): if self.sibling: # C file return self.sibling.out_path() - else: - # Implied C file - return options.opts.src_path / self.dir / f"{self.name}.c" - else: - # ASM - return self.asm_out_path() - - def scan(self, rom_bytes: bytes): + # Implied C file + return options.opts.src_path / self.dir / f"{self.name}.c" + # ASM + return self.asm_out_path() + + def scan(self, rom_bytes: bytes) -> None: CommonSegGroup.scan(self, rom_bytes) if self.rom_start is not None and self.rom_end is not None: self.disassemble_data(rom_bytes) - def get_asm_file_extra_directives(self) -> List[str]: + def get_asm_file_extra_directives(self) -> list[str]: return [] - def split(self, rom_bytes: bytes): + def split(self, rom_bytes: bytes) -> None: super().split(rom_bytes) if self.spim_section is None or not self.should_self_split(): @@ -66,10 +69,10 @@ def cache(self): def get_linker_section(self) -> str: return ".data" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "wa" - def get_linker_entries(self): + def get_linker_entries(self) -> list[LinkerEntry]: return CommonSegCodeSubsegment.get_linker_entries(self) def configure_disassembler_section( @@ -78,6 +81,7 @@ def configure_disassembler_section( "Allows to configure the section before running the analysis on it" section = disassembler_section.get_section() + assert section is not None # Set data string encoding # First check the global configuration @@ -88,7 +92,7 @@ def configure_disassembler_section( if self.str_encoding is not None: section.stringEncoding = self.str_encoding - def disassemble_data(self, rom_bytes): + def disassemble_data(self, rom_bytes: bytes) -> None: if self.is_auto_segment: return @@ -128,16 +132,17 @@ def disassemble_data(self, rom_bytes): rodata_encountered = False - for symbol in self.spim_section.get_section().symbolList: + section = self.spim_section.get_section() + assert section is not None + + for symbol in section.symbolList: symbols.create_symbol_from_spim_symbol( self.get_most_parent(), symbol.contextSym, force_in_segment=True ) # Gather symbols found by spimdisasm and create those symbols in splat's side for referenced_vram in symbol.referencedVrams: - context_sym = self.spim_section.get_section().getSymbol( - referenced_vram, tryPlusOffset=False - ) + context_sym = section.getSymbol(referenced_vram, tryPlusOffset=False) if context_sym is not None: symbols.create_symbol_from_spim_symbol( self.get_most_parent(), context_sym, force_in_segment=False diff --git a/src/splat/segtypes/common/databin.py b/src/splat/segtypes/common/databin.py index 25ffd257..6d207653 100644 --- a/src/splat/segtypes/common/databin.py +++ b/src/splat/segtypes/common/databin.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations from ...util import log, options @@ -17,7 +17,7 @@ def is_data() -> bool: def get_linker_section(self) -> str: return ".data" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "wa" def split(self, rom_bytes): diff --git a/src/splat/segtypes/common/eh_frame.py b/src/splat/segtypes/common/eh_frame.py index ef1a72b9..18a445b7 100644 --- a/src/splat/segtypes/common/eh_frame.py +++ b/src/splat/segtypes/common/eh_frame.py @@ -1,7 +1,10 @@ -from typing import Optional +from __future__ import annotations from .data import CommonSegData -from ...disassembler.disassembler_section import DisassemblerSection +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...disassembler.disassembler_section import DisassemblerSection class CommonSegEh_frame(CommonSegData): @@ -10,7 +13,7 @@ class CommonSegEh_frame(CommonSegData): def get_linker_section(self) -> str: return ".eh_frame" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "aw" def configure_disassembler_section( @@ -21,6 +24,7 @@ def configure_disassembler_section( super().configure_disassembler_section(disassembler_section) section = disassembler_section.get_section() + assert section is not None # We use s32 to make sure spimdisasm disassembles the data from this section as words/references to other symbols section.enableStringGuessing = False diff --git a/src/splat/segtypes/common/gcc_except_table.py b/src/splat/segtypes/common/gcc_except_table.py index 1ff7d4c6..a0cf5529 100644 --- a/src/splat/segtypes/common/gcc_except_table.py +++ b/src/splat/segtypes/common/gcc_except_table.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations from .data import CommonSegData from ...util import log @@ -15,7 +15,7 @@ class CommonSegGcc_except_table(CommonSegData): def get_linker_section(self) -> str: return ".gcc_except_table" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "aw" def configure_disassembler_section( @@ -26,10 +26,11 @@ def configure_disassembler_section( super().configure_disassembler_section(disassembler_section) section = disassembler_section.get_section() + assert section is not None section.enableStringGuessing = False - def disassemble_data(self, rom_bytes): + def disassemble_data(self, rom_bytes: bytes) -> None: if self.is_auto_segment: return diff --git a/src/splat/segtypes/common/group.py b/src/splat/segtypes/common/group.py index 94f02ff2..11429ac4 100644 --- a/src/splat/segtypes/common/group.py +++ b/src/splat/segtypes/common/group.py @@ -1,22 +1,28 @@ -from typing import List, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from ...util import log from .segment import CommonSegment from ..segment import empty_statistics, Segment, SegmentStatistics +if TYPE_CHECKING: + from ...util.vram_classes import SerializedSegmentData + from ..linker_entry import LinkerEntry + class CommonSegGroup(CommonSegment): def __init__( self, - rom_start: Optional[int], - rom_end: Optional[int], + rom_start: int | None, + rom_end: int | None, type: str, name: str, - vram_start: Optional[int], - args: list, - yaml, - ): + vram_start: int | None, + args: list[str], + yaml: SerializedSegmentData | list[str], + ) -> None: super().__init__( rom_start, rom_end, @@ -27,12 +33,15 @@ def __init__( yaml=yaml, ) - self.subsegments: List[Segment] = self.parse_subsegments(yaml) + # TODO: Fix + self.subsegments: list[Segment] = self.parse_subsegments(yaml) # type: ignore[arg-type] - def get_next_seg_start(self, i, subsegment_yamls) -> Optional[int]: + def get_next_seg_start( + self, i: int, subsegment_yamls: list[SerializedSegmentData | list[str]] + ) -> int | None: j = i + 1 while j < len(subsegment_yamls): - ret, is_auto_segment = Segment.parse_segment_start(subsegment_yamls[j]) + ret, _is_auto_segment = Segment.parse_segment_start(subsegment_yamls[j]) if ret is not None: return ret j += 1 @@ -40,13 +49,15 @@ def get_next_seg_start(self, i, subsegment_yamls) -> Optional[int]: # Fallback return self.rom_end - def parse_subsegments(self, yaml) -> List[Segment]: - ret: List[Segment] = [] + def parse_subsegments( + self, yaml: dict[str, list[SerializedSegmentData | list[str]]] + ) -> list[Segment]: + ret: list[Segment] = [] if not yaml or "subsegments" not in yaml: return ret - prev_start: Optional[int] = -1 + prev_start: int | None = -1 last_rom_end = 0 for i, subsegment_yaml in enumerate(yaml["subsegments"]): @@ -71,9 +82,9 @@ def parse_subsegments(self, yaml) -> List[Segment]: # First, try to get the end address from the next segment's start address # Second, try to get the end address from the estimated size of this segment # Third, try to get the end address from the next segment with a start address - end: Optional[int] = None + end: int | None = None if i < len(yaml["subsegments"]) - 1: - end, end_is_auto_segment = Segment.parse_segment_start( + end, _end_is_auto_segment = Segment.parse_segment_start( yaml["subsegments"][i + 1] ) if start is not None and end is None: @@ -104,7 +115,12 @@ def parse_subsegments(self, yaml) -> List[Segment]: end = last_rom_end segment: Segment = Segment.from_yaml( - segment_class, subsegment_yaml, start, end, self, vram + segment_class, + subsegment_yaml, + start, + end, + self, + vram, ) if segment.special_vram_segment: self.special_vram_segment = True @@ -137,15 +153,15 @@ def statistics(self) -> SegmentStatistics: stats[ty] = stats[ty].merge(info) return stats - def get_linker_entries(self): + def get_linker_entries(self) -> list[LinkerEntry]: return [entry for sub in self.subsegments for entry in sub.get_linker_entries()] - def scan(self, rom_bytes): + def scan(self, rom_bytes: bytes) -> None: for sub in self.subsegments: if sub.should_scan(): sub.scan(rom_bytes) - def split(self, rom_bytes): + def split(self, rom_bytes: bytes) -> None: for sub in self.subsegments: if sub.should_split(): sub.split(rom_bytes) @@ -156,7 +172,7 @@ def should_split(self) -> bool: def should_scan(self) -> bool: return self.extract - def cache(self): + def cache(self) -> list[tuple[SerializedSegmentData | list[str], int | None]]: # type: ignore[override] c = [] for sub in self.subsegments: @@ -164,7 +180,7 @@ def cache(self): return c - def get_subsegment_for_ram(self, addr: int) -> Optional[Segment]: + def get_subsegment_for_ram(self, addr: int) -> Segment | None: for sub in self.subsegments: if sub.contains_vram(addr): return sub @@ -175,8 +191,8 @@ def get_subsegment_for_ram(self, addr: int) -> Optional[Segment]: return None def get_next_subsegment_for_ram( - self, addr: int, current_subseg_index: Optional[int] - ) -> Optional[Segment]: + self, addr: int, current_subseg_index: int | None + ) -> Segment | None: """ Returns the first subsegment which comes after the specified address, or None in case this address belongs to the last subsegment of this group @@ -194,8 +210,8 @@ def get_next_subsegment_for_ram( def pair_subsegments_to_other_segment( self, - other_segment: "CommonSegGroup", - ): + other_segment: CommonSegGroup, + ) -> None: # Pair cousins with the same name for segment in self.subsegments: for sibling in other_segment.subsegments: diff --git a/src/splat/segtypes/common/header.py b/src/splat/segtypes/common/header.py index f8c41c66..3514d46c 100644 --- a/src/splat/segtypes/common/header.py +++ b/src/splat/segtypes/common/header.py @@ -1,8 +1,13 @@ -from pathlib import Path +from __future__ import annotations + from ...util import options from .segment import CommonSegment +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path class CommonSegHeader(CommonSegment): @@ -10,11 +15,11 @@ class CommonSegHeader(CommonSegment): def is_data() -> bool: return True - def should_split(self): + def should_split(self) -> bool: return self.extract and options.opts.is_mode_active("code") @staticmethod - def get_line(typ, data, comment): + def get_line(typ: str, data: bytes, comment: str) -> str: if typ == "ascii": text = data.decode("ASCII").strip() text = text.replace("\x00", "\\0") # escape NUL chars @@ -29,18 +34,18 @@ def get_line(typ, data, comment): def out_path(self) -> Path: return options.opts.asm_path / self.dir / f"{self.name}.s" - def parse_header(self, rom_bytes): + def parse_header(self, rom_bytes: bytes) -> list[str]: return [] - def split(self, rom_bytes): + def split(self, rom_bytes: bytes) -> None: header_lines = self.parse_header(rom_bytes) src_path = self.out_path() src_path.parent.mkdir(parents=True, exist_ok=True) - with open(src_path, "w", newline="\n") as f: + with open(src_path, "w", newline="\n", encoding="utf-8") as f: f.write("\n".join(header_lines)) self.log(f"Wrote {self.name} to {src_path}") @staticmethod - def get_default_name(addr): + def get_default_name(addr: int) -> str: return "header" diff --git a/src/splat/segtypes/common/lib.py b/src/splat/segtypes/common/lib.py index 409fa282..49228ab4 100644 --- a/src/splat/segtypes/common/lib.py +++ b/src/splat/segtypes/common/lib.py @@ -1,5 +1,5 @@ +from __future__ import annotations from pathlib import Path -from typing import Optional, List from ...util import log, options @@ -13,7 +13,7 @@ class LinkerEntryLib(LinkerEntry): def __init__( self, segment: Segment, - src_paths: List[Path], + src_paths: list[Path], object_path: Path, section_order: str, section_link: str, @@ -31,11 +31,11 @@ def emit_entry(self, linker_writer: LinkerWriter): class CommonSegLib(CommonSegment): def __init__( self, - rom_start: Optional[int], - rom_end: Optional[int], + rom_start: int | None, + rom_end: int | None, type: str, name: str, - vram_start: Optional[int], + vram_start: int | None, args: list, yaml, ): @@ -73,7 +73,7 @@ def __init__( def get_linker_section(self) -> str: return self.section - def get_linker_entries(self) -> List[LinkerEntry]: + def get_linker_entries(self) -> list[LinkerEntry]: path = options.opts.lib_path / self.name object_path = Path(f"{path}.a:{self.object}.o") diff --git a/src/splat/segtypes/common/linker_offset.py b/src/splat/segtypes/common/linker_offset.py index e941715a..7fd3d768 100644 --- a/src/splat/segtypes/common/linker_offset.py +++ b/src/splat/segtypes/common/linker_offset.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List from ..linker_entry import LinkerEntry, LinkerWriter from ..segment import Segment @@ -24,5 +23,5 @@ def get_linker_section_order(self) -> str: def get_linker_section_linksection(self) -> str: return "" - def get_linker_entries(self) -> List[LinkerEntry]: + def get_linker_entries(self) -> list[LinkerEntry]: return [LinkerEntryOffset(self)] diff --git a/src/splat/segtypes/common/pad.py b/src/splat/segtypes/common/pad.py index c99362d9..0471b577 100644 --- a/src/splat/segtypes/common/pad.py +++ b/src/splat/segtypes/common/pad.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List from .segment import CommonSegment from ..linker_entry import LinkerEntry, LinkerWriter @@ -25,5 +24,5 @@ def get_linker_section_order(self) -> str: def get_linker_section_linksection(self) -> str: return "" - def get_linker_entries(self) -> List[LinkerEntry]: + def get_linker_entries(self) -> list[LinkerEntry]: return [LinkerEntryPad(self)] diff --git a/src/splat/segtypes/common/rodata.py b/src/splat/segtypes/common/rodata.py index 39baeba9..217543bd 100644 --- a/src/splat/segtypes/common/rodata.py +++ b/src/splat/segtypes/common/rodata.py @@ -1,21 +1,26 @@ -from typing import Optional, Set, Tuple, List -import spimdisasm -from ..segment import Segment +from __future__ import annotations + from ...util import log, options, symbols +from .code import CommonSegCode from .data import CommonSegData from ...disassembler.disassembler_section import ( DisassemblerSection, make_rodata_section, ) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..segment import Segment + import spimdisasm class CommonSegRodata(CommonSegData): def get_linker_section(self) -> str: return ".rodata" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "a" @staticmethod @@ -28,7 +33,7 @@ def is_rodata() -> bool: def get_possible_text_subsegment_for_symbol( self, rodata_sym: spimdisasm.mips.symbols.SymbolBase - ) -> Optional[Tuple[Segment, spimdisasm.common.ContextSymbol]]: + ) -> tuple[Segment, spimdisasm.common.ContextSymbol] | None: # Check if this rodata segment does not have a corresponding code file, try to look for one if self.sibling is not None or not options.opts.pair_rodata_to_text: @@ -40,7 +45,8 @@ def get_possible_text_subsegment_for_symbol( if len(rodata_sym.contextSym.referenceFunctions) != 1: return None - func = list(rodata_sym.contextSym.referenceFunctions)[0] + func = next(iter(rodata_sym.contextSym.referenceFunctions)) + assert isinstance(self.parent, CommonSegCode) text_segment = self.parent.get_subsegment_for_ram(func.vram) if text_segment is None or not text_segment.is_text(): @@ -53,6 +59,7 @@ def configure_disassembler_section( "Allows to configure the section before running the analysis on it" section = disassembler_section.get_section() + assert section is not None # Set rodata string encoding # First check the global configuration @@ -63,7 +70,7 @@ def configure_disassembler_section( if self.str_encoding is not None: section.stringEncoding = self.str_encoding - def disassemble_data(self, rom_bytes): + def disassemble_data(self, rom_bytes: bytes) -> None: if self.is_auto_segment: return @@ -101,12 +108,15 @@ def disassemble_data(self, rom_bytes): self.spim_section.analyze() self.spim_section.set_comment_offset(self.rom_start) - possible_text_segments: Set[Segment] = set() + possible_text_segments: set[Segment] = set() last_jumptable_addr_remainder = 0 - misaligned_jumptable_offsets: List[int] = [] + misaligned_jumptable_offsets: list[int] = [] + + section = self.spim_section.get_section() + assert section is not None - for symbol in self.spim_section.get_section().symbolList: + for symbol in section.symbolList: generated_symbol = symbols.create_symbol_from_spim_symbol( self.get_most_parent(), symbol.contextSym, force_in_segment=True ) @@ -114,9 +124,7 @@ def disassemble_data(self, rom_bytes): # Gather symbols found by spimdisasm and create those symbols in splat's side for referenced_vram in symbol.referencedVrams: - context_sym = self.spim_section.get_section().getSymbol( - referenced_vram, tryPlusOffset=False - ) + context_sym = section.getSymbol(referenced_vram, tryPlusOffset=False) if context_sym is not None: symbols.create_symbol_from_spim_symbol( self.get_most_parent(), context_sym, force_in_segment=False diff --git a/src/splat/segtypes/common/rodatabin.py b/src/splat/segtypes/common/rodatabin.py index f3e8575a..4e77072f 100644 --- a/src/splat/segtypes/common/rodatabin.py +++ b/src/splat/segtypes/common/rodatabin.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations from ...util import log, options @@ -17,7 +17,7 @@ def is_rodata() -> bool: def get_linker_section(self) -> str: return ".rodata" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "a" def split(self, rom_bytes): diff --git a/src/splat/segtypes/common/segment.py b/src/splat/segtypes/common/segment.py index 20c89a5e..091d02b8 100644 --- a/src/splat/segtypes/common/segment.py +++ b/src/splat/segtypes/common/segment.py @@ -1,5 +1,5 @@ -from ...segtypes.segment import Segment +from ..segment import Segment class CommonSegment(Segment): - pass + __slots__ = () diff --git a/src/splat/segtypes/common/textbin.py b/src/splat/segtypes/common/textbin.py index e24eb6f7..d7895d5d 100644 --- a/src/splat/segtypes/common/textbin.py +++ b/src/splat/segtypes/common/textbin.py @@ -1,20 +1,23 @@ -from pathlib import Path +from __future__ import annotations import re -from typing import Optional, TextIO +from typing import TextIO, TYPE_CHECKING from ...util import log, options from .segment import CommonSegment +if TYPE_CHECKING: + from pathlib import Path + class CommonSegTextbin(CommonSegment): def __init__( self, - rom_start: Optional[int], - rom_end: Optional[int], + rom_start: int | None, + rom_end: int | None, type: str, name: str, - vram_start: Optional[int], + vram_start: int | None, args: list, yaml, ): @@ -38,10 +41,10 @@ def is_text() -> bool: def get_linker_section(self) -> str: return ".text" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "ax" - def out_path(self) -> Optional[Path]: + def out_path(self) -> Path | None: if self.use_src_path: return options.opts.src_path / self.dir / f"{self.name}.s" diff --git a/src/splat/segtypes/linker_entry.py b/src/splat/segtypes/linker_entry.py index 192a203d..76358ba8 100644 --- a/src/splat/segtypes/linker_entry.py +++ b/src/splat/segtypes/linker_entry.py @@ -1,17 +1,22 @@ +from __future__ import annotations + import os import re -from functools import lru_cache +from functools import cache from pathlib import Path -from typing import Dict, List, OrderedDict, Set, Tuple, Union, Optional +from typing import TYPE_CHECKING +from collections import OrderedDict from ..util import options, log -from .segment import Segment from ..util.symbols import to_cname +if TYPE_CHECKING: + from .segment import Segment + # clean 'foo/../bar' to 'bar' -@lru_cache(maxsize=None) +@cache def clean_up_path(path: Path) -> Path: path_resolved = path.resolve() base_resolved = options.opts.base_path.resolve() @@ -43,7 +48,7 @@ def path_to_object_path(path: Path) -> Path: return clean_up_path(path.with_suffix(full_suffix)) -def write_file_if_different(path: Path, new_content: str): +def write_file_if_different(path: Path, new_content: str) -> None: if path.exists(): old_content = path.read_text() else: @@ -121,7 +126,7 @@ class LinkerEntry: def __init__( self, segment: Segment, - src_paths: List[Path], + src_paths: list[Path], object_path: Path, section_order: str, section_link: str, @@ -133,23 +138,21 @@ def __init__( self.section_link = section_link self.noload = noload self.bss_contains_common = segment.bss_contains_common - self.object_path: Optional[Path] = path_to_object_path(object_path) + self.object_path: Path | None = path_to_object_path(object_path) @property def section_order_type(self) -> str: if self.section_order == ".rdata": return ".rodata" - else: - return self.section_order + return self.section_order @property def section_link_type(self) -> str: if self.section_link == ".rdata": return ".rodata" - else: - return self.section_link + return self.section_link - def emit_symbol_for_data(self, linker_writer: "LinkerWriter"): + def emit_symbol_for_data(self, linker_writer: LinkerWriter) -> None: if not options.opts.ld_generate_symbol_per_data_segment: return @@ -161,7 +164,7 @@ def emit_symbol_for_data(self, linker_writer: "LinkerWriter"): ) linker_writer._write_symbol(path_cname, ".") - def emit_path(self, linker_writer: "LinkerWriter"): + def emit_path(self, linker_writer: LinkerWriter) -> None: assert self.object_path is not None, ( f"{self.segment.name}, {self.segment.rom_start}" ) @@ -177,7 +180,7 @@ def emit_path(self, linker_writer: "LinkerWriter"): self.object_path, f"{self.section_link}{wildcard}" ) - def emit_entry(self, linker_writer: "LinkerWriter"): + def emit_entry(self, linker_writer: LinkerWriter) -> None: self.emit_symbol_for_data(linker_writer) self.emit_path(linker_writer) @@ -185,14 +188,14 @@ def emit_entry(self, linker_writer: "LinkerWriter"): class LinkerWriter: def __init__(self, is_partial: bool = False): self.linker_discard_section: bool = options.opts.ld_discard_section - self.sections_allowlist: List[str] = options.opts.ld_sections_allowlist - self.sections_denylist: List[str] = options.opts.ld_sections_denylist + self.sections_allowlist: list[str] = options.opts.ld_sections_allowlist + self.sections_denylist: list[str] = options.opts.ld_sections_denylist # Used to store all the linker entries - build tools may want this information - self.entries: List[LinkerEntry] = [] - self.dependencies_entries: List[LinkerEntry] = [] + self.entries: list[LinkerEntry] = [] + self.dependencies_entries: list[LinkerEntry] = [] - self.buffer: List[str] = [] - self.header_symbols: Set[str] = set() + self.buffer: list[str] = [] + self.header_symbols: set[str] = set() self.is_partial: bool = is_partial @@ -210,7 +213,7 @@ def __init__(self, is_partial: bool = False): self._writeln(f"_gp = 0x{options.opts.gp:X};") # Write a series of statements which compute a symbol that represents the highest address among a list of segments' end addresses - def write_max_vram_end_sym(self, symbol: str, overlays: List[Segment]): + def write_max_vram_end_sym(self, symbol: str, overlays: list[Segment]) -> None: for segment in overlays: if segment == overlays[0]: self._writeln( @@ -222,7 +225,9 @@ def write_max_vram_end_sym(self, symbol: str, overlays: List[Segment]): ) # Adds all the entries of a segment to the linker script buffer - def add(self, segment: Segment, max_vram_syms: List[Tuple[str, List[Segment]]]): + def add( + self, segment: Segment, max_vram_syms: list[tuple[str, list[Segment]]] + ) -> None: entries = segment.get_linker_entries() self.entries.extend(entries) self.dependencies_entries.extend(entries) @@ -236,7 +241,7 @@ def add(self, segment: Segment, max_vram_syms: List[Tuple[str, List[Segment]]]): self.add_legacy(segment, entries) return - section_entries: OrderedDict[str, List[LinkerEntry]] = OrderedDict() + section_entries: OrderedDict[str, list[LinkerEntry]] = OrderedDict() for section_name in segment.section_order: if section_name in options.opts.section_order: section_entries[section_name] = [] @@ -296,16 +301,16 @@ def add(self, segment: Segment, max_vram_syms: List[Tuple[str, List[Segment]]]): self._end_segment(segment, all_bss=not any_load) - def add_legacy(self, segment: Segment, entries: List[LinkerEntry]): + def add_legacy(self, segment: Segment, entries: list[LinkerEntry]) -> None: seg_name = segment.get_cname() # To keep track which sections has been started - started_sections: Dict[str, bool] = { + started_sections: dict[str, bool] = { section_name: False for section_name in options.opts.section_order } # Find where sections are last seen - last_seen_sections: Dict[LinkerEntry, str] = {} + last_seen_sections: dict[LinkerEntry, str] = {} for entry in reversed(entries): if ( entry.section_order_type in options.opts.section_order @@ -360,8 +365,8 @@ def add_legacy(self, segment: Segment, entries: List[LinkerEntry]): self._end_segment(segment, all_bss=False) def add_referenced_partial_segment( - self, segment: Segment, max_vram_syms: List[Tuple[str, List[Segment]]] - ): + self, segment: Segment, max_vram_syms: list[tuple[str, list[Segment]]] + ) -> None: entries = segment.get_linker_entries() self.entries.extend(entries) @@ -436,14 +441,14 @@ def add_referenced_partial_segment( self._end_segment(segment, all_bss=not any_load) - def add_partial_segment(self, segment: Segment): + def add_partial_segment(self, segment: Segment) -> None: entries = segment.get_linker_entries() self.entries.extend(entries) self.dependencies_entries.extend(entries) seg_name = segment.get_cname() - section_entries: OrderedDict[str, List[LinkerEntry]] = OrderedDict() + section_entries: OrderedDict[str, list[LinkerEntry]] = OrderedDict() for section_name in segment.section_order: if section_name in options.opts.section_order: section_entries[section_name] = [] @@ -482,7 +487,7 @@ def add_partial_segment(self, segment: Segment): self._end_partial_segment(section_name) - def save_linker_script(self, output_path: Path): + def save_linker_script(self, output_path: Path) -> None: if len(self.sections_allowlist) > 0: address = " 0" if self.is_partial: @@ -510,7 +515,7 @@ def save_linker_script(self, output_path: Path): write_file_if_different(output_path, "\n".join(self.buffer) + "\n") - def save_symbol_header(self): + def save_symbol_header(self) -> None: path = options.opts.ld_symbol_header_path if path: @@ -528,7 +533,7 @@ def save_symbol_header(self): "#endif\n", ) - def save_dependencies_file(self, output_path: Path, target_elf_path: Path): + def save_dependencies_file(self, output_path: Path, target_elf_path: Path) -> None: output = f"{clean_up_path(target_elf_path).as_posix()}:" for entry in self.dependencies_entries: @@ -543,21 +548,21 @@ def save_dependencies_file(self, output_path: Path, target_elf_path: Path): output += f"{entry.object_path.as_posix()}:\n" write_file_if_different(output_path, output) - def _writeln(self, line: str): + def _writeln(self, line: str) -> None: if len(line) == 0: self.buffer.append(line) else: self.buffer.append(" " * self._indent_level + line) - def _begin_block(self): + def _begin_block(self) -> None: self._writeln("{") self._indent_level += 1 - def _end_block(self): + def _end_block(self) -> None: self._indent_level -= 1 self._writeln("}") - def _write_symbol(self, symbol: str, value: Union[str, int]): + def _write_symbol(self, symbol: str, value: str | int) -> None: symbol = to_cname(symbol) if isinstance(value, int): @@ -567,12 +572,12 @@ def _write_symbol(self, symbol: str, value: Union[str, int]): self.header_symbols.add(symbol) - def _write_object_path_section(self, object_path: Path, section: str): + def _write_object_path_section(self, object_path: Path, section: str) -> None: self._writeln(f"{object_path.as_posix()}({section});") def _begin_segment( self, segment: Segment, seg_name: str, noload: bool, is_first: bool - ): + ) -> None: if ( options.opts.ld_use_symbolic_vram_addresses and segment.vram_symbol is not None @@ -608,7 +613,7 @@ def _begin_segment( if segment.ld_fill_value is not None: self._writeln(f"FILL(0x{segment.ld_fill_value:08X});") - def _end_segment(self, segment: Segment, all_bss=False): + def _end_segment(self, segment: Segment, all_bss: bool = False) -> None: self._end_block() name = segment.get_cname() @@ -636,7 +641,9 @@ def _end_segment(self, segment: Segment, all_bss=False): self._writeln("") - def _begin_partial_segment(self, section_name: str, segment: Segment, noload: bool): + def _begin_partial_segment( + self, section_name: str, segment: Segment, noload: bool + ) -> None: line = f"{section_name}" if noload: line += " (NOLOAD)" @@ -647,7 +654,7 @@ def _begin_partial_segment(self, section_name: str, segment: Segment, noload: bo self._writeln(line) self._begin_block() - def _end_partial_segment(self, section_name: str, all_bss=False): + def _end_partial_segment(self, section_name: str, all_bss: bool = False) -> None: self._end_block() self._writeln("") @@ -672,10 +679,10 @@ def _write_segment_sections( self, segment: Segment, seg_name: str, - section_entries: OrderedDict[str, List[LinkerEntry]], + section_entries: OrderedDict[str, list[LinkerEntry]], noload: bool, is_first: bool, - ): + ) -> None: if not is_first: self._end_block() diff --git a/src/splat/segtypes/n64/ci.py b/src/splat/segtypes/n64/ci.py index 68e6e4da..85b9136e 100644 --- a/src/splat/segtypes/n64/ci.py +++ b/src/splat/segtypes/n64/ci.py @@ -1,18 +1,23 @@ -from pathlib import Path -from typing import List, TYPE_CHECKING +from __future__ import annotations + +from typing import TYPE_CHECKING from ...util import log, options from .img import N64SegImg if TYPE_CHECKING: + from pathlib import Path from .palette import N64SegPalette + from ...utils.vram_classes import SerializedSegmentData # Base class for CI4/CI8 class N64SegCi(N64SegImg): - def parse_palette_names(self, yaml, args) -> List[str]: - ret = [self.name] + def parse_palette_names( + self, yaml: SerializedSegmentData | list[str], args: list[str] + ) -> list[str]: + ret: list[str] | str = [self.name] if isinstance(yaml, dict): if "palettes" in yaml: ret = yaml["palettes"] @@ -20,19 +25,19 @@ def parse_palette_names(self, yaml, args) -> List[str]: ret = args[2] if isinstance(ret, str): - ret = [ret] + return [ret] return ret - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.palettes: "List[N64SegPalette]" = [] + self.palettes: list[N64SegPalette] = [] self.palette_names = self.parse_palette_names(self.yaml, self.args) def scan(self, rom_bytes: bytes) -> None: self.n64img.data = rom_bytes[self.rom_start : self.rom_end] - def out_path_pal(self, pal_name) -> Path: + def out_path_pal(self, pal_name: str) -> Path: type_extension = f".{self.type}" if options.opts.image_type_in_extension else "" if len(self.palettes) == 1: @@ -47,7 +52,7 @@ def out_path_pal(self, pal_name) -> Path: return options.opts.asset_path / self.dir / f"{out_name}{type_extension}.png" - def split(self, rom_bytes): + def split(self, rom_bytes: bytes) -> None: self.check_len() assert self.palettes is not None diff --git a/src/splat/segtypes/n64/gfx.py b/src/splat/segtypes/n64/gfx.py index dbf0afd3..4cb7761b 100644 --- a/src/splat/segtypes/n64/gfx.py +++ b/src/splat/segtypes/n64/gfx.py @@ -3,10 +3,11 @@ Dumps out Gfx[] as a .inc.c file. """ +from __future__ import annotations + import re -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING -from pathlib import Path from pygfxd import ( gfxd_buffer_to_string, @@ -44,20 +45,24 @@ from ...util import symbols +if TYPE_CHECKING: + from pathlib import Path + from ...util.vram_classes import SerializedSegmentData + LIGHTS_RE = re.compile(r"\*\(Lightsn \*\)0x[0-9A-F]{8}") class N64SegGfx(CommonSegCodeSubsegment): def __init__( self, - rom_start: Optional[int], - rom_end: Optional[int], + rom_start: int | None, + rom_end: int | None, type: str, name: str, - vram_start: Optional[int], - args: list, - yaml, - ): + vram_start: int | None, + args: list[str], + yaml: SerializedSegmentData | list[str], + ) -> None: super().__init__( rom_start, rom_end, @@ -67,11 +72,11 @@ def __init__( args=args, yaml=yaml, ) - self.file_text = None + self.file_text: str | None = None self.data_only = isinstance(yaml, dict) and yaml.get("data_only", False) self.in_segment = not isinstance(yaml, dict) or yaml.get("in_segment", True) - def format_sym_name(self, sym) -> str: + def format_sym_name(self, sym: symbols.Symbol) -> str: return sym.name def get_linker_section(self) -> str: @@ -80,54 +85,55 @@ def get_linker_section(self) -> str: def out_path(self) -> Path: return options.opts.asset_path / self.dir / f"{self.name}.gfx.inc.c" - def scan(self, rom_bytes: bytes): + def scan(self, rom_bytes: bytes) -> None: self.file_text = self.disassemble_data(rom_bytes) - def get_gfxd_target(self): + def get_gfxd_target(self) -> gfxd_f3d: # noqa: RET503 opt = options.opts.gfx_ucode if opt == "f3d": return gfxd_f3d - elif opt == "f3db": + if opt == "f3db": return gfxd_f3db - elif opt == "f3dex": + if opt == "f3dex": return gfxd_f3dex - elif opt == "f3dexb": + if opt == "f3dexb": return gfxd_f3dexb - elif opt == "f3dex2": + if opt == "f3dex2": return gfxd_f3dex2 - else: - log.error(f"Unknown target {opt}") + log.error(f"Unknown target {opt}") - def tlut_handler(self, addr, idx, count): + def tlut_handler(self, addr: int, idx: int, count: int) -> int: sym = self.create_symbol( addr=addr, in_segment=self.in_segment, type="data", reference=True ) gfxd_printf(self.format_sym_name(sym)) return 1 - def timg_handler(self, addr, fmt, size, width, height, pal): + def timg_handler( + self, addr: int, fmt, size: int, width: int, height: int, pal + ) -> int: sym = self.create_symbol( addr=addr, in_segment=self.in_segment, type="data", reference=True ) gfxd_printf(self.format_sym_name(sym)) return 1 - def cimg_handler(self, addr, fmt, size, width): + def cimg_handler(self, addr: int, fmt, size: int, width: int) -> int: sym = self.create_symbol( addr=addr, in_segment=self.in_segment, type="data", reference=True ) gfxd_printf(self.format_sym_name(sym)) return 1 - def zimg_handler(self, addr): + def zimg_handler(self, addr: int) -> int: sym = self.create_symbol( addr=addr, in_segment=self.in_segment, type="data", reference=True ) gfxd_printf(self.format_sym_name(sym)) return 1 - def dl_handler(self, addr): + def dl_handler(self, addr: int) -> int: # Look for 'Gfx'-typed symbols first sym = self.retrieve_sym_type(symbols.all_symbols_dict, addr, "Gfx") @@ -138,28 +144,28 @@ def dl_handler(self, addr): gfxd_printf(self.format_sym_name(sym)) return 1 - def mtx_handler(self, addr): + def mtx_handler(self, addr: int) -> int: sym = self.create_symbol( addr=addr, in_segment=self.in_segment, type="data", reference=True ) gfxd_printf(f"&{self.format_sym_name(sym)}") return 1 - def lookat_handler(self, addr, count): + def lookat_handler(self, addr: int, count: int) -> int: sym = self.create_symbol( addr=addr, in_segment=self.in_segment, type="data", reference=True ) gfxd_printf(self.format_sym_name(sym)) return 1 - def light_handler(self, addr, count): + def light_handler(self, addr: int, count: int) -> int: sym = self.create_symbol( addr=addr, in_segment=self.in_segment, type="data", reference=True ) gfxd_printf(self.format_sym_name(sym)) return 1 - def vtx_handler(self, addr, count): + def vtx_handler(self, addr: int, count: int) -> int: # Look for 'Vtx'-typed symbols first sym = self.retrieve_sym_type(symbols.all_symbols_dict, addr, "Vtx") @@ -176,20 +182,20 @@ def vtx_handler(self, addr, count): gfxd_printf(f"&{self.format_sym_name(sym)}[{index}]") return 1 - def vp_handler(self, addr): + def vp_handler(self, addr: int) -> int: sym = self.create_symbol( addr=addr, in_segment=self.in_segment, type="data", reference=True ) gfxd_printf(f"&{self.format_sym_name(sym)}") return 1 - def macro_fn(self): + def macro_fn(self) -> int: gfxd_puts(" ") gfxd_macro_dflt() gfxd_puts(",\n") return 0 - def disassemble_data(self, rom_bytes): + def disassemble_data(self, rom_bytes: bytes) -> str: assert isinstance(self.rom_start, int) assert isinstance(self.rom_end, int) assert isinstance(self.vram_start, int) @@ -246,7 +252,7 @@ def disassemble_data(self, rom_bytes): out_str += "};\n" # Poor man's light fix until we get my libgfxd PR merged - def light_sub_func(match): + def light_sub_func(match: re.Match[str]) -> str: light = match.group(0) addr = int(light[12:], 0) sym = self.create_symbol( @@ -254,15 +260,14 @@ def light_sub_func(match): ) return self.format_sym_name(sym) - out_str = re.sub(LIGHTS_RE, light_sub_func, out_str) - - return out_str + return re.sub(LIGHTS_RE, light_sub_func, out_str) - def split(self, rom_bytes: bytes): - if self.file_text and self.out_path(): - self.out_path().parent.mkdir(parents=True, exist_ok=True) + def split(self, rom_bytes: bytes) -> None: + out_path = self.out_path() + if self.file_text and out_path is not None: + out_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.out_path(), "w", newline="\n") as f: + with open(self.out_path(), "w", encoding="utf-8", newline="\n") as f: f.write(self.file_text) def should_scan(self) -> bool: @@ -276,7 +281,7 @@ def should_split(self) -> bool: return self.extract and options.opts.is_mode_active("gfx") @staticmethod - def estimate_size(yaml: Union[Dict, List]) -> Optional[int]: + def estimate_size(yaml: SerializedSegmentData | list[str]) -> int | None: if isinstance(yaml, dict) and "length" in yaml: - return yaml["length"] * 0x10 + return int(yaml["length"]) * 0x10 return None diff --git a/src/splat/segtypes/n64/img.py b/src/splat/segtypes/n64/img.py index 6bb09a68..1bb1b0bd 100644 --- a/src/splat/segtypes/n64/img.py +++ b/src/splat/segtypes/n64/img.py @@ -1,33 +1,37 @@ -from pathlib import Path -from typing import Dict, List, Tuple, Type, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING -from n64img.image import Image from ...util import log, options from ..segment import Segment +if TYPE_CHECKING: + from n64img.image import Image + from pathlib import Path + from ...util.vram_classes import SerializedSegmentData + class N64SegImg(Segment): @staticmethod - def parse_dimensions(yaml: Union[Dict, List]) -> Tuple[int, int]: + def parse_dimensions(yaml: SerializedSegmentData | list[str]) -> tuple[int, int]: if isinstance(yaml, dict): return yaml["width"], yaml["height"] - else: - if len(yaml) < 5: - log.error(f"Error: {yaml} is missing width and height parameters") - return yaml[3], yaml[4] + if len(yaml) < 5: + log.error(f"Error: {yaml} is missing width and height parameters") + return int(yaml[3]), int(yaml[4]) def __init__( self, - rom_start: Optional[int], - rom_end: Optional[int], + rom_start: int | None, + rom_end: int | None, type: str, name: str, - vram_start: Optional[int], - args: list, - yaml, - img_cls: Type[Image], - ): + vram_start: int | None, + args: list[str], + yaml: SerializedSegmentData | list[str], + img_cls: type[Image], + ) -> None: super().__init__( rom_start, rom_end, @@ -77,7 +81,7 @@ def out_path(self) -> Path: def should_split(self) -> bool: return options.opts.is_mode_active("img") - def split(self, rom_bytes): + def split(self, rom_bytes: bytes) -> None: self.check_len() path = self.out_path() @@ -90,15 +94,14 @@ def split(self, rom_bytes): self.log(f"Wrote {self.name} to {path}") @staticmethod - def estimate_size(yaml: Union[Dict, List]) -> int: + def estimate_size(yaml: SerializedSegmentData | list[str]) -> int: width, height = N64SegImg.parse_dimensions(yaml) typ = Segment.parse_segment_type(yaml) if typ == "ci4" or typ == "i4" or typ == "ia4": return width * height // 2 - elif typ in ("ia16", "rgba16"): + if typ in ("ia16", "rgba16"): return width * height * 2 - elif typ == "rgba32": + if typ == "rgba32": return width * height * 4 - else: - return width * height + return width * height diff --git a/src/splat/segtypes/n64/palette.py b/src/splat/segtypes/n64/palette.py index de0d830b..4c58dab1 100644 --- a/src/splat/segtypes/n64/palette.py +++ b/src/splat/segtypes/n64/palette.py @@ -1,20 +1,31 @@ +from __future__ import annotations + from itertools import zip_longest -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, TypeVar from ...util import log, options from ...util.color import unpack_color from ..segment import Segment +if TYPE_CHECKING: + from pathlib import Path + from collections.abc import Iterable + from typing import Final + + from ..linker_entry import LinkerEntry + from ...util.vram_classes import SerializedSegmentData + +T = TypeVar("T") + -VALID_SIZES = [0x20, 0x40, 0x80, 0x100, 0x200] +VALID_SIZES: Final = (0x20, 0x40, 0x80, 0x100, 0x200) class N64SegPalette(Segment): require_unique_name = False - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) if self.extract: @@ -58,7 +69,7 @@ def __init__(self, *args, **kwargs): size = 0 self.palette_size: int = size - self.global_id: Optional[str] = ( + self.global_id: str | None = ( self.yaml.get("global_id") if isinstance(self.yaml, dict) else None ) @@ -66,8 +77,10 @@ def get_cname(self) -> str: return super().get_cname() + "_pal" @staticmethod - def parse_palette_bytes(data) -> List[Tuple[int, int, int, int]]: - def iter_in_groups(iterable, n, fillvalue=None): + def parse_palette_bytes(data: bytes) -> list[tuple[int, int, int, int]]: + def iter_in_groups( + iterable: Iterable[T], n: int, fillvalue: object = None + ) -> zip_longest[tuple[T, ...]]: args = [iter(iterable)] * n return zip_longest(*args, fillvalue=fillvalue) @@ -78,7 +91,7 @@ def iter_in_groups(iterable, n, fillvalue=None): return palette - def parse_palette(self, rom_bytes) -> List[Tuple[int, int, int, int]]: + def parse_palette(self, rom_bytes: bytes) -> list[tuple[int, int, int, int]]: assert self.rom_start is not None data = rom_bytes[self.rom_start : self.rom_start + self.palette_size] @@ -91,7 +104,7 @@ def out_path(self) -> Path: return options.opts.asset_path / self.dir / f"{self.name}.png" # TODO NEED NAMES... - def get_linker_entries(self): + def get_linker_entries(self) -> list[LinkerEntry]: from ..linker_entry import LinkerEntry return [ @@ -106,7 +119,7 @@ def get_linker_entries(self): ] @staticmethod - def estimate_size(yaml: Union[Dict, List]) -> int: + def estimate_size(yaml: SerializedSegmentData | list[str]) -> int: if isinstance(yaml, dict): if "size" in yaml: return int(yaml["size"]) diff --git a/src/splat/segtypes/n64/vtx.py b/src/splat/segtypes/n64/vtx.py index acf4bcbb..5911fc39 100644 --- a/src/splat/segtypes/n64/vtx.py +++ b/src/splat/segtypes/n64/vtx.py @@ -5,26 +5,32 @@ Originally written by Mark Street (https://github.com/mkst) """ +from __future__ import annotations + import struct -from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING from ...util import options, log from ..common.codesubsegment import CommonSegCodeSubsegment +if TYPE_CHECKING: + from pathlib import Path + from ...util.vram_classes import SerializedSegmentData + from ...util.symbols import Symbol + class N64SegVtx(CommonSegCodeSubsegment): def __init__( self, - rom_start: Optional[int], - rom_end: Optional[int], + rom_start: int | None, + rom_end: int | None, type: str, name: str, - vram_start: Optional[int], - args: list, - yaml, - ): + vram_start: int | None, + args: list[str], + yaml: SerializedSegmentData | list[str], + ) -> None: super().__init__( rom_start, rom_end, @@ -34,10 +40,10 @@ def __init__( args=args, yaml=yaml, ) - self.file_text: Optional[str] = None + self.file_text: str | None = None self.data_only = isinstance(yaml, dict) and yaml.get("data_only", False) - def format_sym_name(self, sym) -> str: + def format_sym_name(self, sym: Symbol) -> str: return sym.name def get_linker_section(self) -> str: @@ -46,10 +52,10 @@ def get_linker_section(self) -> str: def out_path(self) -> Path: return options.opts.asset_path / self.dir / f"{self.name}.vtx.inc.c" - def scan(self, rom_bytes: bytes): + def scan(self, rom_bytes: bytes) -> None: self.file_text = self.disassemble_data(rom_bytes) - def disassemble_data(self, rom_bytes) -> str: + def disassemble_data(self, rom_bytes: bytes) -> str: assert isinstance(self.rom_start, int) assert isinstance(self.rom_end, int) assert isinstance(self.vram_start, int) @@ -88,11 +94,12 @@ def disassemble_data(self, rom_bytes) -> str: lines.append("") return "\n".join(lines) - def split(self, rom_bytes: bytes): - if self.file_text and self.out_path(): - self.out_path().parent.mkdir(parents=True, exist_ok=True) + def split(self, rom_bytes: bytes) -> None: + out_path = self.out_path() + if self.file_text and out_path is not None: + out_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.out_path(), "w", newline="\n") as f: + with open(out_path, "w", encoding="utf-8", newline="\n") as f: f.write(self.file_text) def should_scan(self) -> bool: @@ -102,7 +109,7 @@ def should_split(self) -> bool: return self.extract and options.opts.is_mode_active("vtx") @staticmethod - def estimate_size(yaml: Union[Dict, List]) -> Optional[int]: + def estimate_size(yaml: SerializedSegmentData | list[str]) -> int | None: if isinstance(yaml, dict) and "length" in yaml: return yaml["length"] * 0x10 return None diff --git a/src/splat/segtypes/ps2/ctor.py b/src/splat/segtypes/ps2/ctor.py index 003059ac..192b6ea6 100644 --- a/src/splat/segtypes/ps2/ctor.py +++ b/src/splat/segtypes/ps2/ctor.py @@ -1,7 +1,10 @@ -from typing import Optional +from __future__ import annotations from ..common.data import CommonSegData -from ...disassembler.disassembler_section import DisassemblerSection +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...disassembler.disassembler_section import DisassemblerSection class Ps2SegCtor(CommonSegData): @@ -10,7 +13,7 @@ class Ps2SegCtor(CommonSegData): def get_linker_section(self) -> str: return ".ctor" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "a" def configure_disassembler_section( @@ -21,6 +24,7 @@ def configure_disassembler_section( super().configure_disassembler_section(disassembler_section) section = disassembler_section.get_section() + assert section is not None # We use s32 to make sure spimdisasm disassembles the data from this section as words/references to other symbols section.enableStringGuessing = False diff --git a/src/splat/segtypes/ps2/lit4.py b/src/splat/segtypes/ps2/lit4.py index d2f1d581..064deab3 100644 --- a/src/splat/segtypes/ps2/lit4.py +++ b/src/splat/segtypes/ps2/lit4.py @@ -1,7 +1,10 @@ -from typing import Optional +from __future__ import annotations from ..common.data import CommonSegData -from ...disassembler.disassembler_section import DisassemblerSection +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...disassembler.disassembler_section import DisassemblerSection class Ps2SegLit4(CommonSegData): @@ -10,7 +13,7 @@ class Ps2SegLit4(CommonSegData): def get_linker_section(self) -> str: return ".lit4" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "wa" def configure_disassembler_section( @@ -21,6 +24,7 @@ def configure_disassembler_section( super().configure_disassembler_section(disassembler_section) section = disassembler_section.get_section() + assert section is not None # Tell spimdisasm this section only contains floats section.enableStringGuessing = False diff --git a/src/splat/segtypes/ps2/lit8.py b/src/splat/segtypes/ps2/lit8.py index 81b5cf99..462d34dc 100644 --- a/src/splat/segtypes/ps2/lit8.py +++ b/src/splat/segtypes/ps2/lit8.py @@ -1,7 +1,10 @@ -from typing import Optional +from __future__ import annotations from ..common.data import CommonSegData -from ...disassembler.disassembler_section import DisassemblerSection +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...disassembler.disassembler_section import DisassemblerSection class Ps2SegLit8(CommonSegData): @@ -10,7 +13,7 @@ class Ps2SegLit8(CommonSegData): def get_linker_section(self) -> str: return ".lit8" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "wa" def configure_disassembler_section( @@ -21,6 +24,7 @@ def configure_disassembler_section( super().configure_disassembler_section(disassembler_section) section = disassembler_section.get_section() + assert section is not None # Tell spimdisasm this section only contains doubles section.enableStringGuessing = False diff --git a/src/splat/segtypes/ps2/vtables.py b/src/splat/segtypes/ps2/vtables.py index afbdbd17..25a29dcd 100644 --- a/src/splat/segtypes/ps2/vtables.py +++ b/src/splat/segtypes/ps2/vtables.py @@ -1,7 +1,10 @@ -from typing import Optional +from __future__ import annotations from ..common.data import CommonSegData -from ...disassembler.disassembler_section import DisassemblerSection +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...disassembler.disassembler_section import DisassemblerSection class Ps2SegVtables(CommonSegData): @@ -10,7 +13,7 @@ class Ps2SegVtables(CommonSegData): def get_linker_section(self) -> str: return ".vtables" - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: return "a" def configure_disassembler_section( @@ -21,6 +24,7 @@ def configure_disassembler_section( super().configure_disassembler_section(disassembler_section) section = disassembler_section.get_section() + assert section is not None # We use s32 to make sure spimdisasm disassembles the data from this section as words/references to other symbols section.enableStringGuessing = False diff --git a/src/splat/segtypes/psx/header.py b/src/splat/segtypes/psx/header.py index 6b0990cb..fc61059f 100644 --- a/src/splat/segtypes/psx/header.py +++ b/src/splat/segtypes/psx/header.py @@ -4,7 +4,7 @@ class PsxSegHeader(CommonSegHeader): # little endian so reverse words, TODO: use struct.unpack(" list[str]: header_lines = [] header_lines.append(".section .data\n") header_lines.append( diff --git a/src/splat/segtypes/segment.py b/src/splat/segtypes/segment.py index 784ca4a2..9d744df7 100644 --- a/src/splat/segtypes/segment.py +++ b/src/splat/segtypes/segment.py @@ -1,15 +1,18 @@ +from __future__ import annotations + import collections import dataclasses import importlib import importlib.util from pathlib import Path -from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING, Union, Tuple +from typing import TYPE_CHECKING, Union from intervaltree import Interval, IntervalTree from ..util import vram_classes -from ..util.vram_classes import VramClass + +from ..util.vram_classes import VramClass, SerializedSegmentData from ..util import log, options, symbols from ..util.symbols import Symbol, to_cname @@ -18,54 +21,57 @@ # circular import if TYPE_CHECKING: from ..segtypes.linker_entry import LinkerEntry + from typing_extensions import TypeAlias + +SerializedSegment: TypeAlias = Union[SerializedSegmentData, list[str]] -def parse_segment_vram(segment: Union[dict, list]) -> Optional[int]: + +def parse_segment_vram(segment: SerializedSegment) -> int | None: if isinstance(segment, dict) and "vram" in segment: return int(segment["vram"]) - else: - return None + return None -def parse_segment_vram_symbol(segment: Union[dict, list]) -> Optional[str]: +def parse_segment_vram_symbol(segment: SerializedSegment) -> str | None: if isinstance(segment, dict) and "vram_symbol" in segment: return str(segment["vram_symbol"]) - else: - return None + return None -def parse_segment_vram_class(segment: Union[dict, list]) -> Optional[VramClass]: +def parse_segment_vram_class(segment: SerializedSegment) -> VramClass | None: if isinstance(segment, dict) and "vram_class" in segment: return vram_classes.resolve(segment["vram_class"]) return None -def parse_segment_follows_vram(segment: Union[dict, list]) -> Optional[str]: +def parse_segment_follows_vram(segment: SerializedSegment) -> str | None: if isinstance(segment, dict): return segment.get("follows_vram", None) return None -def parse_segment_align(segment: Union[dict, list]) -> Optional[int]: +def parse_segment_align(segment: SerializedSegment) -> int | None: if isinstance(segment, dict) and "align" in segment: return int(segment["align"]) return None -def parse_segment_subalign(segment: Union[dict, list]) -> Optional[int]: +def parse_segment_subalign(segment: SerializedSegment) -> int | None: default = options.opts.subalign if isinstance(segment, dict): - subalign = segment.get("subalign", default) + subalign: int | str | None = segment.get("subalign", default) if subalign is not None: - subalign = int(subalign) - return subalign + return int(subalign) + return None return default -def parse_segment_section_order(segment: Union[dict, list]) -> List[str]: +def parse_segment_section_order(segment: SerializedSegment) -> list[str]: default = options.opts.section_order if isinstance(segment, dict): - return segment.get("section_order", default) + section_order: list[str] = segment.get("section_order", default) + return section_order return default @@ -77,7 +83,7 @@ class SegmentStatisticsInfo: size: int count: int - def merge(self, other: "SegmentStatisticsInfo") -> "SegmentStatisticsInfo": + def merge(self, other: SegmentStatisticsInfo) -> SegmentStatisticsInfo: return SegmentStatisticsInfo( size=self.size + other.size, count=self.count + other.count ) @@ -94,7 +100,7 @@ class Segment: require_unique_name = True @staticmethod - def get_class_for_type(seg_type) -> Type["Segment"]: + def get_class_for_type(seg_type: str) -> type[Segment]: # so .data loads SegData, for example if seg_type.startswith("."): seg_type = seg_type[1:] @@ -117,7 +123,7 @@ def get_class_for_type(seg_type) -> Type["Segment"]: return segment_class @staticmethod - def get_base_segment_class(seg_type): + def get_base_segment_class(seg_type: str) -> type[Segment] | None: platform = options.opts.platform is_platform_seg = False @@ -136,10 +142,12 @@ def get_base_segment_class(seg_type): return None seg_prefix = platform.capitalize() if is_platform_seg else "Common" - return getattr(segmodule, f"{seg_prefix}Seg{seg_type.capitalize()}") + return getattr( # type: ignore[no-any-return] + segmodule, f"{seg_prefix}Seg{seg_type.capitalize()}" + ) @staticmethod - def get_extension_segment_class(seg_type): + def get_extension_segment_class(seg_type: str) -> type[Segment] | None: platform = options.opts.platform ext_path = options.opts.extensions_path @@ -161,12 +169,12 @@ def get_extension_segment_class(seg_type): except Exception: return None - return getattr( + return getattr( # type: ignore[no-any-return] ext_mod, f"{platform.upper()}Seg{seg_type[0].upper()}{seg_type[1:]}" ) @staticmethod - def parse_segment_start(segment: Union[dict, list]) -> Tuple[Optional[int], bool]: + def parse_segment_start(segment: SerializedSegment) -> tuple[int | None, bool]: """ Parses the rom start address of a given segment. @@ -178,7 +186,7 @@ def parse_segment_start(segment: Union[dict, list]) -> Tuple[Optional[int], bool """ if isinstance(segment, dict): - s = segment.get("start", None) + s: str | None = segment.get("start", None) else: s = segment[0] @@ -186,85 +194,82 @@ def parse_segment_start(segment: Union[dict, list]) -> Tuple[Optional[int], bool return None, False if s == "auto": return None, True - else: - return int(s), False + return int(s), False @staticmethod - def parse_segment_type(segment: Union[dict, list]) -> str: + def parse_segment_type(segment: SerializedSegment) -> str: if isinstance(segment, dict): return str(segment["type"]) - else: - return str(segment[1]) + return str(segment[1]) - @staticmethod - def parse_segment_name(cls, rom_start, segment: Union[dict, list]) -> str: - if isinstance(segment, dict) and "name" in segment: - return str(segment["name"]) - elif isinstance(segment, dict) and "dir" in segment: - return str(segment["dir"]) + @classmethod + def parse_segment_name( + cls, rom_start: int | None, segment: SerializedSegment + ) -> str: + if isinstance(segment, dict): + if "name" in segment: + return str(segment["name"]) + if "dir" in segment: + return str(segment["dir"]) elif isinstance(segment, list) and len(segment) >= 3: return str(segment[2]) - else: - return str(cls.get_default_name(rom_start)) + assert rom_start is not None + return str(cls.get_default_name(rom_start)) @staticmethod - def parse_segment_symbol_name_format(segment: Union[dict, list]) -> str: + def parse_segment_symbol_name_format(segment: SerializedSegment) -> str: if isinstance(segment, dict) and "symbol_name_format" in segment: return str(segment["symbol_name_format"]) - else: - return options.opts.symbol_name_format + return options.opts.symbol_name_format @staticmethod - def parse_segment_symbol_name_format_no_rom(segment: Union[dict, list]) -> str: + def parse_segment_symbol_name_format_no_rom(segment: SerializedSegment) -> str: if isinstance(segment, dict) and "symbol_name_format_no_rom" in segment: return str(segment["symbol_name_format_no_rom"]) - else: - return options.opts.symbol_name_format_no_rom + return options.opts.symbol_name_format_no_rom @staticmethod - def parse_segment_file_path(segment: Union[dict, list]) -> Optional[Path]: + def parse_segment_file_path(segment: SerializedSegment) -> Path | None: if isinstance(segment, dict) and "path" in segment: return Path(segment["path"]) return None @staticmethod def parse_segment_bss_contains_common( - segment: Union[dict, list], default: bool + segment: SerializedSegment, default: bool ) -> bool: if isinstance(segment, dict) and "bss_contains_common" in segment: return bool(segment["bss_contains_common"]) return default @staticmethod - def parse_linker_section_order(yaml: Union[dict, list]) -> Optional[str]: + def parse_linker_section_order(yaml: SerializedSegment) -> str | None: if isinstance(yaml, dict) and "linker_section_order" in yaml: return str(yaml["linker_section_order"]) return None @staticmethod - def parse_linker_section(yaml: Union[dict, list]) -> Optional[str]: + def parse_linker_section(yaml: SerializedSegment) -> str | None: if isinstance(yaml, dict) and "linker_section" in yaml: return str(yaml["linker_section"]) return None @staticmethod - def parse_ld_fill_value( - yaml: Union[dict, list], default: Optional[int] - ) -> Optional[int]: + def parse_ld_fill_value(yaml: SerializedSegment, default: int | None) -> int | None: if isinstance(yaml, dict) and "ld_fill_value" in yaml: return yaml["ld_fill_value"] return default @staticmethod - def parse_ld_align_segment_start(yaml: Union[dict, list]) -> Optional[int]: + def parse_ld_align_segment_start(yaml: SerializedSegment) -> int | None: if isinstance(yaml, dict) and "ld_align_segment_start" in yaml: return yaml["ld_align_segment_start"] return options.opts.ld_align_segment_start @staticmethod def parse_suggestion_rodata_section_start( - yaml: Union[dict, list], - ) -> Optional[bool]: + yaml: SerializedSegment, + ) -> bool | None: if isinstance(yaml, dict): suggestion_rodata_section_start = yaml.get( "suggestion_rodata_section_start" @@ -275,62 +280,62 @@ def parse_suggestion_rodata_section_start( return None @staticmethod - def parse_pair_segment(yaml: Union[dict, list]) -> Optional[str]: + def parse_pair_segment(yaml: SerializedSegment) -> str | None: if isinstance(yaml, dict) and "pair_segment" in yaml: return yaml["pair_segment"] return None def __init__( self, - rom_start: Optional[int], - rom_end: Optional[int], + rom_start: int | None, + rom_end: int | None, type: str, name: str, - vram_start: Optional[int], - args: list, - yaml, - ): + vram_start: int | None, + args: list[str], + yaml: SerializedSegment, + ) -> None: self.rom_start = rom_start self.rom_end = rom_end self.type = type self.name = name - self.vram_start: Optional[int] = vram_start + self.vram_start: int | None = vram_start - self.align: Optional[int] = None - self.given_subalign: Optional[int] = options.opts.subalign - self.exclusive_ram_id: Optional[str] = None + self.align: int | None = None + self.given_subalign: int | None = options.opts.subalign + self.exclusive_ram_id: str | None = None self.given_dir: Path = Path() # Default to global options. - self.given_find_file_boundaries: Optional[bool] = None + self.given_find_file_boundaries: bool | None = None # Symbols known to be in this segment - self.given_seg_symbols: Dict[int, List[Symbol]] = {} + self.given_seg_symbols: dict[int, list[Symbol]] = {} # Ranges for faster symbol lookup self.symbol_ranges_ram: IntervalTree = IntervalTree() self.symbol_ranges_rom: IntervalTree = IntervalTree() - self.given_section_order: List[str] = options.opts.section_order + self.given_section_order: list[str] = options.opts.section_order - self.vram_class: Optional[VramClass] = None - self.given_follows_vram: Optional[str] = None - self.given_vram_symbol: Optional[str] = None + self.vram_class: VramClass | None = None + self.given_follows_vram: str | None = None + self.given_vram_symbol: str | None = None self.given_symbol_name_format: str = options.opts.symbol_name_format self.given_symbol_name_format_no_rom: str = ( options.opts.symbol_name_format_no_rom ) - self.parent: Optional[Segment] = None - self.sibling: Optional[Segment] = None - self.siblings: Dict[str, Segment] = {} - self.pair_segment_name: Optional[str] = self.parse_pair_segment(yaml) - self.paired_segment: Optional[Segment] = None + self.parent: Segment | None = None + self.sibling: Segment | None = None + self.siblings: dict[str, Segment] = {} + self.pair_segment_name: str | None = self.parse_pair_segment(yaml) + self.paired_segment: Segment | None = None - self.file_path: Optional[Path] = None + self.file_path: Path | None = None - self.args: List[str] = args + self.args: list[str] = args self.yaml = yaml self.extract: bool = True @@ -340,7 +345,7 @@ def __init__( elif self.type.startswith("."): self.extract = False - self.warnings: List[str] = [] + self.warnings: list[str] = [] self.did_run = False self.bss_contains_common = Segment.parse_segment_bss_contains_common( yaml, options.opts.ld_bss_contains_common @@ -349,29 +354,29 @@ def __init__( # For segments which are not in the usual VRAM segment space, like N64's IPL3 which lives in 0xA4... self.special_vram_segment: bool = False - self.linker_section_order: Optional[str] = self.parse_linker_section_order(yaml) - self.linker_section: Optional[str] = self.parse_linker_section(yaml) + self.linker_section_order: str | None = self.parse_linker_section_order(yaml) + self.linker_section: str | None = self.parse_linker_section(yaml) # If not defined on the segment then default to the global option - self.ld_fill_value: Optional[int] = self.parse_ld_fill_value( + self.ld_fill_value: int | None = self.parse_ld_fill_value( yaml, options.opts.ld_fill_value ) - self.ld_align_segment_start: Optional[int] = self.parse_ld_align_segment_start( + self.ld_align_segment_start: int | None = self.parse_ld_align_segment_start( yaml ) # True if this segment was generated based on auto_link_sections self.is_generated: bool = False - self.given_suggestion_rodata_section_start: Optional[bool] = ( + self.given_suggestion_rodata_section_start: bool | None = ( self.parse_suggestion_rodata_section_start(yaml) ) # Is an automatic segment, generated automatically or declared on the yaml by the user self.is_auto_segment: bool = False - self.index_within_group: Optional[int] = None + self.index_within_group: int | None = None if self.rom_start is not None and self.rom_end is not None: if self.rom_start > self.rom_end: @@ -381,18 +386,20 @@ def __init__( @staticmethod def from_yaml( - cls: Type["Segment"], - yaml: Union[dict, list], - rom_start: Optional[int], - rom_end: Optional[int], - parent: Optional["Segment"], - vram=None, - ): - type = Segment.parse_segment_type(yaml) - name = Segment.parse_segment_name(cls, rom_start, yaml) + cls: type[Segment], + yaml: SerializedSegment, + rom_start: int | None, + rom_end: int | None, + parent: Segment | None, + vram: int | None = None, + ) -> Segment: + type = cls.parse_segment_type(yaml) + name = cls.parse_segment_name(rom_start, yaml) vram_class = parse_segment_vram_class(yaml) + vram_start: int | None + if vram is not None: vram_start = vram elif vram_class: @@ -400,7 +407,7 @@ def from_yaml( else: vram_start = parse_segment_vram(yaml) - args: List[str] = [] if isinstance(yaml, dict) else yaml[3:] + args: list[str] = [] if isinstance(yaml, dict) else yaml[3:] ret = cls( rom_start=rom_start, @@ -502,7 +509,7 @@ def is_noload() -> bool: return False @staticmethod - def estimate_size(yaml: Union[Dict, List]) -> Optional[int]: + def estimate_size(yaml: SerializedSegment) -> int | None: return None @property @@ -513,8 +520,7 @@ def needs_symbols(self) -> bool: def dir(self) -> Path: if self.parent: return self.parent.dir / self.given_dir - else: - return self.given_dir + return self.given_dir @property def show_file_boundaries(self) -> bool: @@ -537,27 +543,26 @@ def symbol_name_format_no_rom(self) -> str: return self.given_symbol_name_format_no_rom @property - def subalign(self) -> Optional[int]: + def subalign(self) -> int | None: assert self.parent is None, ( f"subalign is not valid for non-top-level segments. ({self})" ) return self.given_subalign @property - def vram_symbol(self) -> Optional[str]: + def vram_symbol(self) -> str | None: if self.vram_class and self.vram_class.vram_symbol: return self.vram_class.vram_symbol - elif self.given_vram_symbol: + if self.given_vram_symbol: return self.given_vram_symbol - else: - return None + return None - def get_exclusive_ram_id(self) -> Optional[str]: + def get_exclusive_ram_id(self) -> str | None: if self.parent: return self.parent.get_exclusive_ram_id() return self.exclusive_ram_id - def add_symbol(self, symbol: Symbol): + def add_symbol(self, symbol: Symbol) -> None: if symbol.vram_start not in self.given_seg_symbols: self.given_seg_symbols[symbol.vram_start] = [] self.given_seg_symbols[symbol.vram_start].append(symbol) @@ -569,18 +574,16 @@ def add_symbol(self, symbol: Symbol): self.symbol_ranges_rom.addi(symbol.rom, symbol.rom_end, symbol) @property - def seg_symbols(self) -> Dict[int, List[Symbol]]: + def seg_symbols(self) -> dict[int, list[Symbol]]: if self.parent: return self.parent.seg_symbols - else: - return self.given_seg_symbols + return self.given_seg_symbols @property - def size(self) -> Optional[int]: + def size(self) -> int | None: if self.rom_start is not None and self.rom_end is not None: return self.rom_end - self.rom_start - else: - return None + return None @property def statistics(self) -> SegmentStatistics: @@ -594,14 +597,13 @@ def statistics_type(self) -> SegmentType: return self.type @property - def vram_end(self) -> Optional[int]: + def vram_end(self) -> int | None: if self.vram_start is not None and self.size is not None: return self.vram_start + self.size - else: - return None + return None @property - def section_order(self) -> List[str]: + def section_order(self) -> list[str]: return self.given_section_order @property @@ -630,29 +632,25 @@ def get_cname(self) -> str: def contains_vram(self, vram: int) -> bool: if self.vram_start is not None and self.vram_end is not None: return vram >= self.vram_start and vram < self.vram_end - else: - return False + return False def contains_rom(self, rom: int) -> bool: if self.rom_start is not None and self.rom_end is not None: return rom >= self.rom_start and rom < self.rom_end - else: - return False + return False - def rom_to_ram(self, rom_addr: int) -> Optional[int]: + def rom_to_ram(self, rom_addr: int) -> int | None: if self.vram_start is not None and self.rom_start is not None: return self.vram_start + rom_addr - self.rom_start - else: - return None + return None - def ram_to_rom(self, ram_addr: int) -> Optional[int]: + def ram_to_rom(self, ram_addr: int) -> int | None: if not self.contains_vram(ram_addr) and ram_addr != self.vram_end: return None if self.vram_start is not None and self.rom_start is not None: return self.rom_start + ram_addr - self.vram_start - else: - return None + return None def should_scan(self) -> bool: return self.should_split() @@ -660,13 +658,13 @@ def should_scan(self) -> bool: def should_split(self) -> bool: return self.extract and options.opts.is_mode_active(self.type) - def scan(self, rom_bytes: bytes): + def scan(self, rom_bytes: bytes) -> None: pass - def split(self, rom_bytes: bytes): + def split(self, rom_bytes: bytes) -> None: pass - def cache(self): + def cache(self) -> tuple[SerializedSegment, int | None]: return (self.yaml, self.rom_end) def get_linker_section(self) -> str: @@ -690,7 +688,7 @@ def get_linker_section_linksection(self) -> str: return self.linker_section return self.get_linker_section() - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: """ Allows specifying flags for a section. @@ -701,7 +699,7 @@ def get_section_flags(self) -> Optional[str]: Example: ``` - def get_section_flags(self) -> Optional[str]: + def get_section_flags(self) -> str | None: # Tells the linker to allocate this section return "a" ``` @@ -715,10 +713,10 @@ def get_section_asm_line(self) -> str: line += f', "{section_flags}"' return line - def out_path(self) -> Optional[Path]: + def out_path(self) -> Path | None: return None - def get_most_parent(self) -> "Segment": + def get_most_parent(self) -> Segment: seg = self while seg.parent: @@ -726,7 +724,7 @@ def get_most_parent(self) -> "Segment": return seg - def get_linker_entries(self) -> "List[LinkerEntry]": + def get_linker_entries(self) -> list[LinkerEntry]: from ..segtypes.linker_entry import LinkerEntry if not self.has_linker_entry: @@ -745,24 +743,24 @@ def get_linker_entries(self) -> "List[LinkerEntry]": self.is_noload(), ) ] - else: - return [] + return [] - def log(self, msg): + def log(self, msg: str) -> None: if options.opts.verbose: log.write(f"{self.type} {self.name}: {msg}") - def warn(self, msg: str): + def warn(self, msg: str) -> None: self.warnings.append(msg) @staticmethod - def get_default_name(addr) -> str: + def get_default_name(addr: int) -> str: return f"{addr:X}" - def is_name_default(self): + def is_name_default(self) -> bool: + assert self.rom_start is not None return self.name == self.get_default_name(self.rom_start) - def unique_id(self): + def unique_id(self) -> str: if self.parent: s = self.parent.unique_id() + "_" else: @@ -771,7 +769,7 @@ def unique_id(self): return s + self.type + "_" + self.name @staticmethod - def visible_ram(seg1: "Segment", seg2: "Segment") -> bool: + def visible_ram(seg1: Segment, seg2: Segment) -> bool: if seg1.get_most_parent() == seg2.get_most_parent(): return True if seg1.get_exclusive_ram_id() is None or seg2.get_exclusive_ram_id() is None: @@ -779,8 +777,8 @@ def visible_ram(seg1: "Segment", seg2: "Segment") -> bool: return seg1.get_exclusive_ram_id() != seg2.get_exclusive_ram_id() def retrieve_symbol( - self, syms: Dict[int, List[Symbol]], addr: int - ) -> Optional[Symbol]: + self, syms: dict[int, list[Symbol]], addr: int + ) -> Symbol | None: if addr not in syms: return None @@ -801,8 +799,8 @@ def retrieve_symbol( return items[0] def retrieve_sym_type( - self, syms: Dict[int, List[Symbol]], addr: int, type: str - ) -> Optional[symbols.Symbol]: + self, syms: dict[int, list[Symbol]], addr: int, type: str + ) -> symbols.Symbol | None: if addr not in syms: return None @@ -824,15 +822,15 @@ def get_symbol( self, addr: int, in_segment: bool = False, - type: Optional[str] = None, + type: str | None = None, create: bool = False, define: bool = False, reference: bool = False, search_ranges: bool = False, local_only: bool = False, - ) -> Optional[Symbol]: - ret: Optional[Symbol] = None - rom: Optional[int] = None + ) -> Symbol | None: + ret: Symbol | None = None + rom: int | None = None most_parent = self.get_most_parent() @@ -844,7 +842,7 @@ def get_symbol( if not ret and search_ranges: # Search ranges first, starting with rom if rom is not None: - cands: Set[Interval] = most_parent.symbol_ranges_rom[rom] + cands: set[Interval] = most_parent.symbol_ranges_rom[rom] if cands: ret = cands.pop().data # and then vram if we can't find a rom match @@ -890,7 +888,7 @@ def create_symbol( self, addr: int, in_segment: bool, - type: Optional[str] = None, + type: str | None = None, define: bool = False, reference: bool = False, search_ranges: bool = False, diff --git a/src/splat/util/cache_handler.py b/src/splat/util/cache_handler.py index deb55780..0e42821a 100644 --- a/src/splat/util/cache_handler.py +++ b/src/splat/util/cache_handler.py @@ -1,14 +1,18 @@ +from __future__ import annotations + import pickle -from typing import Any, Dict +from typing import TYPE_CHECKING, Any + +from . import log, options -from . import options, log -from ..segtypes.common.segment import Segment +if TYPE_CHECKING: + from ..segtypes.segment import Segment class Cache: - def __init__(self, config: Dict[str, Any], use_cache: bool, verbose: bool): + def __init__(self, config: dict[str, Any], use_cache: bool, verbose: bool) -> None: self.use_cache: bool = use_cache - self.cache: Dict[str, Any] = {} + self.cache: dict[str, Any] = {} # Load cache if use_cache and options.opts.cache_path.exists(): @@ -32,7 +36,7 @@ def __init__(self, config: Dict[str, Any], use_cache: bool, verbose: bool): "__options__": config.get("options"), } - def save(self, verbose: bool): + def save(self, verbose: bool) -> None: if self.cache != {} and self.use_cache: if verbose: log.write("Writing cache") diff --git a/src/splat/util/color.py b/src/splat/util/color.py index f8250ae0..0d143d67 100644 --- a/src/splat/util/color.py +++ b/src/splat/util/color.py @@ -1,10 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING from math import ceil from . import options +if TYPE_CHECKING: + from collections.abc import Sequence + # RRRRRGGG GGBBBBBA -def unpack_color(data): +def unpack_color(data: Sequence[int]) -> tuple[int, int, int, int]: s = int.from_bytes(data[0:2], byteorder=options.opts.endianness) r = (s >> 11) & 0x1F diff --git a/src/splat/util/compiler.py b/src/splat/util/compiler.py index eeea374a..0b30719e 100644 --- a/src/splat/util/compiler.py +++ b/src/splat/util/compiler.py @@ -1,5 +1,5 @@ +from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Dict @dataclass @@ -15,7 +15,7 @@ class Compiler: asm_nonmatching_label_macro: str = "nonmatching" c_newline: str = "\n" asm_inc_header: str = "" - asm_emit_size_directive: Optional[bool] = None + asm_emit_size_directive: bool | None = None j_as_branch: bool = False uses_include_asm: bool = True align_on_branch_labels: bool = False @@ -64,7 +64,7 @@ class Compiler: MWCCPS2 = Compiler("MWCCPS2", uses_include_asm=False) EEGCC = Compiler("EEGCC", align_on_branch_labels=True) -compiler_for_name: Dict[str, Compiler] = { +compiler_for_name: dict[str, Compiler] = { x.name: x for x in [ GCC, diff --git a/src/splat/util/conf.py b/src/splat/util/conf.py index ea45069b..76241096 100644 --- a/src/splat/util/conf.py +++ b/src/splat/util/conf.py @@ -6,8 +6,9 @@ config = conf.load("path/to/splat.yaml") """ -from typing import Any, Dict, List, Optional -from pathlib import Path +from __future__ import annotations + +from typing import Any, TYPE_CHECKING # This unused import makes the yaml library faster. don't remove import pylibyaml # noqa: F401 @@ -15,6 +16,9 @@ from . import options, vram_classes +if TYPE_CHECKING: + from pathlib import Path + def _merge_configs(main_config, additional_config, additional_config_path): # Merge rules are simple @@ -28,7 +32,7 @@ def _merge_configs(main_config, additional_config, additional_config_path): main_config[curkey] = additional_config[curkey] elif type(main_config[curkey]) is not type(additional_config[curkey]): raise TypeError( - f"Could not merge {str(additional_config_path)}: type for key '{curkey}' in configs does not match" + f"Could not merge {additional_config_path!s}: type for key '{curkey}' in configs does not match" ) else: # keys exist and match, see if a list to append @@ -49,12 +53,12 @@ def _merge_configs(main_config, additional_config, additional_config_path): def load( - config_path: List[Path], - modes: Optional[List[str]] = None, + config_path: list[Path], + modes: list[str] | None = None, verbose: bool = False, disassemble_all: bool = False, make_full_disasm_for_code=False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Returns a `dict` with resolved splat config. @@ -76,7 +80,7 @@ def load( Config with invalid options may raise an error. """ - config: Dict[str, Any] = {} + config: dict[str, Any] = {} for entry in config_path: with entry.open() as f: additional_config = yaml.load(f.read(), Loader=yaml.SafeLoader) diff --git a/src/splat/util/file_presets.py b/src/splat/util/file_presets.py index 7965f68b..c94ad8d2 100644 --- a/src/splat/util/file_presets.py +++ b/src/splat/util/file_presets.py @@ -15,7 +15,7 @@ from . import options, log -def write_all_files(): +def write_all_files() -> None: if not options.opts.generate_asm_macros_files: return @@ -36,7 +36,7 @@ def _write(filepath: str, contents: str): f.write(contents) -def write_include_asm_h(): +def write_include_asm_h() -> None: if not options.opts.compiler.uses_include_asm: # These compilers do not use the `INCLUDE_ASM` macro. return @@ -115,7 +115,7 @@ def write_include_asm_h(): _write(f"{directory}/include_asm.h", file_data) -def write_assembly_inc_files(): +def write_assembly_inc_files() -> None: directory = options.opts.generated_asm_macros_directory.as_posix() func_macros = f"""\ @@ -352,7 +352,7 @@ def write_assembly_inc_files(): _write(f"{directory}/macro.inc", f"{preamble}\n{gas}") -def write_gte_macros(): +def write_gte_macros() -> None: # Taken directly from https://github.com/Decompollaborate/rabbitizer/blob/-/docs/r3000gte/gte_macros.s # Please try to upstream any fix/update done here. gte_macros = """\ diff --git a/src/splat/util/log.py b/src/splat/util/log.py index 6dc8ff64..66cd7735 100644 --- a/src/splat/util/log.py +++ b/src/splat/util/log.py @@ -1,17 +1,29 @@ +from __future__ import annotations + import sys -from typing import NoReturn, Optional -from pathlib import Path +from typing import TYPE_CHECKING, NoReturn, Optional, TextIO + +from colorama import Fore, Style, init + +if TYPE_CHECKING: + from pathlib import Path -from colorama import Fore, init, Style + from typing_extensions import TypeAlias init(autoreset=True) newline = True -Status = Optional[str] +Status: TypeAlias = Optional[str] -def write(*args, status=None, **kwargs): +def write( + *args: object, + status: Status = None, + sep: str | None = None, + end: str | None = None, + flush: bool = False, +) -> None: global newline if not newline: @@ -21,37 +33,40 @@ def write(*args, status=None, **kwargs): print( status_to_ansi(status) + str(args[0]), *args[1:], - **kwargs, + sep=sep, + end=end, file=output_file(status), + flush=flush, ) -def error(*args, **kwargs) -> NoReturn: - write(*args, **kwargs, status="error") +def error( + *args: object, sep: str | None = None, end: str | None = None, flush: bool = False +) -> NoReturn: + write(*args, status="error", sep=sep, end=end, flush=flush) sys.exit(2) # The line_num is expected to be zero-indexed -def parsing_error_preamble(path: Path, line_num: int, line: str): +def parsing_error_preamble(path: Path, line_num: int, line: str) -> None: write("") write(f"error reading {path}, line {line_num + 1}:", status="error") write(f"\t{line}") -def status_to_ansi(status: Status): +def status_to_ansi(status: Status) -> Fore | str: if status == "ok": return Fore.GREEN - elif status == "warn": + if status == "warn": return Fore.YELLOW + Style.BRIGHT - elif status == "error": + if status == "error": return Fore.RED + Style.BRIGHT - elif status == "skip": + if status == "skip": return Style.DIM - else: - return "" + return "" -def output_file(status: Status): +def output_file(status: Status) -> TextIO: if status == "warn" or status == "error": return sys.stderr return sys.stdout diff --git a/src/splat/util/n64/find_code_length.py b/src/splat/util/n64/find_code_length.py index 1f9053ca..84b76026 100755 --- a/src/splat/util/n64/find_code_length.py +++ b/src/splat/util/n64/find_code_length.py @@ -1,4 +1,5 @@ #! /usr/bin/env python3 +from __future__ import annotations import argparse @@ -6,7 +7,7 @@ import spimdisasm -def int_any_base(x): +def int_any_base(x: str) -> int: return int(x, 0) @@ -24,7 +25,9 @@ def int_any_base(x): ) -def run(rom_bytes, start_offset, vram, end_offset=None): +def run( + rom_bytes: bytes, start_offset: int, vram: int, end_offset: int | None = None +) -> int: rom_addr = start_offset last_return = rom_addr @@ -46,7 +49,7 @@ def run(rom_bytes, start_offset, vram, end_offset=None): return end -def main(): +def main() -> None: args = parser.parse_args() with open(args.rom, "rb") as f: diff --git a/src/splat/util/n64/rominfo.py b/src/splat/util/n64/rominfo.py index 772c273d..f65eb3b9 100755 --- a/src/splat/util/n64/rominfo.py +++ b/src/splat/util/n64/rominfo.py @@ -1,17 +1,14 @@ #! /usr/bin/env python3 +from __future__ import annotations import argparse - import hashlib import itertools import struct - import sys import zlib from dataclasses import dataclass - from pathlib import Path -from typing import Optional, List import rabbitizer import spimdisasm @@ -78,8 +75,11 @@ class EntryAddressInfo: @staticmethod def new( - value: Optional[int], hi: Optional[int], lo: Optional[int], ori: Optional[int] - ) -> Optional["EntryAddressInfo"]: + value: int | None, + hi: int | None, + lo: int | None, + ori: int | None, + ) -> EntryAddressInfo | None: if value is not None and hi is not None and lo is not None: return EntryAddressInfo(value, hi, lo, ori == lo) return None @@ -88,12 +88,12 @@ def new( @dataclass class N64EntrypointInfo: entry_size: int - data_size: Optional[int] - bss_start_address: Optional[EntryAddressInfo] - bss_size: Optional[EntryAddressInfo] - bss_end_address: Optional[EntryAddressInfo] - main_address: Optional[EntryAddressInfo] - stack_top: Optional[EntryAddressInfo] + data_size: int | None + bss_start_address: EntryAddressInfo | None + bss_size: EntryAddressInfo | None + bss_end_address: EntryAddressInfo | None + main_address: EntryAddressInfo | None + stack_top: EntryAddressInfo | None traditional_entrypoint: bool ori_entrypoint: bool @@ -102,7 +102,7 @@ def segment_size(self) -> int: return self.entry_size + self.data_size return self.entry_size - def get_bss_size(self) -> Optional[int]: + def get_bss_size(self) -> int | None: if self.bss_size is not None: return self.bss_size.value if self.bss_start_address is not None and self.bss_end_address is not None: @@ -111,35 +111,40 @@ def get_bss_size(self) -> Optional[int]: @staticmethod def parse_rom_bytes( - rom_bytes, vram: int, offset: int = 0x1000, size: int = 0x60 - ) -> "N64EntrypointInfo": + rom_bytes: bytes, + vram: int, + offset: int = 0x1000, + size: int = 0x60, + ) -> N64EntrypointInfo: word_list = spimdisasm.common.Utils.bytesToWords( - rom_bytes, offset, offset + size + rom_bytes, + offset, + offset + size, ) nops_count = 0 register_values = [0 for _ in range(32)] completed_pair = [False for _ in range(32)] - hi_assignments: List[Optional[int]] = [None for _ in range(32)] - lo_assignments: List[Optional[int]] = [None for _ in range(32)] + hi_assignments: list[int | None] = [None for _ in range(32)] + lo_assignments: list[int | None] = [None for _ in range(32)] # We need to track if something was paired using an ori instead of an # addiu or similar, because if that's the case we can't emit normal # relocations in the generated symbol_addrs file for it. - ori_assignments: List[Optional[int]] = [None for _ in range(32)] + ori_assignments: list[int | None] = [None for _ in range(32)] - register_bss_address: Optional[int] = None - register_bss_size: Optional[int] = None - register_main_address: Optional[int] = None + register_bss_address: int | None = None + register_bss_size: int | None = None + register_main_address: int | None = None - bss_address: Optional[EntryAddressInfo] = None - bss_size: Optional[EntryAddressInfo] = None - bss_end_address: Optional[EntryAddressInfo] = None + bss_address: EntryAddressInfo | None = None + bss_size: EntryAddressInfo | None = None + bss_end_address: EntryAddressInfo | None = None traditional_entrypoint = True ori_entrypoint = False decrementing_bss_routine = True - data_size: Optional[int] = None - func_call_target: Optional[EntryAddressInfo] = None + data_size: int | None = None + func_call_target: EntryAddressInfo | None = None size = 0 i = 0 @@ -173,11 +178,10 @@ def parse_rom_bytes( if insn.isUnsigned(): ori_assignments[insn.rt.value] = current_rom ori_entrypoint = True - elif insn.doesStore(): - if insn.rt == rabbitizer.RegGprO32.zero: - # Try to detect the zero-ing bss algorithm - # sw $zero, 0x0($t0) - register_bss_address = insn.rs.value + elif insn.doesStore() and insn.rt == rabbitizer.RegGprO32.zero: + # Try to detect the zero-ing bss algorithm + # sw $zero, 0x0($t0) + register_bss_address = insn.rs.value elif insn.isBranch(): if insn.uniqueId == rabbitizer.InstrId.cpu_beq: traditional_entrypoint = False @@ -233,7 +237,10 @@ def parse_rom_bytes( # entrypoint to actual code. traditional_entrypoint = False func_call_target = EntryAddressInfo( - insn.getInstrIndexAsVram(), current_rom, current_rom, False + insn.getInstrIndexAsVram(), + current_rom, + current_rom, + False, ) elif insn.uniqueId == rabbitizer.InstrId.cpu_break: @@ -293,18 +300,17 @@ def parse_rom_bytes( ori_assignments[rabbitizer.RegGprO32.sp.value], ) - if not traditional_entrypoint: - if func_call_target is not None: - main_address = func_call_target - if func_call_target.value > vram: - # Some weird-entrypoint games have non-code between the - # entrypoint and the actual user code. - # We try to find where actual code may begin, and tag - # everything in between as "entrypoint data". + if not traditional_entrypoint and func_call_target is not None: + main_address = func_call_target + if func_call_target.value > vram: + # Some weird-entrypoint games have non-code between the + # entrypoint and the actual user code. + # We try to find where actual code may begin, and tag + # everything in between as "entrypoint data". - code_start = find_code_after_data(rom_bytes, offset + i * 4, vram) - if code_start is not None and code_start > offset + size: - data_size = code_start - (offset + size) + code_start = find_code_after_data(rom_bytes, offset + i * 4, vram) + if code_start is not None and code_start > offset + size: + data_size = code_start - (offset + size) return N64EntrypointInfo( size, @@ -320,9 +326,12 @@ def parse_rom_bytes( def find_code_after_data( - rom_bytes: bytes, offset: int, vram: int, threshold: int = 0x18000 -) -> Optional[int]: - code_offset: Optional[int] = None + rom_bytes: bytes, + offset: int, + vram: int, + threshold: int = 0x18000, +) -> int | None: + code_offset: int | None = None # We loop through every word until we find a valid `jr $ra` instruction and # hope for it to be part of valid code. @@ -337,7 +346,9 @@ def find_code_after_data( if insn.isValid() and insn.isReturn(): # Check the instruction on the delay slot of the `jr $ra` is valid too. next_word = spimdisasm.common.Utils.bytesToWords( - rom_bytes, offset + 4, offset + 4 + 4 + rom_bytes, + offset + 4, + offset + 4 + 4, )[0] if rabbitizer.Instruction(next_word, vram + 4).isValid(): jr_ra_found = True @@ -389,7 +400,7 @@ def get_country_name(self) -> str: return country_codes[self.country_code] -def swap_bytes(data): +def swap_bytes(data: bytes) -> bytes: return bytes( itertools.chain.from_iterable( struct.pack(">H", x) for (x,) in struct.iter_unpack(" bytes: rom_bytes = rom_path.read_bytes() if rom_path.suffix.lower() == ".n64": @@ -410,17 +421,17 @@ def read_rom(rom_path: Path): return rom_bytes -def get_cic(rom_bytes: bytes): +def get_cic(rom_bytes: bytes) -> CIC: ipl3_crc = zlib.crc32(rom_bytes[0x40:0x1000]) return crc_to_cic.get(ipl3_crc, unknown_cic) -def get_entry_point(program_counter: int, cic: CIC): +def get_entry_point(program_counter: int, cic: CIC) -> int: return program_counter - cic.offset -def guess_header_encoding(rom_bytes: bytes): +def guess_header_encoding(rom_bytes: bytes) -> str: header = rom_bytes[0x20:0x34] encodings = ["ASCII", "shift_jis", "euc-jp"] for encoding in encodings: @@ -435,7 +446,9 @@ def guess_header_encoding(rom_bytes: bytes): def get_info( - rom_path: Path, rom_bytes: Optional[bytes] = None, header_encoding=None + rom_path: Path, + rom_bytes: bytes | None = None, + header_encoding: str | None = None, ) -> N64Rom: if rom_bytes is None: rom_bytes = read_rom(rom_path) @@ -457,7 +470,7 @@ def get_info_bytes(rom_bytes: bytes, header_encoding: str) -> N64Rom: sys.exit( "splat could not decode the game name;" " try using a different encoding by passing the --header-encoding argument" - " (see docs.python.org/3/library/codecs.html#standard-encodings for valid encodings)" + " (see docs.python.org/3/library/codecs.html#standard-encodings for valid encodings)", ) country_code = rom_bytes[0x3E] @@ -470,7 +483,9 @@ def get_info_bytes(rom_bytes: bytes, header_encoding: str) -> N64Rom: sha1 = hashlib.sha1(rom_bytes).hexdigest() entrypoint_info = N64EntrypointInfo.parse_rom_bytes( - rom_bytes, entry_point, size=0x100 + rom_bytes, + entry_point, + size=0x100, ) return N64Rom( @@ -488,7 +503,9 @@ def get_info_bytes(rom_bytes: bytes, header_encoding: str) -> N64Rom: ) -def get_compiler_info(rom_bytes, entry_point, print_result=True): +def get_compiler_info( + rom_bytes: bytes, entry_point: int, print_result: bool = True +) -> str: jumps = 0 branches = 0 @@ -507,12 +524,12 @@ def get_compiler_info(rom_bytes, entry_point, print_result=True): if print_result: print( f"{branches} branches and {jumps} jumps detected in the first code segment." - f" Compiler is most likely {compiler}" + f" Compiler is most likely {compiler}", ) return compiler -def main(): +def main() -> None: rabbitizer.config.pseudos_pseudoB = True args = parser.parse_args() diff --git a/src/splat/util/options.py b/src/splat/util/options.py index 2a6337f0..c0b8b65b 100644 --- a/src/splat/util/options.py +++ b/src/splat/util/options.py @@ -1,10 +1,15 @@ +from __future__ import annotations + from dataclasses import dataclass import os from pathlib import Path -from typing import cast, Dict, List, Literal, Mapping, Optional, Set, Type, TypeVar +from typing import cast, Literal, TypeVar, TYPE_CHECKING from . import compiler -from .compiler import Compiler + +if TYPE_CHECKING: + from collections.abc import Mapping + from .compiler import Compiler @dataclass @@ -12,7 +17,7 @@ class SplatOpts: # Debug / logging verbose: bool dump_symbols: bool - modes: List[str] + modes: list[str] # Project configuration @@ -21,7 +26,7 @@ class SplatOpts: # Determines the path to the target binary target_path: Path # Path to the final elf target - elf_path: Optional[Path] + elf_path: Path | None # Determines the platform of the target binary platform: str # Determines the compiler used to compile the target binary @@ -30,13 +35,13 @@ class SplatOpts: endianness: Literal["big", "little"] # Determines the default section order of the target binary # this can be overridden per-segment - section_order: List[str] + section_order: list[str] # Determines the code that is inserted by default in generated .c files generated_c_preamble: str # Determines the code that is inserted by default in generated .s files generated_s_preamble: str # Determines any extra content to be added in the generated macro.inc file - generated_macro_inc_content: Optional[str] + generated_macro_inc_content: str | None # Determines if files related to assembly macros should be regenerated by splat generate_asm_macros_files: bool # Changes the definition of the generated `INCLUDE_ASM`. @@ -48,7 +53,7 @@ class SplatOpts: # Determines whether to use .o as the suffix for all binary files?... TODO document use_o_as_suffix: bool # the value of the $gp register to correctly calculate offset to %gp_rel relocs - gp: Optional[int] + gp: int | None # Checks and errors if there are any non consecutive segment types check_consecutive_segment_types: bool # Disable checks on `platform` option. @@ -63,8 +68,8 @@ class SplatOpts: # as well as optional metadata such as rom address, type, and more # # It's possible to use more than one file by supplying a list instead of a string - symbol_addrs_paths: List[Path] - reloc_addrs_paths: List[Path] + symbol_addrs_paths: list[Path] + reloc_addrs_paths: list[Path] # Determines the path to the project build directory build_path: Path # Determines the path to the source code directory @@ -95,32 +100,32 @@ class SplatOpts: undefined_syms_auto_path: Path # Determines the path in which to search for custom splat extensions - extensions_path: Optional[Path] + extensions_path: Path | None # Determines the path to library files that are to be linked into the target binary lib_path: Path # TODO document - elf_section_list_path: Optional[Path] + elf_section_list_path: Path | None # Linker script # Determines the default subalign value to be specified in the generated linker script - subalign: Optional[int] + subalign: int | None # Determines whether to emit the subalign directive in the generated linker script emit_subalign: bool # The following option determines a list of sections for which automatic linker script entries should be added - auto_link_sections: List[str] + auto_link_sections: list[str] # Determines the desired path to the linker script that splat will generate ld_script_path: Path # Determines the desired path to the linker symbol header, # which exposes externed definitions for all segment ram/rom start/end locations - ld_symbol_header_path: Optional[Path] + ld_symbol_header_path: Path | None # Determines whether to add a discard section with a wildcard to the linker script ld_discard_section: bool # A list of sections to preserve during link time. It can be useful to preserve debugging sections - ld_sections_allowlist: List[str] + ld_sections_allowlist: list[str] # A list of sections to discard during link time. It can be useful to avoid using the wildcard discard. Note that this option does not turn off `ld_discard_section` - ld_sections_denylist: List[str] + ld_sections_denylist: list[str] # Determines whether to add wildcards for section linking in the linker script (.rodata* for example) ld_wildcard_sections: bool # Determines whether to use `follows_vram` (segment option) and @@ -132,9 +137,9 @@ class SplatOpts: # Change linker script generation to allow partially linking segments. Requires both `ld_partial_scripts_path` and `ld_partial_build_segments_path` to be set. ld_partial_linking: bool # Folder were each intermediary linker script will be written to. - ld_partial_scripts_path: Optional[Path] + ld_partial_scripts_path: Path | None # Folder where the built partially linked segments will be placed by the build system. - ld_partial_build_segments_path: Optional[Path] + ld_partial_build_segments_path: Path | None # Generate a dependency file for every linker script generated. Dependency files will have the same path and name as the corresponding linker script, but changing the extension to `.d`. Requires `elf_path` to be set. ld_dependencies: bool # Legacy linker script generation does not impose the section_order specified in the yaml options or per-segment options. @@ -146,11 +151,11 @@ class SplatOpts: # Specifies the starting offset for rom address symbols in the linker script. ld_rom_start: int # The value passed to the FILL statement on each segment. `None` disables using FILL statements on the linker script. Defaults to a fill value of 0. - ld_fill_value: Optional[int] + ld_fill_value: int | None # Allows to control if `bss` sections (and derivatived sections) will be put on a `NOLOAD` segment on the generated linker script or not. ld_bss_is_noload: bool # Aligns the start of the segment to the given value - ld_align_segment_start: Optional[int] + ld_align_segment_start: int | None # Allows to toggle aligning the `*_VRAM_END` linker symbol for each segment. ld_align_segment_vram_end: bool # Allows to toggle aligning the `*_END` linker symbol for each section of each section. @@ -160,7 +165,7 @@ class SplatOpts: # Sets the default option for the `bss_contains_common` attribute of all segments. ld_bss_contains_common: bool # Specify an expression to be used for the `_gp` symbol in the generated linker script instead of a hardcoded value. - ld_gp_expression: Optional[str] + ld_gp_expression: str | None ################################################################################ # C file options @@ -208,7 +213,7 @@ class SplatOpts: # Determines the macro used to declare the given symbol is a non matching one. asm_nonmatching_label_macro: str # Toggles the .size directive emitted by the disassembler - asm_emit_size_directive: Optional[bool] + asm_emit_size_directive: bool | None # Determines the number of characters to left align before the TODO finish documenting mnemonic_ljust: int # Determines whether to pad the rom address @@ -227,13 +232,13 @@ class SplatOpts: # Generate .asmproc.d dependency files for each C file which still reference functions in assembly files create_asm_dependencies: bool # Global option for rodata string encoding. This can be overriden per segment - string_encoding: Optional[str] + string_encoding: str | None # Global option for data string encoding. This can be overriden per segment - data_string_encoding: Optional[str] + data_string_encoding: str | None # Global option for the rodata string guesser. 0 disables the guesser completely. - rodata_string_guesser_level: Optional[int] + rodata_string_guesser_level: int | None # Global option for the data string guesser. 0 disables the guesser completely. - data_string_guesser_level: Optional[int] + data_string_guesser_level: int | None # Global option for allowing data symbols using addends on symbol references. It can be overriden per symbol allow_data_addends: bool # Tells the disassembler to try disassembling functions with unknown instructions instead of falling back to disassembling as raw data @@ -246,8 +251,8 @@ class SplatOpts: make_full_disasm_for_code: bool # Allow specifying that the global memory range may be larger than what was automatically detected. # Useful for projects where splat is used in multiple individual files, meaning the expected global segment may not be properly detected because each instance of splat can't see the info from other files. - global_vram_start: Optional[int] - global_vram_end: Optional[int] + global_vram_start: int | None + global_vram_end: int | None # For `c` segments (functions under the nonmatchings folder). # If True then use the `%gp_rel` explicit relocation parameter on instructions that use the $gp register, # otherwise strip the `%gp_rel` parameter entirely and convert those instructions into macro instructions that may not assemble to the original @@ -297,13 +302,13 @@ def is_mode_active(self, mode: str) -> bool: class OptParser: - _read_opts: Set[str] + _read_opts: set[str] def __init__(self, yaml: Mapping[str, object]) -> None: self._yaml = yaml self._read_opts = set() - def parse_opt(self, opt: str, t: Type[T], default: Optional[T] = None) -> T: + def parse_opt(self, opt: str, t: type[T], default: T | None = None) -> T: if opt not in self._yaml: if default is not None: return default @@ -313,17 +318,17 @@ def parse_opt(self, opt: str, t: Type[T], default: Optional[T] = None) -> T: if isinstance(value, t): return value if t is float and isinstance(value, int): - return cast(T, float(value)) + return cast("T", float(value)) raise ValueError(f"Expected {opt} to have type {t}, got {type(value)}") - def parse_optional_opt(self, opt: str, t: Type[T]) -> Optional[T]: + def parse_optional_opt(self, opt: str, t: type[T]) -> T | None: if opt not in self._yaml: return None return self.parse_opt(opt, t) def parse_optional_opt_with_default( - self, opt: str, t: Type[T], default: Optional[T] - ) -> Optional[T]: + self, opt: str, t: type[T], default: T | None + ) -> T | None: if opt not in self._yaml: return default self._read_opts.add(opt) @@ -331,36 +336,33 @@ def parse_optional_opt_with_default( if value is None or isinstance(value, t): return value if t is float and isinstance(value, int): - return cast(T, float(value)) + return cast("T", float(value)) raise ValueError(f"Expected {opt} to have type {t}, got {type(value)}") def parse_opt_within( - self, opt: str, t: Type[T], within: List[T], default: Optional[T] = None + self, opt: str, t: type[T], within: list[T], default: T | None = None ) -> T: value = self.parse_opt(opt, t, default) if value not in within: raise ValueError(f"Invalid value for {opt}: {value}") return value - def parse_path( - self, base_path: Path, opt: str, default: Optional[str] = None - ) -> Path: + def parse_path(self, base_path: Path, opt: str, default: str | None = None) -> Path: return Path(os.path.normpath(base_path / self.parse_opt(opt, str, default))) - def parse_optional_path(self, base_path: Path, opt: str) -> Optional[Path]: + def parse_optional_path(self, base_path: Path, opt: str) -> Path | None: if opt not in self._yaml: return None return self.parse_path(base_path, opt) - def parse_path_list(self, base_path: Path, opt: str, default: str) -> List[Path]: + def parse_path_list(self, base_path: Path, opt: str, default: str) -> list[Path]: paths = self.parse_opt(opt, object, default) if isinstance(paths, str): return [base_path / paths] - elif isinstance(paths, list): + if isinstance(paths, list): return [base_path / path for path in paths] - else: - raise ValueError(f"Expected str or list for '{opt}', got {type(paths)}") + raise ValueError(f"Expected str or list for '{opt}', got {type(paths)}") def check_no_unread_opts(self) -> None: opts = [opt for opt in self._yaml if opt not in self._read_opts] @@ -369,9 +371,9 @@ def check_no_unread_opts(self) -> None: def _parse_yaml( - yaml: Dict, - config_paths: List[Path], - modes: List[str], + yaml: Mapping[str, object], + config_paths: list[Path], + modes: list[str], verbose: bool = False, disasm_all: bool = False, make_full_disasm_for_code: bool = False, @@ -413,10 +415,9 @@ def parse_endianness() -> Literal["big", "little"]: if endianness == "big": return "big" - elif endianness == "little": + if endianness == "little": return "little" - else: - raise ValueError(f"Invalid endianness: {endianness}") + raise ValueError(f"Invalid endianness: {endianness}") def parse_include_asm_macro_style() -> Literal["default", "maspsx_hack"]: include_asm_macro_style = p.parse_opt_within( @@ -428,10 +429,9 @@ def parse_include_asm_macro_style() -> Literal["default", "maspsx_hack"]: if include_asm_macro_style == "default": return "default" - elif include_asm_macro_style == "maspsx_hack": + if include_asm_macro_style == "maspsx_hack": return "maspsx_hack" - else: - raise ValueError(f"Invalid endianness: {include_asm_macro_style}") + raise ValueError(f"Invalid endianness: {include_asm_macro_style}") default_ld_bss_is_noload = True if platform == "psx": @@ -651,13 +651,13 @@ def parse_include_asm_macro_style() -> Literal["default", "maspsx_hack"]: def initialize( - config: Dict, - config_paths: List[Path], - modes: Optional[List[str]] = None, - verbose=False, - disasm_all=False, - make_full_disasm_for_code=False, -): + config: Mapping[str, Mapping[str, object]], + config_paths: list[Path], + modes: list[str] | None = None, + verbose: bool = False, + disasm_all: bool = False, + make_full_disasm_for_code: bool = False, +) -> None: global opts if not modes: diff --git a/src/splat/util/palettes.py b/src/splat/util/palettes.py index 03d33e6d..76e29cd9 100644 --- a/src/splat/util/palettes.py +++ b/src/splat/util/palettes.py @@ -1,4 +1,6 @@ -from typing import Dict +from __future__ import annotations + +from typing import TYPE_CHECKING from ..util import log @@ -6,12 +8,15 @@ from ..segtypes.n64.ci import N64SegCi from ..segtypes.n64.palette import N64SegPalette -global_ids: Dict[str, N64SegPalette] +if TYPE_CHECKING: + from ..segtypes.segment import Segment + +global_ids: dict[str, N64SegPalette] = {} # Resolve Raster#palette and Palette#raster links -def initialize(all_segments): - def collect_global_ids(segments): +def initialize(all_segments: list[Segment]) -> None: + def collect_global_ids(segments: list[Segment]) -> None: for segment in segments: if isinstance(segment, N64SegPalette): if segment.global_id is not None: @@ -20,9 +25,9 @@ def collect_global_ids(segments): if isinstance(segment, CommonSegGroup): collect_global_ids(segment.subsegments) - def process(segments): - raster_map: Dict[str, N64SegCi] = {} - palette_map: Dict[str, N64SegPalette] = {} + def process(segments: list[Segment]) -> None: + raster_map: dict[str, N64SegCi] = {} + palette_map: dict[str, N64SegPalette] = {} for segment in segments: if isinstance(segment, N64SegPalette): @@ -73,7 +78,7 @@ def process(segments): log.error(f"Palette {pal.name} has no linked rasters") global global_ids - global_ids = {} + global_ids.clear() collect_global_ids(all_segments) diff --git a/src/splat/util/progress_bar.py b/src/splat/util/progress_bar.py index caec2cee..a809df0c 100644 --- a/src/splat/util/progress_bar.py +++ b/src/splat/util/progress_bar.py @@ -1,8 +1,14 @@ +from __future__ import annotations + import tqdm import sys +from typing import TYPE_CHECKING, TextIO + +if TYPE_CHECKING: + from collections.abc import Sequence -out_file = sys.stderr +out_file: TextIO = sys.stderr -def get_progress_bar(elements: list) -> tqdm.tqdm: +def get_progress_bar(elements: Sequence[object]) -> tqdm.tqdm: return tqdm.tqdm(elements, total=len(elements), file=out_file) diff --git a/src/splat/util/ps2/ps2elfinfo.py b/src/splat/util/ps2/ps2elfinfo.py index 2b9419fe..4dd98301 100644 --- a/src/splat/util/ps2/ps2elfinfo.py +++ b/src/splat/util/ps2/ps2elfinfo.py @@ -3,7 +3,6 @@ from __future__ import annotations import dataclasses -from pathlib import Path import spimdisasm from spimdisasm.elf32 import ( Elf32File, @@ -11,10 +10,13 @@ Elf32SectionHeaderFlag, Elf32ObjectFileType, ) -from typing import Optional +from typing import TYPE_CHECKING from .. import log +if TYPE_CHECKING: + from pathlib import Path + ELF_SECTION_MAPPING: dict[str, str] = { ".text": "asm", @@ -56,11 +58,11 @@ class Ps2Elf: size: int compiler: str elf_section_names: list[tuple[str, bool]] - gp: Optional[int] - ld_gp_expression: Optional[str] + gp: int | None + ld_gp_expression: str | None @staticmethod - def get_info(elf_path: Path, elf_bytes: bytes) -> Optional[Ps2Elf]: + def get_info(elf_path: Path, elf_bytes: bytes) -> Ps2Elf | None: # Avoid spimdisasm from complaining about unknown sections. spimdisasm.common.GlobalConfig.QUIET = True @@ -81,7 +83,7 @@ def get_info(elf_path: Path, elf_bytes: bytes) -> Optional[Ps2Elf]: gp = elf.reginfo.gpValue else: gp = None - first_small_section_info: Optional[tuple[str, int]] = None + first_small_section_info: tuple[str, int] | None = None first_segment_name = "cod" segs = [FakeSegment(first_segment_name, 0, 0, [])] @@ -91,7 +93,7 @@ def get_info(elf_path: Path, elf_bytes: bytes) -> Optional[Ps2Elf]: elf_section_names: list[tuple[str, bool]] = [] - first_offset: Optional[int] = None + first_offset: int | None = None rom_size = 0 previous_type = Elf32Constants.Elf32SectionHeaderType.PROGBITS diff --git a/src/splat/util/psx/psxexeinfo.py b/src/splat/util/psx/psxexeinfo.py index f6b13451..924da41a 100755 --- a/src/splat/util/psx/psxexeinfo.py +++ b/src/splat/util/psx/psxexeinfo.py @@ -3,12 +3,9 @@ from __future__ import annotations import argparse - +import dataclasses import hashlib import struct - -import dataclasses - from pathlib import Path import rabbitizer @@ -72,22 +69,18 @@ } -def is_valid(insn) -> bool: +def is_valid(insn: rabbitizer.Instruction) -> bool: if not insn.isValid(): - if insn.instrIdType.name in ("CPU_SPECIAL", "CPU_COP2"): - return True - else: - return False + return insn.instrIdType.name in ("CPU_SPECIAL", "CPU_COP2") opcode = insn.getOpcodeName() - if opcode in UNSUPPORTED_OPS: - return False - - return True + return opcode not in UNSUPPORTED_OPS def try_find_text( - rom_bytes, start_offset=PAYLOAD_OFFSET, valid_threshold=32 + rom_bytes: bytes, + start_offset: int = PAYLOAD_OFFSET, + valid_threshold: int = 32, ) -> tuple[int, int]: start = end = 0 good_count = valid_count = 0 @@ -123,7 +116,7 @@ def try_find_text( return (start, end) -def try_get_gp(rom_bytes, start_offset, max_instructions=50) -> int: +def try_get_gp(rom_bytes: bytes, start_offset: int, max_instructions: int = 50) -> int: # $gp is set like this: # /* A7738 800B7138 0E801C3C */ lui $gp, (0x800E0000 >> 16) # /* A773C 800B713C 90409C27 */ addiu $gp, $gp, 0x4090 @@ -142,7 +135,7 @@ def try_get_gp(rom_bytes, start_offset, max_instructions=50) -> int: return gp -def read_word(exe_bytes, offset) -> int: +def read_word(exe_bytes: bytes, offset: int) -> int: return struct.unpack(" PsxExe: ) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Gives information on PSX EXEs") parser.add_argument("exe", help="Path to an PSX EXE") diff --git a/src/splat/util/relocs.py b/src/splat/util/relocs.py index 37a7a01a..99c72310 100644 --- a/src/splat/util/relocs.py +++ b/src/splat/util/relocs.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from dataclasses import dataclass -from typing import Dict import spimdisasm @@ -15,14 +16,14 @@ class Reloc: addend: int = 0 -all_relocs: Dict[int, Reloc] = {} +all_relocs: dict[int, Reloc] = {} -def add_reloc(reloc: Reloc): +def add_reloc(reloc: Reloc) -> None: all_relocs[reloc.rom_address] = reloc -def initialize(): +def initialize() -> None: global all_relocs all_relocs = {} @@ -114,7 +115,7 @@ def initialize(): add_reloc(reloc) -def initialize_spim_context(): +def initialize_spim_context() -> None: for rom_address, reloc in all_relocs.items(): reloc_type = spimdisasm.common.RelocType.fromStr(reloc.reloc_type) diff --git a/src/splat/util/statistics.py b/src/splat/util/statistics.py index 78efd197..74ea8917 100644 --- a/src/splat/util/statistics.py +++ b/src/splat/util/statistics.py @@ -1,40 +1,42 @@ +from __future__ import annotations + from colorama import Fore, Style -from typing import Dict, Optional from . import log def fmt_size(size: int) -> str: if size > 1000000: - return str(size // 1000000) + " MB" - elif size > 1000: - return str(size // 1000) + " KB" - else: - return str(size) + " B" + return f"{size // 1000000} MB" + if size > 1000: + return f"{size // 1000} KB" + return f"{size} B" class Statistics: - def __init__(self): - self.seg_sizes: Dict[str, int] = {} - self.seg_split: Dict[str, int] = {} - self.seg_cached: Dict[str, int] = {} + __slots__ = ("seg_cached", "seg_sizes", "seg_split") + + def __init__(self) -> None: + self.seg_sizes: dict[str, int] = {} + self.seg_split: dict[str, int] = {} + self.seg_cached: dict[str, int] = {} - def add_size(self, typ: str, size: Optional[int]): + def add_size(self, typ: str, size: int | None) -> None: if typ not in self.seg_sizes: self.seg_sizes[typ] = 0 self.seg_sizes[typ] += 0 if size is None else size - def count_split(self, typ: str, count: int = 1): + def count_split(self, typ: str, count: int = 1) -> None: if typ not in self.seg_split: self.seg_split[typ] = 0 self.seg_split[typ] += count - def count_cached(self, typ: str, count: int = 1): + def count_cached(self, typ: str, count: int = 1) -> None: if typ not in self.seg_cached: self.seg_cached[typ] = 0 self.seg_cached[typ] += count - def print_statistics(self, total_size: int): + def print_statistics(self, total_size: int) -> None: unk_size = self.seg_sizes.get("unk", 0) rest_size = 0 diff --git a/src/splat/util/symbols.py b/src/splat/util/symbols.py index cc38bb10..0e047f25 100644 --- a/src/splat/util/symbols.py +++ b/src/splat/util/symbols.py @@ -1,24 +1,26 @@ +from __future__ import annotations + from dataclasses import dataclass import re -from typing import Dict, List, Optional, Set, TYPE_CHECKING +from typing import TYPE_CHECKING import spimdisasm from intervaltree import IntervalTree from ..disassembler import disassembler_instance -from pathlib import Path # circular import if TYPE_CHECKING: + from pathlib import Path from ..segtypes.segment import Segment from . import log, options, progress_bar -all_symbols: List["Symbol"] = [] -all_symbols_dict: Dict[int, List["Symbol"]] = {} +all_symbols: list[Symbol] = [] +all_symbols_dict: dict[int, list[Symbol]] = {} all_symbols_ranges = IntervalTree() -ignored_addresses: Set[int] = set() -to_mark_as_defined: Set[str] = set() +ignored_addresses: set[int] = set() +to_mark_as_defined: set[str] = set() # Initialize a spimdisasm context, used to store symbols and functions spim_context = spimdisasm.common.Context() @@ -52,7 +54,7 @@ def is_falsey(str: str) -> bool: return str.lower() in FALSEY_VALS -def add_symbol(sym: "Symbol"): +def add_symbol(sym: Symbol) -> None: all_symbols.append(sym) if sym.vram_start is not None: if sym.vram_start not in all_symbols_dict: @@ -74,21 +76,21 @@ def to_cname(symbol_name: str) -> str: def handle_sym_addrs( - path: Path, sym_addrs_lines: List[str], all_segments: "List[Segment]" -): - def get_seg_for_name(name: str) -> Optional["Segment"]: + path: Path, sym_addrs_lines: list[str], all_segments: list[Segment] +) -> None: + def get_seg_for_name(name: str) -> Segment | None: for segment in all_segments: if segment.name == name: return segment return None - def get_seg_for_rom(rom: int) -> Optional["Segment"]: + def get_seg_for_rom(rom: int) -> Segment | None: for segment in all_segments: if segment.contains_rom(rom): return segment return None - seen_symbols: Dict[str, "Symbol"] = dict() + seen_symbols: dict[str, Symbol] = dict() prog_bar = progress_bar.get_progress_bar(sym_addrs_lines) prog_bar.set_description(f"Loading symbols ({path.stem})") line: str @@ -332,7 +334,7 @@ def get_seg_for_rom(rom: int) -> Optional["Segment"]: add_symbol(sym) -def initialize(all_segments: "List[Segment]"): +def initialize(all_segments: list[Segment]) -> None: global all_symbols global all_symbols_dict global all_symbols_ranges @@ -344,23 +346,23 @@ def initialize(all_segments: "List[Segment]"): # Manual list of func name / addrs for path in options.opts.symbol_addrs_paths: if path.exists(): - with open(path) as f: + with open(path, encoding="utf-8") as f: sym_addrs_lines = f.readlines() handle_sym_addrs(path, sym_addrs_lines, all_segments) -def initialize_spim_context(all_segments: "List[Segment]") -> None: +def initialize_spim_context(all_segments: list[Segment]) -> None: global_vrom_start = None global_vrom_end = None global_vram_start = options.opts.global_vram_start global_vram_end = options.opts.global_vram_end - overlay_segments: Set[spimdisasm.common.SymbolsSegment] = set() + overlay_segments: set[spimdisasm.common.SymbolsSegment] = set() spim_context.bannedSymbols |= ignored_addresses from ..segtypes.common.code import CommonSegCode - global_segments_after_overlays: List[CommonSegCode] = [] + global_segments_after_overlays: list[CommonSegCode] = [] for segment in all_segments: if not isinstance(segment, CommonSegCode): @@ -495,7 +497,7 @@ def initialize_spim_context(all_segments: "List[Segment]") -> None: def add_symbol_to_spim_segment( - segment: spimdisasm.common.SymbolsSegment, sym: "Symbol" + segment: spimdisasm.common.SymbolsSegment, sym: Symbol ) -> spimdisasm.common.ContextSymbol: if sym.type == "func": context_sym = segment.addFunction( @@ -555,7 +557,7 @@ def add_symbol_to_spim_segment( def add_symbol_to_spim_section( - section: spimdisasm.mips.sections.SectionBase, sym: "Symbol" + section: spimdisasm.mips.sections.SectionBase, sym: Symbol ) -> spimdisasm.common.ContextSymbol: if sym.type == "func": context_sym = section.addFunction( @@ -609,11 +611,11 @@ def add_symbol_to_spim_section( # force_in_segment=True when the symbol belongs to this specific segment. # force_in_segment=False when this symbol is just a reference. def create_symbol_from_spim_symbol( - segment: "Segment", + segment: Segment, context_sym: spimdisasm.common.ContextSymbol, *, force_in_segment: bool, -) -> "Symbol": +) -> Symbol: in_segment = False sym_type = None @@ -663,7 +665,7 @@ def create_symbol_from_spim_symbol( return sym -def mark_c_funcs_as_defined(): +def mark_c_funcs_as_defined() -> None: for symbol in all_symbols: if len(to_mark_as_defined) == 0: return @@ -677,12 +679,12 @@ def mark_c_funcs_as_defined(): class Symbol: vram_start: int - given_name: Optional[str] = None - given_name_end: Optional[str] = None - rom: Optional[int] = None - type: Optional[str] = None - given_size: Optional[int] = None - segment: Optional["Segment"] = None + given_name: str | None = None + given_name_end: str | None = None + rom: int | None = None + type: str | None = None + given_size: int | None = None + segment: Segment | None = None defined: bool = False referenced: bool = False @@ -691,29 +693,29 @@ class Symbol: force_migration: bool = False force_not_migration: bool = False - function_owner: Optional[str] = None + function_owner: str | None = None allow_addend: bool = False dont_allow_addend: bool = False - can_reference: Optional[bool] = None - can_be_referenced: Optional[bool] = None + can_reference: bool | None = None + can_be_referenced: bool | None = None - linker_section: Optional[str] = None + linker_section: str | None = None allow_duplicated: bool = False - given_filename: Optional[str] = None - given_visibility: Optional[str] = None + given_filename: str | None = None + given_visibility: str | None = None - given_align: Optional[int] = None + given_align: int | None = None - use_non_matching_label: Optional[bool] = None + use_non_matching_label: bool | None = None - _generated_default_name: Optional[str] = None - _last_type: Optional[str] = None + _generated_default_name: str | None = None + _last_type: str | None = None - def __str__(self): + def __str__(self) -> str: return self.name def __eq__(self, other: object) -> bool: @@ -722,7 +724,7 @@ def __eq__(self, other: object) -> bool: return self.vram_start == other.vram_start and self.segment == other.segment # https://stackoverflow.com/a/56915493/6292472 - def __hash__(self): + def __hash__(self) -> int: return hash((self.vram_start, self.segment)) def format_name(self, format: str) -> str: @@ -777,11 +779,11 @@ def default_name(self) -> str: return self._generated_default_name @property - def rom_end(self): + def rom_end(self) -> int | None: return None if not self.rom else self.rom + self.size @property - def vram_end(self): + def vram_end(self) -> int: return self.vram_start + self.size @property @@ -800,19 +802,21 @@ def filename(self) -> str: return self.given_filename return self.name - def contains_vram(self, offset): + def contains_vram(self, offset: int) -> bool: return offset >= self.vram_start and offset < self.vram_end - def contains_rom(self, offset): + def contains_rom(self, offset: int) -> bool: + if self.rom is None or self.rom_end is None: + return False return offset >= self.rom and offset < self.rom_end -def get_all_symbols(): +def get_all_symbols() -> list[Symbol]: global all_symbols return all_symbols -def reset_symbols(): +def reset_symbols() -> None: global all_symbols global all_symbols_dict global all_symbols_ranges diff --git a/src/splat/util/utils.py b/src/splat/util/utils.py index a311dd3b..ec0656c7 100644 --- a/src/splat/util/utils.py +++ b/src/splat/util/utils.py @@ -1,9 +1,11 @@ -from typing import List, Optional, TypeVar +from __future__ import annotations + +from typing import TypeVar T = TypeVar("T") -def list_index(the_list: List[T], value: T) -> Optional[int]: +def list_index(the_list: list[T], value: T) -> int | None: if value not in the_list: return None return the_list.index(value) diff --git a/src/splat/util/vram_classes.py b/src/splat/util/vram_classes.py index b28c6d7f..b6ddb5f5 100644 --- a/src/splat/util/vram_classes.py +++ b/src/splat/util/vram_classes.py @@ -1,30 +1,70 @@ +from __future__ import annotations + from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, TypedDict from . import log +if TYPE_CHECKING: + from typing_extensions import NotRequired + @dataclass(frozen=True) class VramClass: name: str vram: int - given_vram_symbol: Optional[str] = None - follows_classes: List[str] = field(default_factory=list, compare=False) + given_vram_symbol: str | None = None + follows_classes: list[str] = field(default_factory=list, compare=False) @property - def vram_symbol(self) -> Optional[str]: + def vram_symbol(self) -> str | None: if self.given_vram_symbol is not None: return self.given_vram_symbol - elif self.follows_classes: + if self.follows_classes: return self.name + "_CLASS_VRAM" - else: - return None + return None -_vram_classes: Dict[str, VramClass] = {} +_vram_classes: dict[str, VramClass] = {} -def initialize(yaml: Any): +class SerializedSegmentData(TypedDict): + name: NotRequired[str] + vram: int + vram_symbol: str | None + follows_classes: list[str] + vram_class: NotRequired[str] + follows_vram: NotRequired[str | None] + align: NotRequired[str] + subalign: NotRequired[str] + section_order: NotRequired[list[str]] + start: NotRequired[str] + type: NotRequired[str] + dir: NotRequired[str] + symbol_name_format: NotRequired[str] + symbol_name_format_no_rom: NotRequired[str] + path: NotRequired[str] + bss_contains_common: NotRequired[bool] + linker_section_order: NotRequired[str] + linker_section: NotRequired[str] + ld_fill_value: NotRequired[int] + ld_align_segment_start: NotRequired[int] + pair_segment: NotRequired[str] + exclusive_ram_id: NotRequired[str] + find_file_boundaries: NotRequired[bool] + size: NotRequired[int] + global_id: NotRequired[str] + length: NotRequired[int] + in_segment: NotRequired[bool] + data_only: NotRequired[bool] + bss_size: NotRequired[int] + str_encoding: NotRequired[str] + detect_redundant_function_end: NotRequired[bool] + width: NotRequired[int] + height: NotRequired[int] + + +def initialize(yaml: list[SerializedSegmentData | list[str]] | None) -> None: global _vram_classes _vram_classes = {} @@ -47,8 +87,8 @@ def initialize(yaml: Any): for vram_class in yaml: name: str vram: int - vram_symbol: Optional[str] = None - follows_classes: List[str] = [] + vram_symbol: str | None = None + follows_classes: list[str] = [] if isinstance(vram_class, dict): if "name" not in vram_class: @@ -83,7 +123,7 @@ def initialize(yaml: Any): f"vram_class ({vram_class}) must have 2 elements, got {len(vram_class)}" ) name = vram_class[0] - vram = vram_class[1] + vram = int(vram_class[1]) else: log.error(f"vram_class must be a dict or list, got {type(vram_class)}") diff --git a/test.py b/test.py index d8dfc9b0..a54d5464 100755 --- a/test.py +++ b/test.py @@ -2,11 +2,9 @@ import difflib import filecmp -import io from pathlib import Path import spimdisasm import unittest -from typing import List, Tuple from src.splat import __version__ from src.splat.disassembler import disassembler_instance @@ -21,17 +19,17 @@ class Testing(unittest.TestCase): def compare_files(self, test_path, ref_path): - with io.open(test_path) as test_f, io.open(ref_path) as ref_f: + with open(test_path) as test_f, open(ref_path) as ref_f: self.assertListEqual(list(test_f), list(ref_f)) - def get_same_files(self, dcmp: filecmp.dircmp, out: List[Tuple[str, str, str]]): + def get_same_files(self, dcmp: filecmp.dircmp, out: list[tuple[str, str, str]]): for name in dcmp.same_files: out.append((name, dcmp.left, dcmp.right)) for sub_dcmp in dcmp.subdirs.values(): self.get_same_files(sub_dcmp, out) - def get_diff_files(self, dcmp: filecmp.dircmp, out: List[Tuple[str, str, str]]): + def get_diff_files(self, dcmp: filecmp.dircmp, out: list[tuple[str, str, str]]): for name in dcmp.diff_files: out.append((name, dcmp.left, dcmp.right)) @@ -39,7 +37,7 @@ def get_diff_files(self, dcmp: filecmp.dircmp, out: List[Tuple[str, str, str]]): self.get_diff_files(sub_dcmp, out) def get_left_only_files( - self, dcmp: filecmp.dircmp, out: List[Tuple[str, str, str]] + self, dcmp: filecmp.dircmp, out: list[tuple[str, str, str]] ): for name in dcmp.left_only: out.append((name, dcmp.left, dcmp.right)) @@ -48,7 +46,7 @@ def get_left_only_files( self.get_left_only_files(sub_dcmp, out) def get_right_only_files( - self, dcmp: filecmp.dircmp, out: List[Tuple[str, str, str]] + self, dcmp: filecmp.dircmp, out: list[tuple[str, str, str]] ): for name in dcmp.right_only: out.append((name, dcmp.left, dcmp.right)) @@ -64,16 +62,16 @@ def test_basic_app(self): "test/basic_app/split", "test/basic_app/expected", [".gitkeep"] ) - diff_files: List[Tuple[str, str, str]] = [] + diff_files: list[tuple[str, str, str]] = [] self.get_diff_files(comparison, diff_files) - same_files: List[Tuple[str, str, str]] = [] + same_files: list[tuple[str, str, str]] = [] self.get_same_files(comparison, same_files) - left_only_files: List[Tuple[str, str, str]] = [] + left_only_files: list[tuple[str, str, str]] = [] self.get_left_only_files(comparison, left_only_files) - right_only_files: List[Tuple[str, str, str]] = [] + right_only_files: list[tuple[str, str, str]] = [] self.get_right_only_files(comparison, right_only_files) print("same_files", same_files) @@ -470,7 +468,7 @@ def test_overlay(self): ], } - all_segments: List["Segment"] = [ + all_segments: list[Segment] = [ CommonSegCode( rom_start=0x1000, rom_end=0x1140, @@ -512,7 +510,7 @@ def test_global(self): ], } - all_segments: List["Segment"] = [ + all_segments: list[Segment] = [ CommonSegCode( rom_start=0x1000, rom_end=0x1140,