Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ __pycache__/
*.pyo
*.pyd
.Python
*.egg-info/
env/
venv/
.venv/
uv.lock

# Environment Variables
.env
Expand Down
6 changes: 3 additions & 3 deletions config/config.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
sam3:
checkpoint_path: "models/sam3_ms/sam3.pt"
bpe_path: "models/bpe_simple_vocab_16e6.txt.gz"
# device: "cuda" | "cpu" | leave empty for auto (cuda if available)
# Set "cpu" if GPU does not match current PyTorch
# device: "cpu"
# device: "cuda" | "mps" | "cpu" | leave empty for auto (cuda > mps > cpu)
# Apple Silicon: use "mps" (requires PYTORCH_ENABLE_MPS_FALLBACK=1)
# device: "mps"
score_threshold: 0.5
epsilon_factor: 0.02
min_area: 100
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

# Skip PaddleX model host connectivity check to avoid startup delay
os.environ.setdefault("PADDLE_PDX_DISABLE_MODEL_SOURCE_CHECK", "True")
# MPS (Apple Silicon) lacks a few ops; let PyTorch fall back to CPU for those.
# Must be set before `import torch`.
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
# Suppress requests urllib3/chardet version warning
warnings.filterwarnings("ignore", message=".*doesn't match a supported version.*")

Expand Down
9 changes: 8 additions & 1 deletion modules/sam3_info_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,14 @@ def __init__(self, checkpoint_path: str, bpe_path: str, device: str = None):
super().__init__()
self.checkpoint_path = checkpoint_path
self.bpe_path = bpe_path
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if device:
self.device = device
elif torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
self._processor = None

# 图像状态缓存
Expand Down
58 changes: 58 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
[project]
name = "edit-banana"
version = "0.1.0"
description = "Universal Content Re-Editor: Make the Uneditable, Editable"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.10"
dependencies = [
"pyyaml",
"opencv-python-headless",
"numpy",
"Pillow",
"scikit-image",
"requests",
"fastapi",
"uvicorn[standard]",
"pytesseract",
]

[project.optional-dependencies]
paddleocr = [
"paddlepaddle==3.2.2",
"paddleocr",
]
formula = [
"pix2text",
"onnxruntime",
]
rmbg = [
"onnxruntime",
]
torch = [
"torch>=2.0",
"torchvision",
]
sam3 = [
"edit-banana[torch]",
"sam3",
"einops",
"pycocotools",
"psutil",
]

[project.scripts]
edit-banana = "main:main"

[tool.setuptools.packages.find]
include = ["modules*"]

[tool.uv]
dev-dependencies = []

[tool.uv.sources]
sam3 = { path = "sam3_src", editable = true }
# On macOS Apple Silicon, use default PyPI torch (includes MPS support).
# For Linux+CUDA, override via:
# torch = { index = "pytorch-cu121" }
# torchvision = { index = "pytorch-cu121" }