diff --git a/disco_xform_utils.py b/disco_xform_utils.py index 107b602..1e75a7c 100644 --- a/disco_xform_utils.py +++ b/disco_xform_utils.py @@ -28,7 +28,7 @@ def transform_image_3d(img_filepath, midas_model, midas_transform, device, rot_m predictions using nyu dataset """ print("Running AdaBins depth estimation implementation...") - infer_helper = InferenceHelper(dataset='nyu') + infer_helper = InferenceHelper(dataset='nyu', device=device) image_pil_area = w*h if image_pil_area > MAX_ADABINS_AREA: @@ -128,4 +128,4 @@ def get_spherical_projection(H, W, center, magnitude,device): d = center - grid d_sum = torch.sqrt((d**2).sum(axis=-1)) grid += d * d_sum.unsqueeze(-1) * magnitude - return grid.unsqueeze(0) \ No newline at end of file + return grid.unsqueeze(0)