Skip to content

Commit 0a893ab

Browse files
adamlererapaszke
authored andcommitted
fix serialization bug for large files
1 parent 34fa5e0 commit 0a893ab

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

torch/csrc/generic/serialization.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,19 @@ void THPStorage_(writeFileRaw)(THStorage *self, int fd)
3939
SYSCHECK(write(fd, &self->size, sizeof(long)));
4040
// fast track for bytes and little endian
4141
if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
42-
SYSCHECK(write(fd, data, sizeof(real) * self->size));
42+
char *bytes = (char *) data;
43+
uint64_t remaining = sizeof(real) * self->size;
44+
while (remaining > 0) {
45+
ssize_t result = write(fd, bytes, remaining);
46+
if (result < 0)
47+
throw std::system_error(result, std::system_category());
48+
bytes += result;
49+
remaining -= result;
50+
}
4351
} else {
4452
long buffer_size = std::min(self->size, (long)5000);
4553
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
46-
for (long i = 0; i < self->size; i += buffer_size) {
54+
for (int64_t i = 0; i < self->size; i += buffer_size) {
4755
size_t to_convert = std::min(self->size - i, buffer_size);
4856
if (sizeof(real) == 2) {
4957
THP_encodeInt16Buffer((uint8_t*)le_buffer.get(),
@@ -61,7 +69,7 @@ void THPStorage_(writeFileRaw)(THStorage *self, int fd)
6169
THPByteOrder::THP_LITTLE_ENDIAN,
6270
to_convert);
6371
}
64-
SYSCHECK(write(fd, data, to_convert * sizeof(real)));
72+
SYSCHECK(write(fd, le_buffer.get(), to_convert * sizeof(real)));
6573
}
6674
}
6775
}
@@ -82,11 +90,19 @@ THStorage * THPStorage_(readFileRaw)(int fd)
8290

8391
// fast track for bytes and little endian
8492
if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
85-
SYSCHECK(read(fd, data, sizeof(real) * storage->size));
93+
char *bytes = (char *) data;
94+
uint64_t remaining = sizeof(real) * storage->size;
95+
while (remaining > 0) {
96+
ssize_t result = read(fd, bytes, remaining);
97+
if (result < 0)
98+
throw std::system_error(result, std::system_category());
99+
bytes += result;
100+
remaining -= result;
101+
}
86102
} else {
87103
long buffer_size = std::min(size, (long)5000);
88104
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
89-
for (long i = 0; i < size; i += buffer_size) {
105+
for (int64_t i = 0; i < size; i += buffer_size) {
90106
size_t to_convert = std::min(size - i, buffer_size);
91107
SYSCHECK(read(fd, le_buffer.get(), sizeof(real) * to_convert));
92108
if (sizeof(real) == 2) {

0 commit comments

Comments
 (0)