torch max()函数
torch.max()返回的是两个值, 第一个是最大值, 第二个是最大值所在的索引, 一般情况,我们都是求最大值所在的索引
import torch
a = torch.tensor([[1, 5, 2, 1], [2, 6, 3, 8]])
print(a)
res, index = torch.max(a, 1)
print(res)
print(index)
只用最大值索引求准确率:
# 准确率的计算
# 100个样本, 10 个类别
predict = torch.rand(100, 10)
label = torch.randint(10, (100,), dtype=torch.int64)
pred_y = torch.max(predict, 1)[1].numpy()
y_label = label.numpy()
accuracy = (pred_y == y_label).sum() / len(y_label)
print("准确率:", accuracy)
结果为
准确率: 0.21
这里是取的随机数, 结果不重要