Skip to content

TensorTypeExtensions: Added conversion between Tensor to primitive C# types instead of throwing NotSupportedException #4290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Replaced Unsafe.AsPointer with fixed(...)
  • Loading branch information
Nucs committed Oct 3, 2019
commit f067269b29c94206d61afb321abe93ad73d2c74f
309 changes: 156 additions & 153 deletions src/Microsoft.ML.Dnn/TensorTypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -289,99 +289,101 @@ public static void CopyTo<T>(this Tensor tensor, Span<T> destination) where T :
throw new ArgumentException("Destinion was too short to perform CopyTo.");

//Perform cast to type <T>.
var dst = (T*) Unsafe.AsPointer(ref destination.GetPinnableReference());
switch (tensor.dtype)
fixed (T* dst_ = destination)
{
#if _REGEN
%foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase%
case TF_DataType.#1:
var dst = dst_;
switch (tensor.dtype)
{
var converter = Converts.FindConverter<#3, T>();
var src = (#3*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
%
#if _REGEN
%foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase%
case TF_DataType.#1:
{
var converter = Converts.FindConverter<#3, T>();
var src = (#3*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
%
#else

case TF_DataType.TF_BOOL:
{
var converter = Converts.FindConverter<bool, T>();
var src = (bool*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_UINT8:
{
var converter = Converts.FindConverter<byte, T>();
var src = (byte*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_INT16:
{
var converter = Converts.FindConverter<short, T>();
var src = (short*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_UINT16:
{
var converter = Converts.FindConverter<ushort, T>();
var src = (ushort*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_INT32:
{
var converter = Converts.FindConverter<int, T>();
var src = (int*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_UINT32:
{
var converter = Converts.FindConverter<uint, T>();
var src = (uint*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_INT64:
{
var converter = Converts.FindConverter<long, T>();
var src = (long*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_UINT64:
{
var converter = Converts.FindConverter<ulong, T>();
var src = (ulong*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_DOUBLE:
{
var converter = Converts.FindConverter<double, T>();
var src = (double*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_FLOAT:
{
var converter = Converts.FindConverter<float, T>();
var src = (float*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_BOOL:
{
var converter = Converts.FindConverter<bool, T>();
var src = (bool*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_UINT8:
{
var converter = Converts.FindConverter<byte, T>();
var src = (byte*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_INT16:
{
var converter = Converts.FindConverter<short, T>();
var src = (short*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_UINT16:
{
var converter = Converts.FindConverter<ushort, T>();
var src = (ushort*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_INT32:
{
var converter = Converts.FindConverter<int, T>();
var src = (int*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_UINT32:
{
var converter = Converts.FindConverter<uint, T>();
var src = (uint*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_INT64:
{
var converter = Converts.FindConverter<long, T>();
var src = (long*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_UINT64:
{
var converter = Converts.FindConverter<ulong, T>();
var src = (ulong*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_DOUBLE:
{
var converter = Converts.FindConverter<double, T>();
var src = (double*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
case TF_DataType.TF_FLOAT:
{
var converter = Converts.FindConverter<float, T>();
var src = (float*) tensor.buffer;
Parallel.For(0, len, i => *(dst + i) = converter(unchecked(*(src + i))));
return;
}
#endif
case TF_DataType.TF_STRING:
{
var src = tensor.StringData();
var culture = CultureInfo.InvariantCulture;

switch (typeof(T).as_dtype())
case TF_DataType.TF_STRING:
{
var src = tensor.StringData();
var culture = CultureInfo.InvariantCulture;

switch (typeof(T).as_dtype())
{
#if _REGEN
%foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase%
case TF_DataType.#1: {
Expand All @@ -392,75 +394,76 @@ public static void CopyTo<T>(this Tensor tensor, Span<T> destination) where T :
%
#else

case TF_DataType.TF_BOOL:
{
var sdst = (bool*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToBoolean(culture));
return;
}
case TF_DataType.TF_UINT8:
{
var sdst = (byte*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToByte(culture));
return;
}
case TF_DataType.TF_INT16:
{
var sdst = (short*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToInt16(culture));
return;
}
case TF_DataType.TF_UINT16:
{
var sdst = (ushort*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToUInt16(culture));
return;
}
case TF_DataType.TF_INT32:
{
var sdst = (int*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToInt32(culture));
return;
}
case TF_DataType.TF_UINT32:
{
var sdst = (uint*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToUInt32(culture));
return;
}
case TF_DataType.TF_INT64:
{
var sdst = (long*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToInt64(culture));
return;
}
case TF_DataType.TF_UINT64:
{
var sdst = (ulong*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToUInt64(culture));
return;
}
case TF_DataType.TF_DOUBLE:
{
var sdst = (double*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToDouble(culture));
return;
}
case TF_DataType.TF_FLOAT:
{
var sdst = (float*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToSingle(culture));
return;
}
case TF_DataType.TF_BOOL:
{
var sdst = (bool*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToBoolean(culture));
return;
}
case TF_DataType.TF_UINT8:
{
var sdst = (byte*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToByte(culture));
return;
}
case TF_DataType.TF_INT16:
{
var sdst = (short*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToInt16(culture));
return;
}
case TF_DataType.TF_UINT16:
{
var sdst = (ushort*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToUInt16(culture));
return;
}
case TF_DataType.TF_INT32:
{
var sdst = (int*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToInt32(culture));
return;
}
case TF_DataType.TF_UINT32:
{
var sdst = (uint*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToUInt32(culture));
return;
}
case TF_DataType.TF_INT64:
{
var sdst = (long*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToInt64(culture));
return;
}
case TF_DataType.TF_UINT64:
{
var sdst = (ulong*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToUInt64(culture));
return;
}
case TF_DataType.TF_DOUBLE:
{
var sdst = (double*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToDouble(culture));
return;
}
case TF_DataType.TF_FLOAT:
{
var sdst = (float*) Unsafe.AsPointer(ref destination.GetPinnableReference());
Parallel.For(0, len, i => *(sdst + i) = ((IConvertible) src[i]).ToSingle(culture));
return;
}
#endif
default:
throw new NotSupportedException();
default:
throw new NotSupportedException();
}
}
case TF_DataType.TF_COMPLEX64:
case TF_DataType.TF_COMPLEX128:
default:
throw new NotSupportedException();
}
case TF_DataType.TF_COMPLEX64:
case TF_DataType.TF_COMPLEX128:
default:
throw new NotSupportedException();
}
}
}
Expand Down