ph-pth-onnx

import torch
import torchvision.models as models

# 加载预训练的 ResNet-50 模型
model = models.resnet50(pretrained=False)
model.eval()

# 示例输入
example_input = torch.randn(1, 3, 224, 224)

# 将模型转换为 TorchScript
script_model = torch.jit.trace(model, example_input)

# 保存 TorchScript 模型
script_model.save("resnet_script_model.pt")
import torch
from torchvision.models import resnet

# 构建相应的模型架构
model = resnet.resnet50()  # 根据你的模型类型进行修改

# 加载 TorchScript 模型的参数权重
model.load_state_dict(torch.jit.load("resnet_script_model.pt").state_dict())

# 保存为.pth格式
torch.save(model.state_dict(), "resnet_model.pth")


# 加载预训练的 ResNet 模型
# model = models.resnet50(pretrained=False)  # 这里使用了一个预训练的 ResNet-50 模型,你可以根据自己的模型类型进行修改

# 加载模型权重
model.load_state_dict(torch.load("resnet_model.pth"))

# 设置模型为评估模式
model.eval()

# 示例输入
example_input = torch.randn(1, 3, 224, 224)

# 导出为 ONNX 格式
torch.onnx.export(model, example_input, "resnet_model.onnx", export_params=True, opset_version=12)