使用 TensorFlow 2.x API 介绍图像中的显着性图
磐创AITensorFlow 2.x 简介
在计算机视觉领域中,显着性图是图像在人类视线最初聚焦的区域。显着性图的主要目标是突出特定像素对人类视觉感知的重要性。
例如,在下图中,花和手是人们首先注意到的东西,因此必须在显着性图上强调它们。还有一点需要注意的是,人工神经网络创建的显着性图并不总是与生物或自然视觉产生的显着性图相同。
什么是显着图?
显着性图是深度学习和计算机视觉的一个关键主题。在深度卷积神经网络(CNN)的训练过程中,了解每一层的特征图变得至关重要。CNN 的特征图告诉我们模型的学习特征。而显着性图主要关注图像的特定像素,而忽略其他像素。
显着性图
图像的显着性图表示图像中最突出和最集中的像素。有时,图像中较亮的像素会告诉我们像素的显着性。这意味着像素的亮度与图像的显着性成正比。
假设我们想要关注图像的特定部分,比如想要关注鸟的图像而不是天空、鸟巢等其他部分。然后通过计算显着图,我们将实现这一点。它将有助于降低计算成本,通常是灰度图像,但可以根据我们的视觉舒适度转换为另一种格式的彩色图像。
显着性图也称为“热图”,因为图像的热度/亮度对识别对象的类别有影响。显着性图旨在确定中央凹(高分辨率的颜色)中每个地方显着或可观察的区域,并根据显着性的空间模式影响注意力区域的决策。它用于各种视觉注意模型。
如何使用 TensorFlow 计算显着性图?
显着图可以通过对输入图像 X 取类别概率 Pk的导数来计算。
等一下!这似乎很熟悉!是的,这与我们用于训练模型的反向传播相同。我们只需要再迈出一步:梯度不会在我们网络的第一层停止。相反,我们必须将其返回给输入图像 X。
因此,显着性图根据特定的类别预测 Pi 为每个输入像素提供合适的表征。对花卉预测具有重要意义的像素应聚集在花卉像素周围。否则,经过训练的模型会发生一些非常奇怪的事情。
显着图的优势在于,由于它们完全依赖于梯度计算,许多常用的深度学习模型可以免费为我们提供显着图。我们根本不需要修改网络架构;我们只需要稍微调整梯度计算。
不同类型的显着图
静态显着性:针对图像的每个静态像素点计算出重要的感兴趣区域,进行显着性图分析。
动态显着性:关注视频数据的动态特征。视频中的显着性图是通过计算视频的光流来计算的。移动实体/对象被认为是显着对象。
代码
我们将逐步研究 ResNet50 架构,该架构已在 ImageNet 上进行了预训练。但是你可以采用其他预训练的深度学习模型或你自己的训练模型。
我们将说明如何利用 TensorFlow 2.x 中最著名的 DL 模型开发基本的显着性图。在教程中,我们使用了 Wikimedia 图像作为测试图像。
我们首先创建一个具有 ImageNet 权重的 ResNet50。使用简单的辅助函数,我们将图像导入并准备将其馈送到 ResNet50。
# Import necessary packages
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def input_img(path):
image = tf.image.decode_png(tf.io.read_file(path))
image = tf.expand_dims(image, axis=0)
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [224,224])
return image
def normalize_image(img):
grads_norm = img[:,:,0]+ img[:,:,1]+ img[:,:,2]
grads_norm = (grads_norm - tf.reduce_min(grads_norm))/ (tf.reduce_max(grads_norm)- tf.reduce_min(grads_norm))
return grads_norm
def get_image():
import urllib.request
filename = 'image.jpg'
img_url = r"https://upload.wikimedia.org/wikipedia/commons/d/d7/White_stork_%28Ciconia_ciconia%29_on_nest.jpg"
urllib.request.urlretrieve(img_url, filename)
def plot_maps(img1, img2,vmin=0.3,vmax=0.7, mix_val=2):
f = plt.figure(figsize=(15,45))
plt.subplot(1,3,1)
plt.imshow(img1,vmin=vmin, vmax=vmax, cmap="ocean")
plt.axis("off")
plt.subplot(1,3,2)
plt.imshow(img2, cmap = "ocean")
plt.axis("off")
plt.subplot(1,3,3)
plt.imshow(img1*mix_val+img2/mix_val, cmap = "ocean" )
plt.axis("off")
图1:输入图像
为了获得预测向量,ResNet50 将直接从 Keras 应用程序中加载。
test_model = tf.keras.applications.resnet50.ResNet50()
#test_model.summary()
get_image()
img_path = "image.jpg"
input_img = input_img(img_path)
input_img = tf.keras.applications.densenet.preprocess_input(input_img)
plt.imshow(normalize_image(input_img[0]), cmap = "ocean")
result = test_model(input_img)
max_idx = tf.argmax(result,axis = 1)
tf.keras.applications.imagenet_utils.decode_predictions(result.numpy())
TensorFlow 2.x 上提供了 GradientTape 函数,该函数能够处理反向传播相关操作。在这里,我们将利用 GradientTape 的优势来计算给定图像的显着性图。
with tf.GradientTape() as tape:
tape.watch(input_img)
result = test_model(input_img)
max_score = result[0,max_idx[0]]
grads = tape.gradient(max_score, input_img)
plot_maps(normalize_image(grads[0]), normalize_image(input_img[0]))
图2:(1)Saliency_map,(2)input_image,(3)overlayed_image
关于Tensorflow 2.x 的结论
在这篇博客中,我们从不同方面定义了显着性图。我们添加了一个图形表示来深入理解“显着性地图”这个术语。此外,我们通过使用 TensorFlow API 在 python 中实现它来理解它。结果似乎很容易理解。
在本文中,我们学习了:
使用 tensorflow 的图像的显着性图
实现了一个 python 代码来计算图像的显着性图
显着性图的数学背景
计算了显着性图
github_repo:https://github.com/Rabusi/Saliency_Map_in_DL