Skip to content

littleotherut/GAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GAN图像修复项目

基于生成对抗网络(GAN)的图像修复系统,支持多种掩码类型和先进的网络架构。


📋 项目简介

本项目实现了基于GAN的图像修复(Image Inpainting)功能,能够智能填补图像中的缺损区域。项目支持两种生成器架构:

  • 原始GAN生成器:基于U-Net的编码器-解码器结构
  • FFC生成器:采用Fast Fourier Convolution的先进架构,支持预训练权重

主要特性

多种掩码类型

  • rect: 规则矩形遮挡(适合测试)
  • random: 随机小矩形遮挡
  • qd_imd: QD-IMD数据集的真实不规则掩码(最贴近实际应用)

训练技术

  • 感知损失 (Perceptual Loss):基于VGG16提取语义特征
  • 谱归一化 (Spectral Normalization):提升训练稳定性
  • 实例归一化 (Instance Normalization):适应小批量训练
  • 学习率调度器 (Cosine Annealing):平滑优化过程
  • 混合精度训练 (AMP):加速训练并节省显存

完善的评估体系

  • FID (Fréchet Inception Distance):衡量生成图像与真实图像的分布差异
  • SSIM (Structural Similarity Index):评估结构一致性
  • 感知损失 (Perceptual Loss):度量高层语义相似度

🏗️ 项目结构(省略数据集部分)

GAN/
├── final.py                    # 主训练脚本(支持命令行参数配置)
├── model.py                    # 原始GAN生成器/判别器定义
├── dataset.py                  # 数据集加载与掩码生成
├── RepairTrainer.py            # 训练器类(封装训练逻辑)
├── ImageMetrics.py             # 评估指标计算(FID/SSIM/Perceptual Loss)
├── train_examples.sh           # 训练命令示例(非核心)
├── verify_system.py            # 系统环境验证脚本(非核心)
├── test.py                     # 模型测试脚本(非核心)
├── saicinpainting/             # FFC模块实现
│   ├── training/modules/ffc.py # Fast Fourier Convolution实现
│   └── ...
├── output/                     # 训练输出目录(自动创建)
│   └── run_YYYYMMDD_HHMMSS/   # 每次运行的专属目录
│       ├── training.log        # 训练日志
│       ├── metrics.png         # 损失/FID/SSIM曲线图
│       ├── checkpoint_epochX.pth  # 检查点模型
│       └── gan_repair_optimized.pth  # 最终模型
├── output/                     # 训练输出目录(自动创建)
│   └── ffc_prepretrain
│       ├── training.log        # 训练日志
│       ├── metrics.png         # 损失/FID/SSIM曲线图
│       ├── checkpoint_epochX.pth  # 检查点模型
│       └── gan_repair_optimized.pth  # 最终模型
└── README.md                   # 本文件

🚀 快速开始

1. 环境验证

验证环境

python verify_system.py

2. 数据准备

下载CelebA数据集

  1. 访问 CelebA官网
  2. 下载 img_align_celeba.zip(约1.4GB,包含202,599张对齐的人脸图像)
  3. 解压到 data/img_align_celeba/

下载QD-IMD掩码数据集(可选,用于真实不规则遮挡)

  1. 访问 QD-IMD GitHub
  2. 下载训练集和测试集掩码
  3. 解压到 qd_imd/train/qd_imd/test/

3. 训练模型

基础训练(使用原始GAN生成器)

python final.py \
    --image_folder "data/img_align_celeba" \
    --batch_size 16 \
    --epochs 20 \
    --lr 0.0002 \
    --mask_type "random" \
    --mask_ratio 0.15 \
    --output_dir "./output"

使用QD-IMD真实掩码训练

python final.py \
    --image_folder "data/img_align_celeba" \
    --mask_type "qd_imd" \
    --qd_imd_path "qd_imd/train" \
    --batch_size 16 \
    --epochs 20 \
    --output_dir "./output"

使用FFC生成器(支持预训练权重)

python final.py \
    --image_folder "data/img_align_celeba" \
    --generator_type "ffc" \
    --pretrained_weights "generator_weights.pth" \
    --mask_type "qd_imd" \
    --batch_size 12 \
    --epochs 10 \
    --output_dir "./training_output"

4. 测试模型

python test.py \
    --model_path "output/run_20251015_120000/gan_repair_optimized.pth" \
    --test_image "path/to/test/image.jpg" \
    --mask_type "qd_imd"

📊 命令行参数详解

数据相关

参数 类型 默认值 说明
--image_folder str - CelebA数据集路径
--mask_type str qd_imd 掩码类型:rect/random/qd_imd
--qd_imd_path str qd_imd/train QD-IMD掩码路径(仅qd_imd模式)
--mask_ratio float 0.15 遮挡占比(0-1)
--use_augmentation bool True 是否启用数据增强(水平翻转)

训练相关

参数 类型 默认值 说明
--batch_size int 16 批次大小(根据显存调整)
--epochs int 10 训练轮数
--train_batches_per_epoch int 4000 每epoch训练批次数
--lr float 0.0002 学习率
--beta1 float 0.5 Adam优化器beta1参数
--l1_weight float 100.0 L1损失权重
--perceptual_weight float 10.0 感知损失权重
--n_critic int 5 判别器相对生成器的训练次数
--save_interval int 1 模型保存间隔(epoch)
--disable_amp flag - 禁用混合精度训练

模型相关

参数 类型 默认值 说明
--generator_type str original 生成器类型:original/ffc
--pretrained_weights str None 预训练权重路径(仅ffc模式)
--freeze_generator flag - 冻结生成器参数(仅微调判别器)
--num_filters int 64 基础滤波器数量

输出相关

参数 类型 默认值 说明
--output_dir str ./training_output 输出根目录
--run_name str run_时间戳 本次运行名称(用于创建子目录)
--log_interval int 10 日志打印间隔(batch)

📈 训练监控

实时日志

训练过程中会实时打印:

Epoch 1/20 [Batch 100/4000] D_loss: 0.4523 | G_loss: 2.1345 | Perceptual: 0.0856

指标曲线

训练结束后生成 metrics.png,包含6个子图:

  1. Discriminator Loss:判别器损失(real/fake)
  2. Generator Loss:生成器损失(total/gan/l1/perceptual)
  3. FID Curve:FID值变化(越低越好,优秀模型<50)
  4. SSIM Curve:SSIM值变化(越高越好,理想值接近1)
  5. Generator Learning Rate:生成器学习率衰减曲线
  6. Discriminator Learning Rate:判别器学习率衰减曲线

模型检查点

  • checkpoint_epoch{X}.pth:每个epoch保存的检查点
  • gan_repair_optimized.pth:最终训练完成的模型

🧪 核心模块说明

1. Generator(生成器)

原始GAN生成器 (model.py)

  • 架构:U-Net(编码器-解码器 + 跳跃连接)
  • 输入:(batch, 3, 218, 178) - 缺损图像
  • 输出:(batch, 3, 218, 178) - 修复图像
  • 特点:Instance Normalization + LeakyReLU

FFC生成器 (saicinpainting/training/modules/ffc.py)

  • 架构:FFCResNetGenerator(Fast Fourier Convolution + ResNet)
  • 输入:(batch, 4, 218, 178) - RGB + Mask通道
  • 输出:(batch, 3, 218, 178) - 修复图像
  • 特点:全局感受野 + 预训练权重支持

2. Discriminator(判别器)

  • 架构:PatchGAN(局部真实性判断)
  • 输入:(batch, 3, 218, 178)
  • 输出:(batch, 1, 12, 10) - 局部真实性得分
  • 特点:Spectral Normalization(提升训练稳定性)

3. MaskGenerator(掩码生成器)

QD-IMD模式 (dataset.py)

  • 加载真实不规则掩码图像
  • 支持随机旋转、缩放增强
  • 二值化处理(1=保留,0=遮挡)

算法生成模式

  • rect: 单一大矩形遮挡
  • random: 随机x个小矩形遮挡

4. ImageMetrics(评估指标)

FID (Fréchet Inception Distance)

  • 使用InceptionV3提取2048维特征
  • 计算真实图像与生成图像的分布距离
  • 范围:[0, +∞),越小越好(优秀模型<50)

SSIM (Structural Similarity Index)

  • 评估亮度、对比度、结构三方面相似度
  • 范围:[0, 1],越接近1越好

Perceptual Loss(感知损失)

  • 使用VGG16提取relu3_3层特征
  • 度量高层语义差异(比像素级L1更符合人类感知)

掩码类型选择

掩码类型 难度 训练速度 实际应用
rect 简单 适合测试模型基本能力
random 中等 模拟噪声干扰
qd_imd 困难 最贴近真实场景(推荐)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors