# 检查是否已经有标签文件 labels_path = "e:\\code\\.idea\\pytorch\\imagenet_classes.txt" if not os.path.exists(labels_path): # 从GitHub下载标签文件 url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" response = requests.get(url) if response.status_code == 200: with open(labels_path, 'wb') as f: f.write(response.content) print(f"标签文件已下载到 {labels_path}") # 读取下载的标签文件 with open(labels_path) as f: labels = [line.strip() for line in f.readlines()] else: print("无法下载标签文件,使用内置的部分标签") else: # 读取已有的标签文件 with open(labels_path) as f: labels = [line.strip() for line in f.readlines()]
最后就直接输出结果
1 2 3 4 5 6 7 8
# 获取预测结果 _, indices = torch.sort(out, descending=True) percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100 print("预测的前5个类别:") for idx in indices[0][:5]: print(f"{labels[idx]}: {percentage[idx].item():.2f}%") _, indices = torch.sort(out, descending=True) [(labels[idx],percentage[idx].item()) for idx in indices[0][:5]]