TensorFlow知识图谱实战
上QQ阅读APP看书,第一时间看更新

2.3.1 ResNet50模型和参数的载入

首先是模型的载入,笔者选择ResNet模型作为载入的目标,即将图2.5中倒数第4个模型作为载入,导入代码如下:

     resnet = tf.keras.applications.ResNet50() #(载入可能卡住,下文有解决办法)

如果是第一次载入这个模型,那么在终端上会显示如图2.16所示的信息。

图2.16 第一次载入

这是因为第一次载入时Keras在载入模型的同时会将模型默认参数下载并载入,可能会由于网络原因卡住,因此模型终端有可能在此停止运行。解决的办法非常简单,使用PyCharm运行程序时会在交互端出现蓝色字符链接,点击下载即可,之后显式地告诉Keras参数的位置,代码如下:

这里weight函数显式地告诉模型所需要载入的参数位置。

注意

由于是显式地引入参数地址,因此需要写成绝对地址。

下面看一下ResNet50模型在Keras中的源码定义,代码如图2.17所示。

图2.17 ResNet50模型的源码定义

classes参数是ResNet基于imagenet数据集预训练的分类数,一般而言,使用预训练模型是用作特征提取,而不是完整的使用模型作为同样的“分类器”,因此直接屏蔽掉最上面一层的分类层即可,代码可以改成如下形式:

使用summary函数可以将ResNet模型的结构打印出来,如图2.18所示。

图2.18 ResNet模型的结构

可以看到这里的模型最后几层的名称和参数,这是已经载入模型参数后的模型结构。

可能有读者对include_top=False这个参数设置有疑问,实际上笔者在这里做的是基于已训练模型为基础的“迁移学习”任务。迁移学习是将已训练模型去掉最高层的顶端输出层作为新任务的特征提取器,即这里利用“imagenet”预训练的特征提取方法迁移到目标数据集上,并根据目标任务追加新的层作为特定的“接口层”,从而在目标任务上快速、高效地学习新的任务。

【程序2-12】