torch.utils.tensorboard
是 PyTorch 中用于集成 TensorBoard 的一个模块,方便用户在训练模型时记录和可视化各种指标。TensorBoard 是一个强大的可视化工具,常用于监控和调试机器学习模型的训练过程。
主要功能
- 标量记录:可以记录训练过程中的标量数据,如损失值和准确率。这些数据可以在 TensorBoard 中以图表的形式呈现,帮助分析模型的训练效果。
- 图像记录:支持记录训练过程中生成的图像,可以用于可视化模型生成的输出或输入数据。
- 模型图记录:能够记录和可视化模型的计算图,这对于理解模型的结构和数据流非常有帮助。
- 分布记录:可以记录数据的分布情况,帮助观察模型在训练过程中参数或损失的变化。
- 直方图记录:可以记录权重、梯度等的直方图,方便观察这些参数在训练过程中的变化。
示例代码
import torch
from torch.utils.tensorboard import SummaryWriter
# 创建一个 SummaryWriter 实例
writer = SummaryWriter('logs')
# 记录训练损失
for epoch in range(10):
loss = 0.1 * (10 - epoch) # 示例损失
writer.add_scalar('Loss/train', loss, epoch)
# 记录模型图
dummy_input = torch.randn(1, 3, 224, 224) # 示例输入
model = ... # 假设 model 是你的模型
writer.add_graph(model, dummy_input)
# 关闭 SummaryWriter
writer.close()
命令行运行
tensorboard --logdir=logs