基于生成对抗网络(GAN)的图像修复系统,支持多种掩码类型和先进的网络架构。
本项目实现了基于GAN的图像修复(Image Inpainting)功能,能够智能填补图像中的缺损区域。项目支持两种生成器架构:
- 原始GAN生成器:基于U-Net的编码器-解码器结构
- FFC生成器:采用Fast Fourier Convolution的先进架构,支持预训练权重
✅ 多种掩码类型
rect: 规则矩形遮挡(适合测试)random: 随机小矩形遮挡qd_imd: QD-IMD数据集的真实不规则掩码(最贴近实际应用)
✅ 训练技术
- 感知损失 (Perceptual Loss):基于VGG16提取语义特征
- 谱归一化 (Spectral Normalization):提升训练稳定性
- 实例归一化 (Instance Normalization):适应小批量训练
- 学习率调度器 (Cosine Annealing):平滑优化过程
- 混合精度训练 (AMP):加速训练并节省显存
✅ 完善的评估体系
- FID (Fréchet Inception Distance):衡量生成图像与真实图像的分布差异
- SSIM (Structural Similarity Index):评估结构一致性
- 感知损失 (Perceptual Loss):度量高层语义相似度
GAN/
├── final.py # 主训练脚本(支持命令行参数配置)
├── model.py # 原始GAN生成器/判别器定义
├── dataset.py # 数据集加载与掩码生成
├── RepairTrainer.py # 训练器类(封装训练逻辑)
├── ImageMetrics.py # 评估指标计算(FID/SSIM/Perceptual Loss)
├── train_examples.sh # 训练命令示例(非核心)
├── verify_system.py # 系统环境验证脚本(非核心)
├── test.py # 模型测试脚本(非核心)
├── saicinpainting/ # FFC模块实现
│ ├── training/modules/ffc.py # Fast Fourier Convolution实现
│ └── ...
├── output/ # 训练输出目录(自动创建)
│ └── run_YYYYMMDD_HHMMSS/ # 每次运行的专属目录
│ ├── training.log # 训练日志
│ ├── metrics.png # 损失/FID/SSIM曲线图
│ ├── checkpoint_epochX.pth # 检查点模型
│ └── gan_repair_optimized.pth # 最终模型
├── output/ # 训练输出目录(自动创建)
│ └── ffc_prepretrain
│ ├── training.log # 训练日志
│ ├── metrics.png # 损失/FID/SSIM曲线图
│ ├── checkpoint_epochX.pth # 检查点模型
│ └── gan_repair_optimized.pth # 最终模型
└── README.md # 本文件
验证环境
python verify_system.py下载CelebA数据集
- 访问 CelebA官网
- 下载
img_align_celeba.zip(约1.4GB,包含202,599张对齐的人脸图像) - 解压到
data/img_align_celeba/
下载QD-IMD掩码数据集(可选,用于真实不规则遮挡)
- 访问 QD-IMD GitHub
- 下载训练集和测试集掩码
- 解压到
qd_imd/train/和qd_imd/test/
基础训练(使用原始GAN生成器)
python final.py \
--image_folder "data/img_align_celeba" \
--batch_size 16 \
--epochs 20 \
--lr 0.0002 \
--mask_type "random" \
--mask_ratio 0.15 \
--output_dir "./output"使用QD-IMD真实掩码训练
python final.py \
--image_folder "data/img_align_celeba" \
--mask_type "qd_imd" \
--qd_imd_path "qd_imd/train" \
--batch_size 16 \
--epochs 20 \
--output_dir "./output"使用FFC生成器(支持预训练权重)
python final.py \
--image_folder "data/img_align_celeba" \
--generator_type "ffc" \
--pretrained_weights "generator_weights.pth" \
--mask_type "qd_imd" \
--batch_size 12 \
--epochs 10 \
--output_dir "./training_output"python test.py \
--model_path "output/run_20251015_120000/gan_repair_optimized.pth" \
--test_image "path/to/test/image.jpg" \
--mask_type "qd_imd"| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--image_folder |
str | - | CelebA数据集路径 |
--mask_type |
str | qd_imd |
掩码类型:rect/random/qd_imd |
--qd_imd_path |
str | qd_imd/train |
QD-IMD掩码路径(仅qd_imd模式) |
--mask_ratio |
float | 0.15 | 遮挡占比(0-1) |
--use_augmentation |
bool | True | 是否启用数据增强(水平翻转) |
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--batch_size |
int | 16 | 批次大小(根据显存调整) |
--epochs |
int | 10 | 训练轮数 |
--train_batches_per_epoch |
int | 4000 | 每epoch训练批次数 |
--lr |
float | 0.0002 | 学习率 |
--beta1 |
float | 0.5 | Adam优化器beta1参数 |
--l1_weight |
float | 100.0 | L1损失权重 |
--perceptual_weight |
float | 10.0 | 感知损失权重 |
--n_critic |
int | 5 | 判别器相对生成器的训练次数 |
--save_interval |
int | 1 | 模型保存间隔(epoch) |
--disable_amp |
flag | - | 禁用混合精度训练 |
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--generator_type |
str | original |
生成器类型:original/ffc |
--pretrained_weights |
str | None | 预训练权重路径(仅ffc模式) |
--freeze_generator |
flag | - | 冻结生成器参数(仅微调判别器) |
--num_filters |
int | 64 | 基础滤波器数量 |
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--output_dir |
str | ./training_output |
输出根目录 |
--run_name |
str | run_时间戳 |
本次运行名称(用于创建子目录) |
--log_interval |
int | 10 | 日志打印间隔(batch) |
训练过程中会实时打印:
Epoch 1/20 [Batch 100/4000] D_loss: 0.4523 | G_loss: 2.1345 | Perceptual: 0.0856
训练结束后生成 metrics.png,包含6个子图:
- Discriminator Loss:判别器损失(real/fake)
- Generator Loss:生成器损失(total/gan/l1/perceptual)
- FID Curve:FID值变化(越低越好,优秀模型<50)
- SSIM Curve:SSIM值变化(越高越好,理想值接近1)
- Generator Learning Rate:生成器学习率衰减曲线
- Discriminator Learning Rate:判别器学习率衰减曲线
checkpoint_epoch{X}.pth:每个epoch保存的检查点gan_repair_optimized.pth:最终训练完成的模型
原始GAN生成器 (model.py)
- 架构:U-Net(编码器-解码器 + 跳跃连接)
- 输入:(batch, 3, 218, 178) - 缺损图像
- 输出:(batch, 3, 218, 178) - 修复图像
- 特点:Instance Normalization + LeakyReLU
FFC生成器 (saicinpainting/training/modules/ffc.py)
- 架构:FFCResNetGenerator(Fast Fourier Convolution + ResNet)
- 输入:(batch, 4, 218, 178) - RGB + Mask通道
- 输出:(batch, 3, 218, 178) - 修复图像
- 特点:全局感受野 + 预训练权重支持
- 架构:PatchGAN(局部真实性判断)
- 输入:(batch, 3, 218, 178)
- 输出:(batch, 1, 12, 10) - 局部真实性得分
- 特点:Spectral Normalization(提升训练稳定性)
QD-IMD模式 (dataset.py)
- 加载真实不规则掩码图像
- 支持随机旋转、缩放增强
- 二值化处理(1=保留,0=遮挡)
算法生成模式
rect: 单一大矩形遮挡random: 随机x个小矩形遮挡
FID (Fréchet Inception Distance)
- 使用InceptionV3提取2048维特征
- 计算真实图像与生成图像的分布距离
- 范围:[0, +∞),越小越好(优秀模型<50)
SSIM (Structural Similarity Index)
- 评估亮度、对比度、结构三方面相似度
- 范围:[0, 1],越接近1越好
Perceptual Loss(感知损失)
- 使用VGG16提取relu3_3层特征
- 度量高层语义差异(比像素级L1更符合人类感知)
| 掩码类型 | 难度 | 训练速度 | 实际应用 |
|---|---|---|---|
| rect | 简单 | 快 | 适合测试模型基本能力 |
| random | 中等 | 中 | 模拟噪声干扰 |
| qd_imd | 困难 | 慢 | 最贴近真实场景(推荐) |