@@ -189,8 +189,9 @@ private void CheckTrainingParameters(ImageClassificationEstimator.Options option
189
189
return ( jpegData , resizedImage ) ;
190
190
}
191
191
192
- private static Tensor Encode ( VBuffer < byte > buffer , int length )
192
+ private static Tensor Encode ( VBuffer < byte > buffer )
193
193
{
194
+ int length = buffer . Length ;
194
195
var size = c_api . TF_StringEncodedSize ( ( UIntPtr ) length ) ;
195
196
var handle = c_api . TF_AllocateTensor ( TF_DataType . TF_STRING , IntPtr . Zero , 0 , ( UIntPtr ) ( ( ulong ) size + 8 ) ) ;
196
197
//AllocationType = AllocationType.Tensorflow;
@@ -221,7 +222,7 @@ public ImageProcessor(ImageClassificationTransformer transformer)
221
222
222
223
public Tensor ProcessImage ( VBuffer < byte > imgBuf )
223
224
{
224
- var imageTensor = Encode ( imgBuf , imgBuf . Length ) ;
225
+ var imageTensor = Encode ( imgBuf ) ;
225
226
var processedTensor = _imagePreprocessingRunner . AddInput ( imageTensor , 0 ) . Run ( ) [ 0 ] ;
226
227
imageTensor . Dispose ( ) ;
227
228
return processedTensor ;
@@ -1170,7 +1171,7 @@ internal sealed class Options : TransformInputBase
1170
1171
private readonly IHost _host ;
1171
1172
private readonly Options _options ;
1172
1173
private readonly DnnModel _dnnModel ;
1173
- private readonly TF_DataType [ ] _tfInputTypes ;
1174
+ private readonly DataViewType [ ] _inputTypes ;
1174
1175
private readonly DataViewType [ ] _outputTypes ;
1175
1176
private ImageClassificationTransformer _transformer ;
1176
1177
@@ -1179,7 +1180,7 @@ internal ImageClassificationEstimator(IHostEnvironment env, Options options, Dnn
1179
1180
_host = Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( ImageClassificationEstimator ) ) ;
1180
1181
_options = options ;
1181
1182
_dnnModel = dnnModel ;
1182
- _tfInputTypes = new [ ] { TF_DataType . TF_STRING } ;
1183
+ _inputTypes = new [ ] { new VectorDataViewType ( NumberDataViewType . Byte ) } ;
1183
1184
_outputTypes = new [ ] { new VectorDataViewType ( NumberDataViewType . Single ) , NumberDataViewType . UInt32 . GetItemType ( ) } ;
1184
1185
}
1185
1186
@@ -1206,9 +1207,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
1206
1207
var input = _options . InputColumns [ i ] ;
1207
1208
if ( ! inputSchema . TryFindColumn ( input , out var col ) )
1208
1209
throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input ) ;
1209
- // var expectedType = DnnUtils.Tf2MlNetType(_tfInputTypes [i]) ;
1210
- // if (col.ItemType != expectedType)
1211
- // throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
1210
+ var expectedType = _inputTypes [ i ] ;
1211
+ if ( ! col . ItemType . Equals ( expectedType ) )
1212
+ throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input , expectedType . ToString ( ) , col . ItemType . ToString ( ) ) ;
1212
1213
}
1213
1214
for ( var i = 0 ; i < _options . OutputColumns . Length ; i ++ )
1214
1215
{
0 commit comments