@@ -102,7 +102,7 @@ private static void ValidateTrainData(IDataView trainData, ColumnInformation col
102
102
}
103
103
}
104
104
105
- private static void ValidateColumnInformation ( IDataView trainData , ColumnInformation columnInformation , TaskKind task )
105
+ private static void ValidateColumnInformation ( IDataView trainData , ColumnInformation columnInformation , TaskKind task )
106
106
{
107
107
ValidateColumnInformation ( columnInformation ) ;
108
108
ValidateTrainDataColumn ( trainData , columnInformation . LabelColumnName , LabelColumnPurposeName , GetAllowedLabelTypes ( task ) ) ;
@@ -217,7 +217,7 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
217
217
throw new ArgumentException ( $ "{ schemaMismatchError } Column '{ trainCol . Name } ' exists in train data, but not in validation data.", nameof ( validationData ) ) ;
218
218
}
219
219
220
- if ( trainCol . Type != validCol . Value . Type )
220
+ if ( trainCol . Type != validCol . Value . Type && ! trainCol . Type . Equals ( validCol . Value . Type ) )
221
221
{
222
222
throw new ArgumentException ( $ "{ schemaMismatchError } Column '{ trainCol . Name } ' is of type { trainCol . Type } in train data, and type " +
223
223
$ "{ validCol . Value . Type } in validation data.", nameof ( validationData ) ) ;
@@ -260,7 +260,7 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa
260
260
throw new ArgumentException ( exceptionMessage ) ;
261
261
}
262
262
263
- if ( allowedTypes == null )
263
+ if ( allowedTypes == null )
264
264
{
265
265
return ;
266
266
}
0 commit comments