Skip to content

Commit 583a86d

Browse files
committed
Add API for mapping Polars types to SQL types
- Introduces a new endpoint to retrieve mappings from Polars data types to SQL types for various SQL dialects, improving support for type-aware file imports and downstream table creation. - Refactors type mapping logic for reusability and maintainability, and adds serializer validation for dialect selection.
1 parent 34c8567 commit 583a86d

File tree

4 files changed

+156
-95
lines changed

4 files changed

+156
-95
lines changed

desktop/core/src/desktop/api_public_urls_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@
164164
re_path(r'^importer/file/guess_metadata/?$', importer_api.guess_file_metadata, name='importer_guess_file_metadata'),
165165
re_path(r'^importer/file/guess_header/?$', importer_api.guess_file_header, name='importer_guess_file_header'),
166166
re_path(r'^importer/file/preview/?$', importer_api.preview_file, name='importer_preview_file'),
167+
re_path(r'^importer/file/sql_type_mapping/?$', importer_api.get_sql_type_mapping, name='importer_get_sql_type_mapping'),
167168
]
168169

169170
urlpatterns += [

desktop/core/src/desktop/lib/importer/api.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
GuessFileMetadataSerializer,
3131
LocalFileUploadSerializer,
3232
PreviewFileSerializer,
33+
SqlTypeMapperSerializer,
3334
)
3435

3536
LOG = logging.getLogger()
@@ -90,8 +91,7 @@ def local_file_upload(request: Request) -> Response:
9091
@parser_classes([JSONParser])
9192
@api_error_handler
9293
def guess_file_metadata(request: Request) -> Response:
93-
"""
94-
Guess the metadata of a file based on its content or extension.
94+
"""Guess the metadata of a file based on its content or extension.
9595
9696
This API endpoint detects file type and extracts metadata properties such as
9797
delimiters for CSV files or sheet names for Excel files.
@@ -203,8 +203,7 @@ def preview_file(request: Request) -> Response:
203203
@parser_classes([JSONParser])
204204
@api_error_handler
205205
def guess_file_header(request: Request) -> Response:
206-
"""
207-
Guess whether a file has a header row.
206+
"""Guess whether a file has a header row.
208207
209208
This API endpoint analyzes a file to determine if it contains a header row based on the
210209
content pattern. It works for both Excel files and delimited text files (CSV, TSV, etc.)
@@ -244,3 +243,39 @@ def guess_file_header(request: Request) -> Response:
244243
# Handle other errors
245244
LOG.exception(f"Error detecting file header: {e}", exc_info=True)
246245
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
246+
247+
248+
@api_view(['GET'])
249+
@parser_classes([JSONParser])
250+
@api_error_handler
251+
def get_sql_type_mapping(request: Request) -> Response:
252+
"""Get mapping from Polars data types to SQL types for a specific dialect.
253+
254+
This API endpoint returns a dictionary mapping Polars data types to the corresponding
255+
SQL types for a specific SQL dialect.
256+
257+
Args:
258+
request: Request object containing query parameters:
259+
- sql_dialect: The SQL dialect to get mappings for (e.g., 'hive', 'impala', 'trino')
260+
261+
Returns:
262+
Response containing a mapping dictionary:
263+
- A mapping from Polars data type names to SQL type names for the specified dialect
264+
"""
265+
serializer = SqlTypeMapperSerializer(data=request.query_params)
266+
267+
if not serializer.is_valid():
268+
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
269+
270+
validated_data = serializer.validated_data
271+
sql_dialect = validated_data['sql_dialect']
272+
273+
try:
274+
type_mapping = operations.get_sql_type_mapping(sql_dialect)
275+
return Response(type_mapping, status=status.HTTP_200_OK)
276+
277+
except ValueError as e:
278+
return Response({'error': str(e)}, status=status.HTTP_400_BAD_REQUEST)
279+
except Exception as e:
280+
LOG.exception(f"Error getting SQL type mapping: {e}", exc_info=True)
281+
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

desktop/core/src/desktop/lib/importer/operations.py

Lines changed: 97 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,73 @@
4141
is_magic_lib_available = False
4242

4343

44+
# Base mapping for most SQL engines (Hive, Impala, SparkSQL)
45+
SQL_TYPE_BASE_MAP = {
46+
# signed ints
47+
"Int8": "TINYINT",
48+
"Int16": "SMALLINT",
49+
"Int32": "INT",
50+
"Int64": "BIGINT",
51+
# unsigned ints: same size signed by default (Hive/Impala/SparkSQL)
52+
"UInt8": "TINYINT",
53+
"UInt16": "SMALLINT",
54+
"UInt32": "INT",
55+
"UInt64": "BIGINT",
56+
# floats & decimal
57+
"Float32": "FLOAT",
58+
"Float64": "DOUBLE",
59+
"Decimal": "DECIMAL", # Hive/Impala/SparkSQL use DECIMAL(precision,scale)
60+
# boolean, string, binary
61+
"Boolean": "BOOLEAN",
62+
"Utf8": "STRING", # STRING covers Hive/VARCHAR/CHAR for default
63+
"String": "STRING",
64+
"Categorical": "STRING",
65+
"Enum": "STRING",
66+
"Binary": "BINARY",
67+
# temporal
68+
"Date": "DATE",
69+
"Time": "TIMESTAMP", # Hive/Impala/SparkSQL have no pure TIME type
70+
"Datetime": "TIMESTAMP",
71+
"Duration": "INTERVAL DAY TO SECOND",
72+
# nested & other
73+
"Array": "ARRAY",
74+
"List": "ARRAY",
75+
"Struct": "STRUCT",
76+
"Object": "STRING",
77+
"Null": "STRING", # no SQL NULL type—use STRING or handle as special case
78+
"Unknown": "STRING",
79+
}
80+
81+
# Per‑dialect overrides for the few differences
82+
SQL_TYPE_DIALECT_OVERRIDES = {
83+
"hive": {},
84+
"impala": {},
85+
"sparksql": {},
86+
"trino": {
87+
"Int32": "INTEGER",
88+
"UInt32": "INTEGER",
89+
"Utf8": "VARCHAR",
90+
"String": "VARCHAR",
91+
"Binary": "VARBINARY",
92+
"Float32": "REAL",
93+
"Struct": "ROW",
94+
"Object": "JSON",
95+
"Duration": "INTERVAL DAY TO SECOND", # explicit SQL syntax
96+
},
97+
"phoenix": {
98+
**{f"UInt{b}": f"UNSIGNED_{t}" for b, t in [(8, "TINYINT"), (16, "SMALLINT"), (32, "INT"), (64, "LONG")]},
99+
"Utf8": "VARCHAR",
100+
"String": "VARCHAR",
101+
"Binary": "VARBINARY",
102+
"Duration": "STRING", # Phoenix treats durations as strings
103+
"Struct": "STRING", # no native STRUCT type
104+
"Object": "VARCHAR",
105+
"Time": "TIME", # Phoenix has its own TIME type
106+
"Decimal": "DECIMAL", # up to precision 38
107+
},
108+
}
109+
110+
44111
def local_file_upload(upload_file, username: str) -> Dict[str, str]:
45112
"""Uploads a local file to a temporary directory with a unique filename.
46113
@@ -95,8 +162,7 @@ def local_file_upload(upload_file, username: str) -> Dict[str, str]:
95162

96163

97164
def guess_file_metadata(file_path: str, import_type: str, fs=None) -> Dict[str, Any]:
98-
"""
99-
Guess the metadata of a file based on its content or extension.
165+
"""Guess the metadata of a file based on its content or extension.
100166
101167
Args:
102168
file_path: Path to the file to analyze
@@ -177,8 +243,7 @@ def preview_file(
177243
fs=None,
178244
preview_rows: int = 50,
179245
) -> Dict[str, Any]:
180-
"""
181-
Generate a preview of a file's content with column type mapping.
246+
"""Generate a preview of a file's content with column type mapping.
182247
183248
This method reads a file and returns a preview of its contents, along with
184249
column information and metadata for creating tables or further processing.
@@ -268,8 +333,7 @@ def preview_file(
268333

269334

270335
def _detect_file_type(file_sample: bytes) -> str:
271-
"""
272-
Detect the file type based on its content.
336+
"""Detect the file type based on its content.
273337
274338
Args:
275339
file_sample: Binary sample of the file content
@@ -308,8 +372,7 @@ def _detect_file_type(file_sample: bytes) -> str:
308372

309373

310374
def _get_excel_metadata(fh: BinaryIO) -> Dict[str, Any]:
311-
"""
312-
Extract metadata for Excel files (.xlsx, .xls).
375+
"""Extract metadata for Excel files (.xlsx, .xls).
313376
314377
Args:
315378
fh: File handle for the Excel file
@@ -371,8 +434,7 @@ def _get_sheet_names_xlsx(fh: BinaryIO) -> List[str]:
371434

372435

373436
def _get_delimited_metadata(file_sample: Union[bytes, str], file_type: str) -> Dict[str, Any]:
374-
"""
375-
Extract metadata for delimited files (CSV, TSV, etc.).
437+
"""Extract metadata for delimited files (CSV, TSV, etc.).
376438
377439
Args:
378440
file_sample: Binary or string sample of the file content
@@ -543,8 +605,7 @@ def _preview_delimited_file(
543605

544606

545607
def guess_file_header(file_path: str, file_type: str, import_type: str, sheet_name: Optional[str] = None, fs=None) -> bool:
546-
"""
547-
Guess whether a file has a header row.
608+
"""Guess whether a file has a header row.
548609
549610
This function analyzes a file to determine if it contains a header row based on the
550611
content pattern. It works for both Excel files and delimited text files (CSV, TSV, etc.).
@@ -633,98 +694,43 @@ def guess_file_header(file_path: str, file_type: str, import_type: str, sheet_na
633694
fh.close()
634695

635696

636-
def _map_polars_dtype_to_sql_type(dialect: str, polars_type: str) -> str:
637-
"""
638-
Map a Polars dtype to the corresponding SQL type for a given dialect.
697+
def get_sql_type_mapping(dialect: str) -> Dict[str, str]:
698+
"""Get all type mappings from Polars dtypes to SQL types for a given SQL dialect.
639699
640-
Supports all Polars dtypes as listed in the Polars docs:
641-
Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
642-
Float32, Float64, Decimal, Boolean, Utf8/String, Categorical, Enum,
643-
Binary, Date, Time, Datetime, Duration, Array, List, Struct,
644-
Object, Null, Unknown.
700+
This function returns a dictionary mapping of all Polars data types to their
701+
corresponding SQL types for a specific dialect.
645702
646703
Args:
647704
dialect: One of "hive", "impala", "trino", "phoenix", "sparksql".
648-
polars_type: Polars dtype name as string.
649705
650706
Returns:
651-
A string representing the SQL type.
707+
A dict mapping Polars dtype names to SQL type names.
652708
653709
Raises:
654-
ValueError: If the dialect or polars_type is not supported.
710+
ValueError: If the dialect is not supported.
655711
"""
656-
# Base mapping for most engines (Hive, Impala, SparkSQL)
657-
base_map = {
658-
# signed ints
659-
"Int8": "TINYINT",
660-
"Int16": "SMALLINT",
661-
"Int32": "INT",
662-
"Int64": "BIGINT",
663-
# unsigned ints: same size signed by default (Hive/Impala/SparkSQL)
664-
"UInt8": "TINYINT",
665-
"UInt16": "SMALLINT",
666-
"UInt32": "INT",
667-
"UInt64": "BIGINT",
668-
# floats & decimal
669-
"Float32": "FLOAT",
670-
"Float64": "DOUBLE",
671-
"Decimal": "DECIMAL", # Hive/Impala/SparkSQL use DECIMAL(precision,scale)
672-
# boolean, string, binary
673-
"Boolean": "BOOLEAN",
674-
"Utf8": "STRING", # STRING covers Hive/VARCHAR/CHAR for default
675-
"String": "STRING",
676-
"Categorical": "STRING",
677-
"Enum": "STRING",
678-
"Binary": "BINARY",
679-
# temporal
680-
"Date": "DATE",
681-
"Time": "TIMESTAMP", # Hive/Impala/SparkSQL have no pure TIME type
682-
"Datetime": "TIMESTAMP",
683-
"Duration": "INTERVAL DAY TO SECOND",
684-
# nested & other
685-
"Array": "ARRAY",
686-
"List": "ARRAY",
687-
"Struct": "STRUCT",
688-
"Object": "STRING",
689-
"Null": "STRING", # no SQL NULL type—use STRING or handle as special case
690-
"Unknown": "STRING",
691-
}
692-
693-
# Per‑dialect overrides for the few differences
694-
overrides = {
695-
"hive": {},
696-
"impala": {},
697-
"sparksql": {},
698-
"trino": {
699-
"Int32": "INTEGER",
700-
"UInt32": "INTEGER",
701-
"Utf8": "VARCHAR",
702-
"String": "VARCHAR",
703-
"Binary": "VARBINARY",
704-
"Float32": "REAL",
705-
"Struct": "ROW",
706-
"Object": "JSON",
707-
"Duration": "INTERVAL DAY TO SECOND", # explicit SQL syntax
708-
},
709-
"phoenix": {
710-
**{f"UInt{b}": f"UNSIGNED_{t}" for b, t in [(8, "TINYINT"), (16, "SMALLINT"), (32, "INT"), (64, "LONG")]},
711-
"Utf8": "VARCHAR",
712-
"String": "VARCHAR",
713-
"Binary": "VARBINARY",
714-
"Duration": "STRING", # Phoenix treats durations as strings
715-
"Struct": "STRING", # no native STRUCT type
716-
"Object": "VARCHAR",
717-
"Time": "TIME", # Phoenix has its own TIME type
718-
"Decimal": "DECIMAL", # up to precision 38
719-
},
720-
}
721-
722712
dl = dialect.lower()
723-
if dl not in overrides:
713+
if dl not in SQL_TYPE_DIALECT_OVERRIDES:
724714
raise ValueError(f"Unsupported dialect: {dialect}")
725715

726716
# Merge base_map and overrides[dl] into a new dict, giving precedence to any overlapping keys in overrides[dl]
727-
mapping = {**base_map, **overrides[dl]}
717+
return {**SQL_TYPE_BASE_MAP, **SQL_TYPE_DIALECT_OVERRIDES[dl]}
718+
719+
720+
def _map_polars_dtype_to_sql_type(dialect: str, polars_type: str) -> str:
721+
"""Map a Polars dtype to the corresponding SQL type for a given dialect.
722+
723+
Args:
724+
dialect: One of "hive", "impala", "trino", "phoenix", "sparksql".
725+
polars_type: Polars dtype name as string.
726+
727+
Returns:
728+
A string representing the SQL type.
729+
730+
Raises:
731+
ValueError: If the dialect or polars_type is not supported.
732+
"""
733+
mapping = get_sql_type_mapping(dialect)
728734

729735
if polars_type not in mapping:
730736
raise ValueError(f"No mapping for Polars dtype {polars_type} in dialect {dialect}")

desktop/core/src/desktop/lib/importer/serializers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,25 @@ def validate(self, data):
134134
return data
135135

136136

137+
class SqlTypeMapperSerializer(serializers.Serializer):
138+
"""Serializer for SQL type mapping requests.
139+
140+
This serializer validates the parameters required for retrieving type mapping information
141+
from Polars data types to SQL types for a specific dialect.
142+
143+
Attributes:
144+
sql_dialect: Target SQL dialect for type mapping
145+
"""
146+
147+
sql_dialect = serializers.ChoiceField(
148+
choices=['hive', 'impala', 'trino', 'phoenix', 'sparksql'], required=True, help_text="SQL dialect for mapping column types"
149+
)
150+
151+
def validate(self, data):
152+
"""Validate the complete data set."""
153+
return data
154+
155+
137156
class GuessFileHeaderSerializer(serializers.Serializer):
138157
"""Serializer for file header guessing request validation.
139158

0 commit comments

Comments
 (0)