diff --git a/desktop/conf.dist/hue.ini b/desktop/conf.dist/hue.ini index fb60ad66cd7..26891840de5 100644 --- a/desktop/conf.dist/hue.ini +++ b/desktop/conf.dist/hue.ini @@ -1064,6 +1064,19 @@ tls=no ## Enable integration with Google Storage for RAZ # is_raz_gs_enabled=false +## Configuration options for the importer +# ------------------------------------------------------------------------ +[[importer]] +# Turns on the data importer functionality +## is_enabled=false + +# A limit on the local file size (bytes) that can be uploaded through the importer. The default is 157286400 bytes (150 MiB). +## max_local_file_size_upload_limit=157286400 + +# Security setting to specify local file extensions that are not allowed to be uploaded through the importer. +# Provide a comma-separated list of extensions including the dot (e.g., ".exe, .zip, .rar, .tar, .gz"). +## restrict_local_file_extensions=.exe, .zip, .rar, .tar, .gz + ########################################################################### # Settings to configure the snippets available in the Notebook ########################################################################### diff --git a/desktop/conf/pseudo-distributed.ini.tmpl b/desktop/conf/pseudo-distributed.ini.tmpl index 4afb1174f23..fabc3c72083 100644 --- a/desktop/conf/pseudo-distributed.ini.tmpl +++ b/desktop/conf/pseudo-distributed.ini.tmpl @@ -1049,6 +1049,19 @@ ## Enable integration with Google Storage for RAZ # is_raz_gs_enabled=false + ## Configuration options for the importer + # ------------------------------------------------------------------------ + [[importer]] + # Turns on the data importer functionality + ## is_enabled=false + + # A limit on the local file size (bytes) that can be uploaded through the importer. The default is 157286400 bytes (150 MiB). + ## max_local_file_size_upload_limit=157286400 + + # Security setting to specify local file extensions that are not allowed to be uploaded through the importer. + # Provide a comma-separated list of extensions including the dot (e.g., ".exe, .zip, .rar, .tar, .gz"). + ## restrict_local_file_extensions=.exe, .zip, .rar, .tar, .gz + ########################################################################### # Settings to configure the snippets available in the Notebook ########################################################################### diff --git a/desktop/core/base_requirements.txt b/desktop/core/base_requirements.txt index 61edcb5eb54..8bb398aa656 100644 --- a/desktop/core/base_requirements.txt +++ b/desktop/core/base_requirements.txt @@ -41,6 +41,7 @@ Mako==1.2.3 Markdown==3.7 openpyxl==3.0.9 phoenixdb==1.2.1 +polars[calamine]==1.8.2 # Python >= 3.8 prompt-toolkit==3.0.39 protobuf==3.20.3 pyarrow==17.0.0 @@ -48,6 +49,7 @@ pyformance==0.3.2 python-dateutil==2.8.2 python-daemon==2.2.4 python-ldap==3.4.3 +python-magic==0.4.27 python-oauth2==1.1.0 python-pam==2.0.2 pytidylib==0.3.2 diff --git a/desktop/core/generate_requirements.py b/desktop/core/generate_requirements.py index db0475efe2c..6ca37c83761 100755 --- a/desktop/core/generate_requirements.py +++ b/desktop/core/generate_requirements.py @@ -44,7 +44,6 @@ def __init__(self): ] self.requirements = [ - "setuptools==70.0.0", "apache-ranger==0.0.3", "asn1crypto==0.24.0", "avro-python3==1.8.2", @@ -88,6 +87,7 @@ def __init__(self): "Mako==1.2.3", "openpyxl==3.0.9", "phoenixdb==1.2.1", + "polars[calamine]==1.8.2", # Python >= 3.8 "prompt-toolkit==3.0.39", "protobuf==3.20.3", "psutil==5.8.0", @@ -97,6 +97,7 @@ def __init__(self): "python-daemon==2.2.4", "python-dateutil==2.8.2", "python-ldap==3.4.3", + "python-magic==0.4.27", "python-oauth2==1.1.0", "python-pam==2.0.2", "pytidylib==0.3.2", @@ -107,6 +108,7 @@ def __init__(self): "requests-kerberos==0.14.0", "rsa==4.7.2", "ruff==0.11.10", + "setuptools==70.0.0", "six==1.16.0", "slack-sdk==3.31.0", "SQLAlchemy==1.3.8", diff --git a/desktop/core/src/desktop/api2.py b/desktop/core/src/desktop/api2.py index 742e65f5ade..f1a1614d955 100644 --- a/desktop/core/src/desktop/api2.py +++ b/desktop/core/src/desktop/api2.py @@ -15,12 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import re import json import logging -import zipfile +import os +import re import tempfile +import zipfile from builtins import map from datetime import datetime from io import StringIO as string_io @@ -28,7 +28,7 @@ from celery.app.control import Control from django.core import management from django.db import transaction -from django.http import HttpResponse, JsonResponse +from django.http import HttpResponse from django.shortcuts import redirect from django.utils.html import escape from django.utils.translation import gettext as _ @@ -48,10 +48,11 @@ ENABLE_NEW_STORAGE_BROWSER, ENABLE_SHARING, ENABLE_WORKFLOW_CREATION_ACTION, - TASK_SERVER_V2, get_clusters, + IMPORTER, + TASK_SERVER_V2, ) -from desktop.lib.conf import GLOBAL_CONFIG, BoundContainer, is_anonymous +from desktop.lib.conf import BoundContainer, GLOBAL_CONFIG, is_anonymous from desktop.lib.connectors.models import Connector from desktop.lib.django_util import JsonResponse, login_notrequired, render from desktop.lib.exceptions_renderable import PopupException @@ -60,16 +61,16 @@ from desktop.lib.paths import get_desktop_root from desktop.log import DEFAULT_LOG_DIR from desktop.models import ( + __paginate, + _get_gist_document, Directory, Document, Document2, FilesystemException, - UserPreferences, - __paginate, - _get_gist_document, get_cluster_config, get_user_preferences, set_user_preferences, + UserPreferences, uuid_default, ) from desktop.views import _get_config_errors, get_banner_message, serve_403_error @@ -91,7 +92,7 @@ search_entities_interactive as metadata_search_entities_interactive, ) from metadata.conf import has_catalog -from notebook.connectors.base import Notebook, get_interpreter +from notebook.connectors.base import get_interpreter, Notebook from notebook.management.commands import notebook_setup from pig.management.commands import pig_setup from search.management.commands import search_setup @@ -132,19 +133,41 @@ def get_banners(request): @api_error_handler def get_config(request): + """ + Returns Hue application's config information. + Includes settings for various components like storage, task server, importer, etc. + """ + # Get base cluster configuration config = get_cluster_config(request.user) - config['hue_config']['is_admin'] = is_admin(request.user) - config['hue_config']['is_yarn_enabled'] = is_yarn() - config['hue_config']['enable_task_server'] = TASK_SERVER_V2.ENABLED.get() - config['hue_config']['enable_workflow_creation_action'] = ENABLE_WORKFLOW_CREATION_ACTION.get() - config['storage_browser']['enable_chunked_file_upload'] = ENABLE_CHUNKED_FILE_UPLOADER.get() - config['storage_browser']['enable_new_storage_browser'] = ENABLE_NEW_STORAGE_BROWSER.get() - config['storage_browser']['restrict_file_extensions'] = RESTRICT_FILE_EXTENSIONS.get() - config['storage_browser']['concurrent_max_connection'] = CONCURRENT_MAX_CONNECTIONS.get() - config['storage_browser']['file_upload_chunk_size'] = FILE_UPLOAD_CHUNK_SIZE.get() - config['storage_browser']['enable_file_download_button'] = SHOW_DOWNLOAD_BUTTON.get() - config['storage_browser']['max_file_editor_size'] = MAX_FILEEDITOR_SIZE - config['storage_browser']['enable_extract_uploaded_archive'] = ENABLE_EXTRACT_UPLOADED_ARCHIVE.get() + + # Core application configuration + config['hue_config'] = { + 'is_admin': is_admin(request.user), + 'is_yarn_enabled': is_yarn(), + 'enable_task_server': TASK_SERVER_V2.ENABLED.get(), + 'enable_workflow_creation_action': ENABLE_WORKFLOW_CREATION_ACTION.get(), + } + + # Storage browser configuration + config['storage_browser'] = { + 'enable_chunked_file_upload': ENABLE_CHUNKED_FILE_UPLOADER.get(), + 'enable_new_storage_browser': ENABLE_NEW_STORAGE_BROWSER.get(), + 'restrict_file_extensions': RESTRICT_FILE_EXTENSIONS.get(), + 'concurrent_max_connection': CONCURRENT_MAX_CONNECTIONS.get(), + 'file_upload_chunk_size': FILE_UPLOAD_CHUNK_SIZE.get(), + 'enable_file_download_button': SHOW_DOWNLOAD_BUTTON.get(), + 'max_file_editor_size': MAX_FILEEDITOR_SIZE, + 'enable_extract_uploaded_archive': ENABLE_EXTRACT_UPLOADED_ARCHIVE.get(), + } + + # Importer configuration + config['importer'] = { + 'is_enabled': IMPORTER.IS_ENABLED.get(), + 'restrict_local_file_extensions': IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.get(), + 'max_local_file_size_upload_limit': IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.get(), + } + + # Other general configuration config['clusters'] = list(get_clusters(request.user).values()) config['documents'] = {'types': list(Document2.objects.documents(user=request.user).order_by().values_list('type', flat=True).distinct())} config['status'] = 0 @@ -624,7 +647,7 @@ def copy_document(request): # Import workspace for all oozie jobs if document.type == 'oozie-workflow2' or document.type == 'oozie-bundle2' or document.type == 'oozie-coordinator2': - from oozie.models2 import Bundle, Coordinator, Workflow, _import_workspace + from oozie.models2 import _import_workspace, Bundle, Coordinator, Workflow # Update the name field in the json 'data' field if document.type == 'oozie-workflow2': workflow = Workflow(document=document) @@ -998,7 +1021,7 @@ def is_reserved_directory(doc): documents = json.loads(request.POST.get('documents')) documents = json.loads(documents) - except ValueError as e: + except ValueError: raise PopupException(_('Failed to import documents, the file does not contain valid JSON.')) # Validate documents @@ -1112,7 +1135,7 @@ def gist_create(request): statement = request.POST.get('statement', '') gist_type = request.POST.get('doc_type', 'hive') name = request.POST.get('name', '') - description = request.POST.get('description', '') + _ = request.POST.get('description', '') response = _gist_create(request.get_host(), request.is_secure(), request.user, statement, gist_type, name) @@ -1333,7 +1356,7 @@ def _create_or_update_document_with_owner(doc, owner, uuids_map): doc['pk'] = existing_doc.pk else: create_new = True - except FilesystemException as e: + except FilesystemException: create_new = True if create_new: diff --git a/desktop/core/src/desktop/api_public_urls_v1.py b/desktop/core/src/desktop/api_public_urls_v1.py index 50487e224b3..cf4eac70b9e 100644 --- a/desktop/core/src/desktop/api_public_urls_v1.py +++ b/desktop/core/src/desktop/api_public_urls_v1.py @@ -19,6 +19,7 @@ from desktop import api_public from desktop.lib.botserver import api as botserver_api +from desktop.lib.importer import api as importer_api # "New" query API (i.e. connector based, lean arguments). # e.g. https://demo.gethue.com/api/query/execute/hive @@ -158,6 +159,14 @@ re_path(r'^indexer/importer/submit', api_public.importer_submit, name='indexer_importer_submit'), ] +urlpatterns += [ + re_path(r'^importer/upload/file/?$', importer_api.local_file_upload, name='importer_local_file_upload'), + re_path(r'^importer/file/guess_metadata/?$', importer_api.guess_file_metadata, name='importer_guess_file_metadata'), + re_path(r'^importer/file/guess_header/?$', importer_api.guess_file_header, name='importer_guess_file_header'), + re_path(r'^importer/file/preview/?$', importer_api.preview_file, name='importer_preview_file'), + re_path(r'^importer/sql_type_mapping/?$', importer_api.get_sql_type_mapping, name='importer_get_sql_type_mapping'), +] + urlpatterns += [ re_path(r'^connector/types/?$', api_public.get_connector_types, name='connector_get_types'), re_path(r'^connector/instances/?$', api_public.get_connectors_instances, name='connector_get_instances'), diff --git a/desktop/core/src/desktop/conf.py b/desktop/core/src/desktop/conf.py index e8cb929c52e..63954bdd1ae 100644 --- a/desktop/core/src/desktop/conf.py +++ b/desktop/core/src/desktop/conf.py @@ -16,13 +16,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import sys +import datetime import glob -import stat -import socket import logging -import datetime +import os +import socket +import stat +import sys from collections import OrderedDict from django.db import connection @@ -30,16 +30,16 @@ from desktop import appmanager from desktop.lib.conf import ( - Config, - ConfigSection, - UnspecifiedConfigSection, coerce_bool, coerce_csv, coerce_json_dict, coerce_password_from_script, coerce_str_lowercase, coerce_string, + Config, + ConfigSection, list_of_compiled_res, + UnspecifiedConfigSection, validate_path, ) from desktop.lib.i18n import force_unicode @@ -106,7 +106,7 @@ def get_dn(fqdn=None): else: LOG.warning("allowed_hosts value to '*'. It is a security risk") val.append('*') - except Exception as e: + except Exception: LOG.warning("allowed_hosts value to '*'. It is a security risk") val.append('*') return val @@ -2952,3 +2952,32 @@ def is_ofs_enabled(): def has_ofs_access(user): from desktop.auth.backend import is_admin return user.is_authenticated and user.is_active and (is_admin(user) or user.has_hue_permission(action="ofs_access", app="filebrowser")) + + +IMPORTER = ConfigSection( + key='importer', + help=_("""Configuration options for the importer."""), + members=dict( + IS_ENABLED=Config( + key='is_enabled', + help=_('Enable or disable the new importer functionality'), + type=coerce_bool, + default=False, + ), + RESTRICT_LOCAL_FILE_EXTENSIONS=Config( + key='restrict_local_file_extensions', + default=None, + type=coerce_csv, + help=_( + 'Security setting to specify local file extensions that are not allowed to be uploaded through the importer. ' + 'Provide a comma-separated list of extensions including the dot (e.g., ".exe, .zip, .rar, .tar, .gz").' + ), + ), + MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT=Config( + key="max_local_file_size_upload_limit", + default=157286400, # 150 MiB + type=int, + help=_('Maximum local file size (in bytes) that users can upload through the importer. The default is 157286400 bytes (150 MiB).'), + ), + ), +) diff --git a/desktop/core/src/desktop/lib/importer/__init__.py b/desktop/core/src/desktop/lib/importer/__init__.py new file mode 100644 index 00000000000..d053ec41e3c --- /dev/null +++ b/desktop/core/src/desktop/lib/importer/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# Licensed to Cloudera, Inc. under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. Cloudera, Inc. licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/desktop/core/src/desktop/lib/importer/api.py b/desktop/core/src/desktop/lib/importer/api.py new file mode 100644 index 00000000000..8d1713905fa --- /dev/null +++ b/desktop/core/src/desktop/lib/importer/api.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python +# Licensed to Cloudera, Inc. under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. Cloudera, Inc. licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import wraps + +from rest_framework import status +from rest_framework.decorators import api_view, parser_classes +from rest_framework.parsers import JSONParser, MultiPartParser +from rest_framework.request import Request +from rest_framework.response import Response + +from desktop.lib.importer import operations +from desktop.lib.importer.serializers import ( + GuessFileHeaderSerializer, + GuessFileMetadataSerializer, + LocalFileUploadSerializer, + PreviewFileSerializer, + SqlTypeMapperSerializer, +) + +LOG = logging.getLogger() + + +# TODO: Improve error response further with better context -- Error UX Phase 2 +def api_error_handler(fn): + """ + Decorator to handle exceptions and return a JSON response with an error message. + """ + + @wraps(fn) + def decorator(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + LOG.exception(f"Error running {fn.__name__}: {str(e)}") + return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + return decorator + + +@api_view(["POST"]) +@parser_classes([MultiPartParser]) +@api_error_handler +def local_file_upload(request: Request) -> Response: + """Handle the local file upload operation. + + This endpoint allows users to upload a file from their local system. + Uploaded file is validated using LocalFileUploadSerializer and processed using local_file_upload operation. + + Args: + request: Request object containing the file to upload + + Returns: + Response containing the result of the local upload operation, including: + - file_path: Path where the file was saved (if successful) + + Note: + - File size limits apply based on MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT configuration. + - File type restrictions apply based on RESTRICT_LOCAL_FILE_EXTENSIONS configuration. + """ + + serializer = LocalFileUploadSerializer(data=request.data) + + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + uploaded_file = serializer.validated_data["file"] + + LOG.info(f"User {request.user.username} is uploading a local file: {uploaded_file.name}") + result = operations.local_file_upload(uploaded_file, request.user.username) + + return Response(result, status=status.HTTP_201_CREATED) + + +@api_view(["GET"]) +@parser_classes([JSONParser]) +@api_error_handler +def guess_file_metadata(request: Request) -> Response: + """Guess the metadata of a file based on its content or extension. + + This API endpoint detects file type and extracts metadata properties such as + delimiters for CSV files or sheet names for Excel files. + + Args: + request: Request object containing query parameters: + - file_path: Path to the file + - import_type: 'local' or 'remote' + + Returns: + Response containing file metadata including: + - type: File type (e.g., excel, csv, tsv) + - sheet_names: List of sheet names (for Excel files) + - field_separator: Field separator character (for delimited files) + - quote_char: Quote character (for delimited files) + - record_separator: Record separator (for delimited files) + """ + serializer = GuessFileMetadataSerializer(data=request.query_params) + + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + validated_data = serializer.validated_data + file_path = validated_data["file_path"] + import_type = validated_data["import_type"] + + try: + metadata = operations.guess_file_metadata( + file_path=file_path, import_type=import_type, fs=request.fs if import_type == "remote" else None + ) + + return Response(metadata, status=status.HTTP_200_OK) + + except ValueError as e: + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception as e: + # Handle other errors + LOG.exception(f"Error guessing file metadata: {e}", exc_info=True) + return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@api_view(["GET"]) +@parser_classes([JSONParser]) +@api_error_handler +def preview_file(request: Request) -> Response: + """Preview a file based on its path and import type. + + Args: + request: Request object containing query parameters for file preview + + Returns: + Response containing a dict preview of the file content, including: + - type: Type of the file (e.g., csv, tsv, excel) + - columns: List of column metadata + - preview_data: Sample data from the file + """ + serializer = PreviewFileSerializer(data=request.query_params) + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + # Get validated data + validated_data = serializer.validated_data + file_path = validated_data["file_path"] + file_type = validated_data["file_type"] + import_type = validated_data["import_type"] + sql_dialect = validated_data["sql_dialect"] + has_header = validated_data.get("has_header") + + try: + if file_type == "excel": + sheet_name = validated_data.get("sheet_name") + + preview = operations.preview_file( + file_path=file_path, + file_type=file_type, + import_type=import_type, + sql_dialect=sql_dialect, + has_header=has_header, + sheet_name=sheet_name, + fs=request.fs if import_type == "remote" else None, + ) + else: # Delimited file types + field_separator = validated_data.get("field_separator") + quote_char = validated_data.get("quote_char") + record_separator = validated_data.get("record_separator") + + preview = operations.preview_file( + file_path=file_path, + file_type=file_type, + import_type=import_type, + sql_dialect=sql_dialect, + has_header=has_header, + field_separator=field_separator, + quote_char=quote_char, + record_separator=record_separator, + fs=request.fs if import_type == "remote" else None, + ) + + return Response(preview, status=status.HTTP_200_OK) + + except ValueError as e: + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception as e: + LOG.exception(f"Error previewing file: {e}", exc_info=True) + return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@api_view(["GET"]) +@parser_classes([JSONParser]) +@api_error_handler +def guess_file_header(request: Request) -> Response: + """Guess whether a file has a header row. + + This API endpoint analyzes a file to determine if it contains a header row based on the + content pattern. It works for both Excel files and delimited text files (CSV, TSV, etc.) + + Args: + request: Request object containing query parameters: + - file_path: Path to the file + - file_type: Type of file ('excel', 'csv', 'tsv', 'delimiter_format') + - import_type: 'local' or 'remote' + - sheet_name: Sheet name for Excel files (required for Excel) + + Returns: + Response containing: + - has_header: Boolean indicating whether the file has a header row + """ + serializer = GuessFileHeaderSerializer(data=request.query_params) + + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + validated_data = serializer.validated_data + + try: + has_header = operations.guess_file_header( + file_path=validated_data["file_path"], + file_type=validated_data["file_type"], + import_type=validated_data["import_type"], + sheet_name=validated_data.get("sheet_name"), + fs=request.fs if validated_data["import_type"] == "remote" else None, + ) + + return Response({"has_header": has_header}, status=status.HTTP_200_OK) + + except ValueError as e: + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception as e: + # Handle other errors + LOG.exception(f"Error detecting file header: {e}", exc_info=True) + return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@api_view(["GET"]) +@parser_classes([JSONParser]) +@api_error_handler +def get_sql_type_mapping(request: Request) -> Response: + """Get mapping from Polars data types to SQL types for a specific dialect. + + This API endpoint returns a dictionary mapping Polars data types to the corresponding + SQL types for a specific SQL dialect. + + Args: + request: Request object containing query parameters: + - sql_dialect: The SQL dialect to get mappings for (e.g., 'hive', 'impala', 'trino') + + Returns: + Response containing a mapping dictionary: + - A mapping from Polars data type names to SQL type names for the specified dialect + """ + serializer = SqlTypeMapperSerializer(data=request.query_params) + + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + validated_data = serializer.validated_data + sql_dialect = validated_data["sql_dialect"] + + try: + type_mapping = operations.get_sql_type_mapping(sql_dialect) + return Response(type_mapping, status=status.HTTP_200_OK) + + except ValueError as e: + return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) + except Exception as e: + LOG.exception(f"Error getting SQL type mapping: {e}", exc_info=True) + return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/desktop/core/src/desktop/lib/importer/api_tests.py b/desktop/core/src/desktop/lib/importer/api_tests.py new file mode 100644 index 00000000000..945ec5aa879 --- /dev/null +++ b/desktop/core/src/desktop/lib/importer/api_tests.py @@ -0,0 +1,685 @@ +#!/usr/bin/env python +# Licensed to Cloudera, Inc. under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. Cloudera, Inc. licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +from django.core.files.uploadedfile import SimpleUploadedFile +from rest_framework import status +from rest_framework.test import APIRequestFactory + +from desktop.lib.importer import api + + +class TestLocalFileUploadAPI: + @patch("desktop.lib.importer.api.LocalFileUploadSerializer") + @patch("desktop.lib.importer.api.operations.local_file_upload") + def test_local_file_upload_success(self, mock_local_file_upload, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"file": SimpleUploadedFile("test.csv", b"content")}) + mock_serializer_class.return_value = mock_serializer + + mock_local_file_upload.return_value = {"file_path": "/tmp/user_12345_test.csv"} + + request = APIRequestFactory().post("importer/upload/file/") + request.user = MagicMock(username="test_user") + + response = api.local_file_upload(request) + + assert response.status_code == status.HTTP_201_CREATED + assert response.data == {"file_path": "/tmp/user_12345_test.csv"} + mock_local_file_upload.assert_called_once_with(mock_serializer.validated_data["file"], "test_user") + + @patch("desktop.lib.importer.api.LocalFileUploadSerializer") + def test_local_file_upload_invalid_data(self, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"file": ["File too large"]}) + mock_serializer_class.return_value = mock_serializer + + request = APIRequestFactory().post("importer/upload/file/") + request.user = MagicMock(username="test_user") + + response = api.local_file_upload(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"file": ["File too large"]} + + @patch("desktop.lib.importer.api.LocalFileUploadSerializer") + @patch("desktop.lib.importer.api.operations.local_file_upload") + def test_local_file_upload_operation_error(self, mock_local_file_upload, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"file": SimpleUploadedFile("test.csv", b"content")}) + mock_serializer_class.return_value = mock_serializer + + mock_local_file_upload.side_effect = IOError("Operation error") + + request = APIRequestFactory().post("importer/upload/file/") + request.user = MagicMock(username="test_user") + + response = api.local_file_upload(request) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert response.data == {"error": "Operation error"} + + +class TestGuessFileMetadataAPI: + @patch("desktop.lib.importer.api.GuessFileMetadataSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_metadata") + def test_guess_csv_file_metadata_success(self, mock_guess_file_metadata, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "import_type": "local"} + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_metadata.return_value = {"type": "csv", "field_separator": ",", "quote_char": '"', "record_separator": "\n"} + + request = APIRequestFactory().get("importer/file/guess_metadata/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "import_type": "local"} + request.fs = None + + response = api.guess_file_metadata(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"type": "csv", "field_separator": ",", "quote_char": '"', "record_separator": "\n"} + mock_guess_file_metadata.assert_called_once_with(file_path="/path/to/test.csv", import_type="local", fs=None) + + @patch("desktop.lib.importer.api.GuessFileMetadataSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_metadata") + def test_guess_excel_file_metadata_success(self, mock_guess_file_metadata, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.xlsx", "import_type": "local"} + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_metadata.return_value = {"type": "excel", "sheet_names": ["Sheet1", "Sheet2"]} + + request = APIRequestFactory().get("importer/file/guess_metadata/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.xlsx", "import_type": "local"} + request.fs = None + + response = api.guess_file_metadata(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"type": "excel", "sheet_names": ["Sheet1", "Sheet2"]} + mock_guess_file_metadata.assert_called_once_with(file_path="/path/to/test.xlsx", import_type="local", fs=None) + + @patch("desktop.lib.importer.api.GuessFileMetadataSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_metadata") + def test_guess_file_metadata_remote_csv_file(self, mock_guess_file_metadata, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), validated_data={"file_path": "s3a://bucket/user/test_user/test.csv", "import_type": "remote"} + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_metadata.return_value = {"type": "csv", "field_separator": ",", "quote_char": '"', "record_separator": "\n"} + mock_fs = MagicMock() + + request = APIRequestFactory().get("importer/file/guess_metadata/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "s3a://bucket/user/test_user/test.csv", "import_type": "remote"} + request.fs = mock_fs + + response = api.guess_file_metadata(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"type": "csv", "field_separator": ",", "quote_char": '"', "record_separator": "\n"} + mock_guess_file_metadata.assert_called_once_with(file_path="s3a://bucket/user/test_user/test.csv", import_type="remote", fs=mock_fs) + + @patch("desktop.lib.importer.api.GuessFileMetadataSerializer") + def test_guess_file_metadata_invalid_data(self, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"file_path": ["This field is required"]}) + mock_serializer_class.return_value = mock_serializer + + request = APIRequestFactory().get("importer/file/guess_metadata/") + request.user = MagicMock(username="test_user") + request.query_params = {} + + response = api.guess_file_metadata(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"file_path": ["This field is required"]} + + @patch("desktop.lib.importer.api.GuessFileMetadataSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_metadata") + def test_guess_file_metadata_value_error(self, mock_guess_file_metadata, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "import_type": "local"} + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_metadata.side_effect = ValueError("File does not exist") + + request = APIRequestFactory().get("importer/file/guess_metadata/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "import_type": "local"} + request.fs = None + + response = api.guess_file_metadata(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"error": "File does not exist"} + + @patch("desktop.lib.importer.api.GuessFileMetadataSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_metadata") + def test_guess_file_metadata_operation_error(self, mock_guess_file_metadata, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "import_type": "local"} + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_metadata.side_effect = RuntimeError("Operation error") + + request = APIRequestFactory().get("importer/file/guess_metadata/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "import_type": "local"} + request.fs = None + + response = api.guess_file_metadata(request) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert response.data == {"error": "Operation error"} + + +class TestPreviewFileAPI: + @patch("desktop.lib.importer.api.PreviewFileSerializer") + @patch("desktop.lib.importer.api.operations.preview_file") + def test_preview_csv_file_success(self, mock_preview_file, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), + validated_data={ + "file_path": "/path/to/test.csv", + "file_type": "csv", + "import_type": "local", + "sql_dialect": "hive", + "has_header": True, + "field_separator": ",", + "quote_char": '"', + "record_separator": "\n", + }, + ) + mock_serializer_class.return_value = mock_serializer + + mock_preview_result = { + "type": "csv", + "columns": [{"name": "col1", "type": "INT"}, {"name": "col2", "type": "STRING"}], + "preview_data": {"col1": [1, 2], "col2": ["a", "b"]}, + } + mock_preview_file.return_value = mock_preview_result + + request = APIRequestFactory().get("importer/file/preview/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + request.fs = None + + response = api.preview_file(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == mock_preview_result + mock_preview_file.assert_called_once_with( + file_path="/path/to/test.csv", + file_type="csv", + import_type="local", + sql_dialect="hive", + has_header=True, + field_separator=",", + quote_char='"', + record_separator="\n", + fs=None, + ) + + @patch("desktop.lib.importer.api.PreviewFileSerializer") + @patch("desktop.lib.importer.api.operations.preview_file") + def test_preview_excel_file_success(self, mock_preview_file, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), + validated_data={ + "file_path": "/path/to/test.xlsx", + "file_type": "excel", + "import_type": "local", + "sql_dialect": "hive", + "has_header": True, + "sheet_name": "Sheet1", + }, + ) + mock_serializer_class.return_value = mock_serializer + + mock_preview_result = { + "type": "excel", + "columns": [{"name": "col1", "type": "INT"}, {"name": "col2", "type": "STRING"}], + "preview_data": {"col1": [1, 2], "col2": ["a", "b"]}, + } + mock_preview_file.return_value = mock_preview_result + + request = APIRequestFactory().get("importer/file/preview/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.xlsx", "file_type": "excel", "import_type": "local"} + request.fs = None + + response = api.preview_file(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == mock_preview_result + mock_preview_file.assert_called_once_with( + file_path="/path/to/test.xlsx", + file_type="excel", + import_type="local", + sql_dialect="hive", + has_header=True, + sheet_name="Sheet1", + fs=None, + ) + + @patch("desktop.lib.importer.api.PreviewFileSerializer") + @patch("desktop.lib.importer.api.operations.preview_file") + def test_preview_tsv_file_success(self, mock_preview_file, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), + validated_data={ + "file_path": "/path/to/test.tsv", + "file_type": "tsv", + "import_type": "local", + "sql_dialect": "impala", + "has_header": True, + "field_separator": "\t", + "quote_char": '"', + "record_separator": "\n", + }, + ) + mock_serializer_class.return_value = mock_serializer + + mock_preview_result = { + "type": "tsv", + "columns": [{"name": "id", "type": "INT"}, {"name": "name", "type": "STRING"}], + "preview_data": {"id": [1, 2], "name": ["Product A", "Product B"]}, + } + mock_preview_file.return_value = mock_preview_result + + request = APIRequestFactory().get("importer/file/preview/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.tsv", "file_type": "tsv", "import_type": "local"} + request.fs = None + + response = api.preview_file(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == mock_preview_result + mock_preview_file.assert_called_once_with( + file_path="/path/to/test.tsv", + file_type="tsv", + import_type="local", + sql_dialect="impala", + has_header=True, + field_separator="\t", + quote_char='"', + record_separator="\n", + fs=None, + ) + + @patch("desktop.lib.importer.api.PreviewFileSerializer") + @patch("desktop.lib.importer.api.operations.preview_file") + def test_preview_remote_csv_file_success(self, mock_preview_file, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), + validated_data={ + "file_path": "s3a://bucket/user/test_user/test.csv", + "file_type": "csv", + "import_type": "remote", + "sql_dialect": "hive", + "has_header": True, + "field_separator": ",", + "quote_char": '"', + "record_separator": "\n", + }, + ) + mock_serializer_class.return_value = mock_serializer + + mock_preview_result = { + "type": "csv", + "columns": [{"name": "col1", "type": "INT"}, {"name": "col2", "type": "STRING"}], + "preview_data": {"col1": [1, 2], "col2": ["a", "b"]}, + } + + mock_preview_file.return_value = mock_preview_result + mock_fs = MagicMock() + + request = APIRequestFactory().get("importer/file/preview/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"} + request.fs = mock_fs + + response = api.preview_file(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == mock_preview_result + mock_preview_file.assert_called_once_with( + file_path="s3a://bucket/user/test_user/test.csv", + file_type="csv", + import_type="remote", + sql_dialect="hive", + has_header=True, + field_separator=",", + quote_char='"', + record_separator="\n", + fs=mock_fs, + ) + + @patch("desktop.lib.importer.api.PreviewFileSerializer") + def test_preview_file_invalid_data(self, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"file_type": ["Not a valid choice."]}) + mock_serializer_class.return_value = mock_serializer + + request = APIRequestFactory().get("importer/file/preview/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.pdf", "file_type": "pdf", "import_type": "local"} + + response = api.preview_file(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"file_type": ["Not a valid choice."]} + + @patch("desktop.lib.importer.api.PreviewFileSerializer") + def test_preview_file_missing_required_param(self, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"sql_dialect": ["This field is required."]}) + mock_serializer_class.return_value = mock_serializer + + request = APIRequestFactory().get("importer/file/preview/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + + response = api.preview_file(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"sql_dialect": ["This field is required."]} + + @patch("desktop.lib.importer.api.PreviewFileSerializer") + @patch("desktop.lib.importer.api.operations.preview_file") + def test_preview_file_value_error(self, mock_preview_file, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), + validated_data={ + "file_path": "/path/to/test.csv", + "file_type": "csv", + "import_type": "local", + "sql_dialect": "hive", + "has_header": True, + "field_separator": ",", + "quote_char": '"', + "record_separator": "\n", + }, + ) + mock_serializer_class.return_value = mock_serializer + + mock_preview_file.side_effect = ValueError("File does not exist") + + request = APIRequestFactory().get("importer/file/preview/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + request.fs = None + + response = api.preview_file(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"error": "File does not exist"} + + @patch("desktop.lib.importer.api.PreviewFileSerializer") + @patch("desktop.lib.importer.api.operations.preview_file") + def test_preview_file_operation_error(self, mock_preview_file, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), + validated_data={ + "file_path": "/path/to/test.csv", + "file_type": "csv", + "import_type": "local", + "sql_dialect": "hive", + "has_header": True, + "field_separator": ",", + "quote_char": '"', + "record_separator": "\n", + }, + ) + mock_serializer_class.return_value = mock_serializer + + mock_preview_file.side_effect = RuntimeError("Operation error") + + request = APIRequestFactory().get("importer/file/preview/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + request.fs = None + + response = api.preview_file(request) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert response.data == {"error": "Operation error"} + + +class TestGuessFileHeaderAPI: + @patch("desktop.lib.importer.api.GuessFileHeaderSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_header") + def test_guess_csv_file_header_success(self, mock_guess_file_header, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_header.return_value = True + + request = APIRequestFactory().get("importer/file/guess_header/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + request.fs = None + + response = api.guess_file_header(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"has_header": True} + mock_guess_file_header.assert_called_once_with( + file_path="/path/to/test.csv", file_type="csv", import_type="local", sheet_name=None, fs=None + ) + + @patch("desktop.lib.importer.api.GuessFileHeaderSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_header") + def test_guess_excel_file_header_success(self, mock_guess_file_header, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), + validated_data={"file_path": "/path/to/test.xlsx", "file_type": "excel", "import_type": "local", "sheet_name": "Sheet1"}, + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_header.return_value = True + + request = APIRequestFactory().get("importer/file/guess_header/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.xlsx", "file_type": "excel", "import_type": "local", "sheet_name": "Sheet1"} + request.fs = None + + response = api.guess_file_header(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"has_header": True} + mock_guess_file_header.assert_called_once_with( + file_path="/path/to/test.xlsx", file_type="excel", import_type="local", sheet_name="Sheet1", fs=None + ) + + @patch("desktop.lib.importer.api.GuessFileHeaderSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_header") + def test_guess_remote_csv_file_header_success(self, mock_guess_file_header, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), + validated_data={"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"}, + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_header.return_value = True + mock_fs = MagicMock() + + request = APIRequestFactory().get("importer/file/guess_header/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"} + request.fs = mock_fs + + response = api.guess_file_header(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"has_header": True} + mock_guess_file_header.assert_called_once_with( + file_path="s3a://bucket/user/test_user/test.csv", file_type="csv", import_type="remote", sheet_name=None, fs=mock_fs + ) + + @patch("desktop.lib.importer.api.GuessFileHeaderSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_header") + def test_guess_remote_csv_file_header_success_false_value(self, mock_guess_file_header, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), + validated_data={"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"}, + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_header.return_value = False + mock_fs = MagicMock() + + request = APIRequestFactory().get("importer/file/guess_header/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "s3a://bucket/user/test_user/test.csv", "file_type": "csv", "import_type": "remote"} + request.fs = mock_fs + + response = api.guess_file_header(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"has_header": False} + mock_guess_file_header.assert_called_once_with( + file_path="s3a://bucket/user/test_user/test.csv", file_type="csv", import_type="remote", sheet_name=None, fs=mock_fs + ) + + @patch("desktop.lib.importer.api.GuessFileHeaderSerializer") + def test_guess_file_header_invalid_data(self, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"file_type": ["This field is required"]}) + mock_serializer_class.return_value = mock_serializer + + request = APIRequestFactory().get("importer/file/guess_header/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "import_type": "local"} + + response = api.guess_file_header(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"file_type": ["This field is required"]} + + @patch("desktop.lib.importer.api.GuessFileHeaderSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_header") + def test_guess_file_header_value_error(self, mock_guess_file_header, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_header.side_effect = ValueError("File does not exist") + + request = APIRequestFactory().get("importer/file/guess_header/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + request.fs = None + + response = api.guess_file_header(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"error": "File does not exist"} + + @patch("desktop.lib.importer.api.GuessFileHeaderSerializer") + @patch("desktop.lib.importer.api.operations.guess_file_header") + def test_guess_file_header_operation_error(self, mock_guess_file_header, mock_serializer_class): + mock_serializer = MagicMock( + is_valid=MagicMock(return_value=True), validated_data={"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + ) + mock_serializer_class.return_value = mock_serializer + + mock_guess_file_header.side_effect = RuntimeError("Operation error") + + request = APIRequestFactory().get("importer/file/guess_header/") + request.user = MagicMock(username="test_user") + request.query_params = {"file_path": "/path/to/test.csv", "file_type": "csv", "import_type": "local"} + request.fs = None + + response = api.guess_file_header(request) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert response.data == {"error": "Operation error"} + + +class TestSqlTypeMappingAPI: + @patch("desktop.lib.importer.api.SqlTypeMapperSerializer") + @patch("desktop.lib.importer.api.operations.get_sql_type_mapping") + def test_get_sql_type_mapping_success(self, mock_get_sql_type_mapping, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"sql_dialect": "hive"}) + mock_serializer_class.return_value = mock_serializer + + mock_get_sql_type_mapping.return_value = {"Int32": "INT", "Utf8": "STRING", "Float64": "DOUBLE", "Boolean": "BOOLEAN"} + + request = APIRequestFactory().get("importer/sql_type_mapping/") + request.user = MagicMock(username="test_user") + request.query_params = {"sql_dialect": "hive"} + + response = api.get_sql_type_mapping(request) + + assert response.status_code == status.HTTP_200_OK + assert response.data == {"Int32": "INT", "Utf8": "STRING", "Float64": "DOUBLE", "Boolean": "BOOLEAN"} + mock_get_sql_type_mapping.assert_called_once_with("hive") + + @patch("desktop.lib.importer.api.SqlTypeMapperSerializer") + def test_get_sql_type_mapping_invalid_dialect(self, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=False), errors={"sql_dialect": ["Not a valid choice."]}) + mock_serializer_class.return_value = mock_serializer + + request = APIRequestFactory().get("importer/sql_type_mapping/") + request.user = MagicMock(username="test_user") + request.query_params = {"sql_dialect": "invalid_dialect"} + + response = api.get_sql_type_mapping(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"sql_dialect": ["Not a valid choice."]} + + @patch("desktop.lib.importer.api.SqlTypeMapperSerializer") + @patch("desktop.lib.importer.api.operations.get_sql_type_mapping") + def test_get_sql_type_mapping_value_error(self, mock_get_sql_type_mapping, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"sql_dialect": "hive"}) + mock_serializer_class.return_value = mock_serializer + + mock_get_sql_type_mapping.side_effect = ValueError("Unsupported dialect") + + request = APIRequestFactory().get("importer/sql_type_mapping/") + request.user = MagicMock(username="test_user") + request.query_params = {"sql_dialect": "hive"} + + response = api.get_sql_type_mapping(request) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.data == {"error": "Unsupported dialect"} + + @patch("desktop.lib.importer.api.SqlTypeMapperSerializer") + @patch("desktop.lib.importer.api.operations.get_sql_type_mapping") + def test_get_sql_type_mapping_operation_error(self, mock_get_sql_type_mapping, mock_serializer_class): + mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data={"sql_dialect": "hive"}) + mock_serializer_class.return_value = mock_serializer + + mock_get_sql_type_mapping.side_effect = RuntimeError("Operation error") + + request = APIRequestFactory().get("importer/sql_type_mapping/") + request.user = MagicMock(username="test_user") + request.query_params = {"sql_dialect": "hive"} + + response = api.get_sql_type_mapping(request) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert response.data == {"error": "Operation error"} diff --git a/desktop/core/src/desktop/lib/importer/operations.py b/desktop/core/src/desktop/lib/importer/operations.py new file mode 100644 index 00000000000..0ecf2de6e3b --- /dev/null +++ b/desktop/core/src/desktop/lib/importer/operations.py @@ -0,0 +1,736 @@ +#!/usr/bin/env python +# Licensed to Cloudera, Inc. under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. Cloudera, Inc. licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import csv +import logging +import os +import tempfile +import uuid +import xml.etree.ElementTree as ET +import zipfile +from io import BytesIO +from typing import Any, BinaryIO, Dict, List, Optional, Union + +import polars as pl + +LOG = logging.getLogger() + +try: + # Import this at the module level to avoid checking availability in each _detect_file_type method call + import magic + + is_magic_lib_available = True +except ImportError as e: + LOG.exception(f"Failed to import python-magic: {str(e)}") + is_magic_lib_available = False + + +# Base mapping for most SQL engines (Hive, Impala, SparkSQL) +SQL_TYPE_BASE_MAP = { + # signed ints + "Int8": "TINYINT", + "Int16": "SMALLINT", + "Int32": "INT", + "Int64": "BIGINT", + # unsigned ints: same size signed by default (Hive/Impala/SparkSQL) + "UInt8": "TINYINT", + "UInt16": "SMALLINT", + "UInt32": "INT", + "UInt64": "BIGINT", + # floats & decimal + "Float32": "FLOAT", + "Float64": "DOUBLE", + "Decimal": "DECIMAL", # Hive/Impala/SparkSQL use DECIMAL(precision,scale) + # boolean, string, binary + "Boolean": "BOOLEAN", + "Utf8": "STRING", # STRING covers Hive/VARCHAR/CHAR for default + "String": "STRING", + "Categorical": "STRING", + "Enum": "STRING", + "Binary": "BINARY", + # temporal + "Date": "DATE", + "Time": "TIMESTAMP", # Hive/Impala/SparkSQL have no pure TIME type + "Datetime": "TIMESTAMP", + "Duration": "INTERVAL DAY TO SECOND", + # nested & other + "Array": "ARRAY", + "List": "ARRAY", + "Struct": "STRUCT", + "Object": "STRING", + "Null": "STRING", # no SQL NULL type—use STRING or handle as special case + "Unknown": "STRING", +} + +# Per‑dialect overrides for the few differences +SQL_TYPE_DIALECT_OVERRIDES = { + "hive": {}, + "impala": {}, + "sparksql": {}, + "trino": { + "Int32": "INTEGER", + "UInt32": "INTEGER", + "Utf8": "VARCHAR", + "String": "VARCHAR", + "Binary": "VARBINARY", + "Float32": "REAL", + "Struct": "ROW", + "Object": "JSON", + "Duration": "INTERVAL DAY TO SECOND", # explicit SQL syntax + }, + "phoenix": { + **{f"UInt{b}": f"UNSIGNED_{t}" for b, t in [(8, "TINYINT"), (16, "SMALLINT"), (32, "INT"), (64, "LONG")]}, + "Utf8": "VARCHAR", + "String": "VARCHAR", + "Binary": "VARBINARY", + "Duration": "STRING", # Phoenix treats durations as strings + "Struct": "STRING", # no native STRUCT type + "Object": "VARCHAR", + "Time": "TIME", # Phoenix has its own TIME type + "Decimal": "DECIMAL", # up to precision 38 + }, +} + + +def local_file_upload(upload_file, username: str) -> Dict[str, str]: + """Uploads a local file to a temporary directory with a unique filename. + + This function takes an uploaded file and username, generates a unique filename, + and saves the file to a temporary directory. The filename is created using + the username, a unique ID, and a sanitized version of the original filename. + + Args: + upload_file: The uploaded file object from Django's file upload handling. + username: The username of the user uploading the file. + + Returns: + Dict[str, str]: A dictionary containing: + - file_path: The full path where the file was saved + + Raises: + ValueError: If upload_file or username is None/empty + Exception: If there are issues with file operations + + Example: + >>> result = upload_local_file(request.FILES["file"], "hue_user") + >>> print(result) + {'file_path': '/tmp/hue_user_a1b2c3d4_myfile.txt'} + """ + if not upload_file: + raise ValueError("Upload file cannot be None or empty.") + + if not username: + raise ValueError("Username cannot be None or empty.") + + # Generate a unique filename + unique_id = uuid.uuid4().hex[:8] + filename = f"{username}_{unique_id}_{upload_file.name}" + + # Create a temporary file with our generated filename + temp_dir = tempfile.gettempdir() + destination_path = os.path.join(temp_dir, filename) + + try: + # Simply write the file content to temporary location + with open(destination_path, "wb") as destination: + for chunk in upload_file.chunks(): + destination.write(chunk) + + return {"file_path": destination_path} + + except Exception as e: + if os.path.exists(destination_path): + LOG.debug(f"Error during local file upload, cleaning up temporary file: {destination_path}") + os.remove(destination_path) + raise e + + +def guess_file_metadata(file_path: str, import_type: str, fs=None) -> Dict[str, Any]: + """Guess the metadata of a file based on its content or extension. + + Args: + file_path: Path to the file to analyze + import_type: Type of import ('local' or 'remote') + fs: File system object for remote files (default: None) + + Returns: + Dict containing the file metadata: + - type: File type (e.g., excel, csv, tsv) + - sheet_names: List of sheet names (for Excel files) + - field_separator: Field separator character (for delimited files) + - quote_char: Quote character (for delimited files) + - record_separator: Record separator (for delimited files) + + Raises: + ValueError: If the file does not exist or parameters are invalid + Exception: For various file processing errors + """ + if not file_path: + raise ValueError("File path cannot be empty") + + if import_type not in ["local", "remote"]: + raise ValueError(f"Unsupported import type: {import_type}") + + if import_type == "remote" and fs is None: + raise ValueError("File system object is required for remote import type") + + # Check if file exists based on import type + if import_type == "local" and not os.path.exists(file_path): + raise ValueError(f"Local file does not exist: {file_path}") + elif import_type == "remote" and fs and not fs.exists(file_path): + raise ValueError(f"Remote file does not exist: {file_path}") + + error_occurred = False + fh = open(file_path, "rb") if import_type == "local" else fs.open(file_path, "rb") + + try: + sample = fh.read(16 * 1024) # Read 16 KiB sample + + if not sample: + raise ValueError("File is empty, cannot detect file format.") + + file_type = _detect_file_type(sample) + + if file_type == "unknown": + raise ValueError("Unable to detect file format.") + + if file_type == "excel": + metadata = _get_excel_metadata(fh) + else: + # CSV, TSV, or other delimited formats + metadata = _get_delimited_metadata(sample, file_type) + + return metadata + + except Exception as e: + error_occurred = True + LOG.exception(f"Error guessing file metadata: {e}", exc_info=True) + raise e + + finally: + fh.close() + if import_type == "local" and error_occurred and os.path.exists(file_path): + LOG.debug(f"Due to error in guess_file_metadata, cleaning up uploaded local file: {file_path}") + os.remove(file_path) + + +def preview_file( + file_path: str, + file_type: str, + import_type: str, + sql_dialect: str, + has_header: bool = False, + sheet_name: Optional[str] = None, + field_separator: Optional[str] = ",", + quote_char: Optional[str] = '"', + record_separator: Optional[str] = "\n", + fs=None, + preview_rows: int = 50, +) -> Dict[str, Any]: + """Generate a preview of a file's content with column type mapping. + + This method reads a file and returns a preview of its contents, along with + column information and metadata for creating tables or further processing. + + Args: + file_path: Path to the file to preview + file_type: Type of file ('excel', 'csv', 'tsv', 'delimiter_format') + import_type: Type of import ('local' or 'remote') + sql_dialect: SQL dialect for type mapping ('hive', 'impala', etc.) + has_header: Whether the file has a header row or not + sheet_name: Sheet name for Excel files (required for Excel) + field_separator: Field separator character for delimited files + quote_char: Quote character for delimited files + record_separator: Record separator for delimited files + fs: File system object for remote files (default: None) + preview_rows: Number of rows to include in preview (default: 50) + + Returns: + Dict containing: + - type: File type + - columns: List of column metadata (name, type) + - preview_data: Preview of the file data + + Raises: + ValueError: If the file does not exist or parameters are invalid + Exception: For various file processing errors + """ + if not file_path: + raise ValueError("File path cannot be empty") + + if sql_dialect.lower() not in ["hive", "impala", "trino", "phoenix", "sparksql"]: + raise ValueError(f"Unsupported SQL dialect: {sql_dialect}") + + if file_type not in ["excel", "csv", "tsv", "delimiter_format"]: + raise ValueError(f"Unsupported file type: {file_type}") + + if import_type not in ["local", "remote"]: + raise ValueError(f"Unsupported import type: {import_type}") + + if import_type == "remote" and fs is None: + raise ValueError("File system object is required for remote import type") + + # Check if file exists based on import type + if import_type == "local" and not os.path.exists(file_path): + raise ValueError(f"Local file does not exist: {file_path}") + elif import_type == "remote" and fs and not fs.exists(file_path): + raise ValueError(f"Remote file does not exist: {file_path}") + + error_occurred = False + fh = open(file_path, "rb") if import_type == "local" else fs.open(file_path, "rb") + + try: + if file_type == "excel": + if not sheet_name: + raise ValueError("Sheet name is required for Excel files.") + + preview = _preview_excel_file(fh, file_type, sheet_name, sql_dialect, has_header, preview_rows) + elif file_type in ["csv", "tsv", "delimiter_format"]: + # Process escapable characters + try: + if field_separator: + field_separator = codecs.decode(field_separator, "unicode_escape") + if quote_char: + quote_char = codecs.decode(quote_char, "unicode_escape") + if record_separator: + record_separator = codecs.decode(record_separator, "unicode_escape") + + except Exception as e: + LOG.exception(f"Error decoding escape characters: {e}", exc_info=True) + raise ValueError("Invalid escape characters in field_separator, quote_char, or record_separator.") + + preview = _preview_delimited_file(fh, file_type, field_separator, quote_char, record_separator, sql_dialect, has_header, preview_rows) + else: + raise ValueError(f"Unsupported file type: {file_type}") + + return preview + except Exception as e: + error_occurred = True + LOG.exception(f"Error previewing file: {e}", exc_info=True) + raise e + + finally: + fh.close() + if import_type == "local" and error_occurred and os.path.exists(file_path): + LOG.debug(f"Due to error in preview_file, cleaning up uploaded local file: {file_path}") + os.remove(file_path) + + +def _detect_file_type(file_sample: bytes) -> str: + """Detect the file type based on its content. + + Args: + file_sample: Binary sample of the file content + + Returns: + String indicating the detected file type ('excel', 'delimiter_format', or 'unknown') + + Raises: + RuntimeError: If python-magic or libmagic is not available + Exception: If an error occurs during file type detection + """ + # Check if magic library is available + if not is_magic_lib_available: + error = "Unable to guess file type. python-magic or its dependency libmagic is not installed." + LOG.error(error) + raise RuntimeError(error) + + try: + # Use libmagic to detect MIME type from content + file_type = magic.from_buffer(file_sample, mime=True) + + # Map MIME type to the simplified type categories + if any(keyword in file_type for keyword in ["excel", "spreadsheet", "officedocument.sheet"]): + return "excel" + elif any(keyword in file_type for keyword in ["text", "csv", "plain"]): + # For text files, analyze the content later to determine specific format + return "delimiter_format" + + LOG.info(f"Detected MIME type: {file_type}, but not recognized as supported format") + return "unknown" + except Exception as e: + message = f"Error detecting file type: {e}" + LOG.exception(message, exc_info=True) + + raise Exception(message) + + +def _get_excel_metadata(fh: BinaryIO) -> Dict[str, Any]: + """Extract metadata for Excel files (.xlsx, .xls). + + Args: + fh: File handle for the Excel file + + Returns: + Dict containing Excel metadata: + - type: 'excel' + - sheet_names: List of sheet names + + Raises: + Exception: If there's an error processing the Excel file + """ + try: + fh.seek(0) + try: + sheet_names = _get_sheet_names_xlsx(BytesIO(fh.read())) + except Exception: + LOG.warning("Failed to read Excel file for sheet names with Zip + XML parsing approach, trying next with Polars.") + + # Possibly some other format instead of .xlsx + fh.seek(0) + + sheet_names = pl.read_excel( + BytesIO(fh.read()), sheet_id=0, infer_schema_length=10000, read_options={"n_rows": 0} + ).keys() # No need to read rows for detecting sheet names + + return { + "type": "excel", + "sheet_names": list(sheet_names), + } + except Exception as e: + message = f"Error extracting Excel file metadata: {e}" + LOG.error(message, exc_info=True) + + raise Exception(message) + + +def _get_sheet_names_xlsx(fh: BinaryIO) -> List[str]: + """Quickly list sheet names from an .xlsx file handle. + + - Uses only the stdlib (zipfile + xml.etree). + - Parses only the small `xl/workbook.xml` metadata (~10-20KB). + - No full worksheet data is loaded into memory. + + Args: + fh: Binary file-like object of the workbook. + + Returns: + A list of worksheet names. + """ + with zipfile.ZipFile(fh, "r") as z, z.open("xl/workbook.xml") as f: + tree = ET.parse(f) + + # XML namespace for SpreadsheetML + ns = {"x": "http://schemas.openxmlformats.org/spreadsheetml/2006/main"} + sheets = tree.getroot().find("x:sheets", ns) + + return [s.get("name") for s in sheets] + + +def _get_delimited_metadata(file_sample: Union[bytes, str], file_type: str) -> Dict[str, Any]: + """Extract metadata for delimited files (CSV, TSV, etc.). + + Args: + file_sample: Binary or string sample of the file content + file_type: Initial file type detection ('delimiter_format') + + Returns: + Dict containing delimited file metadata: + - type: Specific file type ('csv', 'tsv', etc.) + - field_separator: Field separator character + - quote_char: Quote character + - record_separator: Record separator character + + Raises: + Exception: If there's an error processing the delimited file + """ + if isinstance(file_sample, bytes): + file_sample = file_sample.decode("utf-8", errors="replace") + + # Use CSV Sniffer to detect delimiter and other formatting + try: + dialect = csv.Sniffer().sniff(file_sample) + except Exception as sniff_error: + message = f"Failed to sniff delimited file: {sniff_error}" + LOG.exception(message) + + raise Exception(message) + + # Refine file type based on detected delimiter + if file_type == "delimiter_format": + if dialect.delimiter == ",": + file_type = "csv" + elif dialect.delimiter == "\t": + file_type = "tsv" + # Other delimiters remain as 'delimiter_format' + + return { + "type": file_type, + "field_separator": dialect.delimiter, + "quote_char": dialect.quotechar, + "record_separator": dialect.lineterminator, + } + + +def _preview_excel_file( + fh: BinaryIO, file_type: str, sheet_name: str, dialect: str, has_header: bool, preview_rows: int = 50 +) -> Dict[str, Any]: + """Preview an Excel file (.xlsx, .xls) + + Args: + fh: File handle for the Excel file + file_type: Type of file ('excel') + sheet_name: Name of the sheet to preview + dialect: SQL dialect for type mapping + has_header: Whether the file has a header row or not + preview_rows: Number of rows to include in preview (default: 50) + + Returns: + Dict containing: + - type: 'excel' + - columns: List of column metadata + - preview_data: Preview of file data + + Raises: + Exception: If there's an error processing the Excel file + """ + try: + fh.seek(0) + + df = pl.read_excel( + BytesIO(fh.read()), sheet_name=sheet_name, has_header=has_header, read_options={"n_rows": preview_rows}, infer_schema_length=10000 + ) + + # Return empty result if the df is empty + if df.height == 0: + return {"type": file_type, "columns": [], "preview_data": {}} + + schema = df.schema + preview_data = df.to_dict(as_series=False) + + # Create column metadata with SQL type mapping + columns = [] + for col in df.columns: + col_type = str(schema[col]) + sql_type = _map_polars_dtype_to_sql_type(dialect, col_type) + + columns.append({"name": col, "type": sql_type}) + + result = {"type": file_type, "columns": columns, "preview_data": preview_data} + + return result + + except Exception as e: + message = f"Error previewing Excel file: {e}" + LOG.error(message, exc_info=True) + + raise Exception(message) + + +def _preview_delimited_file( + fh: BinaryIO, + file_type: str, + field_separator: str, + quote_char: str, + record_separator: str, + dialect: str, + has_header: bool, + preview_rows: int = 50, +) -> Dict[str, Any]: + """Preview a delimited file (CSV, TSV, etc.) + + Args: + fh: File handle for the delimited file + file_type: Type of file ('csv', 'tsv', 'delimiter_format') + field_separator: Field separator character + quote_char: Quote character + record_separator: Record separator character + dialect: SQL dialect for type mapping + has_header: Whether the file has a header row or not + preview_rows: Number of rows to include in preview (default: 50) + + Returns: + Dict containing: + - type: File type + - columns: List of column metadata + - preview_data: Preview of file data + + Raises: + Exception: If there's an error processing the delimited file + """ + try: + fh.seek(0) + + df = pl.read_csv( + BytesIO(fh.read()), + separator=field_separator, + quote_char=quote_char, + eol_char="\n" if record_separator == "\r\n" else record_separator, + has_header=has_header, + infer_schema_length=10000, + n_rows=preview_rows, + ignore_errors=True, + ) + + # Return empty result if the df is empty + if df.height == 0: + return {"type": file_type, "columns": [], "preview_data": {}} + + schema = df.schema + preview_data = df.to_dict(as_series=False) + + # Create detailed column metadata with SQL type mapping + columns = [] + for col in df.columns: + col_type = str(schema[col]) + sql_type = _map_polars_dtype_to_sql_type(dialect, col_type) + + columns.append({"name": col, "type": sql_type}) + + result = {"type": file_type, "columns": columns, "preview_data": preview_data} + + return result + + except Exception as e: + message = f"Error previewing delimited file: {e}" + + LOG.error(message, exc_info=True) + raise Exception(message) + + +def guess_file_header(file_path: str, file_type: str, import_type: str, sheet_name: Optional[str] = None, fs=None) -> bool: + """Guess whether a file has a header row. + + This function analyzes a file to determine if it contains a header row based on the + content pattern. It works for both Excel files and delimited text files (CSV, TSV, etc.). + + Args: + file_path: Path to the file to analyze + file_type: Type of file ('excel', 'csv', 'tsv', 'delimiter_format') + import_type: Type of import ('local' or 'remote') + sheet_name: Sheet name for Excel files (required for Excel) + fs: File system object for remote files (default: None) + + Returns: + has_header: Boolean indicating whether the file has a header row + + Raises: + ValueError: If the file does not exist or parameters are invalid + Exception: For various file processing errors + """ + if not file_path: + raise ValueError("File path cannot be empty") + + if file_type not in ["excel", "csv", "tsv", "delimiter_format"]: + raise ValueError(f"Unsupported file type: {file_type}") + + if import_type not in ["local", "remote"]: + raise ValueError(f"Unsupported import type: {import_type}") + + if import_type == "remote" and fs is None: + raise ValueError("File system object is required for remote import type") + + # Check if file exists based on import type + if import_type == "local" and not os.path.exists(file_path): + raise ValueError(f"Local file does not exist: {file_path}") + elif import_type == "remote" and fs and not fs.exists(file_path): + raise ValueError(f"Remote file does not exist: {file_path}") + + fh = open(file_path, "rb") if import_type == "local" else fs.open(file_path, "rb") + + has_header = False + + try: + if file_type == "excel": + if not sheet_name: + raise ValueError("Sheet name is required for Excel files.") + + # Convert excel sample to CSV for header detection + try: + fh.seek(0) + + csv_snippet = pl.read_excel( + source=BytesIO(fh.read()), sheet_name=sheet_name, infer_schema_length=10000, read_options={"n_rows": 20} + ).write_csv(file=None) + + if isinstance(csv_snippet, bytes): + csv_snippet = csv_snippet.decode("utf-8", errors="replace") + + has_header = csv.Sniffer().has_header(csv_snippet) + LOG.info(f"Detected header for Excel file: {has_header}") + + except Exception as e: + message = f"Error detecting header in Excel file: {e}" + LOG.exception(message, exc_info=True) + + raise Exception(message) + + elif file_type in ["csv", "tsv", "delimiter_format"]: + try: + # Reset file position + fh.seek(0) + + # Read 16 KiB sample + sample = fh.read(16 * 1024).decode("utf-8", errors="replace") + + has_header = csv.Sniffer().has_header(sample) + LOG.info(f"Detected header for delimited file: {has_header}") + + except Exception as e: + message = f"Error detecting header in delimited file: {e}" + LOG.exception(message, exc_info=True) + + raise Exception(message) + + return has_header + + finally: + fh.close() + + +def get_sql_type_mapping(dialect: str) -> Dict[str, str]: + """Get all type mappings from Polars dtypes to SQL types for a given SQL dialect. + + This function returns a dictionary mapping of all Polars data types to their + corresponding SQL types for a specific dialect. + + Args: + dialect: One of "hive", "impala", "trino", "phoenix", "sparksql". + + Returns: + A dict mapping Polars dtype names to SQL type names. + + Raises: + ValueError: If the dialect is not supported. + """ + dl = dialect.lower() + if dl not in SQL_TYPE_DIALECT_OVERRIDES: + raise ValueError(f"Unsupported dialect: {dialect}") + + # Merge base_map and overrides[dl] into a new dict, giving precedence to any overlapping keys in overrides[dl] + return {**SQL_TYPE_BASE_MAP, **SQL_TYPE_DIALECT_OVERRIDES[dl]} + + +def _map_polars_dtype_to_sql_type(dialect: str, polars_type: str) -> str: + """Map a Polars dtype to the corresponding SQL type for a given dialect. + + Args: + dialect: One of "hive", "impala", "trino", "phoenix", "sparksql". + polars_type: Polars dtype name as string. + + Returns: + A string representing the SQL type. + + Raises: + ValueError: If the dialect or polars_type is not supported. + """ + mapping = get_sql_type_mapping(dialect) + + if polars_type not in mapping: + raise ValueError(f"No mapping for Polars dtype {polars_type} in dialect {dialect}") + + return mapping[polars_type] diff --git a/desktop/core/src/desktop/lib/importer/operations_tests.py b/desktop/core/src/desktop/lib/importer/operations_tests.py new file mode 100644 index 00000000000..49ade93168d --- /dev/null +++ b/desktop/core/src/desktop/lib/importer/operations_tests.py @@ -0,0 +1,622 @@ +#!/usr/bin/env python +# Licensed to Cloudera, Inc. under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. Cloudera, Inc. licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import zipfile +from unittest.mock import MagicMock, mock_open, patch + +import pytest +from django.core.files.uploadedfile import SimpleUploadedFile + +from desktop.lib.importer import operations + + +class TestLocalFileUpload: + @patch("uuid.uuid4") + def test_local_file_upload_success(self, mock_uuid): + # Mock uuid to get a predictable filename + mock_uuid.return_value.hex = "12345678" + + test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv") + + result = operations.local_file_upload(test_file, "test_user") + + # Get the expected file path + temp_dir = tempfile.gettempdir() + expected_path = os.path.join(temp_dir, "test_user_12345678_test_file.csv") + + try: + assert "file_path" in result + assert result["file_path"] == expected_path + + # Verify the file was created and has the right content + assert os.path.exists(expected_path) + with open(expected_path, "rb") as f: + assert f.read() == b"header1,header2\nvalue1,value2" + + finally: + # Clean up the file + if os.path.exists(expected_path): + os.remove(expected_path) + + assert not os.path.exists(expected_path), "Temporary file was not cleaned up properly" + + def test_local_file_upload_none_file(self): + with pytest.raises(ValueError, match="Upload file cannot be None or empty."): + operations.local_file_upload(None, "test_user") + + def test_local_file_upload_none_username(self): + test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv") + + with pytest.raises(ValueError, match="Username cannot be None or empty."): + operations.local_file_upload(test_file, None) + + @patch("os.path.join") + @patch("builtins.open", new_callable=mock_open) + def test_local_file_upload_exception_handling(self, mock_file_open, mock_join): + # Setup mocks to raise an exception when opening the file + mock_file_open.side_effect = IOError("Test IO Error") + mock_join.return_value = "/tmp/test_user_12345678_test_file.csv" + + test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv") + + with pytest.raises(Exception, match="Test IO Error"): + operations.local_file_upload(test_file, "test_user") + + +@pytest.mark.usefixtures("cleanup_temp_files") +class TestGuessFileMetadata: + @pytest.fixture + def cleanup_temp_files(self): + """Fixture to clean up temporary files after tests.""" + temp_files = [] + + yield temp_files + + # Clean up after test + for file_path in temp_files: + if os.path.exists(file_path): + os.remove(file_path) + + @patch("desktop.lib.importer.operations.is_magic_lib_available", True) + @patch("desktop.lib.importer.operations.magic") + def test_guess_file_metadata_csv(self, mock_magic, cleanup_temp_files): + # Create a temporary CSV file + test_content = "col1,col2,col3\nval1,val2,val3\nval4,val5,val6" + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + # Mock magic.from_buffer to return text/csv MIME type + mock_magic.from_buffer.return_value = "text/plain" + + result = operations.guess_file_metadata(temp_file.name, "local") + + assert result == { + "type": "csv", + "field_separator": ",", + "quote_char": '"', + "record_separator": "\r\n", + } + + @patch("desktop.lib.importer.operations.is_magic_lib_available", True) + @patch("desktop.lib.importer.operations.magic") + def test_guess_file_metadata_tsv(self, mock_magic, cleanup_temp_files): + # Create a temporary TSV file + test_content = "col1\tcol2\tcol3\nval1\tval2\tval3\nval4\tval5\tval6" + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + # Mock magic.from_buffer to return text/plain MIME type + mock_magic.from_buffer.return_value = "text/plain" + + result = operations.guess_file_metadata(temp_file.name, "local") + + assert result == { + "type": "tsv", + "field_separator": "\t", + "quote_char": '"', + "record_separator": "\r\n", + } + + @patch("desktop.lib.importer.operations.is_magic_lib_available", True) + @patch("desktop.lib.importer.operations.magic") + @patch("desktop.lib.importer.operations._get_sheet_names_xlsx") + def test_guess_file_metadata_excel(self, mock_get_sheet_names, mock_magic, cleanup_temp_files): + # Create a simple .xlsx file + test_content = """ + + + + + + + """ + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + # Mock magic.from_buffer to return Excel MIME type + mock_magic.from_buffer.return_value = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + + # Mock _get_sheet_names_xlsx to return sheet names + mock_get_sheet_names.return_value = ["Sheet1", "Sheet2", "Sheet3"] + + result = operations.guess_file_metadata(temp_file.name, "local") + + assert result == { + "type": "excel", + "sheet_names": ["Sheet1", "Sheet2", "Sheet3"], + } + + @patch("desktop.lib.importer.operations.is_magic_lib_available", True) + @patch("desktop.lib.importer.operations.magic") + def test_guess_file_metadata_unsupported_type(self, mock_magic, cleanup_temp_files): + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(b"Binary content") + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + # Mock magic.from_buffer to return an unsupported MIME type + mock_magic.from_buffer.return_value = "application/octet-stream" + + with pytest.raises(ValueError, match="Unable to detect file format."): + operations.guess_file_metadata(temp_file.name, "local") + + def test_guess_file_metadata_nonexistent_file(self): + file_path = "/path/to/nonexistent/file.csv" + + with pytest.raises(ValueError, match="Local file does not exist."): + operations.guess_file_metadata(file_path, "local") + + def test_guess_remote_file_metadata_no_fs(self): + with pytest.raises(ValueError, match="File system object is required for remote import type"): + operations.guess_file_metadata( + file_path="s3a://bucket/user/test_user/test.csv", # Remote file path + import_type="remote", # Remote file but no fs provided + ) + + def test_guess_file_metadata_empty_file(self, cleanup_temp_files): + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + with pytest.raises(ValueError, match="File is empty, cannot detect file format."): + operations.guess_file_metadata(temp_file.name, "local") + + @patch("desktop.lib.importer.operations.is_magic_lib_available", False) + def test_guess_file_metadata_no_magic_lib(self, cleanup_temp_files): + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(b"Content") + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + with pytest.raises(RuntimeError, match="Unable to guess file type. python-magic or its dependency libmagic is not installed."): + operations.guess_file_metadata(temp_file.name, "local") + + +@pytest.mark.usefixtures("cleanup_temp_files") +class TestPreviewFile: + @pytest.fixture + def cleanup_temp_files(self): + """Fixture to clean up temporary files after tests.""" + temp_files = [] + + yield temp_files + + # Clean up after test + for file_path in temp_files: + if os.path.exists(file_path): + os.remove(file_path) + + @patch("desktop.lib.importer.operations.pl") + def test_preview_excel_file(self, mock_pl, cleanup_temp_files): + # Minimal Excel file content (not a real Excel binary, just placeholder for test) + test_content = """ + + + + + """ + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + mock_df = MagicMock( + height=2, + columns=["col1", "col2"], + schema={"col1": "Int32", "col2": "Utf8"}, + to_dict=MagicMock(return_value={"col1": [1, 2], "col2": ["foo", "bar"]}), + ) + + mock_pl.read_excel.return_value = mock_df + + result = operations.preview_file( + file_path=temp_file.name, file_type="excel", import_type="local", sql_dialect="hive", has_header=True, sheet_name="Sheet1" + ) + + assert result == { + "type": "excel", + "columns": [ + {"name": "col1", "type": "INT"}, + {"name": "col2", "type": "STRING"}, + ], + "preview_data": {"col1": [1, 2], "col2": ["foo", "bar"]}, + } + + def test_preview_csv_file(self, cleanup_temp_files): + # Create a temporary CSV file + test_content = "col1,col2\n1.1,true\n2.2,false" + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + result = operations.preview_file( + file_path=temp_file.name, + file_type="csv", + import_type="local", + sql_dialect="hive", + has_header=True, + field_separator=",", + quote_char='"', + record_separator="\n", + ) + + assert result == { + "type": "csv", + "columns": [ + {"name": "col1", "type": "DOUBLE"}, + {"name": "col2", "type": "BOOLEAN"}, + ], + "preview_data": {"col1": [1.1, 2.2], "col2": [True, False]}, + } + + def test_preview_csv_file_with_header_false(self, cleanup_temp_files): + # Create a temporary CSV file + test_content = "sample1,sample2\n1.1,true\n2.2,false" + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + result = operations.preview_file( + file_path=temp_file.name, + file_type="csv", + import_type="local", + sql_dialect="hive", + has_header=False, # No header in this case + field_separator=",", + quote_char='"', + record_separator="\n", + ) + + assert result == { + "type": "csv", + "columns": [{"name": "column_1", "type": "STRING"}, {"name": "column_2", "type": "STRING"}], + "preview_data": {"column_1": ["sample1", "1.1", "2.2"], "column_2": ["sample2", "true", "false"]}, + } + + def test_preview_empty_csv_file(self, cleanup_temp_files): + # Create a temporary CSV file + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(b" ") + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + result = operations.preview_file( + file_path=temp_file.name, + file_type="csv", + import_type="local", + sql_dialect="hive", + has_header=True, + field_separator=",", + quote_char='"', + record_separator="\n", + ) + + assert result == { + "type": "csv", + "columns": [], + "preview_data": {}, + } + + def test_preview_invalid_file_path(self): + with pytest.raises(ValueError, match="File path cannot be empty"): + operations.preview_file(file_path="", file_type="csv", import_type="local", sql_dialect="hive", has_header=True) + + def test_preview_unsupported_file_type(self): + with pytest.raises(ValueError, match="Unsupported file type: json"): + operations.preview_file( + file_path="/path/to/test.json", + file_type="json", # Unsupported type + import_type="local", + sql_dialect="hive", + has_header=True, + ) + + def test_preview_unsupported_sql_dialect(self): + with pytest.raises(ValueError, match="Unsupported SQL dialect: mysql"): + operations.preview_file( + file_path="/path/to/test.csv", + file_type="csv", + import_type="local", + sql_dialect="mysql", # Unsupported dialect + has_header=True, + ) + + def test_preview_remote_file_no_fs(self): + with pytest.raises(ValueError, match="File system object is required for remote import type"): + operations.preview_file( + file_path="s3a://bucket/user/test_user/test.csv", # Remote file path + file_type="csv", + import_type="remote", # Remote file but no fs provided + sql_dialect="hive", + has_header=True, + ) + + @patch("os.path.exists") + def test_preview_nonexistent_local_file(self, mock_exists): + mock_exists.return_value = False + + with pytest.raises(ValueError, match="Local file does not exist: /path/to/nonexistent.csv"): + operations.preview_file( + file_path="/path/to/nonexistent.csv", file_type="csv", import_type="local", sql_dialect="hive", has_header=True + ) + + def test_preview_trino_dialect_type_mapping(self, cleanup_temp_files): + test_content = "string_col\nfoo\nbar" + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + result = operations.preview_file( + file_path=temp_file.name, + file_type="csv", + import_type="local", + sql_dialect="trino", # Trino dialect for different type mapping + has_header=True, + field_separator=",", + ) + + # Check the result for Trino-specific type mapping + assert result["columns"][0]["type"] == "VARCHAR" # Not STRING + + +@pytest.mark.usefixtures("cleanup_temp_files") +class TestGuessFileHeader: + @pytest.fixture + def cleanup_temp_files(self): + """Fixture to clean up temporary files after tests.""" + temp_files = [] + + yield temp_files + + # Clean up after test + for file_path in temp_files: + if os.path.exists(file_path): + os.remove(file_path) + + def test_guess_header_csv(self, cleanup_temp_files): + test_content = "header1,header2\nvalue1,value2\nvalue3,value4" + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + result = operations.guess_file_header(file_path=temp_file.name, file_type="csv", import_type="local") + + assert result + + @patch("desktop.lib.importer.operations.pl") + @patch("desktop.lib.importer.operations.csv.Sniffer") + def test_guess_header_excel(self, mock_sniffer, mock_pl, cleanup_temp_files): + test_content = """ + + + + + """ + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + # Mock polars read_excel and CSV conversion + mock_pl.read_excel.return_value.write_csv.return_value = "header1,header2\nvalue1,value2\nvalue3,value4" + + # Mock csv.Sniffer + mock_sniffer_instance = MagicMock() + mock_sniffer_instance.has_header.return_value = True + mock_sniffer.return_value = mock_sniffer_instance + + result = operations.guess_file_header(file_path=temp_file.name, file_type="excel", import_type="local", sheet_name="Sheet1") + + assert result + + def test_guess_header_excel_no_sheet_name(self, cleanup_temp_files): + test_content = """ + + """ + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") + temp_file.write(test_content.encode("utf-8")) + temp_file.close() + + cleanup_temp_files.append(temp_file.name) + + with pytest.raises(ValueError, match="Sheet name is required for Excel files"): + operations.guess_file_header( + file_path=temp_file.name, + file_type="excel", + import_type="local", + # Missing sheet_name + ) + + def test_guess_header_invalid_path(self): + with pytest.raises(ValueError, match="File path cannot be empty"): + operations.guess_file_header(file_path="", file_type="csv", import_type="local") + + def test_guess_header_unsupported_file_type(self): + with pytest.raises(ValueError, match="Unsupported file type: json"): + operations.guess_file_header( + file_path="/path/to/test.json", + file_type="json", # Unsupported type + import_type="local", + ) + + def test_guess_header_nonexistent_local_file(self): + with pytest.raises(ValueError, match="Local file does not exist"): + operations.guess_file_header(file_path="/path/to/nonexistent.csv", file_type="csv", import_type="local") + + def test_guess_header_remote_file_no_fs(self): + with pytest.raises(ValueError, match="File system object is required for remote import type"): + operations.guess_file_header( + file_path="hdfs:///path/to/test.csv", + file_type="csv", + import_type="remote", # Remote but no fs provided + ) + + +class TestSqlTypeMapping: + def test_get_sql_type_mapping_hive(self): + mappings = operations.get_sql_type_mapping("hive") + + # Check some key mappings for Hive + assert mappings["Int32"] == "INT" + assert mappings["Utf8"] == "STRING" + assert mappings["Float64"] == "DOUBLE" + assert mappings["Boolean"] == "BOOLEAN" + assert mappings["Decimal"] == "DECIMAL" + + def test_get_sql_type_mapping_trino(self): + mappings = operations.get_sql_type_mapping("trino") + + # Check some key mappings for Trino that differ from Hive + assert mappings["Int32"] == "INTEGER" + assert mappings["Utf8"] == "VARCHAR" + assert mappings["Binary"] == "VARBINARY" + assert mappings["Float32"] == "REAL" + assert mappings["Struct"] == "ROW" + assert mappings["Object"] == "JSON" + + def test_get_sql_type_mapping_phoenix(self): + mappings = operations.get_sql_type_mapping("phoenix") + + # Check some key mappings for Phoenix + assert mappings["UInt32"] == "UNSIGNED_INT" + assert mappings["Utf8"] == "VARCHAR" + assert mappings["Time"] == "TIME" + assert mappings["Struct"] == "STRING" # Phoenix treats structs as strings + assert mappings["Duration"] == "STRING" # Phoenix treats durations as strings + + def test_get_sql_type_mapping_impala(self): + result = operations.get_sql_type_mapping("impala") + + # Impala uses the base mappings, so check those + assert result["Int32"] == "INT" + assert result["Int64"] == "BIGINT" + assert result["Float64"] == "DOUBLE" + assert result["Utf8"] == "STRING" + + def test_get_sql_type_mapping_sparksql(self): + result = operations.get_sql_type_mapping("sparksql") + + # SparkSQL uses the base mappings, so check those + assert result["Int32"] == "INT" + assert result["Int64"] == "BIGINT" + assert result["Float64"] == "DOUBLE" + assert result["Utf8"] == "STRING" + + def test_get_sql_type_mapping_unsupported_dialect(self): + with pytest.raises(ValueError, match="Unsupported dialect: mysql"): + operations.get_sql_type_mapping("mysql") + + def test_map_polars_dtype_to_sql_type(self): + # Test with Hive dialect + assert operations._map_polars_dtype_to_sql_type("hive", "Int64") == "BIGINT" + assert operations._map_polars_dtype_to_sql_type("hive", "Float32") == "FLOAT" + + # Test with Trino dialect + assert operations._map_polars_dtype_to_sql_type("trino", "Int64") == "BIGINT" + assert operations._map_polars_dtype_to_sql_type("trino", "Float32") == "REAL" + + # Test unsupported type + with pytest.raises(ValueError, match="No mapping for Polars dtype"): + operations._map_polars_dtype_to_sql_type("hive", "NonExistentType") + + +@pytest.mark.usefixtures("cleanup_temp_files") +class TestExcelSheetNames: + @pytest.fixture + def cleanup_temp_files(self): + """Fixture to clean up temporary files after tests.""" + temp_files = [] + + yield temp_files + + # Clean up after test + for file_path in temp_files: + if os.path.exists(file_path): + os.remove(file_path) + + def test_get_sheet_names_xlsx(self, cleanup_temp_files): + # Create a temporary Excel file with minimal XML structure + with tempfile.NamedTemporaryFile(suffix=".xlsx") as tmp_file: + # Create a minimal XLSX with workbook.xml + with zipfile.ZipFile(tmp_file.name, "w") as zip_file: + # Create a minimal workbook.xml file + workbook_xml = """ + + + + + + + """ + zip_file.writestr("xl/workbook.xml", workbook_xml) + + cleanup_temp_files.append(tmp_file.name) + + # Test the function + with open(tmp_file.name, "rb") as f: + sheet_names = operations._get_sheet_names_xlsx(f) + + assert sheet_names == ["Sheet1", "Sheet2", "CustomSheet"] diff --git a/desktop/core/src/desktop/lib/importer/serializers.py b/desktop/core/src/desktop/lib/importer/serializers.py new file mode 100644 index 00000000000..7f20efb9ee1 --- /dev/null +++ b/desktop/core/src/desktop/lib/importer/serializers.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# Licensed to Cloudera, Inc. under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. Cloudera, Inc. licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from rest_framework import serializers + +from desktop.conf import IMPORTER + + +class LocalFileUploadSerializer(serializers.Serializer): + """Serializer for file upload validation. + + This serializer validates that the uploaded file is present and has an + acceptable file format and size. + + Attributes: + file: File field that must be included in the request + """ + + file = serializers.FileField(required=True, help_text="CSV or Excel file to upload and process") + + def validate_file(self, value): + # Check if the file type is restricted + _, file_type = os.path.splitext(value.name) + restricted_extensions = IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.get() + if restricted_extensions and file_type.lower() in [ext.lower() for ext in restricted_extensions]: + raise serializers.ValidationError(f'Uploading files with type "{file_type}" is not allowed. Hue is configured to restrict this type.') + + # Check file size + max_size = IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.get() + if value.size > max_size: + max_size_mib = max_size / (1024 * 1024) + raise serializers.ValidationError(f"File too large. Maximum file size is {max_size_mib:.0f} MiB.") + + return value + + +class GuessFileMetadataSerializer(serializers.Serializer): + """Serializer for file metadata guessing request validation. + + This serializer validates the parameters required for guessing metadata from a file. + + Attributes: + file_path: Path to the file + import_type: Type of import (local or remote) + """ + + file_path = serializers.CharField(required=True, help_text="Full path to the file to analyze") + import_type = serializers.ChoiceField( + choices=["local", "remote"], required=True, help_text="Whether the file is local or on a remote filesystem" + ) + + +class PreviewFileSerializer(serializers.Serializer): + """Serializer for file preview request validation. + + This serializer validates the parameters required for previewing file content. + + Attributes: + file_path: Path to the file to preview + file_type: Type of file format (csv, tsv, excel, delimiter_format) + import_type: Type of import (local or remote) + sql_dialect: Target SQL dialect for type mapping + has_header: Whether the file has a header row or not + sheet_name: Sheet name for Excel files (required when file_type is excel) + field_separator: Field separator character (required for delimited files) + quote_char: Quote character (required for delimited files) + record_separator: Record separator character (required for delimited files) + """ + + file_path = serializers.CharField(required=True, help_text="Full path to the file to preview") + file_type = serializers.ChoiceField( + choices=["csv", "tsv", "excel", "delimiter_format"], required=True, help_text="Type of file (csv, tsv, excel, delimiter_format)" + ) + import_type = serializers.ChoiceField( + choices=["local", "remote"], required=True, help_text="Whether the file is local or on a remote filesystem" + ) + sql_dialect = serializers.ChoiceField( + choices=["hive", "impala", "trino", "phoenix", "sparksql"], required=True, help_text="SQL dialect for mapping column types" + ) + + has_header = serializers.BooleanField(required=True, help_text="Whether the file has a header row or not") + + # Excel-specific fields + sheet_name = serializers.CharField(required=False, help_text="Sheet name for Excel files") + + # Delimited file-specific fields + field_separator = serializers.CharField(required=False, help_text="Field separator character") + quote_char = serializers.CharField(required=False, help_text="Quote character") + record_separator = serializers.CharField(required=False, help_text="Record separator character") + + def validate(self, data): + """Validate the complete data set with interdependent field validation.""" + + if data.get("file_type") == "excel" and not data.get("sheet_name"): + raise serializers.ValidationError({"sheet_name": "Sheet name is required for Excel files."}) + + if data.get("file_type") in ["csv", "tsv", "delimiter_format"]: + if not data.get("field_separator"): + # If not provided, set default value based on file type + if data.get("file_type") == "csv": + data["field_separator"] = "," + elif data.get("file_type") == "tsv": + data["field_separator"] = "\t" + else: + raise serializers.ValidationError({"field_separator": "Field separator is required for delimited files"}) + + if not data.get("quote_char"): + data["quote_char"] = '"' # Default quote character + + if not data.get("record_separator"): + data["record_separator"] = "\n" # Default record separator + + return data + + +class SqlTypeMapperSerializer(serializers.Serializer): + """Serializer for SQL type mapping requests. + + This serializer validates the parameters required for retrieving type mapping information + from Polars data types to SQL types for a specific dialect. + + Attributes: + sql_dialect: Target SQL dialect for type mapping + """ + + sql_dialect = serializers.ChoiceField( + choices=["hive", "impala", "trino", "phoenix", "sparksql"], required=True, help_text="SQL dialect for mapping column types" + ) + + +class GuessFileHeaderSerializer(serializers.Serializer): + """Serializer for file header guessing request validation. + + This serializer validates the parameters required for guessing if a file has a header row. + + Attributes: + file_path: Path to the file to analyze + file_type: Type of file format (csv, tsv, excel, delimiter_format) + import_type: Type of import (local or remote) + sheet_name: Sheet name for Excel files (required when file_type is excel) + """ + + file_path = serializers.CharField(required=True, help_text="Full path to the file to analyze") + file_type = serializers.ChoiceField( + choices=["csv", "tsv", "excel", "delimiter_format"], required=True, help_text="Type of file (csv, tsv, excel, delimiter_format)" + ) + import_type = serializers.ChoiceField( + choices=["local", "remote"], required=True, help_text="Whether the file is local or on a remote filesystem" + ) + + # Excel-specific fields + sheet_name = serializers.CharField(required=False, help_text="Sheet name for Excel files") + + def validate(self, data): + """Validate the complete data set with interdependent field validation.""" + + if data.get("file_type") == "excel" and not data.get("sheet_name"): + raise serializers.ValidationError({"sheet_name": "Sheet name is required for Excel files."}) + + return data diff --git a/desktop/core/src/desktop/lib/importer/serializers_tests.py b/desktop/core/src/desktop/lib/importer/serializers_tests.py new file mode 100644 index 00000000000..529b8f1b7a1 --- /dev/null +++ b/desktop/core/src/desktop/lib/importer/serializers_tests.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python +# Licensed to Cloudera, Inc. under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. Cloudera, Inc. licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from django.core.files.uploadedfile import SimpleUploadedFile + +from desktop.conf import IMPORTER +from desktop.lib.importer.serializers import ( + GuessFileHeaderSerializer, + GuessFileMetadataSerializer, + LocalFileUploadSerializer, + PreviewFileSerializer, + SqlTypeMapperSerializer, +) + + +class TestLocalFileUploadSerializer: + def test_valid_file(self): + resets = [ + IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.set_for_testing([]), + IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.set_for_testing(10 * 1024 * 1024), # 10 MiB limit + ] + try: + test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv") + + serializer = LocalFileUploadSerializer(data={"file": test_file}) + + assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}" + assert serializer.validated_data["file"] == test_file + finally: + for reset in resets: + reset() + + def test_restricted_file_extension(self): + resets = [ + IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.set_for_testing([".sh", ".csv", ".exe"]), + IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.set_for_testing(10 * 1024 * 1024), # 10 MiB limit + ] + try: + test_file = SimpleUploadedFile( + name="test_file.exe", content=b"This is not a real executable", content_type="application/octet-stream" + ) + + serializer = LocalFileUploadSerializer(data={"file": test_file}) + + assert not serializer.is_valid() + assert "file" in serializer.errors + assert serializer.errors["file"][0] == 'Uploading files with type ".exe" is not allowed. Hue is configured to restrict this type.' + finally: + for reset in resets: + reset() + + def test_file_size_too_large(self): + resets = [ + IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.set_for_testing([]), + IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.set_for_testing(10), # Just 10 bytes + ] + try: + test_file = SimpleUploadedFile( + name="test_file.csv", content=b"This content is more than 10 bytes which exceeds our mock size limit", content_type="text/csv" + ) + + serializer = LocalFileUploadSerializer(data={"file": test_file}) + + assert not serializer.is_valid() + assert "file" in serializer.errors + assert serializer.errors["file"][0] == "File too large. Maximum file size is 0 MiB." # 10 bytes is very less than 1 MiB + finally: + for reset in resets: + reset() + + def test_missing_file(self): + serializer = LocalFileUploadSerializer(data={}) + + assert not serializer.is_valid() + assert "file" in serializer.errors + assert serializer.errors["file"][0] == "No file was submitted." + + +class TestGuessFileMetadataSerializer: + def test_valid_data(self): + # Test with local import type + local_valid_data = {"file_path": "/path/to/file.csv", "import_type": "local"} + + serializer = GuessFileMetadataSerializer(data=local_valid_data) + + assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}" + assert serializer.validated_data == local_valid_data + + # Test with remote import type + remote_valid_data = {"file_path": "s3a://bucket/user/test_user/file.csv", "import_type": "remote"} + + serializer = GuessFileMetadataSerializer(data=remote_valid_data) + + assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}" + assert serializer.validated_data == remote_valid_data + + def test_missing_required_fields(self): + # Test missing file_path + invalid_data = {"import_type": "local"} + + serializer = GuessFileMetadataSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "file_path" in serializer.errors + + # Test missing import_type + invalid_data = {"file_path": "/path/to/file.csv"} + + serializer = GuessFileMetadataSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "import_type" in serializer.errors + + def test_invalid_import_type(self): + invalid_data = { + "file_path": "/path/to/file.csv", + "import_type": "invalid_type", # Not one of 'local' or 'remote' + } + + serializer = GuessFileMetadataSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "import_type" in serializer.errors + assert serializer.errors["import_type"][0] == '"invalid_type" is not a valid choice.' + + +class TestPreviewFileSerializer: + def test_valid_csv_data(self): + valid_data = { + "file_path": "/path/to/file.csv", + "file_type": "csv", + "import_type": "local", + "sql_dialect": "hive", + "has_header": True, + "field_separator": ",", + } + + serializer = PreviewFileSerializer(data=valid_data) + + assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}" + # Check that default values are set for quote_char and record_separator + assert serializer.validated_data["quote_char"] == '"' + assert serializer.validated_data["record_separator"] == "\n" + + def test_valid_excel_data(self): + valid_data = { + "file_path": "/path/to/file.xlsx", + "file_type": "excel", + "import_type": "local", + "sql_dialect": "hive", + "has_header": True, + "sheet_name": "Sheet1", + } + + serializer = PreviewFileSerializer(data=valid_data) + + assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}" + + def test_missing_required_fields(self): + # Test with minimal data + invalid_data = { + "file_path": "/path/to/file.csv", + } + + serializer = PreviewFileSerializer(data=invalid_data) + + assert not serializer.is_valid() + # Check that all required fields are reported as missing + for field in ["file_type", "import_type", "sql_dialect", "has_header"]: + assert field in serializer.errors + + def test_invalid_file_type(self): + invalid_data = { + "file_path": "/path/to/file.json", + "file_type": "json", # Not a valid choice + "import_type": "local", + "sql_dialect": "hive", + "has_header": True, + } + + serializer = PreviewFileSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "file_type" in serializer.errors + assert serializer.errors["file_type"][0] == '"json" is not a valid choice.' + + def test_excel_without_sheet_name(self): + invalid_data = { + "file_path": "/path/to/file.xlsx", + "file_type": "excel", + "import_type": "local", + "sql_dialect": "hive", + "has_header": True, + # Missing sheet_name + } + + serializer = PreviewFileSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "sheet_name" in serializer.errors + assert serializer.errors["sheet_name"][0] == "Sheet name is required for Excel files." + + def test_delimited_without_field_separator(self): + # For delimiter_format type (not csv/tsv) field separator is required + invalid_data = { + "file_path": "/path/to/file.txt", + "file_type": "delimiter_format", + "import_type": "local", + "sql_dialect": "hive", + "has_header": True, + # Missing field_separator + } + + serializer = PreviewFileSerializer(data=invalid_data) + assert not serializer.is_valid() + assert "field_separator" in serializer.errors + + def test_default_separators_by_file_type(self): + # For CSV, field_separator should default to ',' + csv_data = {"file_path": "/path/to/file.csv", "file_type": "csv", "import_type": "local", "sql_dialect": "hive", "has_header": True} + + serializer = PreviewFileSerializer(data=csv_data) + + assert serializer.is_valid(), f"CSV serializer validation failed: {serializer.errors}" + assert serializer.validated_data["field_separator"] == "," + + # For TSV, field_separator should default to '\t' + tsv_data = {"file_path": "/path/to/file.tsv", "file_type": "tsv", "import_type": "local", "sql_dialect": "hive", "has_header": True} + + serializer = PreviewFileSerializer(data=tsv_data) + + assert serializer.is_valid(), f"TSV serializer validation failed: {serializer.errors}" + assert serializer.validated_data["field_separator"] == "\t" + + +class TestSqlTypeMapperSerializer: + def test_valid_sql_dialect(self): + for dialect in ["hive", "impala", "trino", "phoenix", "sparksql"]: + valid_data = {"sql_dialect": dialect} + + serializer = SqlTypeMapperSerializer(data=valid_data) + + assert serializer.is_valid(), f"Failed for dialect '{dialect}': {serializer.errors}" + assert serializer.validated_data["sql_dialect"] == dialect + + def test_invalid_sql_dialect(self): + invalid_data = {"sql_dialect": "invalid_dialect"} + + serializer = SqlTypeMapperSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "sql_dialect" in serializer.errors + assert serializer.errors["sql_dialect"][0] == '"invalid_dialect" is not a valid choice.' + + def test_missing_sql_dialect(self): + invalid_data = {} # Empty data + + serializer = SqlTypeMapperSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "sql_dialect" in serializer.errors + assert serializer.errors["sql_dialect"][0] == "This field is required." + + +class TestGuessFileHeaderSerializer: + def test_valid_data_csv(self): + valid_data = {"file_path": "/path/to/file.csv", "file_type": "csv", "import_type": "local"} + + serializer = GuessFileHeaderSerializer(data=valid_data) + + assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}" + assert serializer.validated_data == valid_data + + def test_valid_data_excel(self): + valid_data = {"file_path": "/path/to/file.xlsx", "file_type": "excel", "import_type": "local", "sheet_name": "Sheet1"} + + serializer = GuessFileHeaderSerializer(data=valid_data) + + assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}" + assert serializer.validated_data == valid_data + + def test_missing_required_fields(self): + # Missing file_path + invalid_data = {"file_type": "csv", "import_type": "local"} + + serializer = GuessFileHeaderSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "file_path" in serializer.errors + + # Missing file_type + invalid_data = {"file_path": "/path/to/file.csv", "import_type": "local"} + + serializer = GuessFileHeaderSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "file_type" in serializer.errors + + # Missing import_type + invalid_data = {"file_path": "/path/to/file.csv", "file_type": "csv"} + + serializer = GuessFileHeaderSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "import_type" in serializer.errors + + def test_excel_without_sheet_name(self): + invalid_data = { + "file_path": "/path/to/file.xlsx", + "file_type": "excel", + "import_type": "local", + # Missing sheet_name + } + + serializer = GuessFileHeaderSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "sheet_name" in serializer.errors + assert serializer.errors["sheet_name"][0] == "Sheet name is required for Excel files." + + def test_non_excel_with_sheet_name(self): + # This should pass, as sheet_name is optional for non-Excel files + valid_data = { + "file_path": "/path/to/file.csv", + "file_type": "csv", + "import_type": "local", + "sheet_name": "Sheet1", # Unnecessary but not invalid + } + + serializer = GuessFileHeaderSerializer(data=valid_data) + + assert serializer.is_valid(), f"Serializer validation failed: {serializer.errors}" + + def test_invalid_file_type(self): + invalid_data = { + "file_path": "/path/to/file.json", + "file_type": "json", # Not a valid choice + "import_type": "local", + } + + serializer = GuessFileHeaderSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "file_type" in serializer.errors + assert serializer.errors["file_type"][0] == '"json" is not a valid choice.' + + def test_invalid_import_type(self): + invalid_data = { + "file_path": "/path/to/file.csv", + "file_type": "csv", + "import_type": "invalid_type", # Not a valid choice + } + + serializer = GuessFileHeaderSerializer(data=invalid_data) + + assert not serializer.is_valid() + assert "import_type" in serializer.errors + assert serializer.errors["import_type"][0] == '"invalid_type" is not a valid choice.'