Skip to content

Deform conv2d mps support #9026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
93e044b
Start of branch
goldfishsound Oct 7, 2024
d838bf7
Setting up for development
goldfishsound Oct 10, 2024
95eb1cd
Initial commit for deform_conv2d for MPS
goldfishsound Oct 25, 2024
c53e1bd
New mps kernel for deform_conv2d and updated shader functions in kern…
goldfishsound Nov 12, 2024
1153b84
Renaming source file.
goldfishsound Nov 15, 2024
1c87a26
Changed part of the file name from _kernal to _kernel
goldfishsound Nov 15, 2024
8a984de
Remove files in product dir
goldfishsound Nov 16, 2024
970183d
Removing framework dir and included files.
goldfishsound Nov 16, 2024
2895f4f
Removing build_xcode dir and included files.
goldfishsound Nov 16, 2024
2f06f7f
Changed location references to pytorch
goldfishsound Nov 16, 2024
66d76d3
Clean up git - Removing .DS_Store
goldfishsound Nov 16, 2024
c8eb2ea
Altering the kernel deformable_im2col to mimic the cpp kernel impleme…
goldfishsound Nov 16, 2024
c92eaa4
Re-ordering include sequence
goldfishsound Nov 16, 2024
b445aed
Including mps in TestDeformConv::test_is_leaf_node
goldfishsound Nov 16, 2024
951880c
Updates gitignore
goldfishsound Dec 1, 2024
1aa7c0b
Merge branch 'main' into deform_conv2d_mps_support
goldfishsound Dec 1, 2024
83080da
Merge branch 'pytorch:main' into deform_conv2d_mps_support
goldfishsound Dec 2, 2024
9f68fd4
Update .gitignore
goldfishsound Dec 2, 2024
dc305ae
CleanUp
goldfishsound Dec 2, 2024
e25e620
Cleaned up - removed added exclusions.
goldfishsound Dec 4, 2024
e4fb8c5
Updated
goldfishsound Dec 4, 2024
e39867f
Removed CMakePresets.json
goldfishsound Dec 4, 2024
3e2bc0e
Updated to exclude CMakePresets.json
goldfishsound Dec 4, 2024
bd62ab3
Added bilinear_interpolate_2 function which is identical to the one u…
goldfishsound Mar 4, 2025
350454f
Reorganized the numbering of argumnet indexes in img2col
goldfishsound Mar 6, 2025
b31a28c
Added threadgroups_per_grid to deformable_col2im and deformable_col2i…
goldfishsound Mar 11, 2025
358dacc
Added printTensor utility function - only temporarily
goldfishsound Mar 11, 2025
7da876a
Modifying TestDeformConv to include mps tests.
goldfishsound Mar 11, 2025
da6134d
Merge branch 'pytorch:main' into deform_conv2d_mps_support
goldfishsound Apr 19, 2025
25a2944
House Cleaning:
goldfishsound Mar 11, 2025
bf7784d
Renaming of bilinear_interpolate2 to bilinear_interpolate_deform_conv2d
goldfishsound Apr 20, 2025
9d3105f
Added constant mps_backward_eps for eps in backward test.
goldfishsound Apr 20, 2025
501d617
Removed unused includes
goldfishsound Apr 20, 2025
a294c2e
Delete
goldfishsound Apr 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Start of branch
  • Loading branch information
goldfishsound committed Oct 7, 2024
commit 93e044b0cfa0c831877c161f1409ae87b8398814
Binary file added .DS_Store
Binary file not shown.
5 changes: 3 additions & 2 deletions android/gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#Tue Aug 27 15:56:14 CEST 2024
distributionBase=GRADLE_USER_HOME
distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
Binary file added framework/.DS_Store
Binary file not shown.
Binary file added framework/include/.DS_Store
Binary file not shown.
Binary file added framework/include/torchvision/.DS_Store
Binary file not shown.
26 changes: 26 additions & 0 deletions framework/include/torchvision/io/image/cpu/common_jpeg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "common_jpeg.h"

namespace vision {
namespace image {
namespace detail {

#if JPEG_FOUND
void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
* pointer */
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;

/* Always display the message. */
/* We could postpone this until after returning, if we chose. */
// (*cinfo->err->output_message)(cinfo);
/* Create the message */
(*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg);

/* Return control to the setjmp point */
longjmp(myerr->setjmp_buffer, 1);
}
#endif

} // namespace detail
} // namespace image
} // namespace vision
27 changes: 27 additions & 0 deletions framework/include/torchvision/io/image/cpu/common_jpeg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#if JPEG_FOUND
#include <stdio.h>

#include <jpeglib.h>
#include <setjmp.h>

namespace vision {
namespace image {
namespace detail {

static const JOCTET EOI_BUFFER[1] = {JPEG_EOI};
struct torch_jpeg_error_mgr {
struct jpeg_error_mgr pub; /* "public" fields */
char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */
jmp_buf setjmp_buffer; /* for return to caller */
};

using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*;
void torch_jpeg_error_exit(j_common_ptr cinfo);

} // namespace detail
} // namespace image
} // namespace vision

#endif
6 changes: 6 additions & 0 deletions framework/include/torchvision/io/image/cpu/common_png.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

#if PNG_FOUND
#include <png.h>
#include <setjmp.h>
#endif
41 changes: 41 additions & 0 deletions framework/include/torchvision/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "decode_image.h"

#include "decode_jpeg.h"
#include "decode_png.h"

namespace vision {
namespace image {

torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode,
bool apply_exif_orientation) {
// Check that tensor is a CPU tensor
TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor");
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");

auto datap = data.data_ptr<uint8_t>();

const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"

if (memcmp(jpeg_signature, datap, 3) == 0) {
return decode_jpeg(data, mode, apply_exif_orientation);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decode_png(
data, mode, /*allow_16_bits=*/false, apply_exif_orientation);
} else {
TORCH_CHECK(
false,
"Unsupported image file. Only jpeg and png ",
"are currently supported.");
}
}

} // namespace image
} // namespace vision
15 changes: 15 additions & 0 deletions framework/include/torchvision/io/image/cpu/decode_image.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool apply_exif_orientation = false);

} // namespace image
} // namespace vision
Loading