×

签到

分享到微信

打开微信,使用扫一扫进入页面后,点击右上角菜单,

点击“发送给朋友”或“分享到朋友圈”完成分享

detr转离线模型报错 Ashpool2024-04-22 20:11:58 回复 查看 技术答疑 使用求助 社区交流
detr转离线模型报错

【星空体育硬件产品型号】必填*:
例如:MLU220
必填项,例如:MLU220


【使用操作系统】必填*:
例如:ubuntu 16.04




【出错信息】必填*:(2]BVQLKORBB3GHIBUQCV8I.png


【当前已做了哪些信息确认】选填:
基于cpu运行下模型能成功trace,但当tensor和model放到mlu上报错

【出错代码链接】选填:

    if args.cfg == 'qua':

        qconfig={'use_avg':False, 'data_scale':1.0, 'firstconv':False, 'per_channel': False}

        quantized_net = mlu_quantize.quantize_dynamic_mlu(model,qconfig_spec=qconfig, dtype='int8', gen_quant=True)

    elif args.cfg == 'mlu':

        if args.fake_device:

            ct.set_device(-1)

            ct.set_core_version('MLU220')

        model.to(torch.device('cpu'))

        quantized_net = torch_mlu.core.mlu_quantize.quantize_dynamic_mlu(model)

        state_dict = torch.load("./detr_int8.pt")

        quantized_net.load_state_dict(state_dict, strict=False)

        quantized_net.eval()

        quantized_net.to(ct.mlu_device())

        # quantized_net.to(torch.device("cpu"))

        if args.jit:

            print("### jit")

            ct.save_as_cambricon('detr_int8_1_4')

            torch.set_grad_enabled(False)

            ct.set_core_number(4)

            trace_input = torch.randn(1, 3, 640, 640, dtype=torch.float)

            input_mlu_data = trace_input.to(ct.mlu_device())

            # input_mlu_data = trace_input.to(torch.device("cpu"))

            quantized_net = torch.jit.trace(quantized_net, input_mlu_data, check_trace = False)

            outputs_class,outputs_coord = quantized_net(trace_input)

            # quantized_net = torch.jit. (quantized_net)

            ct.save_as_cambricon('')


版权所有 © 2024 星空体育 备案/许可证号:京ICP备17003415号-1
关闭