Skip to content

add ssim psnr metric #1282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions ppdiffusers/scripts/ssim_psnr_score/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SSIM and PSNR

SSIM(Structural Similarity Index)是一种用于衡量两幅图像结构相似度的指标,常用于图像质量评价任务。与像素级别的误差不同,SSIM 模拟人类视觉系统从亮度、对比度和结构等多个维度来评估图像之间的差异。其取值范围为 [-1, 1],其中 1 表示两张图像完全相同,值越高说明图像质量越接近参考图像。

PSNR(Peak Signal-to-Noise Ratio)是一种基于像素差异的图像质量评估指标,用于衡量图像压缩或生成后与参考图像之间的误差大小。它通过最大像素值与均方误差(MSE)之间的比值计算得出,通常以分贝(dB)为单位。PSNR 值越高表示重建图像与原图越接近,图像失真越小。对于 8-bit 图像,PSNR 值通常大于 30dB 被认为质量良好。



## 依赖
- math
- numpy==1.26.4
- cv2

## 快速使用
计算两个图片数据集的SSIM与PSNR, `path/to/dataset1`/`path/to/dataset2`为图片文件夹

```
python evaluation.py --dataset1 path/to/dataset1 --dataset2 path/to/dataset2
```
图片数据集的结构应如下:
```shell
├── dataset

├── 00000.png
├── 00001.png
......
├── 00999.png

```

参数说明
- `num-workers`: 用于加载数据的子进程个数,默认为`min(8, num_cpus)`。
- `resolution`:调整图片的分辨率


## 参考
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里参考的内容是不是不对

- [https://github.com/ali-vilab/TeaCache](https://github.com/ali-vilab/TeaCache)
94 changes: 94 additions & 0 deletions ppdiffusers/scripts/ssim_psnr_score/calculate_psnr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件头部都加下paddle的 copyright吧


import numpy as np
import paddle


def img_psnr(img1, img2):
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
# compute psnr
if mse < 1e-10:
return 100
psnr = 20 * math.log10(1 / math.sqrt(mse))
return psnr




def calculate_psnr(videos1, videos2):
# videos [batch_size, timestamps, channel, h, w]

assert videos1.shape == videos2.shape

psnr_results = []

for video_num in range(videos1.shape[0]):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]

psnr_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] numpy

img1 = video1[clip_timestamp].numpy()
img2 = video2[clip_timestamp].numpy()

# calculate psnr of a video
psnr_results_of_a_video.append(img_psnr(img1, img2))

psnr_results.append(psnr_results_of_a_video)

psnr_results = np.array(psnr_results)

psnr = {}
psnr_std = {}

for clip_timestamp in range(len(video1)):
psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])

result = {
"value": psnr,
"value_std": psnr_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}

return result


def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = paddle.zeros(shape=[NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE])
videos2 = paddle.zeros(shape=[NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE])
paddle.set_device("gpu")

import json

result = calculate_psnr(videos1, videos2)
print(json.dumps(result, indent=4))


if __name__ == "__main__":
main()
126 changes: 126 additions & 0 deletions ppdiffusers/scripts/ssim_psnr_score/calculate_ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cv2
import numpy as np
import paddle


def ssim(img1, img2):
C1 = 0.01**2
C2 = 0.03**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()


def calculate_ssim_function(img1, img2):
# [0,1]
# ssim is the only metric extremely sensitive to gray being compared to b/w
if not img1.shape == img2.shape:
raise ValueError("Input images must have the same dimensions.")
if img1.ndim == 2:
return ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[0] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1[i], img2[i]))
return np.array(ssims).mean()
elif img1.shape[0] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError("Wrong input image dimensions.")


def trans(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数有啥作用?

return x


def calculate_ssim(videos1, videos2):
# videos [batch_size, timestamps, channel, h, w]

assert videos1.shape == videos2.shape

videos1 = trans(videos1)
videos2 = trans(videos2)

ssim_results = []

for video_num in range(videos1.shape[0]):
# get a video
# video [timestamps, channel, h, w]
video1 = videos1[video_num]
video2 = videos2[video_num]

ssim_results_of_a_video = []
for clip_timestamp in range(len(video1)):
# get a img
# img [timestamps[x], channel, h, w]
# img [channel, h, w] numpy

img1 = video1[clip_timestamp].numpy()
img2 = video2[clip_timestamp].numpy()

# calculate ssim of a video
ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))

ssim_results.append(ssim_results_of_a_video)

ssim_results = np.array(ssim_results)

ssim = {}
ssim_std = {}

for clip_timestamp in range(len(video1)):
ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])

result = {
"value": ssim,
"value_std": ssim_std,
"video_setting": video1.shape,
"video_setting_name": "time, channel, heigth, width",
}

return result

def main():
NUMBER_OF_VIDEOS = 8
VIDEO_LENGTH = 50
CHANNEL = 3
SIZE = 64
videos1 = paddle.zeros(shape=[NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE])
videos2 = paddle.zeros(shape=[NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE])
paddle.set_device("gpu")

import json

result = calculate_ssim(videos1, videos2)
print(json.dumps(result, indent=4))


if __name__ == "__main__":
main()
126 changes: 126 additions & 0 deletions ppdiffusers/scripts/ssim_psnr_score/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
import paddle
import sys


from calculate_ssim import calculate_ssim_function
from calculate_psnr import img_psnr
from ppdiffusers import StableDiffusionXLPipeline, PixArtAlphaPipeline, StableVideoDiffusionPipeline
from ppdiffusers import UNet2DConditionModel, LCMScheduler,FluxPipeline
from ppdiffusers import DPMSolverMultistepScheduler
from ppdiffusers.utils import load_image, export_to_video

import paddle.vision.transforms as TF
from tqdm import tqdm
import pathlib
import re
import numpy as np
from PIL import Image
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'fid_clip_score')))
from fid_score import ImagePathDataset
def extract_number(filename):
filename = os.path.basename(filename)
match = re.search(r'\d+', filename)
return int(match.group()) if match else float('inf')

IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}

def parse_args():
parser = argparse.ArgumentParser(description="Simple example of TGATE V2.")
parser.add_argument(
"--dataset1",
type=str,
default=None,
required=True,
help="Path to save the original generated results.",
)
parser.add_argument(
"--dataset2",
type=str,
default=None,
required=True,
help="Path to save the speed up generated results.",
)
parser.add_argument(
"--resolution",
type=int,
default=None,
help="The resolution to resize."
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size to use")
parser.add_argument("--num_workers", type=int, default=1, help="Number of workers to use for data loading")

args = parser.parse_args()
return args


if __name__ == '__main__':
args = parse_args()

gen_path = pathlib.Path(args.dataset1)
gen_files = sorted([file for ext in IMAGE_EXTENSIONS for file in gen_path.glob("*.{}".format(ext))],key=extract_number)
# get dataset1 path
dataset_gen = ImagePathDataset(gen_files, transforms=TF.ToTensor(), resolution=args.resolution)
dataloader_gen = paddle.io.DataLoader(
dataset_gen,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers,
)
#get dataset2 path
speedgen_path = pathlib.Path(args.dataset2)
files = sorted([file for ext in IMAGE_EXTENSIONS for file in speedgen_path.glob("*.{}".format(ext))],key=extract_number)
dataset_speedgen = ImagePathDataset(files, transforms=TF.ToTensor(), resolution=args.resolution)
dataloader_speedgen = paddle.io.DataLoader(
dataset_speedgen,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=args.num_workers,
)
print(len(dataloader_gen))
print(len(dataloader_speedgen))
ssim_value_list=[]
psnr_value_list=[]
# calculate ssim与psnr
for batch_gen, batch_speedgen in tqdm(zip(dataloader_gen, dataloader_speedgen),
total=len(dataloader_gen),
desc="Calculating SSIM and PSNR"):
batch_speedgen = batch_speedgen["img"]
batch_gen = batch_gen["img"]
batch_speedgen = batch_speedgen.squeeze().numpy() # 将Tensor转换为numpy数组,并调整通道顺序
batch_gen = batch_gen.squeeze().numpy()
ssim_value = calculate_ssim_function(batch_gen,batch_speedgen)
psnr_value = img_psnr(batch_gen,batch_speedgen)
ssim_value_list.append(ssim_value)
psnr_value_list.append(psnr_value)
mean_ssim = np.mean(ssim_value_list)
mean_psnr = np.mean(psnr_value_list)
from pathlib import Path

path = Path(args.dataset1)
parent_path = path.parent
# save the result
res_txt = os.path.basename(args.dataset2)
with open(os.path.join(parent_path, f"{res_txt}.txt"), "w") as f: # ← 注意这里用 "a"
f.write(f"mean_ssim: {mean_ssim}\n")
f.write(f"mean_psnr: {mean_psnr}\n")
print('mean_ssim: ',mean_ssim,'mean_psnr: ',mean_psnr)