Skip to content

feat(query): add time_zone param #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ jobs:
run: |
export ES_URI="http://localhost:9200"
export ES_PORT=9200
export ES_SUPPORT_DATETIME_PARSE=False
nosetests -v --with-coverage --cover-package=es es.tests
- name: Run tests on Elasticsearch 7.10.X
run: |
Expand All @@ -97,6 +98,7 @@ jobs:
export ES_PORT=19200
export ES_SCHEME=https
export ES_USER=admin
export ES_SUPPORT_DATETIME_PARSE=False
nosetests -v --with-coverage --cover-package=es es.tests
- name: Run tests on Opendistro 13
run: |
Expand All @@ -107,6 +109,7 @@ jobs:
export ES_SCHEME=https
export ES_USER=admin
export ES_V2=True
export ES_SUPPORT_DATETIME_PARSE=False
nosetests -v --with-coverage --cover-package=es es.tests
- name: Upload code coverage
run: |
Expand Down
11 changes: 7 additions & 4 deletions es/baseapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_description_from_columns(

class BaseConnection(object):

"""Connection to an ES Cluster """
"""Connection to an ES Cluster"""

def __init__(
self,
Expand Down Expand Up @@ -192,6 +192,7 @@ def __init__(self, url: str, es: Elasticsearch, **kwargs):
self.es = es
self.sql_path = kwargs.get("sql_path", DEFAULT_SQL_PATH)
self.fetch_size = kwargs.get("fetch_size", DEFAULT_FETCH_SIZE)
self.time_zone: Optional[str] = kwargs.get("time_zone")
# This read/write attribute specifies the number of rows to fetch at a
# time with .fetchmany(). It defaults to 1 meaning to fetch a single
# row at a time.
Expand All @@ -218,7 +219,7 @@ def custom_sql_to_method_dispatcher(self, command: str) -> Optional["BaseCursor"
@check_result
@check_closed
def rowcount(self) -> int:
""" Counts the number of rows on a result """
"""Counts the number of rows on a result"""
if self._results:
return len(self._results)
return 0
Expand All @@ -230,7 +231,7 @@ def close(self) -> None:

@check_closed
def execute(self, operation, parameters=None) -> "BaseCursor":
""" Children must implement their own custom execute """
"""Children must implement their own custom execute"""
raise NotImplementedError # pragma: no cover

@check_closed
Expand Down Expand Up @@ -311,11 +312,13 @@ def elastic_query(self, query: str) -> Dict[str, Any]:
payload = {"query": query}
if self.fetch_size is not None:
payload["fetch_size"] = self.fetch_size
if self.time_zone is not None:
payload["time_zone"] = self.time_zone
path = f"/{self.sql_path}/"
try:
response = self.es.transport.perform_request("POST", path, body=payload)
except es_exceptions.ConnectionError:
raise exceptions.OperationalError(f"Error connecting to Elasticsearch")
raise exceptions.OperationalError("Error connecting to Elasticsearch")
except es_exceptions.RequestError as ex:
raise exceptions.ProgrammingError(
f"Error ({ex.error}): {ex.info['error']['reason']}"
Expand Down
2 changes: 1 addition & 1 deletion es/elastic/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def connect(

class Connection(BaseConnection):

"""Connection to an ES Cluster """
"""Connection to an ES Cluster"""

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion es/opendistro/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def connect(

class Connection(BaseConnection):

"""Connection to an ES Cluster """
"""Connection to an ES Cluster"""

def __init__(
self,
Expand Down
90 changes: 78 additions & 12 deletions es/tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,35 @@
from es.opendistro.api import connect as open_connect


def convert_bool(value: str) -> bool:
return True if value == "True" else False


class TestDBAPI(unittest.TestCase):
def setUp(self):
self.driver_name = os.environ.get("ES_DRIVER", "elasticsearch")
host = os.environ.get("ES_HOST", "localhost")
port = int(os.environ.get("ES_PORT", 9200))
scheme = os.environ.get("ES_SCHEME", "http")
verify_certs = os.environ.get("ES_VERIFY_CERTS", False)
user = os.environ.get("ES_USER", None)
password = os.environ.get("ES_PASSWORD", None)
self.host = os.environ.get("ES_HOST", "localhost")
self.port = int(os.environ.get("ES_PORT", 9200))
self.scheme = os.environ.get("ES_SCHEME", "http")
self.verify_certs = os.environ.get("ES_VERIFY_CERTS", False)
self.user = os.environ.get("ES_USER", None)
self.password = os.environ.get("ES_PASSWORD", None)
self.v2 = bool(os.environ.get("ES_V2", False))
self.support_datetime_parse = convert_bool(
os.environ.get("ES_SUPPORT_DATETIME_PARSE", "True")
)

if self.driver_name == "elasticsearch":
self.connect_func = elastic_connect
else:
self.connect_func = open_connect
self.conn = self.connect_func(
host=host,
port=port,
scheme=scheme,
verify_certs=verify_certs,
user=user,
password=password,
host=self.host,
port=self.port,
scheme=self.scheme,
verify_certs=self.verify_certs,
user=self.user,
password=self.password,
v2=self.v2,
)
self.cursor = self.conn.cursor()
Expand Down Expand Up @@ -213,3 +220,62 @@ def test_https(self, mock_elasticsearch):
mock_elasticsearch.assert_called_once_with(
"https://localhost:9200/", http_auth=("user", "password")
)

def test_simple_search_with_time_zone(self):
"""
DBAPI: Test simple search with time zone
UTC -> CST
2019-10-13T00:00:00.000Z => 2019-10-13T08:00:00.000+08:00
2019-10-13T00:00:01.000Z => 2019-10-13T08:01:00.000+08:00
2019-10-13T00:00:02.000Z => 2019-10-13T08:02:00.000+08:00
"""

if not self.support_datetime_parse:
return

conn = self.connect_func(
host=self.host,
port=self.port,
scheme=self.scheme,
verify_certs=self.verify_certs,
user=self.user,
password=self.password,
v2=self.v2,
time_zone="Asia/Shanghai",
)
cursor = conn.cursor()
pattern = "yyyy-MM-dd HH:mm:ss"
sql = f"""
SELECT timestamp FROM data1
WHERE timestamp >= DATETIME_PARSE('2019-10-13 00:08:00', '{pattern}')
"""

rows = cursor.execute(sql).fetchall()
self.assertEqual(len(rows), 3)

def test_simple_search_without_time_zone(self):
"""
DBAPI: Test simple search without time zone
"""

if not self.support_datetime_parse:
return

conn = self.connect_func(
host=self.host,
port=self.port,
scheme=self.scheme,
verify_certs=self.verify_certs,
user=self.user,
password=self.password,
v2=self.v2,
)
cursor = conn.cursor()
pattern = "yyyy-MM-dd HH:mm:ss"
sql = f"""
SELECT * FROM data1
WHERE timestamp >= DATETIME_PARSE('2019-10-13 08:00:00', '{pattern}')
"""

rows = cursor.execute(sql).fetchall()
self.assertEqual(len(rows), 0)