diff --git a/desktop/conf.dist/hue.ini b/desktop/conf.dist/hue.ini
index fb60ad66cd7..26891840de5 100644
--- a/desktop/conf.dist/hue.ini
+++ b/desktop/conf.dist/hue.ini
@@ -1064,6 +1064,19 @@ tls=no
## Enable integration with Google Storage for RAZ
# is_raz_gs_enabled=false
+## Configuration options for the importer
+# ------------------------------------------------------------------------
+[[importer]]
+# Turns on the data importer functionality
+## is_enabled=false
+
+# A limit on the local file size (bytes) that can be uploaded through the importer. The default is 157286400 bytes (150 MiB).
+## max_local_file_size_upload_limit=157286400
+
+# Security setting to specify local file extensions that are not allowed to be uploaded through the importer.
+# Provide a comma-separated list of extensions including the dot (e.g., ".exe, .zip, .rar, .tar, .gz").
+## restrict_local_file_extensions=.exe, .zip, .rar, .tar, .gz
+
###########################################################################
# Settings to configure the snippets available in the Notebook
###########################################################################
diff --git a/desktop/conf/pseudo-distributed.ini.tmpl b/desktop/conf/pseudo-distributed.ini.tmpl
index 4afb1174f23..fabc3c72083 100644
--- a/desktop/conf/pseudo-distributed.ini.tmpl
+++ b/desktop/conf/pseudo-distributed.ini.tmpl
@@ -1049,6 +1049,19 @@
## Enable integration with Google Storage for RAZ
# is_raz_gs_enabled=false
+ ## Configuration options for the importer
+ # ------------------------------------------------------------------------
+ [[importer]]
+ # Turns on the data importer functionality
+ ## is_enabled=false
+
+ # A limit on the local file size (bytes) that can be uploaded through the importer. The default is 157286400 bytes (150 MiB).
+ ## max_local_file_size_upload_limit=157286400
+
+ # Security setting to specify local file extensions that are not allowed to be uploaded through the importer.
+ # Provide a comma-separated list of extensions including the dot (e.g., ".exe, .zip, .rar, .tar, .gz").
+ ## restrict_local_file_extensions=.exe, .zip, .rar, .tar, .gz
+
###########################################################################
# Settings to configure the snippets available in the Notebook
###########################################################################
diff --git a/desktop/core/base_requirements.txt b/desktop/core/base_requirements.txt
index 61edcb5eb54..8bb398aa656 100644
--- a/desktop/core/base_requirements.txt
+++ b/desktop/core/base_requirements.txt
@@ -41,6 +41,7 @@ Mako==1.2.3
Markdown==3.7
openpyxl==3.0.9
phoenixdb==1.2.1
+polars[calamine]==1.8.2 # Python >= 3.8
prompt-toolkit==3.0.39
protobuf==3.20.3
pyarrow==17.0.0
@@ -48,6 +49,7 @@ pyformance==0.3.2
python-dateutil==2.8.2
python-daemon==2.2.4
python-ldap==3.4.3
+python-magic==0.4.27
python-oauth2==1.1.0
python-pam==2.0.2
pytidylib==0.3.2
diff --git a/desktop/core/generate_requirements.py b/desktop/core/generate_requirements.py
index db0475efe2c..6ca37c83761 100755
--- a/desktop/core/generate_requirements.py
+++ b/desktop/core/generate_requirements.py
@@ -44,7 +44,6 @@ def __init__(self):
]
self.requirements = [
- "setuptools==70.0.0",
"apache-ranger==0.0.3",
"asn1crypto==0.24.0",
"avro-python3==1.8.2",
@@ -88,6 +87,7 @@ def __init__(self):
"Mako==1.2.3",
"openpyxl==3.0.9",
"phoenixdb==1.2.1",
+ "polars[calamine]==1.8.2", # Python >= 3.8
"prompt-toolkit==3.0.39",
"protobuf==3.20.3",
"psutil==5.8.0",
@@ -97,6 +97,7 @@ def __init__(self):
"python-daemon==2.2.4",
"python-dateutil==2.8.2",
"python-ldap==3.4.3",
+ "python-magic==0.4.27",
"python-oauth2==1.1.0",
"python-pam==2.0.2",
"pytidylib==0.3.2",
@@ -107,6 +108,7 @@ def __init__(self):
"requests-kerberos==0.14.0",
"rsa==4.7.2",
"ruff==0.11.10",
+ "setuptools==70.0.0",
"six==1.16.0",
"slack-sdk==3.31.0",
"SQLAlchemy==1.3.8",
diff --git a/desktop/core/src/desktop/api2.py b/desktop/core/src/desktop/api2.py
index 742e65f5ade..f1a1614d955 100644
--- a/desktop/core/src/desktop/api2.py
+++ b/desktop/core/src/desktop/api2.py
@@ -15,12 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-import re
import json
import logging
-import zipfile
+import os
+import re
import tempfile
+import zipfile
from builtins import map
from datetime import datetime
from io import StringIO as string_io
@@ -28,7 +28,7 @@
from celery.app.control import Control
from django.core import management
from django.db import transaction
-from django.http import HttpResponse, JsonResponse
+from django.http import HttpResponse
from django.shortcuts import redirect
from django.utils.html import escape
from django.utils.translation import gettext as _
@@ -48,10 +48,11 @@
ENABLE_NEW_STORAGE_BROWSER,
ENABLE_SHARING,
ENABLE_WORKFLOW_CREATION_ACTION,
- TASK_SERVER_V2,
get_clusters,
+ IMPORTER,
+ TASK_SERVER_V2,
)
-from desktop.lib.conf import GLOBAL_CONFIG, BoundContainer, is_anonymous
+from desktop.lib.conf import BoundContainer, GLOBAL_CONFIG, is_anonymous
from desktop.lib.connectors.models import Connector
from desktop.lib.django_util import JsonResponse, login_notrequired, render
from desktop.lib.exceptions_renderable import PopupException
@@ -60,16 +61,16 @@
from desktop.lib.paths import get_desktop_root
from desktop.log import DEFAULT_LOG_DIR
from desktop.models import (
+ __paginate,
+ _get_gist_document,
Directory,
Document,
Document2,
FilesystemException,
- UserPreferences,
- __paginate,
- _get_gist_document,
get_cluster_config,
get_user_preferences,
set_user_preferences,
+ UserPreferences,
uuid_default,
)
from desktop.views import _get_config_errors, get_banner_message, serve_403_error
@@ -91,7 +92,7 @@
search_entities_interactive as metadata_search_entities_interactive,
)
from metadata.conf import has_catalog
-from notebook.connectors.base import Notebook, get_interpreter
+from notebook.connectors.base import get_interpreter, Notebook
from notebook.management.commands import notebook_setup
from pig.management.commands import pig_setup
from search.management.commands import search_setup
@@ -132,19 +133,41 @@ def get_banners(request):
@api_error_handler
def get_config(request):
+ """
+ Returns Hue application's config information.
+ Includes settings for various components like storage, task server, importer, etc.
+ """
+ # Get base cluster configuration
config = get_cluster_config(request.user)
- config['hue_config']['is_admin'] = is_admin(request.user)
- config['hue_config']['is_yarn_enabled'] = is_yarn()
- config['hue_config']['enable_task_server'] = TASK_SERVER_V2.ENABLED.get()
- config['hue_config']['enable_workflow_creation_action'] = ENABLE_WORKFLOW_CREATION_ACTION.get()
- config['storage_browser']['enable_chunked_file_upload'] = ENABLE_CHUNKED_FILE_UPLOADER.get()
- config['storage_browser']['enable_new_storage_browser'] = ENABLE_NEW_STORAGE_BROWSER.get()
- config['storage_browser']['restrict_file_extensions'] = RESTRICT_FILE_EXTENSIONS.get()
- config['storage_browser']['concurrent_max_connection'] = CONCURRENT_MAX_CONNECTIONS.get()
- config['storage_browser']['file_upload_chunk_size'] = FILE_UPLOAD_CHUNK_SIZE.get()
- config['storage_browser']['enable_file_download_button'] = SHOW_DOWNLOAD_BUTTON.get()
- config['storage_browser']['max_file_editor_size'] = MAX_FILEEDITOR_SIZE
- config['storage_browser']['enable_extract_uploaded_archive'] = ENABLE_EXTRACT_UPLOADED_ARCHIVE.get()
+
+ # Core application configuration
+ config['hue_config'] = {
+ 'is_admin': is_admin(request.user),
+ 'is_yarn_enabled': is_yarn(),
+ 'enable_task_server': TASK_SERVER_V2.ENABLED.get(),
+ 'enable_workflow_creation_action': ENABLE_WORKFLOW_CREATION_ACTION.get(),
+ }
+
+ # Storage browser configuration
+ config['storage_browser'] = {
+ 'enable_chunked_file_upload': ENABLE_CHUNKED_FILE_UPLOADER.get(),
+ 'enable_new_storage_browser': ENABLE_NEW_STORAGE_BROWSER.get(),
+ 'restrict_file_extensions': RESTRICT_FILE_EXTENSIONS.get(),
+ 'concurrent_max_connection': CONCURRENT_MAX_CONNECTIONS.get(),
+ 'file_upload_chunk_size': FILE_UPLOAD_CHUNK_SIZE.get(),
+ 'enable_file_download_button': SHOW_DOWNLOAD_BUTTON.get(),
+ 'max_file_editor_size': MAX_FILEEDITOR_SIZE,
+ 'enable_extract_uploaded_archive': ENABLE_EXTRACT_UPLOADED_ARCHIVE.get(),
+ }
+
+ # Importer configuration
+ config['importer'] = {
+ 'is_enabled': IMPORTER.IS_ENABLED.get(),
+ 'restrict_local_file_extensions': IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.get(),
+ 'max_local_file_size_upload_limit': IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.get(),
+ }
+
+ # Other general configuration
config['clusters'] = list(get_clusters(request.user).values())
config['documents'] = {'types': list(Document2.objects.documents(user=request.user).order_by().values_list('type', flat=True).distinct())}
config['status'] = 0
@@ -624,7 +647,7 @@ def copy_document(request):
# Import workspace for all oozie jobs
if document.type == 'oozie-workflow2' or document.type == 'oozie-bundle2' or document.type == 'oozie-coordinator2':
- from oozie.models2 import Bundle, Coordinator, Workflow, _import_workspace
+ from oozie.models2 import _import_workspace, Bundle, Coordinator, Workflow
# Update the name field in the json 'data' field
if document.type == 'oozie-workflow2':
workflow = Workflow(document=document)
@@ -998,7 +1021,7 @@ def is_reserved_directory(doc):
documents = json.loads(request.POST.get('documents'))
documents = json.loads(documents)
- except ValueError as e:
+ except ValueError:
raise PopupException(_('Failed to import documents, the file does not contain valid JSON.'))
# Validate documents
@@ -1112,7 +1135,7 @@ def gist_create(request):
statement = request.POST.get('statement', '')
gist_type = request.POST.get('doc_type', 'hive')
name = request.POST.get('name', '')
- description = request.POST.get('description', '')
+ _ = request.POST.get('description', '')
response = _gist_create(request.get_host(), request.is_secure(), request.user, statement, gist_type, name)
@@ -1333,7 +1356,7 @@ def _create_or_update_document_with_owner(doc, owner, uuids_map):
doc['pk'] = existing_doc.pk
else:
create_new = True
- except FilesystemException as e:
+ except FilesystemException:
create_new = True
if create_new:
diff --git a/desktop/core/src/desktop/api_public_urls_v1.py b/desktop/core/src/desktop/api_public_urls_v1.py
index 50487e224b3..cf4eac70b9e 100644
--- a/desktop/core/src/desktop/api_public_urls_v1.py
+++ b/desktop/core/src/desktop/api_public_urls_v1.py
@@ -19,6 +19,7 @@
from desktop import api_public
from desktop.lib.botserver import api as botserver_api
+from desktop.lib.importer import api as importer_api
# "New" query API (i.e. connector based, lean arguments).
# e.g. https://demo.gethue.com/api/query/execute/hive
@@ -158,6 +159,14 @@
re_path(r'^indexer/importer/submit', api_public.importer_submit, name='indexer_importer_submit'),
]
+urlpatterns += [
+ re_path(r'^importer/upload/file/?$', importer_api.local_file_upload, name='importer_local_file_upload'),
+ re_path(r'^importer/file/guess_metadata/?$', importer_api.guess_file_metadata, name='importer_guess_file_metadata'),
+ re_path(r'^importer/file/guess_header/?$', importer_api.guess_file_header, name='importer_guess_file_header'),
+ re_path(r'^importer/file/preview/?$', importer_api.preview_file, name='importer_preview_file'),
+ re_path(r'^importer/sql_type_mapping/?$', importer_api.get_sql_type_mapping, name='importer_get_sql_type_mapping'),
+]
+
urlpatterns += [
re_path(r'^connector/types/?$', api_public.get_connector_types, name='connector_get_types'),
re_path(r'^connector/instances/?$', api_public.get_connectors_instances, name='connector_get_instances'),
diff --git a/desktop/core/src/desktop/conf.py b/desktop/core/src/desktop/conf.py
index e8cb929c52e..63954bdd1ae 100644
--- a/desktop/core/src/desktop/conf.py
+++ b/desktop/core/src/desktop/conf.py
@@ -16,13 +16,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-import sys
+import datetime
import glob
-import stat
-import socket
import logging
-import datetime
+import os
+import socket
+import stat
+import sys
from collections import OrderedDict
from django.db import connection
@@ -30,16 +30,16 @@
from desktop import appmanager
from desktop.lib.conf import (
- Config,
- ConfigSection,
- UnspecifiedConfigSection,
coerce_bool,
coerce_csv,
coerce_json_dict,
coerce_password_from_script,
coerce_str_lowercase,
coerce_string,
+ Config,
+ ConfigSection,
list_of_compiled_res,
+ UnspecifiedConfigSection,
validate_path,
)
from desktop.lib.i18n import force_unicode
@@ -106,7 +106,7 @@ def get_dn(fqdn=None):
else:
LOG.warning("allowed_hosts value to '*'. It is a security risk")
val.append('*')
- except Exception as e:
+ except Exception:
LOG.warning("allowed_hosts value to '*'. It is a security risk")
val.append('*')
return val
@@ -2952,3 +2952,32 @@ def is_ofs_enabled():
def has_ofs_access(user):
from desktop.auth.backend import is_admin
return user.is_authenticated and user.is_active and (is_admin(user) or user.has_hue_permission(action="ofs_access", app="filebrowser"))
+
+
+IMPORTER = ConfigSection(
+ key='importer',
+ help=_("""Configuration options for the importer."""),
+ members=dict(
+ IS_ENABLED=Config(
+ key='is_enabled',
+ help=_('Enable or disable the new importer functionality'),
+ type=coerce_bool,
+ default=False,
+ ),
+ RESTRICT_LOCAL_FILE_EXTENSIONS=Config(
+ key='restrict_local_file_extensions',
+ default=None,
+ type=coerce_csv,
+ help=_(
+ 'Security setting to specify local file extensions that are not allowed to be uploaded through the importer. '
+ 'Provide a comma-separated list of extensions including the dot (e.g., ".exe, .zip, .rar, .tar, .gz").'
+ ),
+ ),
+ MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT=Config(
+ key="max_local_file_size_upload_limit",
+ default=157286400, # 150 MiB
+ type=int,
+ help=_('Maximum local file size (in bytes) that users can upload through the importer. The default is 157286400 bytes (150 MiB).'),
+ ),
+ ),
+)
diff --git a/desktop/core/src/desktop/lib/importer/__init__.py b/desktop/core/src/desktop/lib/importer/__init__.py
new file mode 100644
index 00000000000..d053ec41e3c
--- /dev/null
+++ b/desktop/core/src/desktop/lib/importer/__init__.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/desktop/core/src/desktop/lib/importer/api.py b/desktop/core/src/desktop/lib/importer/api.py
new file mode 100644
index 00000000000..8d1713905fa
--- /dev/null
+++ b/desktop/core/src/desktop/lib/importer/api.py
@@ -0,0 +1,281 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from functools import wraps
+
+from rest_framework import status
+from rest_framework.decorators import api_view, parser_classes
+from rest_framework.parsers import JSONParser, MultiPartParser
+from rest_framework.request import Request
+from rest_framework.response import Response
+
+from desktop.lib.importer import operations
+from desktop.lib.importer.serializers import (
+ GuessFileHeaderSerializer,
+ GuessFileMetadataSerializer,
+ LocalFileUploadSerializer,
+ PreviewFileSerializer,
+ SqlTypeMapperSerializer,
+)
+
+LOG = logging.getLogger()
+
+
+# TODO: Improve error response further with better context -- Error UX Phase 2
+def api_error_handler(fn):
+ """
+ Decorator to handle exceptions and return a JSON response with an error message.
+ """
+
+ @wraps(fn)
+ def decorator(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ LOG.exception(f"Error running {fn.__name__}: {str(e)}")
+ return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+
+ return decorator
+
+
+@api_view(["POST"])
+@parser_classes([MultiPartParser])
+@api_error_handler
+def local_file_upload(request: Request) -> Response:
+ """Handle the local file upload operation.
+
+ This endpoint allows users to upload a file from their local system.
+ Uploaded file is validated using LocalFileUploadSerializer and processed using local_file_upload operation.
+
+ Args:
+ request: Request object containing the file to upload
+
+ Returns:
+ Response containing the result of the local upload operation, including:
+ - file_path: Path where the file was saved (if successful)
+
+ Note:
+ - File size limits apply based on MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT configuration.
+ - File type restrictions apply based on RESTRICT_LOCAL_FILE_EXTENSIONS configuration.
+ """
+
+ serializer = LocalFileUploadSerializer(data=request.data)
+
+ if not serializer.is_valid():
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+ uploaded_file = serializer.validated_data["file"]
+
+ LOG.info(f"User {request.user.username} is uploading a local file: {uploaded_file.name}")
+ result = operations.local_file_upload(uploaded_file, request.user.username)
+
+ return Response(result, status=status.HTTP_201_CREATED)
+
+
+@api_view(["GET"])
+@parser_classes([JSONParser])
+@api_error_handler
+def guess_file_metadata(request: Request) -> Response:
+ """Guess the metadata of a file based on its content or extension.
+
+ This API endpoint detects file type and extracts metadata properties such as
+ delimiters for CSV files or sheet names for Excel files.
+
+ Args:
+ request: Request object containing query parameters:
+ - file_path: Path to the file
+ - import_type: 'local' or 'remote'
+
+ Returns:
+ Response containing file metadata including:
+ - type: File type (e.g., excel, csv, tsv)
+ - sheet_names: List of sheet names (for Excel files)
+ - field_separator: Field separator character (for delimited files)
+ - quote_char: Quote character (for delimited files)
+ - record_separator: Record separator (for delimited files)
+ """
+ serializer = GuessFileMetadataSerializer(data=request.query_params)
+
+ if not serializer.is_valid():
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+ validated_data = serializer.validated_data
+ file_path = validated_data["file_path"]
+ import_type = validated_data["import_type"]
+
+ try:
+ metadata = operations.guess_file_metadata(
+ file_path=file_path, import_type=import_type, fs=request.fs if import_type == "remote" else None
+ )
+
+ return Response(metadata, status=status.HTTP_200_OK)
+
+ except ValueError as e:
+ return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
+ except Exception as e:
+ # Handle other errors
+ LOG.exception(f"Error guessing file metadata: {e}", exc_info=True)
+ return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+
+
+@api_view(["GET"])
+@parser_classes([JSONParser])
+@api_error_handler
+def preview_file(request: Request) -> Response:
+ """Preview a file based on its path and import type.
+
+ Args:
+ request: Request object containing query parameters for file preview
+
+ Returns:
+ Response containing a dict preview of the file content, including:
+ - type: Type of the file (e.g., csv, tsv, excel)
+ - columns: List of column metadata
+ - preview_data: Sample data from the file
+ """
+ serializer = PreviewFileSerializer(data=request.query_params)
+ if not serializer.is_valid():
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+ # Get validated data
+ validated_data = serializer.validated_data
+ file_path = validated_data["file_path"]
+ file_type = validated_data["file_type"]
+ import_type = validated_data["import_type"]
+ sql_dialect = validated_data["sql_dialect"]
+ has_header = validated_data.get("has_header")
+
+ try:
+ if file_type == "excel":
+ sheet_name = validated_data.get("sheet_name")
+
+ preview = operations.preview_file(
+ file_path=file_path,
+ file_type=file_type,
+ import_type=import_type,
+ sql_dialect=sql_dialect,
+ has_header=has_header,
+ sheet_name=sheet_name,
+ fs=request.fs if import_type == "remote" else None,
+ )
+ else: # Delimited file types
+ field_separator = validated_data.get("field_separator")
+ quote_char = validated_data.get("quote_char")
+ record_separator = validated_data.get("record_separator")
+
+ preview = operations.preview_file(
+ file_path=file_path,
+ file_type=file_type,
+ import_type=import_type,
+ sql_dialect=sql_dialect,
+ has_header=has_header,
+ field_separator=field_separator,
+ quote_char=quote_char,
+ record_separator=record_separator,
+ fs=request.fs if import_type == "remote" else None,
+ )
+
+ return Response(preview, status=status.HTTP_200_OK)
+
+ except ValueError as e:
+ return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
+ except Exception as e:
+ LOG.exception(f"Error previewing file: {e}", exc_info=True)
+ return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+
+
+@api_view(["GET"])
+@parser_classes([JSONParser])
+@api_error_handler
+def guess_file_header(request: Request) -> Response:
+ """Guess whether a file has a header row.
+
+ This API endpoint analyzes a file to determine if it contains a header row based on the
+ content pattern. It works for both Excel files and delimited text files (CSV, TSV, etc.)
+
+ Args:
+ request: Request object containing query parameters:
+ - file_path: Path to the file
+ - file_type: Type of file ('excel', 'csv', 'tsv', 'delimiter_format')
+ - import_type: 'local' or 'remote'
+ - sheet_name: Sheet name for Excel files (required for Excel)
+
+ Returns:
+ Response containing:
+ - has_header: Boolean indicating whether the file has a header row
+ """
+ serializer = GuessFileHeaderSerializer(data=request.query_params)
+
+ if not serializer.is_valid():
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+ validated_data = serializer.validated_data
+
+ try:
+ has_header = operations.guess_file_header(
+ file_path=validated_data["file_path"],
+ file_type=validated_data["file_type"],
+ import_type=validated_data["import_type"],
+ sheet_name=validated_data.get("sheet_name"),
+ fs=request.fs if validated_data["import_type"] == "remote" else None,
+ )
+
+ return Response({"has_header": has_header}, status=status.HTTP_200_OK)
+
+ except ValueError as e:
+ return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
+ except Exception as e:
+ # Handle other errors
+ LOG.exception(f"Error detecting file header: {e}", exc_info=True)
+ return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+
+
+@api_view(["GET"])
+@parser_classes([JSONParser])
+@api_error_handler
+def get_sql_type_mapping(request: Request) -> Response:
+ """Get mapping from Polars data types to SQL types for a specific dialect.
+
+ This API endpoint returns a dictionary mapping Polars data types to the corresponding
+ SQL types for a specific SQL dialect.
+
+ Args:
+ request: Request object containing query parameters:
+ - sql_dialect: The SQL dialect to get mappings for (e.g., 'hive', 'impala', 'trino')
+
+ Returns:
+ Response containing a mapping dictionary:
+ - A mapping from Polars data type names to SQL type names for the specified dialect
+ """
+ serializer = SqlTypeMapperSerializer(data=request.query_params)
+
+ if not serializer.is_valid():
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+ validated_data = serializer.validated_data
+ sql_dialect = validated_data["sql_dialect"]
+
+ try:
+ type_mapping = operations.get_sql_type_mapping(sql_dialect)
+ return Response(type_mapping, status=status.HTTP_200_OK)
+
+ except ValueError as e:
+ return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
+ except Exception as e:
+ LOG.exception(f"Error getting SQL type mapping: {e}", exc_info=True)
+ return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
diff --git a/desktop/core/src/desktop/lib/importer/api_tests.py b/desktop/core/src/desktop/lib/importer/api_tests.py
new file mode 100644
index 00000000000..945ec5aa879
--- /dev/null
+++ b/desktop/core/src/desktop/lib/importer/api_tests.py
@@ -0,0 +1,685 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import MagicMock, patch
+
+from django.core.files.uploadedfile import SimpleUploadedFile
+from rest_framework import status
+from rest_framework.test import APIRequestFactory
+
+from desktop.lib.importer import api
+
+
+class TestLocalFileUploadAPI:
+ @patch("desktop.lib.importer.api.LocalFileUploadSerializer")
+ @patch("desktop.lib.importer.api.operations.local_file_upload")
+ def test_local_file_upload_success(self, mock_local_file_upload, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"file": SimpleUploadedFile("test.csv", b"content")})
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_local_file_upload.return_value = {"file_path": "/tmp/user_12345_test.csv"}
+
+ request = APIRequestFactory().post("importer/upload/file/")
+ request.user = MagicMock(username="test_user")
+
+ response = api.local_file_upload(request)
+
+ assert response.status_code == status.HTTP_201_CREATED
+ assert response.data == {"file_path": "/tmp/user_12345_test.csv"}
+ mock_local_file_upload.assert_called_once_with(mock_serializer.validated_data["file"], "test_user")
+
+ @patch("desktop.lib.importer.api.LocalFileUploadSerializer")
+ def test_local_file_upload_invalid_data(self, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"file": ["File too large"]})
+ mock_serializer_class.return_value = mock_serializer
+
+ request = APIRequestFactory().post("importer/upload/file/")
+ request.user = MagicMock(username="test_user")
+
+ response = api.local_file_upload(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"file": ["File too large"]}
+
+ @patch("desktop.lib.importer.api.LocalFileUploadSerializer")
+ @patch("desktop.lib.importer.api.operations.local_file_upload")
+ def test_local_file_upload_operation_error(self, mock_local_file_upload, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"file": SimpleUploadedFile("test.csv", b"content")})
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_local_file_upload.side_effect = IOError("Operation error")
+
+ request = APIRequestFactory().post("importer/upload/file/")
+ request.user = MagicMock(username="test_user")
+
+ response = api.local_file_upload(request)
+
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
+ assert response.data == {"error": "Operation error"}
+
+
+class TestGuessFileMetadataAPI:
+ @patch("desktop.lib.importer.api.GuessFileMetadataSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_metadata")
+ def test_guess_csv_file_metadata_success(self, mock_guess_file_metadata, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "import_type": "local"}
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_metadata.return_value = {"type": "csv", "field_separator": ",", "quote_char": '"', "record_separator": "\n"}
+
+ request = APIRequestFactory().get("importer/file/guess_metadata/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "import_type": "local"}
+ request.fs = None
+
+ response = api.guess_file_metadata(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {"type": "csv", "field_separator": ",", "quote_char": '"', "record_separator": "\n"}
+ mock_guess_file_metadata.assert_called_once_with(file_path="/path/to/test.csv", import_type="local", fs=None)
+
+ @patch("desktop.lib.importer.api.GuessFileMetadataSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_metadata")
+ def test_guess_excel_file_metadata_success(self, mock_guess_file_metadata, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.xlsx", "import_type": "local"}
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_metadata.return_value = {"type": "excel", "sheet_names": ["Sheet1", "Sheet2"]}
+
+ request = APIRequestFactory().get("importer/file/guess_metadata/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.xlsx", "import_type": "local"}
+ request.fs = None
+
+ response = api.guess_file_metadata(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {"type": "excel", "sheet_names": ["Sheet1", "Sheet2"]}
+ mock_guess_file_metadata.assert_called_once_with(file_path="/path/to/test.xlsx", import_type="local", fs=None)
+
+ @patch("desktop.lib.importer.api.GuessFileMetadataSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_metadata")
+ def test_guess_file_metadata_remote_csv_file(self, mock_guess_file_metadata, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True), validated_data={"file_path": "s3a://bucket/user/test_user/test.csv", "import_type": "remote"}
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_metadata.return_value = {"type": "csv", "field_separator": ",", "quote_char": '"', "record_separator": "\n"}
+ mock_fs = MagicMock()
+
+ request = APIRequestFactory().get("importer/file/guess_metadata/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "s3a://bucket/user/test_user/test.csv", "import_type": "remote"}
+ request.fs = mock_fs
+
+ response = api.guess_file_metadata(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {"type": "csv", "field_separator": ",", "quote_char": '"', "record_separator": "\n"}
+ mock_guess_file_metadata.assert_called_once_with(file_path="s3a://bucket/user/test_user/test.csv", import_type="remote", fs=mock_fs)
+
+ @patch("desktop.lib.importer.api.GuessFileMetadataSerializer")
+ def test_guess_file_metadata_invalid_data(self, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"file_path": ["This field is required"]})
+ mock_serializer_class.return_value = mock_serializer
+
+ request = APIRequestFactory().get("importer/file/guess_metadata/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {}
+
+ response = api.guess_file_metadata(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"file_path": ["This field is required"]}
+
+ @patch("desktop.lib.importer.api.GuessFileMetadataSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_metadata")
+ def test_guess_file_metadata_value_error(self, mock_guess_file_metadata, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "import_type": "local"}
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_metadata.side_effect = ValueError("File does not exist")
+
+ request = APIRequestFactory().get("importer/file/guess_metadata/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "import_type": "local"}
+ request.fs = None
+
+ response = api.guess_file_metadata(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"error": "File does not exist"}
+
+ @patch("desktop.lib.importer.api.GuessFileMetadataSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_metadata")
+ def test_guess_file_metadata_operation_error(self, mock_guess_file_metadata, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "import_type": "local"}
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_metadata.side_effect = RuntimeError("Operation error")
+
+ request = APIRequestFactory().get("importer/file/guess_metadata/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "import_type": "local"}
+ request.fs = None
+
+ response = api.guess_file_metadata(request)
+
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
+ assert response.data == {"error": "Operation error"}
+
+
+class TestPreviewFileAPI:
+ @patch("desktop.lib.importer.api.PreviewFileSerializer")
+ @patch("desktop.lib.importer.api.operations.preview_file")
+ def test_preview_csv_file_success(self, mock_preview_file, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True),
+ validated_data={
+ "file_path": "/path/to/test.csv",
+ "file_type": "csv",
+ "import_type": "local",
+ "sql_dialect": "hive",
+ "has_header": True,
+ "field_separator": ",",
+ "quote_char": '"',
+ "record_separator": "\n",
+ },
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_preview_result = {
+ "type": "csv",
+ "columns": [{"name": "col1", "type": "INT"}, {"name": "col2", "type": "STRING"}],
+ "preview_data": {"col1": [1, 2], "col2": ["a", "b"]},
+ }
+ mock_preview_file.return_value = mock_preview_result
+
+ request = APIRequestFactory().get("importer/file/preview/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+ request.fs = None
+
+ response = api.preview_file(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == mock_preview_result
+ mock_preview_file.assert_called_once_with(
+ file_path="/path/to/test.csv",
+ file_type="csv",
+ import_type="local",
+ sql_dialect="hive",
+ has_header=True,
+ field_separator=",",
+ quote_char='"',
+ record_separator="\n",
+ fs=None,
+ )
+
+ @patch("desktop.lib.importer.api.PreviewFileSerializer")
+ @patch("desktop.lib.importer.api.operations.preview_file")
+ def test_preview_excel_file_success(self, mock_preview_file, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True),
+ validated_data={
+ "file_path": "/path/to/test.xlsx",
+ "file_type": "excel",
+ "import_type": "local",
+ "sql_dialect": "hive",
+ "has_header": True,
+ "sheet_name": "Sheet1",
+ },
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_preview_result = {
+ "type": "excel",
+ "columns": [{"name": "col1", "type": "INT"}, {"name": "col2", "type": "STRING"}],
+ "preview_data": {"col1": [1, 2], "col2": ["a", "b"]},
+ }
+ mock_preview_file.return_value = mock_preview_result
+
+ request = APIRequestFactory().get("importer/file/preview/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.xlsx", "file_type": "excel", "import_type": "local"}
+ request.fs = None
+
+ response = api.preview_file(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == mock_preview_result
+ mock_preview_file.assert_called_once_with(
+ file_path="/path/to/test.xlsx",
+ file_type="excel",
+ import_type="local",
+ sql_dialect="hive",
+ has_header=True,
+ sheet_name="Sheet1",
+ fs=None,
+ )
+
+ @patch("desktop.lib.importer.api.PreviewFileSerializer")
+ @patch("desktop.lib.importer.api.operations.preview_file")
+ def test_preview_tsv_file_success(self, mock_preview_file, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True),
+ validated_data={
+ "file_path": "/path/to/test.tsv",
+ "file_type": "tsv",
+ "import_type": "local",
+ "sql_dialect": "impala",
+ "has_header": True,
+ "field_separator": "\t",
+ "quote_char": '"',
+ "record_separator": "\n",
+ },
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_preview_result = {
+ "type": "tsv",
+ "columns": [{"name": "id", "type": "INT"}, {"name": "name", "type": "STRING"}],
+ "preview_data": {"id": [1, 2], "name": ["Product A", "Product B"]},
+ }
+ mock_preview_file.return_value = mock_preview_result
+
+ request = APIRequestFactory().get("importer/file/preview/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.tsv", "file_type": "tsv", "import_type": "local"}
+ request.fs = None
+
+ response = api.preview_file(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == mock_preview_result
+ mock_preview_file.assert_called_once_with(
+ file_path="/path/to/test.tsv",
+ file_type="tsv",
+ import_type="local",
+ sql_dialect="impala",
+ has_header=True,
+ field_separator="\t",
+ quote_char='"',
+ record_separator="\n",
+ fs=None,
+ )
+
+ @patch("desktop.lib.importer.api.PreviewFileSerializer")
+ @patch("desktop.lib.importer.api.operations.preview_file")
+ def test_preview_remote_csv_file_success(self, mock_preview_file, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True),
+ validated_data={
+ "file_path": "s3a://bucket/user/test_user/test.csv",
+ "file_type": "csv",
+ "import_type": "remote",
+ "sql_dialect": "hive",
+ "has_header": True,
+ "field_separator": ",",
+ "quote_char": '"',
+ "record_separator": "\n",
+ },
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_preview_result = {
+ "type": "csv",
+ "columns": [{"name": "col1", "type": "INT"}, {"name": "col2", "type": "STRING"}],
+ "preview_data": {"col1": [1, 2], "col2": ["a", "b"]},
+ }
+
+ mock_preview_file.return_value = mock_preview_result
+ mock_fs = MagicMock()
+
+ request = APIRequestFactory().get("importer/file/preview/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"}
+ request.fs = mock_fs
+
+ response = api.preview_file(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == mock_preview_result
+ mock_preview_file.assert_called_once_with(
+ file_path="s3a://bucket/user/test_user/test.csv",
+ file_type="csv",
+ import_type="remote",
+ sql_dialect="hive",
+ has_header=True,
+ field_separator=",",
+ quote_char='"',
+ record_separator="\n",
+ fs=mock_fs,
+ )
+
+ @patch("desktop.lib.importer.api.PreviewFileSerializer")
+ def test_preview_file_invalid_data(self, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"file_type": ["Not a valid choice."]})
+ mock_serializer_class.return_value = mock_serializer
+
+ request = APIRequestFactory().get("importer/file/preview/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.pdf", "file_type": "pdf", "import_type": "local"}
+
+ response = api.preview_file(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"file_type": ["Not a valid choice."]}
+
+ @patch("desktop.lib.importer.api.PreviewFileSerializer")
+ def test_preview_file_missing_required_param(self, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"sql_dialect": ["This field is required."]})
+ mock_serializer_class.return_value = mock_serializer
+
+ request = APIRequestFactory().get("importer/file/preview/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+
+ response = api.preview_file(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"sql_dialect": ["This field is required."]}
+
+ @patch("desktop.lib.importer.api.PreviewFileSerializer")
+ @patch("desktop.lib.importer.api.operations.preview_file")
+ def test_preview_file_value_error(self, mock_preview_file, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True),
+ validated_data={
+ "file_path": "/path/to/test.csv",
+ "file_type": "csv",
+ "import_type": "local",
+ "sql_dialect": "hive",
+ "has_header": True,
+ "field_separator": ",",
+ "quote_char": '"',
+ "record_separator": "\n",
+ },
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_preview_file.side_effect = ValueError("File does not exist")
+
+ request = APIRequestFactory().get("importer/file/preview/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+ request.fs = None
+
+ response = api.preview_file(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"error": "File does not exist"}
+
+ @patch("desktop.lib.importer.api.PreviewFileSerializer")
+ @patch("desktop.lib.importer.api.operations.preview_file")
+ def test_preview_file_operation_error(self, mock_preview_file, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True),
+ validated_data={
+ "file_path": "/path/to/test.csv",
+ "file_type": "csv",
+ "import_type": "local",
+ "sql_dialect": "hive",
+ "has_header": True,
+ "field_separator": ",",
+ "quote_char": '"',
+ "record_separator": "\n",
+ },
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_preview_file.side_effect = RuntimeError("Operation error")
+
+ request = APIRequestFactory().get("importer/file/preview/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+ request.fs = None
+
+ response = api.preview_file(request)
+
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
+ assert response.data == {"error": "Operation error"}
+
+
+class TestGuessFileHeaderAPI:
+ @patch("desktop.lib.importer.api.GuessFileHeaderSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_header")
+ def test_guess_csv_file_header_success(self, mock_guess_file_header, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_header.return_value = True
+
+ request = APIRequestFactory().get("importer/file/guess_header/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+ request.fs = None
+
+ response = api.guess_file_header(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {"has_header": True}
+ mock_guess_file_header.assert_called_once_with(
+ file_path="/path/to/test.csv", file_type="csv", import_type="local", sheet_name=None, fs=None
+ )
+
+ @patch("desktop.lib.importer.api.GuessFileHeaderSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_header")
+ def test_guess_excel_file_header_success(self, mock_guess_file_header, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True),
+ validated_data={"file_path": "/path/to/test.xlsx", "file_type": "excel", "import_type": "local", "sheet_name": "Sheet1"},
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_header.return_value = True
+
+ request = APIRequestFactory().get("importer/file/guess_header/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.xlsx", "file_type": "excel", "import_type": "local", "sheet_name": "Sheet1"}
+ request.fs = None
+
+ response = api.guess_file_header(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {"has_header": True}
+ mock_guess_file_header.assert_called_once_with(
+ file_path="/path/to/test.xlsx", file_type="excel", import_type="local", sheet_name="Sheet1", fs=None
+ )
+
+ @patch("desktop.lib.importer.api.GuessFileHeaderSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_header")
+ def test_guess_remote_csv_file_header_success(self, mock_guess_file_header, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True),
+ validated_data={"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"},
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_header.return_value = True
+ mock_fs = MagicMock()
+
+ request = APIRequestFactory().get("importer/file/guess_header/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"}
+ request.fs = mock_fs
+
+ response = api.guess_file_header(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {"has_header": True}
+ mock_guess_file_header.assert_called_once_with(
+ file_path="s3a://bucket/user/test_user/test.csv", file_type="csv", import_type="remote", sheet_name=None, fs=mock_fs
+ )
+
+ @patch("desktop.lib.importer.api.GuessFileHeaderSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_header")
+ def test_guess_remote_csv_file_header_success_false_value(self, mock_guess_file_header, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True),
+ validated_data={"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"},
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_header.return_value = False
+ mock_fs = MagicMock()
+
+ request = APIRequestFactory().get("importer/file/guess_header/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"}
+ request.fs = mock_fs
+
+ response = api.guess_file_header(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {"has_header": False}
+ mock_guess_file_header.assert_called_once_with(
+ file_path="s3a://bucket/user/test_user/test.csv", file_type="csv", import_type="remote", sheet_name=None, fs=mock_fs
+ )
+
+ @patch("desktop.lib.importer.api.GuessFileHeaderSerializer")
+ def test_guess_file_header_invalid_data(self, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"file_type": ["This field is required"]})
+ mock_serializer_class.return_value = mock_serializer
+
+ request = APIRequestFactory().get("importer/file/guess_header/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "import_type": "local"}
+
+ response = api.guess_file_header(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"file_type": ["This field is required"]}
+
+ @patch("desktop.lib.importer.api.GuessFileHeaderSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_header")
+ def test_guess_file_header_value_error(self, mock_guess_file_header, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_header.side_effect = ValueError("File does not exist")
+
+ request = APIRequestFactory().get("importer/file/guess_header/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+ request.fs = None
+
+ response = api.guess_file_header(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"error": "File does not exist"}
+
+ @patch("desktop.lib.importer.api.GuessFileHeaderSerializer")
+ @patch("desktop.lib.importer.api.operations.guess_file_header")
+ def test_guess_file_header_operation_error(self, mock_guess_file_header, mock_serializer_class):
+ mock_serializer = MagicMock(
+ is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+ )
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_guess_file_header.side_effect = RuntimeError("Operation error")
+
+ request = APIRequestFactory().get("importer/file/guess_header/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"}
+ request.fs = None
+
+ response = api.guess_file_header(request)
+
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
+ assert response.data == {"error": "Operation error"}
+
+
+class TestSqlTypeMappingAPI:
+ @patch("desktop.lib.importer.api.SqlTypeMapperSerializer")
+ @patch("desktop.lib.importer.api.operations.get_sql_type_mapping")
+ def test_get_sql_type_mapping_success(self, mock_get_sql_type_mapping, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"sql_dialect": "hive"})
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_get_sql_type_mapping.return_value = {"Int32": "INT", "Utf8": "STRING", "Float64": "DOUBLE", "Boolean": "BOOLEAN"}
+
+ request = APIRequestFactory().get("importer/sql_type_mapping/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"sql_dialect": "hive"}
+
+ response = api.get_sql_type_mapping(request)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {"Int32": "INT", "Utf8": "STRING", "Float64": "DOUBLE", "Boolean": "BOOLEAN"}
+ mock_get_sql_type_mapping.assert_called_once_with("hive")
+
+ @patch("desktop.lib.importer.api.SqlTypeMapperSerializer")
+ def test_get_sql_type_mapping_invalid_dialect(self, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"sql_dialect": ["Not a valid choice."]})
+ mock_serializer_class.return_value = mock_serializer
+
+ request = APIRequestFactory().get("importer/sql_type_mapping/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"sql_dialect": "invalid_dialect"}
+
+ response = api.get_sql_type_mapping(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"sql_dialect": ["Not a valid choice."]}
+
+ @patch("desktop.lib.importer.api.SqlTypeMapperSerializer")
+ @patch("desktop.lib.importer.api.operations.get_sql_type_mapping")
+ def test_get_sql_type_mapping_value_error(self, mock_get_sql_type_mapping, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"sql_dialect": "hive"})
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_get_sql_type_mapping.side_effect = ValueError("Unsupported dialect")
+
+ request = APIRequestFactory().get("importer/sql_type_mapping/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"sql_dialect": "hive"}
+
+ response = api.get_sql_type_mapping(request)
+
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
+ assert response.data == {"error": "Unsupported dialect"}
+
+ @patch("desktop.lib.importer.api.SqlTypeMapperSerializer")
+ @patch("desktop.lib.importer.api.operations.get_sql_type_mapping")
+ def test_get_sql_type_mapping_operation_error(self, mock_get_sql_type_mapping, mock_serializer_class):
+ mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"sql_dialect": "hive"})
+ mock_serializer_class.return_value = mock_serializer
+
+ mock_get_sql_type_mapping.side_effect = RuntimeError("Operation error")
+
+ request = APIRequestFactory().get("importer/sql_type_mapping/")
+ request.user = MagicMock(username="test_user")
+ request.query_params = {"sql_dialect": "hive"}
+
+ response = api.get_sql_type_mapping(request)
+
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
+ assert response.data == {"error": "Operation error"}
diff --git a/desktop/core/src/desktop/lib/importer/operations.py b/desktop/core/src/desktop/lib/importer/operations.py
new file mode 100644
index 00000000000..0ecf2de6e3b
--- /dev/null
+++ b/desktop/core/src/desktop/lib/importer/operations.py
@@ -0,0 +1,736 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import codecs
+import csv
+import logging
+import os
+import tempfile
+import uuid
+import xml.etree.ElementTree as ET
+import zipfile
+from io import BytesIO
+from typing import Any, BinaryIO, Dict, List, Optional, Union
+
+import polars as pl
+
+LOG = logging.getLogger()
+
+try:
+ # Import this at the module level to avoid checking availability in each _detect_file_type method call
+ import magic
+
+ is_magic_lib_available = True
+except ImportError as e:
+ LOG.exception(f"Failed to import python-magic: {str(e)}")
+ is_magic_lib_available = False
+
+
+# Base mapping for most SQL engines (Hive, Impala, SparkSQL)
+SQL_TYPE_BASE_MAP = {
+ # signed ints
+ "Int8": "TINYINT",
+ "Int16": "SMALLINT",
+ "Int32": "INT",
+ "Int64": "BIGINT",
+ # unsigned ints: same size signed by default (Hive/Impala/SparkSQL)
+ "UInt8": "TINYINT",
+ "UInt16": "SMALLINT",
+ "UInt32": "INT",
+ "UInt64": "BIGINT",
+ # floats & decimal
+ "Float32": "FLOAT",
+ "Float64": "DOUBLE",
+ "Decimal": "DECIMAL", # Hive/Impala/SparkSQL use DECIMAL(precision,scale)
+ # boolean, string, binary
+ "Boolean": "BOOLEAN",
+ "Utf8": "STRING", # STRING covers Hive/VARCHAR/CHAR for default
+ "String": "STRING",
+ "Categorical": "STRING",
+ "Enum": "STRING",
+ "Binary": "BINARY",
+ # temporal
+ "Date": "DATE",
+ "Time": "TIMESTAMP", # Hive/Impala/SparkSQL have no pure TIME type
+ "Datetime": "TIMESTAMP",
+ "Duration": "INTERVAL DAY TO SECOND",
+ # nested & other
+ "Array": "ARRAY",
+ "List": "ARRAY",
+ "Struct": "STRUCT",
+ "Object": "STRING",
+ "Null": "STRING", # no SQL NULL type—use STRING or handle as special case
+ "Unknown": "STRING",
+}
+
+# Per‑dialect overrides for the few differences
+SQL_TYPE_DIALECT_OVERRIDES = {
+ "hive": {},
+ "impala": {},
+ "sparksql": {},
+ "trino": {
+ "Int32": "INTEGER",
+ "UInt32": "INTEGER",
+ "Utf8": "VARCHAR",
+ "String": "VARCHAR",
+ "Binary": "VARBINARY",
+ "Float32": "REAL",
+ "Struct": "ROW",
+ "Object": "JSON",
+ "Duration": "INTERVAL DAY TO SECOND", # explicit SQL syntax
+ },
+ "phoenix": {
+ **{f"UInt{b}": f"UNSIGNED_{t}" for b, t in [(8, "TINYINT"), (16, "SMALLINT"), (32, "INT"), (64, "LONG")]},
+ "Utf8": "VARCHAR",
+ "String": "VARCHAR",
+ "Binary": "VARBINARY",
+ "Duration": "STRING", # Phoenix treats durations as strings
+ "Struct": "STRING", # no native STRUCT type
+ "Object": "VARCHAR",
+ "Time": "TIME", # Phoenix has its own TIME type
+ "Decimal": "DECIMAL", # up to precision 38
+ },
+}
+
+
+def local_file_upload(upload_file, username: str) -> Dict[str, str]:
+ """Uploads a local file to a temporary directory with a unique filename.
+
+ This function takes an uploaded file and username, generates a unique filename,
+ and saves the file to a temporary directory. The filename is created using
+ the username, a unique ID, and a sanitized version of the original filename.
+
+ Args:
+ upload_file: The uploaded file object from Django's file upload handling.
+ username: The username of the user uploading the file.
+
+ Returns:
+ Dict[str, str]: A dictionary containing:
+ - file_path: The full path where the file was saved
+
+ Raises:
+ ValueError: If upload_file or username is None/empty
+ Exception: If there are issues with file operations
+
+ Example:
+ >>> result = upload_local_file(request.FILES["file"], "hue_user")
+ >>> print(result)
+ {'file_path': '/tmp/hue_user_a1b2c3d4_myfile.txt'}
+ """
+ if not upload_file:
+ raise ValueError("Upload file cannot be None or empty.")
+
+ if not username:
+ raise ValueError("Username cannot be None or empty.")
+
+ # Generate a unique filename
+ unique_id = uuid.uuid4().hex[:8]
+ filename = f"{username}_{unique_id}_{upload_file.name}"
+
+ # Create a temporary file with our generated filename
+ temp_dir = tempfile.gettempdir()
+ destination_path = os.path.join(temp_dir, filename)
+
+ try:
+ # Simply write the file content to temporary location
+ with open(destination_path, "wb") as destination:
+ for chunk in upload_file.chunks():
+ destination.write(chunk)
+
+ return {"file_path": destination_path}
+
+ except Exception as e:
+ if os.path.exists(destination_path):
+ LOG.debug(f"Error during local file upload, cleaning up temporary file: {destination_path}")
+ os.remove(destination_path)
+ raise e
+
+
+def guess_file_metadata(file_path: str, import_type: str, fs=None) -> Dict[str, Any]:
+ """Guess the metadata of a file based on its content or extension.
+
+ Args:
+ file_path: Path to the file to analyze
+ import_type: Type of import ('local' or 'remote')
+ fs: File system object for remote files (default: None)
+
+ Returns:
+ Dict containing the file metadata:
+ - type: File type (e.g., excel, csv, tsv)
+ - sheet_names: List of sheet names (for Excel files)
+ - field_separator: Field separator character (for delimited files)
+ - quote_char: Quote character (for delimited files)
+ - record_separator: Record separator (for delimited files)
+
+ Raises:
+ ValueError: If the file does not exist or parameters are invalid
+ Exception: For various file processing errors
+ """
+ if not file_path:
+ raise ValueError("File path cannot be empty")
+
+ if import_type not in ["local", "remote"]:
+ raise ValueError(f"Unsupported import type: {import_type}")
+
+ if import_type == "remote" and fs is None:
+ raise ValueError("File system object is required for remote import type")
+
+ # Check if file exists based on import type
+ if import_type == "local" and not os.path.exists(file_path):
+ raise ValueError(f"Local file does not exist: {file_path}")
+ elif import_type == "remote" and fs and not fs.exists(file_path):
+ raise ValueError(f"Remote file does not exist: {file_path}")
+
+ error_occurred = False
+ fh = open(file_path, "rb") if import_type == "local" else fs.open(file_path, "rb")
+
+ try:
+ sample = fh.read(16 * 1024) # Read 16 KiB sample
+
+ if not sample:
+ raise ValueError("File is empty, cannot detect file format.")
+
+ file_type = _detect_file_type(sample)
+
+ if file_type == "unknown":
+ raise ValueError("Unable to detect file format.")
+
+ if file_type == "excel":
+ metadata = _get_excel_metadata(fh)
+ else:
+ # CSV, TSV, or other delimited formats
+ metadata = _get_delimited_metadata(sample, file_type)
+
+ return metadata
+
+ except Exception as e:
+ error_occurred = True
+ LOG.exception(f"Error guessing file metadata: {e}", exc_info=True)
+ raise e
+
+ finally:
+ fh.close()
+ if import_type == "local" and error_occurred and os.path.exists(file_path):
+ LOG.debug(f"Due to error in guess_file_metadata, cleaning up uploaded local file: {file_path}")
+ os.remove(file_path)
+
+
+def preview_file(
+ file_path: str,
+ file_type: str,
+ import_type: str,
+ sql_dialect: str,
+ has_header: bool = False,
+ sheet_name: Optional[str] = None,
+ field_separator: Optional[str] = ",",
+ quote_char: Optional[str] = '"',
+ record_separator: Optional[str] = "\n",
+ fs=None,
+ preview_rows: int = 50,
+) -> Dict[str, Any]:
+ """Generate a preview of a file's content with column type mapping.
+
+ This method reads a file and returns a preview of its contents, along with
+ column information and metadata for creating tables or further processing.
+
+ Args:
+ file_path: Path to the file to preview
+ file_type: Type of file ('excel', 'csv', 'tsv', 'delimiter_format')
+ import_type: Type of import ('local' or 'remote')
+ sql_dialect: SQL dialect for type mapping ('hive', 'impala', etc.)
+ has_header: Whether the file has a header row or not
+ sheet_name: Sheet name for Excel files (required for Excel)
+ field_separator: Field separator character for delimited files
+ quote_char: Quote character for delimited files
+ record_separator: Record separator for delimited files
+ fs: File system object for remote files (default: None)
+ preview_rows: Number of rows to include in preview (default: 50)
+
+ Returns:
+ Dict containing:
+ - type: File type
+ - columns: List of column metadata (name, type)
+ - preview_data: Preview of the file data
+
+ Raises:
+ ValueError: If the file does not exist or parameters are invalid
+ Exception: For various file processing errors
+ """
+ if not file_path:
+ raise ValueError("File path cannot be empty")
+
+ if sql_dialect.lower() not in ["hive", "impala", "trino", "phoenix", "sparksql"]:
+ raise ValueError(f"Unsupported SQL dialect: {sql_dialect}")
+
+ if file_type not in ["excel", "csv", "tsv", "delimiter_format"]:
+ raise ValueError(f"Unsupported file type: {file_type}")
+
+ if import_type not in ["local", "remote"]:
+ raise ValueError(f"Unsupported import type: {import_type}")
+
+ if import_type == "remote" and fs is None:
+ raise ValueError("File system object is required for remote import type")
+
+ # Check if file exists based on import type
+ if import_type == "local" and not os.path.exists(file_path):
+ raise ValueError(f"Local file does not exist: {file_path}")
+ elif import_type == "remote" and fs and not fs.exists(file_path):
+ raise ValueError(f"Remote file does not exist: {file_path}")
+
+ error_occurred = False
+ fh = open(file_path, "rb") if import_type == "local" else fs.open(file_path, "rb")
+
+ try:
+ if file_type == "excel":
+ if not sheet_name:
+ raise ValueError("Sheet name is required for Excel files.")
+
+ preview = _preview_excel_file(fh, file_type, sheet_name, sql_dialect, has_header, preview_rows)
+ elif file_type in ["csv", "tsv", "delimiter_format"]:
+ # Process escapable characters
+ try:
+ if field_separator:
+ field_separator = codecs.decode(field_separator, "unicode_escape")
+ if quote_char:
+ quote_char = codecs.decode(quote_char, "unicode_escape")
+ if record_separator:
+ record_separator = codecs.decode(record_separator, "unicode_escape")
+
+ except Exception as e:
+ LOG.exception(f"Error decoding escape characters: {e}", exc_info=True)
+ raise ValueError("Invalid escape characters in field_separator, quote_char, or record_separator.")
+
+ preview = _preview_delimited_file(fh, file_type, field_separator, quote_char, record_separator, sql_dialect, has_header, preview_rows)
+ else:
+ raise ValueError(f"Unsupported file type: {file_type}")
+
+ return preview
+ except Exception as e:
+ error_occurred = True
+ LOG.exception(f"Error previewing file: {e}", exc_info=True)
+ raise e
+
+ finally:
+ fh.close()
+ if import_type == "local" and error_occurred and os.path.exists(file_path):
+ LOG.debug(f"Due to error in preview_file, cleaning up uploaded local file: {file_path}")
+ os.remove(file_path)
+
+
+def _detect_file_type(file_sample: bytes) -> str:
+ """Detect the file type based on its content.
+
+ Args:
+ file_sample: Binary sample of the file content
+
+ Returns:
+ String indicating the detected file type ('excel', 'delimiter_format', or 'unknown')
+
+ Raises:
+ RuntimeError: If python-magic or libmagic is not available
+ Exception: If an error occurs during file type detection
+ """
+ # Check if magic library is available
+ if not is_magic_lib_available:
+ error = "Unable to guess file type. python-magic or its dependency libmagic is not installed."
+ LOG.error(error)
+ raise RuntimeError(error)
+
+ try:
+ # Use libmagic to detect MIME type from content
+ file_type = magic.from_buffer(file_sample, mime=True)
+
+ # Map MIME type to the simplified type categories
+ if any(keyword in file_type for keyword in ["excel", "spreadsheet", "officedocument.sheet"]):
+ return "excel"
+ elif any(keyword in file_type for keyword in ["text", "csv", "plain"]):
+ # For text files, analyze the content later to determine specific format
+ return "delimiter_format"
+
+ LOG.info(f"Detected MIME type: {file_type}, but not recognized as supported format")
+ return "unknown"
+ except Exception as e:
+ message = f"Error detecting file type: {e}"
+ LOG.exception(message, exc_info=True)
+
+ raise Exception(message)
+
+
+def _get_excel_metadata(fh: BinaryIO) -> Dict[str, Any]:
+ """Extract metadata for Excel files (.xlsx, .xls).
+
+ Args:
+ fh: File handle for the Excel file
+
+ Returns:
+ Dict containing Excel metadata:
+ - type: 'excel'
+ - sheet_names: List of sheet names
+
+ Raises:
+ Exception: If there's an error processing the Excel file
+ """
+ try:
+ fh.seek(0)
+ try:
+ sheet_names = _get_sheet_names_xlsx(BytesIO(fh.read()))
+ except Exception:
+ LOG.warning("Failed to read Excel file for sheet names with Zip + XML parsing approach, trying next with Polars.")
+
+ # Possibly some other format instead of .xlsx
+ fh.seek(0)
+
+ sheet_names = pl.read_excel(
+ BytesIO(fh.read()), sheet_id=0, infer_schema_length=10000, read_options={"n_rows": 0}
+ ).keys() # No need to read rows for detecting sheet names
+
+ return {
+ "type": "excel",
+ "sheet_names": list(sheet_names),
+ }
+ except Exception as e:
+ message = f"Error extracting Excel file metadata: {e}"
+ LOG.error(message, exc_info=True)
+
+ raise Exception(message)
+
+
+def _get_sheet_names_xlsx(fh: BinaryIO) -> List[str]:
+ """Quickly list sheet names from an .xlsx file handle.
+
+ - Uses only the stdlib (zipfile + xml.etree).
+ - Parses only the small `xl/workbook.xml` metadata (~10-20KB).
+ - No full worksheet data is loaded into memory.
+
+ Args:
+ fh: Binary file-like object of the workbook.
+
+ Returns:
+ A list of worksheet names.
+ """
+ with zipfile.ZipFile(fh, "r") as z, z.open("xl/workbook.xml") as f:
+ tree = ET.parse(f)
+
+ # XML namespace for SpreadsheetML
+ ns = {"x": "http://schemas.openxmlformats.org/spreadsheetml/2006/main"}
+ sheets = tree.getroot().find("x:sheets", ns)
+
+ return [s.get("name") for s in sheets]
+
+
+def _get_delimited_metadata(file_sample: Union[bytes, str], file_type: str) -> Dict[str, Any]:
+ """Extract metadata for delimited files (CSV, TSV, etc.).
+
+ Args:
+ file_sample: Binary or string sample of the file content
+ file_type: Initial file type detection ('delimiter_format')
+
+ Returns:
+ Dict containing delimited file metadata:
+ - type: Specific file type ('csv', 'tsv', etc.)
+ - field_separator: Field separator character
+ - quote_char: Quote character
+ - record_separator: Record separator character
+
+ Raises:
+ Exception: If there's an error processing the delimited file
+ """
+ if isinstance(file_sample, bytes):
+ file_sample = file_sample.decode("utf-8", errors="replace")
+
+ # Use CSV Sniffer to detect delimiter and other formatting
+ try:
+ dialect = csv.Sniffer().sniff(file_sample)
+ except Exception as sniff_error:
+ message = f"Failed to sniff delimited file: {sniff_error}"
+ LOG.exception(message)
+
+ raise Exception(message)
+
+ # Refine file type based on detected delimiter
+ if file_type == "delimiter_format":
+ if dialect.delimiter == ",":
+ file_type = "csv"
+ elif dialect.delimiter == "\t":
+ file_type = "tsv"
+ # Other delimiters remain as 'delimiter_format'
+
+ return {
+ "type": file_type,
+ "field_separator": dialect.delimiter,
+ "quote_char": dialect.quotechar,
+ "record_separator": dialect.lineterminator,
+ }
+
+
+def _preview_excel_file(
+ fh: BinaryIO, file_type: str, sheet_name: str, dialect: str, has_header: bool, preview_rows: int = 50
+) -> Dict[str, Any]:
+ """Preview an Excel file (.xlsx, .xls)
+
+ Args:
+ fh: File handle for the Excel file
+ file_type: Type of file ('excel')
+ sheet_name: Name of the sheet to preview
+ dialect: SQL dialect for type mapping
+ has_header: Whether the file has a header row or not
+ preview_rows: Number of rows to include in preview (default: 50)
+
+ Returns:
+ Dict containing:
+ - type: 'excel'
+ - columns: List of column metadata
+ - preview_data: Preview of file data
+
+ Raises:
+ Exception: If there's an error processing the Excel file
+ """
+ try:
+ fh.seek(0)
+
+ df = pl.read_excel(
+ BytesIO(fh.read()), sheet_name=sheet_name, has_header=has_header, read_options={"n_rows": preview_rows}, infer_schema_length=10000
+ )
+
+ # Return empty result if the df is empty
+ if df.height == 0:
+ return {"type": file_type, "columns": [], "preview_data": {}}
+
+ schema = df.schema
+ preview_data = df.to_dict(as_series=False)
+
+ # Create column metadata with SQL type mapping
+ columns = []
+ for col in df.columns:
+ col_type = str(schema[col])
+ sql_type = _map_polars_dtype_to_sql_type(dialect, col_type)
+
+ columns.append({"name": col, "type": sql_type})
+
+ result = {"type": file_type, "columns": columns, "preview_data": preview_data}
+
+ return result
+
+ except Exception as e:
+ message = f"Error previewing Excel file: {e}"
+ LOG.error(message, exc_info=True)
+
+ raise Exception(message)
+
+
+def _preview_delimited_file(
+ fh: BinaryIO,
+ file_type: str,
+ field_separator: str,
+ quote_char: str,
+ record_separator: str,
+ dialect: str,
+ has_header: bool,
+ preview_rows: int = 50,
+) -> Dict[str, Any]:
+ """Preview a delimited file (CSV, TSV, etc.)
+
+ Args:
+ fh: File handle for the delimited file
+ file_type: Type of file ('csv', 'tsv', 'delimiter_format')
+ field_separator: Field separator character
+ quote_char: Quote character
+ record_separator: Record separator character
+ dialect: SQL dialect for type mapping
+ has_header: Whether the file has a header row or not
+ preview_rows: Number of rows to include in preview (default: 50)
+
+ Returns:
+ Dict containing:
+ - type: File type
+ - columns: List of column metadata
+ - preview_data: Preview of file data
+
+ Raises:
+ Exception: If there's an error processing the delimited file
+ """
+ try:
+ fh.seek(0)
+
+ df = pl.read_csv(
+ BytesIO(fh.read()),
+ separator=field_separator,
+ quote_char=quote_char,
+ eol_char="\n" if record_separator == "\r\n" else record_separator,
+ has_header=has_header,
+ infer_schema_length=10000,
+ n_rows=preview_rows,
+ ignore_errors=True,
+ )
+
+ # Return empty result if the df is empty
+ if df.height == 0:
+ return {"type": file_type, "columns": [], "preview_data": {}}
+
+ schema = df.schema
+ preview_data = df.to_dict(as_series=False)
+
+ # Create detailed column metadata with SQL type mapping
+ columns = []
+ for col in df.columns:
+ col_type = str(schema[col])
+ sql_type = _map_polars_dtype_to_sql_type(dialect, col_type)
+
+ columns.append({"name": col, "type": sql_type})
+
+ result = {"type": file_type, "columns": columns, "preview_data": preview_data}
+
+ return result
+
+ except Exception as e:
+ message = f"Error previewing delimited file: {e}"
+
+ LOG.error(message, exc_info=True)
+ raise Exception(message)
+
+
+def guess_file_header(file_path: str, file_type: str, import_type: str, sheet_name: Optional[str] = None, fs=None) -> bool:
+ """Guess whether a file has a header row.
+
+ This function analyzes a file to determine if it contains a header row based on the
+ content pattern. It works for both Excel files and delimited text files (CSV, TSV, etc.).
+
+ Args:
+ file_path: Path to the file to analyze
+ file_type: Type of file ('excel', 'csv', 'tsv', 'delimiter_format')
+ import_type: Type of import ('local' or 'remote')
+ sheet_name: Sheet name for Excel files (required for Excel)
+ fs: File system object for remote files (default: None)
+
+ Returns:
+ has_header: Boolean indicating whether the file has a header row
+
+ Raises:
+ ValueError: If the file does not exist or parameters are invalid
+ Exception: For various file processing errors
+ """
+ if not file_path:
+ raise ValueError("File path cannot be empty")
+
+ if file_type not in ["excel", "csv", "tsv", "delimiter_format"]:
+ raise ValueError(f"Unsupported file type: {file_type}")
+
+ if import_type not in ["local", "remote"]:
+ raise ValueError(f"Unsupported import type: {import_type}")
+
+ if import_type == "remote" and fs is None:
+ raise ValueError("File system object is required for remote import type")
+
+ # Check if file exists based on import type
+ if import_type == "local" and not os.path.exists(file_path):
+ raise ValueError(f"Local file does not exist: {file_path}")
+ elif import_type == "remote" and fs and not fs.exists(file_path):
+ raise ValueError(f"Remote file does not exist: {file_path}")
+
+ fh = open(file_path, "rb") if import_type == "local" else fs.open(file_path, "rb")
+
+ has_header = False
+
+ try:
+ if file_type == "excel":
+ if not sheet_name:
+ raise ValueError("Sheet name is required for Excel files.")
+
+ # Convert excel sample to CSV for header detection
+ try:
+ fh.seek(0)
+
+ csv_snippet = pl.read_excel(
+ source=BytesIO(fh.read()), sheet_name=sheet_name, infer_schema_length=10000, read_options={"n_rows": 20}
+ ).write_csv(file=None)
+
+ if isinstance(csv_snippet, bytes):
+ csv_snippet = csv_snippet.decode("utf-8", errors="replace")
+
+ has_header = csv.Sniffer().has_header(csv_snippet)
+ LOG.info(f"Detected header for Excel file: {has_header}")
+
+ except Exception as e:
+ message = f"Error detecting header in Excel file: {e}"
+ LOG.exception(message, exc_info=True)
+
+ raise Exception(message)
+
+ elif file_type in ["csv", "tsv", "delimiter_format"]:
+ try:
+ # Reset file position
+ fh.seek(0)
+
+ # Read 16 KiB sample
+ sample = fh.read(16 * 1024).decode("utf-8", errors="replace")
+
+ has_header = csv.Sniffer().has_header(sample)
+ LOG.info(f"Detected header for delimited file: {has_header}")
+
+ except Exception as e:
+ message = f"Error detecting header in delimited file: {e}"
+ LOG.exception(message, exc_info=True)
+
+ raise Exception(message)
+
+ return has_header
+
+ finally:
+ fh.close()
+
+
+def get_sql_type_mapping(dialect: str) -> Dict[str, str]:
+ """Get all type mappings from Polars dtypes to SQL types for a given SQL dialect.
+
+ This function returns a dictionary mapping of all Polars data types to their
+ corresponding SQL types for a specific dialect.
+
+ Args:
+ dialect: One of "hive", "impala", "trino", "phoenix", "sparksql".
+
+ Returns:
+ A dict mapping Polars dtype names to SQL type names.
+
+ Raises:
+ ValueError: If the dialect is not supported.
+ """
+ dl = dialect.lower()
+ if dl not in SQL_TYPE_DIALECT_OVERRIDES:
+ raise ValueError(f"Unsupported dialect: {dialect}")
+
+ # Merge base_map and overrides[dl] into a new dict, giving precedence to any overlapping keys in overrides[dl]
+ return {**SQL_TYPE_BASE_MAP, **SQL_TYPE_DIALECT_OVERRIDES[dl]}
+
+
+def _map_polars_dtype_to_sql_type(dialect: str, polars_type: str) -> str:
+ """Map a Polars dtype to the corresponding SQL type for a given dialect.
+
+ Args:
+ dialect: One of "hive", "impala", "trino", "phoenix", "sparksql".
+ polars_type: Polars dtype name as string.
+
+ Returns:
+ A string representing the SQL type.
+
+ Raises:
+ ValueError: If the dialect or polars_type is not supported.
+ """
+ mapping = get_sql_type_mapping(dialect)
+
+ if polars_type not in mapping:
+ raise ValueError(f"No mapping for Polars dtype {polars_type} in dialect {dialect}")
+
+ return mapping[polars_type]
diff --git a/desktop/core/src/desktop/lib/importer/operations_tests.py b/desktop/core/src/desktop/lib/importer/operations_tests.py
new file mode 100644
index 00000000000..49ade93168d
--- /dev/null
+++ b/desktop/core/src/desktop/lib/importer/operations_tests.py
@@ -0,0 +1,622 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+import zipfile
+from unittest.mock import MagicMock, mock_open, patch
+
+import pytest
+from django.core.files.uploadedfile import SimpleUploadedFile
+
+from desktop.lib.importer import operations
+
+
+class TestLocalFileUpload:
+ @patch("uuid.uuid4")
+ def test_local_file_upload_success(self, mock_uuid):
+ # Mock uuid to get a predictable filename
+ mock_uuid.return_value.hex = "12345678"
+
+ test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv")
+
+ result = operations.local_file_upload(test_file, "test_user")
+
+ # Get the expected file path
+ temp_dir = tempfile.gettempdir()
+ expected_path = os.path.join(temp_dir, "test_user_12345678_test_file.csv")
+
+ try:
+ assert "file_path" in result
+ assert result["file_path"] == expected_path
+
+ # Verify the file was created and has the right content
+ assert os.path.exists(expected_path)
+ with open(expected_path, "rb") as f:
+ assert f.read() == b"header1,header2\nvalue1,value2"
+
+ finally:
+ # Clean up the file
+ if os.path.exists(expected_path):
+ os.remove(expected_path)
+
+ assert not os.path.exists(expected_path), "Temporary file was not cleaned up properly"
+
+ def test_local_file_upload_none_file(self):
+ with pytest.raises(ValueError, match="Upload file cannot be None or empty."):
+ operations.local_file_upload(None, "test_user")
+
+ def test_local_file_upload_none_username(self):
+ test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv")
+
+ with pytest.raises(ValueError, match="Username cannot be None or empty."):
+ operations.local_file_upload(test_file, None)
+
+ @patch("os.path.join")
+ @patch("builtins.open", new_callable=mock_open)
+ def test_local_file_upload_exception_handling(self, mock_file_open, mock_join):
+ # Setup mocks to raise an exception when opening the file
+ mock_file_open.side_effect = IOError("Test IO Error")
+ mock_join.return_value = "/tmp/test_user_12345678_test_file.csv"
+
+ test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv")
+
+ with pytest.raises(Exception, match="Test IO Error"):
+ operations.local_file_upload(test_file, "test_user")
+
+
+@pytest.mark.usefixtures("cleanup_temp_files")
+class TestGuessFileMetadata:
+ @pytest.fixture
+ def cleanup_temp_files(self):
+ """Fixture to clean up temporary files after tests."""
+ temp_files = []
+
+ yield temp_files
+
+ # Clean up after test
+ for file_path in temp_files:
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ @patch("desktop.lib.importer.operations.is_magic_lib_available", True)
+ @patch("desktop.lib.importer.operations.magic")
+ def test_guess_file_metadata_csv(self, mock_magic, cleanup_temp_files):
+ # Create a temporary CSV file
+ test_content = "col1,col2,col3\nval1,val2,val3\nval4,val5,val6"
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ # Mock magic.from_buffer to return text/csv MIME type
+ mock_magic.from_buffer.return_value = "text/plain"
+
+ result = operations.guess_file_metadata(temp_file.name, "local")
+
+ assert result == {
+ "type": "csv",
+ "field_separator": ",",
+ "quote_char": '"',
+ "record_separator": "\r\n",
+ }
+
+ @patch("desktop.lib.importer.operations.is_magic_lib_available", True)
+ @patch("desktop.lib.importer.operations.magic")
+ def test_guess_file_metadata_tsv(self, mock_magic, cleanup_temp_files):
+ # Create a temporary TSV file
+ test_content = "col1\tcol2\tcol3\nval1\tval2\tval3\nval4\tval5\tval6"
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ # Mock magic.from_buffer to return text/plain MIME type
+ mock_magic.from_buffer.return_value = "text/plain"
+
+ result = operations.guess_file_metadata(temp_file.name, "local")
+
+ assert result == {
+ "type": "tsv",
+ "field_separator": "\t",
+ "quote_char": '"',
+ "record_separator": "\r\n",
+ }
+
+ @patch("desktop.lib.importer.operations.is_magic_lib_available", True)
+ @patch("desktop.lib.importer.operations.magic")
+ @patch("desktop.lib.importer.operations._get_sheet_names_xlsx")
+ def test_guess_file_metadata_excel(self, mock_get_sheet_names, mock_magic, cleanup_temp_files):
+ # Create a simple .xlsx file
+ test_content = """
+
+
+
+
+
+
+ """
+
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx")
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ # Mock magic.from_buffer to return Excel MIME type
+ mock_magic.from_buffer.return_value = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
+
+ # Mock _get_sheet_names_xlsx to return sheet names
+ mock_get_sheet_names.return_value = ["Sheet1", "Sheet2", "Sheet3"]
+
+ result = operations.guess_file_metadata(temp_file.name, "local")
+
+ assert result == {
+ "type": "excel",
+ "sheet_names": ["Sheet1", "Sheet2", "Sheet3"],
+ }
+
+ @patch("desktop.lib.importer.operations.is_magic_lib_available", True)
+ @patch("desktop.lib.importer.operations.magic")
+ def test_guess_file_metadata_unsupported_type(self, mock_magic, cleanup_temp_files):
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.write(b"Binary content")
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ # Mock magic.from_buffer to return an unsupported MIME type
+ mock_magic.from_buffer.return_value = "application/octet-stream"
+
+ with pytest.raises(ValueError, match="Unable to detect file format."):
+ operations.guess_file_metadata(temp_file.name, "local")
+
+ def test_guess_file_metadata_nonexistent_file(self):
+ file_path = "/path/to/nonexistent/file.csv"
+
+ with pytest.raises(ValueError, match="Local file does not exist."):
+ operations.guess_file_metadata(file_path, "local")
+
+ def test_guess_remote_file_metadata_no_fs(self):
+ with pytest.raises(ValueError, match="File system object is required for remote import type"):
+ operations.guess_file_metadata(
+ file_path="s3a://bucket/user/test_user/test.csv", # Remote file path
+ import_type="remote", # Remote file but no fs provided
+ )
+
+ def test_guess_file_metadata_empty_file(self, cleanup_temp_files):
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ with pytest.raises(ValueError, match="File is empty, cannot detect file format."):
+ operations.guess_file_metadata(temp_file.name, "local")
+
+ @patch("desktop.lib.importer.operations.is_magic_lib_available", False)
+ def test_guess_file_metadata_no_magic_lib(self, cleanup_temp_files):
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.write(b"Content")
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ with pytest.raises(RuntimeError, match="Unable to guess file type. python-magic or its dependency libmagic is not installed."):
+ operations.guess_file_metadata(temp_file.name, "local")
+
+
+@pytest.mark.usefixtures("cleanup_temp_files")
+class TestPreviewFile:
+ @pytest.fixture
+ def cleanup_temp_files(self):
+ """Fixture to clean up temporary files after tests."""
+ temp_files = []
+
+ yield temp_files
+
+ # Clean up after test
+ for file_path in temp_files:
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ @patch("desktop.lib.importer.operations.pl")
+ def test_preview_excel_file(self, mock_pl, cleanup_temp_files):
+ # Minimal Excel file content (not a real Excel binary, just placeholder for test)
+ test_content = """
+
+
+
+
+ """
+
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx")
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ mock_df = MagicMock(
+ height=2,
+ columns=["col1", "col2"],
+ schema={"col1": "Int32", "col2": "Utf8"},
+ to_dict=MagicMock(return_value={"col1": [1, 2], "col2": ["foo", "bar"]}),
+ )
+
+ mock_pl.read_excel.return_value = mock_df
+
+ result = operations.preview_file(
+ file_path=temp_file.name, file_type="excel", import_type="local", sql_dialect="hive", has_header=True, sheet_name="Sheet1"
+ )
+
+ assert result == {
+ "type": "excel",
+ "columns": [
+ {"name": "col1", "type": "INT"},
+ {"name": "col2", "type": "STRING"},
+ ],
+ "preview_data": {"col1": [1, 2], "col2": ["foo", "bar"]},
+ }
+
+ def test_preview_csv_file(self, cleanup_temp_files):
+ # Create a temporary CSV file
+ test_content = "col1,col2\n1.1,true\n2.2,false"
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ result = operations.preview_file(
+ file_path=temp_file.name,
+ file_type="csv",
+ import_type="local",
+ sql_dialect="hive",
+ has_header=True,
+ field_separator=",",
+ quote_char='"',
+ record_separator="\n",
+ )
+
+ assert result == {
+ "type": "csv",
+ "columns": [
+ {"name": "col1", "type": "DOUBLE"},
+ {"name": "col2", "type": "BOOLEAN"},
+ ],
+ "preview_data": {"col1": [1.1, 2.2], "col2": [True, False]},
+ }
+
+ def test_preview_csv_file_with_header_false(self, cleanup_temp_files):
+ # Create a temporary CSV file
+ test_content = "sample1,sample2\n1.1,true\n2.2,false"
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ result = operations.preview_file(
+ file_path=temp_file.name,
+ file_type="csv",
+ import_type="local",
+ sql_dialect="hive",
+ has_header=False, # No header in this case
+ field_separator=",",
+ quote_char='"',
+ record_separator="\n",
+ )
+
+ assert result == {
+ "type": "csv",
+ "columns": [{"name": "column_1", "type": "STRING"}, {"name": "column_2", "type": "STRING"}],
+ "preview_data": {"column_1": ["sample1", "1.1", "2.2"], "column_2": ["sample2", "true", "false"]},
+ }
+
+ def test_preview_empty_csv_file(self, cleanup_temp_files):
+ # Create a temporary CSV file
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.write(b" ")
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ result = operations.preview_file(
+ file_path=temp_file.name,
+ file_type="csv",
+ import_type="local",
+ sql_dialect="hive",
+ has_header=True,
+ field_separator=",",
+ quote_char='"',
+ record_separator="\n",
+ )
+
+ assert result == {
+ "type": "csv",
+ "columns": [],
+ "preview_data": {},
+ }
+
+ def test_preview_invalid_file_path(self):
+ with pytest.raises(ValueError, match="File path cannot be empty"):
+ operations.preview_file(file_path="", file_type="csv", import_type="local", sql_dialect="hive", has_header=True)
+
+ def test_preview_unsupported_file_type(self):
+ with pytest.raises(ValueError, match="Unsupported file type: json"):
+ operations.preview_file(
+ file_path="/path/to/test.json",
+ file_type="json", # Unsupported type
+ import_type="local",
+ sql_dialect="hive",
+ has_header=True,
+ )
+
+ def test_preview_unsupported_sql_dialect(self):
+ with pytest.raises(ValueError, match="Unsupported SQL dialect: mysql"):
+ operations.preview_file(
+ file_path="/path/to/test.csv",
+ file_type="csv",
+ import_type="local",
+ sql_dialect="mysql", # Unsupported dialect
+ has_header=True,
+ )
+
+ def test_preview_remote_file_no_fs(self):
+ with pytest.raises(ValueError, match="File system object is required for remote import type"):
+ operations.preview_file(
+ file_path="s3a://bucket/user/test_user/test.csv", # Remote file path
+ file_type="csv",
+ import_type="remote", # Remote file but no fs provided
+ sql_dialect="hive",
+ has_header=True,
+ )
+
+ @patch("os.path.exists")
+ def test_preview_nonexistent_local_file(self, mock_exists):
+ mock_exists.return_value = False
+
+ with pytest.raises(ValueError, match="Local file does not exist: /path/to/nonexistent.csv"):
+ operations.preview_file(
+ file_path="/path/to/nonexistent.csv", file_type="csv", import_type="local", sql_dialect="hive", has_header=True
+ )
+
+ def test_preview_trino_dialect_type_mapping(self, cleanup_temp_files):
+ test_content = "string_col\nfoo\nbar"
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ result = operations.preview_file(
+ file_path=temp_file.name,
+ file_type="csv",
+ import_type="local",
+ sql_dialect="trino", # Trino dialect for different type mapping
+ has_header=True,
+ field_separator=",",
+ )
+
+ # Check the result for Trino-specific type mapping
+ assert result["columns"][0]["type"] == "VARCHAR" # Not STRING
+
+
+@pytest.mark.usefixtures("cleanup_temp_files")
+class TestGuessFileHeader:
+ @pytest.fixture
+ def cleanup_temp_files(self):
+ """Fixture to clean up temporary files after tests."""
+ temp_files = []
+
+ yield temp_files
+
+ # Clean up after test
+ for file_path in temp_files:
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ def test_guess_header_csv(self, cleanup_temp_files):
+ test_content = "header1,header2\nvalue1,value2\nvalue3,value4"
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ result = operations.guess_file_header(file_path=temp_file.name, file_type="csv", import_type="local")
+
+ assert result
+
+ @patch("desktop.lib.importer.operations.pl")
+ @patch("desktop.lib.importer.operations.csv.Sniffer")
+ def test_guess_header_excel(self, mock_sniffer, mock_pl, cleanup_temp_files):
+ test_content = """
+
+
+
+
+ """
+
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx")
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ # Mock polars read_excel and CSV conversion
+ mock_pl.read_excel.return_value.write_csv.return_value = "header1,header2\nvalue1,value2\nvalue3,value4"
+
+ # Mock csv.Sniffer
+ mock_sniffer_instance = MagicMock()
+ mock_sniffer_instance.has_header.return_value = True
+ mock_sniffer.return_value = mock_sniffer_instance
+
+ result = operations.guess_file_header(file_path=temp_file.name, file_type="excel", import_type="local", sheet_name="Sheet1")
+
+ assert result
+
+ def test_guess_header_excel_no_sheet_name(self, cleanup_temp_files):
+ test_content = """
+
+ """
+
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx")
+ temp_file.write(test_content.encode("utf-8"))
+ temp_file.close()
+
+ cleanup_temp_files.append(temp_file.name)
+
+ with pytest.raises(ValueError, match="Sheet name is required for Excel files"):
+ operations.guess_file_header(
+ file_path=temp_file.name,
+ file_type="excel",
+ import_type="local",
+ # Missing sheet_name
+ )
+
+ def test_guess_header_invalid_path(self):
+ with pytest.raises(ValueError, match="File path cannot be empty"):
+ operations.guess_file_header(file_path="", file_type="csv", import_type="local")
+
+ def test_guess_header_unsupported_file_type(self):
+ with pytest.raises(ValueError, match="Unsupported file type: json"):
+ operations.guess_file_header(
+ file_path="/path/to/test.json",
+ file_type="json", # Unsupported type
+ import_type="local",
+ )
+
+ def test_guess_header_nonexistent_local_file(self):
+ with pytest.raises(ValueError, match="Local file does not exist"):
+ operations.guess_file_header(file_path="/path/to/nonexistent.csv", file_type="csv", import_type="local")
+
+ def test_guess_header_remote_file_no_fs(self):
+ with pytest.raises(ValueError, match="File system object is required for remote import type"):
+ operations.guess_file_header(
+ file_path="hdfs:///path/to/test.csv",
+ file_type="csv",
+ import_type="remote", # Remote but no fs provided
+ )
+
+
+class TestSqlTypeMapping:
+ def test_get_sql_type_mapping_hive(self):
+ mappings = operations.get_sql_type_mapping("hive")
+
+ # Check some key mappings for Hive
+ assert mappings["Int32"] == "INT"
+ assert mappings["Utf8"] == "STRING"
+ assert mappings["Float64"] == "DOUBLE"
+ assert mappings["Boolean"] == "BOOLEAN"
+ assert mappings["Decimal"] == "DECIMAL"
+
+ def test_get_sql_type_mapping_trino(self):
+ mappings = operations.get_sql_type_mapping("trino")
+
+ # Check some key mappings for Trino that differ from Hive
+ assert mappings["Int32"] == "INTEGER"
+ assert mappings["Utf8"] == "VARCHAR"
+ assert mappings["Binary"] == "VARBINARY"
+ assert mappings["Float32"] == "REAL"
+ assert mappings["Struct"] == "ROW"
+ assert mappings["Object"] == "JSON"
+
+ def test_get_sql_type_mapping_phoenix(self):
+ mappings = operations.get_sql_type_mapping("phoenix")
+
+ # Check some key mappings for Phoenix
+ assert mappings["UInt32"] == "UNSIGNED_INT"
+ assert mappings["Utf8"] == "VARCHAR"
+ assert mappings["Time"] == "TIME"
+ assert mappings["Struct"] == "STRING" # Phoenix treats structs as strings
+ assert mappings["Duration"] == "STRING" # Phoenix treats durations as strings
+
+ def test_get_sql_type_mapping_impala(self):
+ result = operations.get_sql_type_mapping("impala")
+
+ # Impala uses the base mappings, so check those
+ assert result["Int32"] == "INT"
+ assert result["Int64"] == "BIGINT"
+ assert result["Float64"] == "DOUBLE"
+ assert result["Utf8"] == "STRING"
+
+ def test_get_sql_type_mapping_sparksql(self):
+ result = operations.get_sql_type_mapping("sparksql")
+
+ # SparkSQL uses the base mappings, so check those
+ assert result["Int32"] == "INT"
+ assert result["Int64"] == "BIGINT"
+ assert result["Float64"] == "DOUBLE"
+ assert result["Utf8"] == "STRING"
+
+ def test_get_sql_type_mapping_unsupported_dialect(self):
+ with pytest.raises(ValueError, match="Unsupported dialect: mysql"):
+ operations.get_sql_type_mapping("mysql")
+
+ def test_map_polars_dtype_to_sql_type(self):
+ # Test with Hive dialect
+ assert operations._map_polars_dtype_to_sql_type("hive", "Int64") == "BIGINT"
+ assert operations._map_polars_dtype_to_sql_type("hive", "Float32") == "FLOAT"
+
+ # Test with Trino dialect
+ assert operations._map_polars_dtype_to_sql_type("trino", "Int64") == "BIGINT"
+ assert operations._map_polars_dtype_to_sql_type("trino", "Float32") == "REAL"
+
+ # Test unsupported type
+ with pytest.raises(ValueError, match="No mapping for Polars dtype"):
+ operations._map_polars_dtype_to_sql_type("hive", "NonExistentType")
+
+
+@pytest.mark.usefixtures("cleanup_temp_files")
+class TestExcelSheetNames:
+ @pytest.fixture
+ def cleanup_temp_files(self):
+ """Fixture to clean up temporary files after tests."""
+ temp_files = []
+
+ yield temp_files
+
+ # Clean up after test
+ for file_path in temp_files:
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ def test_get_sheet_names_xlsx(self, cleanup_temp_files):
+ # Create a temporary Excel file with minimal XML structure
+ with tempfile.NamedTemporaryFile(suffix=".xlsx") as tmp_file:
+ # Create a minimal XLSX with workbook.xml
+ with zipfile.ZipFile(tmp_file.name, "w") as zip_file:
+ # Create a minimal workbook.xml file
+ workbook_xml = """
+
+
+
+
+
+
+ """
+ zip_file.writestr("xl/workbook.xml", workbook_xml)
+
+ cleanup_temp_files.append(tmp_file.name)
+
+ # Test the function
+ with open(tmp_file.name, "rb") as f:
+ sheet_names = operations._get_sheet_names_xlsx(f)
+
+ assert sheet_names == ["Sheet1", "Sheet2", "CustomSheet"]
diff --git a/desktop/core/src/desktop/lib/importer/serializers.py b/desktop/core/src/desktop/lib/importer/serializers.py
new file mode 100644
index 00000000000..7f20efb9ee1
--- /dev/null
+++ b/desktop/core/src/desktop/lib/importer/serializers.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from rest_framework import serializers
+
+from desktop.conf import IMPORTER
+
+
+class LocalFileUploadSerializer(serializers.Serializer):
+ """Serializer for file upload validation.
+
+ This serializer validates that the uploaded file is present and has an
+ acceptable file format and size.
+
+ Attributes:
+ file: File field that must be included in the request
+ """
+
+ file = serializers.FileField(required=True, help_text="CSV or Excel file to upload and process")
+
+ def validate_file(self, value):
+ # Check if the file type is restricted
+ _, file_type = os.path.splitext(value.name)
+ restricted_extensions = IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.get()
+ if restricted_extensions and file_type.lower() in [ext.lower() for ext in restricted_extensions]:
+ raise serializers.ValidationError(f'Uploading files with type "{file_type}" is not allowed. Hue is configured to restrict this type.')
+
+ # Check file size
+ max_size = IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.get()
+ if value.size > max_size:
+ max_size_mib = max_size / (1024 * 1024)
+ raise serializers.ValidationError(f"File too large. Maximum file size is {max_size_mib:.0f} MiB.")
+
+ return value
+
+
+class GuessFileMetadataSerializer(serializers.Serializer):
+ """Serializer for file metadata guessing request validation.
+
+ This serializer validates the parameters required for guessing metadata from a file.
+
+ Attributes:
+ file_path: Path to the file
+ import_type: Type of import (local or remote)
+ """
+
+ file_path = serializers.CharField(required=True, help_text="Full path to the file to analyze")
+ import_type = serializers.ChoiceField(
+ choices=["local", "remote"], required=True, help_text="Whether the file is local or on a remote filesystem"
+ )
+
+
+class PreviewFileSerializer(serializers.Serializer):
+ """Serializer for file preview request validation.
+
+ This serializer validates the parameters required for previewing file content.
+
+ Attributes:
+ file_path: Path to the file to preview
+ file_type: Type of file format (csv, tsv, excel, delimiter_format)
+ import_type: Type of import (local or remote)
+ sql_dialect: Target SQL dialect for type mapping
+ has_header: Whether the file has a header row or not
+ sheet_name: Sheet name for Excel files (required when file_type is excel)
+ field_separator: Field separator character (required for delimited files)
+ quote_char: Quote character (required for delimited files)
+ record_separator: Record separator character (required for delimited files)
+ """
+
+ file_path = serializers.CharField(required=True, help_text="Full path to the file to preview")
+ file_type = serializers.ChoiceField(
+ choices=["csv", "tsv", "excel", "delimiter_format"], required=True, help_text="Type of file (csv, tsv, excel, delimiter_format)"
+ )
+ import_type = serializers.ChoiceField(
+ choices=["local", "remote"], required=True, help_text="Whether the file is local or on a remote filesystem"
+ )
+ sql_dialect = serializers.ChoiceField(
+ choices=["hive", "impala", "trino", "phoenix", "sparksql"], required=True, help_text="SQL dialect for mapping column types"
+ )
+
+ has_header = serializers.BooleanField(required=True, help_text="Whether the file has a header row or not")
+
+ # Excel-specific fields
+ sheet_name = serializers.CharField(required=False, help_text="Sheet name for Excel files")
+
+ # Delimited file-specific fields
+ field_separator = serializers.CharField(required=False, help_text="Field separator character")
+ quote_char = serializers.CharField(required=False, help_text="Quote character")
+ record_separator = serializers.CharField(required=False, help_text="Record separator character")
+
+ def validate(self, data):
+ """Validate the complete data set with interdependent field validation."""
+
+ if data.get("file_type") == "excel" and not data.get("sheet_name"):
+ raise serializers.ValidationError({"sheet_name": "Sheet name is required for Excel files."})
+
+ if data.get("file_type") in ["csv", "tsv", "delimiter_format"]:
+ if not data.get("field_separator"):
+ # If not provided, set default value based on file type
+ if data.get("file_type") == "csv":
+ data["field_separator"] = ","
+ elif data.get("file_type") == "tsv":
+ data["field_separator"] = "\t"
+ else:
+ raise serializers.ValidationError({"field_separator": "Field separator is required for delimited files"})
+
+ if not data.get("quote_char"):
+ data["quote_char"] = '"' # Default quote character
+
+ if not data.get("record_separator"):
+ data["record_separator"] = "\n" # Default record separator
+
+ return data
+
+
+class SqlTypeMapperSerializer(serializers.Serializer):
+ """Serializer for SQL type mapping requests.
+
+ This serializer validates the parameters required for retrieving type mapping information
+ from Polars data types to SQL types for a specific dialect.
+
+ Attributes:
+ sql_dialect: Target SQL dialect for type mapping
+ """
+
+ sql_dialect = serializers.ChoiceField(
+ choices=["hive", "impala", "trino", "phoenix", "sparksql"], required=True, help_text="SQL dialect for mapping column types"
+ )
+
+
+class GuessFileHeaderSerializer(serializers.Serializer):
+ """Serializer for file header guessing request validation.
+
+ This serializer validates the parameters required for guessing if a file has a header row.
+
+ Attributes:
+ file_path: Path to the file to analyze
+ file_type: Type of file format (csv, tsv, excel, delimiter_format)
+ import_type: Type of import (local or remote)
+ sheet_name: Sheet name for Excel files (required when file_type is excel)
+ """
+
+ file_path = serializers.CharField(required=True, help_text="Full path to the file to analyze")
+ file_type = serializers.ChoiceField(
+ choices=["csv", "tsv", "excel", "delimiter_format"], required=True, help_text="Type of file (csv, tsv, excel, delimiter_format)"
+ )
+ import_type = serializers.ChoiceField(
+ choices=["local", "remote"], required=True, help_text="Whether the file is local or on a remote filesystem"
+ )
+
+ # Excel-specific fields
+ sheet_name = serializers.CharField(required=False, help_text="Sheet name for Excel files")
+
+ def validate(self, data):
+ """Validate the complete data set with interdependent field validation."""
+
+ if data.get("file_type") == "excel" and not data.get("sheet_name"):
+ raise serializers.ValidationError({"sheet_name": "Sheet name is required for Excel files."})
+
+ return data
diff --git a/desktop/core/src/desktop/lib/importer/serializers_tests.py b/desktop/core/src/desktop/lib/importer/serializers_tests.py
new file mode 100644
index 00000000000..529b8f1b7a1
--- /dev/null
+++ b/desktop/core/src/desktop/lib/importer/serializers_tests.py
@@ -0,0 +1,373 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from django.core.files.uploadedfile import SimpleUploadedFile
+
+from desktop.conf import IMPORTER
+from desktop.lib.importer.serializers import (
+ GuessFileHeaderSerializer,
+ GuessFileMetadataSerializer,
+ LocalFileUploadSerializer,
+ PreviewFileSerializer,
+ SqlTypeMapperSerializer,
+)
+
+
+class TestLocalFileUploadSerializer:
+ def test_valid_file(self):
+ resets = [
+ IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.set_for_testing([]),
+ IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.set_for_testing(10 * 1024 * 1024), # 10 MiB limit
+ ]
+ try:
+ test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv")
+
+ serializer = LocalFileUploadSerializer(data={"file": test_file})
+
+ assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}"
+ assert serializer.validated_data["file"] == test_file
+ finally:
+ for reset in resets:
+ reset()
+
+ def test_restricted_file_extension(self):
+ resets = [
+ IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.set_for_testing([".sh", ".csv", ".exe"]),
+ IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.set_for_testing(10 * 1024 * 1024), # 10 MiB limit
+ ]
+ try:
+ test_file = SimpleUploadedFile(
+ name="test_file.exe", content=b"This is not a real executable", content_type="application/octet-stream"
+ )
+
+ serializer = LocalFileUploadSerializer(data={"file": test_file})
+
+ assert not serializer.is_valid()
+ assert "file" in serializer.errors
+ assert serializer.errors["file"][0] == 'Uploading files with type ".exe" is not allowed. Hue is configured to restrict this type.'
+ finally:
+ for reset in resets:
+ reset()
+
+ def test_file_size_too_large(self):
+ resets = [
+ IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.set_for_testing([]),
+ IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.set_for_testing(10), # Just 10 bytes
+ ]
+ try:
+ test_file = SimpleUploadedFile(
+ name="test_file.csv", content=b"This content is more than 10 bytes which exceeds our mock size limit", content_type="text/csv"
+ )
+
+ serializer = LocalFileUploadSerializer(data={"file": test_file})
+
+ assert not serializer.is_valid()
+ assert "file" in serializer.errors
+ assert serializer.errors["file"][0] == "File too large. Maximum file size is 0 MiB." # 10 bytes is very less than 1 MiB
+ finally:
+ for reset in resets:
+ reset()
+
+ def test_missing_file(self):
+ serializer = LocalFileUploadSerializer(data={})
+
+ assert not serializer.is_valid()
+ assert "file" in serializer.errors
+ assert serializer.errors["file"][0] == "No file was submitted."
+
+
+class TestGuessFileMetadataSerializer:
+ def test_valid_data(self):
+ # Test with local import type
+ local_valid_data = {"file_path": "/path/to/file.csv", "import_type": "local"}
+
+ serializer = GuessFileMetadataSerializer(data=local_valid_data)
+
+ assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}"
+ assert serializer.validated_data == local_valid_data
+
+ # Test with remote import type
+ remote_valid_data = {"file_path": "s3a://bucket/user/test_user/file.csv", "import_type": "remote"}
+
+ serializer = GuessFileMetadataSerializer(data=remote_valid_data)
+
+ assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}"
+ assert serializer.validated_data == remote_valid_data
+
+ def test_missing_required_fields(self):
+ # Test missing file_path
+ invalid_data = {"import_type": "local"}
+
+ serializer = GuessFileMetadataSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "file_path" in serializer.errors
+
+ # Test missing import_type
+ invalid_data = {"file_path": "/path/to/file.csv"}
+
+ serializer = GuessFileMetadataSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "import_type" in serializer.errors
+
+ def test_invalid_import_type(self):
+ invalid_data = {
+ "file_path": "/path/to/file.csv",
+ "import_type": "invalid_type", # Not one of 'local' or 'remote'
+ }
+
+ serializer = GuessFileMetadataSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "import_type" in serializer.errors
+ assert serializer.errors["import_type"][0] == '"invalid_type" is not a valid choice.'
+
+
+class TestPreviewFileSerializer:
+ def test_valid_csv_data(self):
+ valid_data = {
+ "file_path": "/path/to/file.csv",
+ "file_type": "csv",
+ "import_type": "local",
+ "sql_dialect": "hive",
+ "has_header": True,
+ "field_separator": ",",
+ }
+
+ serializer = PreviewFileSerializer(data=valid_data)
+
+ assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}"
+ # Check that default values are set for quote_char and record_separator
+ assert serializer.validated_data["quote_char"] == '"'
+ assert serializer.validated_data["record_separator"] == "\n"
+
+ def test_valid_excel_data(self):
+ valid_data = {
+ "file_path": "/path/to/file.xlsx",
+ "file_type": "excel",
+ "import_type": "local",
+ "sql_dialect": "hive",
+ "has_header": True,
+ "sheet_name": "Sheet1",
+ }
+
+ serializer = PreviewFileSerializer(data=valid_data)
+
+ assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}"
+
+ def test_missing_required_fields(self):
+ # Test with minimal data
+ invalid_data = {
+ "file_path": "/path/to/file.csv",
+ }
+
+ serializer = PreviewFileSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ # Check that all required fields are reported as missing
+ for field in ["file_type", "import_type", "sql_dialect", "has_header"]:
+ assert field in serializer.errors
+
+ def test_invalid_file_type(self):
+ invalid_data = {
+ "file_path": "/path/to/file.json",
+ "file_type": "json", # Not a valid choice
+ "import_type": "local",
+ "sql_dialect": "hive",
+ "has_header": True,
+ }
+
+ serializer = PreviewFileSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "file_type" in serializer.errors
+ assert serializer.errors["file_type"][0] == '"json" is not a valid choice.'
+
+ def test_excel_without_sheet_name(self):
+ invalid_data = {
+ "file_path": "/path/to/file.xlsx",
+ "file_type": "excel",
+ "import_type": "local",
+ "sql_dialect": "hive",
+ "has_header": True,
+ # Missing sheet_name
+ }
+
+ serializer = PreviewFileSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "sheet_name" in serializer.errors
+ assert serializer.errors["sheet_name"][0] == "Sheet name is required for Excel files."
+
+ def test_delimited_without_field_separator(self):
+ # For delimiter_format type (not csv/tsv) field separator is required
+ invalid_data = {
+ "file_path": "/path/to/file.txt",
+ "file_type": "delimiter_format",
+ "import_type": "local",
+ "sql_dialect": "hive",
+ "has_header": True,
+ # Missing field_separator
+ }
+
+ serializer = PreviewFileSerializer(data=invalid_data)
+ assert not serializer.is_valid()
+ assert "field_separator" in serializer.errors
+
+ def test_default_separators_by_file_type(self):
+ # For CSV, field_separator should default to ','
+ csv_data = {"file_path": "/path/to/file.csv", "file_type": "csv", "import_type": "local", "sql_dialect": "hive", "has_header": True}
+
+ serializer = PreviewFileSerializer(data=csv_data)
+
+ assert serializer.is_valid(), f"CSV serializer validation failed: {serializer.errors}"
+ assert serializer.validated_data["field_separator"] == ","
+
+ # For TSV, field_separator should default to '\t'
+ tsv_data = {"file_path": "/path/to/file.tsv", "file_type": "tsv", "import_type": "local", "sql_dialect": "hive", "has_header": True}
+
+ serializer = PreviewFileSerializer(data=tsv_data)
+
+ assert serializer.is_valid(), f"TSV serializer validation failed: {serializer.errors}"
+ assert serializer.validated_data["field_separator"] == "\t"
+
+
+class TestSqlTypeMapperSerializer:
+ def test_valid_sql_dialect(self):
+ for dialect in ["hive", "impala", "trino", "phoenix", "sparksql"]:
+ valid_data = {"sql_dialect": dialect}
+
+ serializer = SqlTypeMapperSerializer(data=valid_data)
+
+ assert serializer.is_valid(), f"Failed for dialect '{dialect}': {serializer.errors}"
+ assert serializer.validated_data["sql_dialect"] == dialect
+
+ def test_invalid_sql_dialect(self):
+ invalid_data = {"sql_dialect": "invalid_dialect"}
+
+ serializer = SqlTypeMapperSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "sql_dialect" in serializer.errors
+ assert serializer.errors["sql_dialect"][0] == '"invalid_dialect" is not a valid choice.'
+
+ def test_missing_sql_dialect(self):
+ invalid_data = {} # Empty data
+
+ serializer = SqlTypeMapperSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "sql_dialect" in serializer.errors
+ assert serializer.errors["sql_dialect"][0] == "This field is required."
+
+
+class TestGuessFileHeaderSerializer:
+ def test_valid_data_csv(self):
+ valid_data = {"file_path": "/path/to/file.csv", "file_type": "csv", "import_type": "local"}
+
+ serializer = GuessFileHeaderSerializer(data=valid_data)
+
+ assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}"
+ assert serializer.validated_data == valid_data
+
+ def test_valid_data_excel(self):
+ valid_data = {"file_path": "/path/to/file.xlsx", "file_type": "excel", "import_type": "local", "sheet_name": "Sheet1"}
+
+ serializer = GuessFileHeaderSerializer(data=valid_data)
+
+ assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}"
+ assert serializer.validated_data == valid_data
+
+ def test_missing_required_fields(self):
+ # Missing file_path
+ invalid_data = {"file_type": "csv", "import_type": "local"}
+
+ serializer = GuessFileHeaderSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "file_path" in serializer.errors
+
+ # Missing file_type
+ invalid_data = {"file_path": "/path/to/file.csv", "import_type": "local"}
+
+ serializer = GuessFileHeaderSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "file_type" in serializer.errors
+
+ # Missing import_type
+ invalid_data = {"file_path": "/path/to/file.csv", "file_type": "csv"}
+
+ serializer = GuessFileHeaderSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "import_type" in serializer.errors
+
+ def test_excel_without_sheet_name(self):
+ invalid_data = {
+ "file_path": "/path/to/file.xlsx",
+ "file_type": "excel",
+ "import_type": "local",
+ # Missing sheet_name
+ }
+
+ serializer = GuessFileHeaderSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "sheet_name" in serializer.errors
+ assert serializer.errors["sheet_name"][0] == "Sheet name is required for Excel files."
+
+ def test_non_excel_with_sheet_name(self):
+ # This should pass, as sheet_name is optional for non-Excel files
+ valid_data = {
+ "file_path": "/path/to/file.csv",
+ "file_type": "csv",
+ "import_type": "local",
+ "sheet_name": "Sheet1", # Unnecessary but not invalid
+ }
+
+ serializer = GuessFileHeaderSerializer(data=valid_data)
+
+ assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}"
+
+ def test_invalid_file_type(self):
+ invalid_data = {
+ "file_path": "/path/to/file.json",
+ "file_type": "json", # Not a valid choice
+ "import_type": "local",
+ }
+
+ serializer = GuessFileHeaderSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "file_type" in serializer.errors
+ assert serializer.errors["file_type"][0] == '"json" is not a valid choice.'
+
+ def test_invalid_import_type(self):
+ invalid_data = {
+ "file_path": "/path/to/file.csv",
+ "file_type": "csv",
+ "import_type": "invalid_type", # Not a valid choice
+ }
+
+ serializer = GuessFileHeaderSerializer(data=invalid_data)
+
+ assert not serializer.is_valid()
+ assert "import_type" in serializer.errors
+ assert serializer.errors["import_type"][0] == '"invalid_type" is not a valid choice.'