打开微信,使用扫一扫进入页面后,点击右上角菜单,
点击“发送给朋友”或“分享到朋友圈”完成分享
【星空体育硬件产品型号】必填*:
例如:MLU220
必填项,例如:MLU220
【使用操作系统】必填*:
例如:ubuntu 16.04
【出错信息】必填*:
【当前已做了哪些信息确认】选填:
基于cpu运行下模型能成功trace,但当tensor和model放到mlu上报错
【出错代码链接】选填:
if args.cfg == 'qua':
qconfig={'use_avg':False, 'data_scale':1.0, 'firstconv':False, 'per_channel': False}
quantized_net = mlu_quantize.quantize_dynamic_mlu(model,qconfig_spec=qconfig, dtype='int8', gen_quant=True)
elif args.cfg == 'mlu':
if args.fake_device:
ct.set_device(-1)
ct.set_core_version('MLU220')
model.to(torch.device('cpu'))
quantized_net = torch_mlu.core.mlu_quantize.quantize_dynamic_mlu(model)
state_dict = torch.load("./detr_int8.pt")
quantized_net.load_state_dict(state_dict, strict=False)
quantized_net.eval()
quantized_net.to(ct.mlu_device())
# quantized_net.to(torch.device("cpu"))
if args.jit:
print("### jit")
ct.save_as_cambricon('detr_int8_1_4')
torch.set_grad_enabled(False)
ct.set_core_number(4)
trace_input = torch.randn(1, 3, 640, 640, dtype=torch.float)
input_mlu_data = trace_input.to(ct.mlu_device())
# input_mlu_data = trace_input.to(torch.device("cpu"))
quantized_net = torch.jit.trace(quantized_net, input_mlu_data, check_trace = False)
outputs_class,outputs_coord = quantized_net(trace_input)
# quantized_net = torch.jit. (quantized_net)
ct.save_as_cambricon('')
热门帖子
精华帖子