import torch
import torchvision.models as models
# 加载预训练的 ResNet-50 模型
model = models.resnet50(pretrained=False)
model.eval()
# 示例输入
example_input = torch.randn(1, 3, 224, 224)
# 将模型转换为 TorchScript
script_model = torch.jit.trace(model, example_input)
# 保存 TorchScript 模型
script_model.save("resnet_script_model.pt")
import torch
from torchvision.models import resnet
# 构建相应的模型架构
model = resnet.resnet50() # 根据你的模型类型进行修改
# 加载 TorchScript 模型的参数权重
model.load_state_dict(torch.jit.load("resnet_script_model.pt").state_dict())
# 保存为.pth格式
torch.save(model.state_dict(), "resnet_model.pth")
# 加载预训练的 ResNet 模型
# model = models.resnet50(pretrained=False) # 这里使用了一个预训练的 ResNet-50 模型,你可以根据自己的模型类型进行修改
# 加载模型权重
model.load_state_dict(torch.load("resnet_model.pth"))
# 设置模型为评估模式
model.eval()
# 示例输入
example_input = torch.randn(1, 3, 224, 224)
# 导出为 ONNX 格式
torch.onnx.export(model, example_input, "resnet_model.onnx", export_params=True, opset_version=12)