0%

数据可视化tensorboard

数据可视化,指的是训练神经网络过程中对loss、preview等做出的记录和展示。

为什么需要可视化

我曾经有很长一段时间是闭着眼睛训练的。原因是初出茅庐,拿到开源的代码之后虔诚学习,不敢改动,生怕代码被我改坏了。于是便按照github上的指示输入类似“python3 train.py”的进行训练。论文已经开源的代码通常能够训练出收敛的效果,因为中间探索正确训练参数的过程已经由论文作者一个人走过了,其中的艰辛甚至不必写入论文。最终只需要把探索过程所用的代码删去,保留美丽的最终成果,挂到github上,并告诉追随者“汝只需输入这样一行训练指令,训练xxx个epoch,便可以证实我论文中说的结果是对的”。

假如只是复现,后来者确实可以这样闭着眼睛训练出成果。但问题是假如我已经熟读那份开源代码,觉得它有些地方可以改进,但改进之后势必产生不同的训练结果,我又需要观察训练过程中loss是否收敛,训练中实时的效果展示,该怎么做呢?或者有的人就算不改动代码,也想亲眼看看loss收敛下降那美丽的过程,该怎么做呢?

用tensorboard!

tensorboard本是tensorflow的一个组件。现在很多训练用的是pytorch,pytorch也可以用tensorboard来可视化。使用方法可以参照这篇:PyTorch下的Tensorboard 使用

loss

具体操作的时候,比方说我想确定一个自创的网络模型,应当使用多大的学习率lr比较好,那么我就设立多次对比试验,每次存到一个单独的文件夹下:

\tf-logs\lr=0.001
\tf-logs\lr=0.01
……

但是在写入loss的时候,都写入同一个loss里:

1
2
#loss输出到tensorboard可视化
writer.add_scalar('shlight_loss/total_loss', loss, epoch)

最终效果就像这样:
tensorboard

将两次(lr不同的)训练结果绘制到同一张图里。可以清晰地看出两种lr设定下loss下降的快慢高低。
其中,total_loss子图是shlight_loss图的一部分。只改变total_loss,但是“shlight_loss/”保持不变,都会绘制到shlight_loss类别下。

preview

多次训练的preview可以统一写成:

1
2
writer.add_image('preview', np.array(im), epoch, walltime=None, dataformats='HWC') ###当前训练的preview图示仅仅保存最新的几十张

这样的效果是:

preview

注意点

tensorboard的版本

tensorboard的版本需要固定。旧版本tensorboard写入的log用新版本tensorboard很可能就看不了了,反之亦然。

由于训练很长,或者训练新的项目会搭建新的环境,tensorboard的版本就很容易被改变。这时候要尽量保证tensorboard的版本在写log和读取log时一致。

与tensorflow的关系

tensorboard本身是tensorflow的一部分。假如环境中原本有新版本tensorboard,再安装一个旧版本的tensorflow就会把tensorboard版本连带退回到旧版本。

这时候kill掉之前后台运行的新版本tensorboard,并重新运行tensorboard,后台中跑的就是旧版本tensorboard。再试图用旧版本tensorboard读取之前新版本tensorboard写入的那些log,你将面对一片空白,啥也看不到。

所以,当环境中已经有了tensorboard,且已经用它写过很多log,这时候安装tensorflow就要小心了。最好新建一个全新的环境安装tensorflow,写入和读取log都在新环境里做,和旧的环境隔离。

tensorboard不是tensorboard

仅仅运行一句pip install tensorboard==2.0.2,安装的就仅仅是这一个包。

但是单独这个包可能没法完成log写入和读取。

因为执行pip install tensorboard,你会发现装了三个包:

1
2
3
tensorboard              2.11.2
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1

所以,假如你的tensorboard没法读取之前写的log了,你要同时确认以上三个包在曾经和现在的版本是否一致。我就是往之前的环境里装tensorflow导致tensorboard读不了log了,手动把以上三个包都卸掉之后,回顾了之前的教程,但是只安装了tensorboard==2.0.2,发现还是打不开log。最终看了之前用过的另一台服务器上的pip list,发现tensorboard是由以上三个包共同组成的。重新pip install tensorboard不加版本号,自动安装三个包,才把问题解决。