Skip to content

Commit b64262f

Browse files
wip
1 parent 7b84aee commit b64262f

File tree

2 files changed

+169
-43
lines changed

2 files changed

+169
-43
lines changed

dbldatagen/data_analyzer.py

+121-12
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,18 @@
1212
1313
"""
1414
import logging
15-
from collections import namedtuple
1615
import pprint
16+
from collections import namedtuple
1717

1818
import numpy as np
19-
19+
import pyspark.sql.functions as F
20+
from pyspark import sql
2021
from pyspark.sql.types import LongType, FloatType, IntegerType, StringType, DoubleType, BooleanType, ShortType, \
2122
TimestampType, DateType, DecimalType, ByteType, BinaryType, StructType, ArrayType, DataType, MapType
2223

23-
from pyspark import sql
24-
import pyspark.sql.functions as F
25-
26-
from .utils import strip_margins, json_value_from_path
27-
from .spark_singleton import SparkSingleton
2824
from .html_utils import HtmlUtils
25+
from .spark_singleton import SparkSingleton
26+
from .utils import strip_margins, json_value_from_path
2927

3028

3129
class DataAnalyzer:
@@ -148,12 +146,9 @@ def _addMeasureToSummary(self, measureName, summaryExpr="''", fieldExprs=None, d
148146
# add measures for fields
149147
exprs.extend(fieldExprs)
150148

151-
if dfSummary is not None:
152-
dfResult = dfSummary.union(dfData.selectExpr(*exprs).limit(rowLimit))
153-
else:
154-
dfResult = dfData.selectExpr(*exprs).limit(rowLimit)
149+
dfMeasure = dfData.selectExpr(*exprs).limit(rowLimit) if rowLimit is not None else dfData.selectExpr(*exprs)
155150

156-
return dfResult
151+
return dfSummary.union(dfMeasure) if dfSummary is not None else dfMeasure
157152

158153
@staticmethod
159154
def _is_numeric_type(dtype):
@@ -223,6 +218,112 @@ def _compute_pattern_match_clauses(self):
223218
result = stmts # "\n".join(stmts)
224219
return result
225220

221+
def generateTextFeatures(self, sourceDf):
222+
""" Generate text features from source dataframe
223+
224+
Generates set of text features for each column (analyzing string representation of each column value)
225+
226+
:param sourceDf: Source datafame
227+
:return: Dataframe of text features
228+
"""
229+
# generate named struct of text features for each column
230+
231+
# we need to double escape backslashes in regular expressions as they will be lost in string expansion
232+
WORD_REGEX = r"\\b\\w+\\b"
233+
SPACE_REGEX = r"\\s+"
234+
DIGIT_REGEX = r"\\d"
235+
PUNCTUATION_REGEX = r"[\\?\\.\\;\\,\\!\\{\\}\\[\\]\\(\\)\\>\\<]"
236+
AT_REGEX = r"\\@"
237+
PERIOD_REGEX = r"\\."
238+
HTTP_REGEX = r"^http[s]?\\:\\/\\/"
239+
ALPHA_REGEX = r"[a-zA-Z]"
240+
ALPHA_UPPER_REGEX = r"[A-Z]"
241+
ALPHA_LOWER_REGEX = r"[a-z]"
242+
HEX_REGEX = r"[0-9a-fA-F]"
243+
244+
# for each column, extract text features from string representation of column value (leftmost 4096 characters)
245+
def left4k(name):
246+
return f"left(string({name}), 4096)"
247+
248+
fieldTextFeatures = []
249+
250+
for colInfo in self.columnsInfo:
251+
fieldTextFeatures.append(
252+
strip_margins(
253+
f"""named_struct(
254+
| 'print_len', length(string({colInfo.name})),
255+
| 'word_count', size(regexp_extract_all({left4k(colInfo.name)}, '{WORD_REGEX}',0)),
256+
| 'space_count', size(regexp_extract_all({left4k(colInfo.name)}, '{SPACE_REGEX}',0)),
257+
| 'digit_count', size(regexp_extract_all({left4k(colInfo.name)}, '{DIGIT_REGEX}',0)),
258+
| 'punctuation_count', size(regexp_extract_all({left4k(colInfo.name)}, '{PUNCTUATION_REGEX}',0)),
259+
| 'at_count', size(regexp_extract_all({left4k(colInfo.name)}, '{AT_REGEX}',0)),
260+
| 'period_count', size(regexp_extract_all({left4k(colInfo.name)}, '{PERIOD_REGEX}',0)),
261+
| 'http_count', size(regexp_extract_all({left4k(colInfo.name)}, '{HTTP_REGEX}',0)),
262+
| 'alpha_count', size(regexp_extract_all({left4k(colInfo.name)}, '{ALPHA_REGEX}',0)),
263+
| 'alpha_lower_count', size(regexp_extract_all({left4k(colInfo.name)}, '{ALPHA_LOWER_REGEX}',0)),
264+
| 'alpha_upper_count', size(regexp_extract_all({left4k(colInfo.name)}, '{ALPHA_UPPER_REGEX}',0)),
265+
| 'hex_digit_count', size(regexp_extract_all({left4k(colInfo.name)}, '{HEX_REGEX}',0))
266+
| )
267+
| as {colInfo.name}""", marginChar="|")
268+
)
269+
270+
dfTextFeatures = self._addMeasureToSummary(
271+
'text_features',
272+
fieldExprs=fieldTextFeatures,
273+
dfData=sourceDf,
274+
dfSummary=None,
275+
rowLimit=None)
276+
277+
return dfTextFeatures
278+
279+
def _summarizeTextFeatures(self, textFeaturesDf):
280+
"""
281+
Generate summary of text features
282+
283+
:param textFeaturesDf: Text features dataframe
284+
:return: dataframe of summary text features
285+
"""
286+
assert textFeaturesDf is not None, "textFeaturesDf must be specified"
287+
288+
# generate named struct of summary text features for each column
289+
fieldTextFeatures = []
290+
291+
# TODO: use json syntax asin:print_len when migrating to Spark 10.4LTS as minimum version
292+
293+
for colInfo in self.columnsInfo:
294+
cname = colInfo.name
295+
fieldTextFeatures.append(strip_margins(
296+
f"""to_json(named_struct(
297+
| 'print_len', array(min({cname}.print_len), max({cname}.print_len), avg({cname}.print_len)),
298+
| 'word_count', array(min({cname}.word_count), max({cname}.word_count), avg({cname}.word_count)),
299+
| 'space_count',array(min({cname}.space_count), max({cname}.space_count), avg({cname}.space_count)),
300+
| 'digit_count', array(min({cname}.digit_count), max({cname}.digit_count), avg({cname}.digit_count)),
301+
| 'punctuation_count', array(min({cname}.punctuation_count), max({cname}.punctuation_count),
302+
| avg({cname}.punctuation_count)),
303+
| 'at_count', array(min({cname}.at_count), max({cname}.at_count), avg({cname}.at_count)),
304+
| 'period_count', array(min({cname}.period_count), max({cname}.period_count),
305+
| avg({cname}.period_count)),
306+
| 'http_count', array(min({cname}.http_count), max({cname}.http_count), avg({cname}.http_count)),
307+
| 'alpha_count', array(min({cname}.alpha_count), max({cname}.alpha_count), avg({cname}.alpha_count)),
308+
| 'alpha_lower_count', array(min({cname}.alpha_lower_count), max({cname}.alpha_lower_count),
309+
| avg({cname}.alpha_lower_count)),
310+
| 'alpha_upper_count', array(min({cname}.alpha_upper_count), max({cname}.alpha_upper_count),
311+
| avg({cname}.alpha_upper_count)),
312+
| 'hex_digit_count', array(min({cname}.hex_digit_count), max({cname}.hex_digit_count),
313+
| avg({cname}.hex_digit_count))
314+
| ))
315+
| as {cname}""", marginChar="|")
316+
)
317+
318+
dfSummaryTextFeatures = self._addMeasureToSummary(
319+
'summary_text_features',
320+
fieldExprs=fieldTextFeatures,
321+
dfData=textFeaturesDf,
322+
dfSummary=None,
323+
rowLimit=1)
324+
325+
return dfSummaryTextFeatures
326+
226327
def summarizeToDF(self):
227328
""" Generate summary analysis of data set as dataframe
228329
@@ -368,6 +469,14 @@ def summarizeToDF(self):
368469
dfData=df_under_analysis,
369470
dfSummary=dfDataSummary)
370471

472+
logger.info("Analyzing text features")
473+
dfTextFeatures = self.generateTextFeatures(self._getExpandedSourceDf())
474+
475+
logger.info("Summarizing text features")
476+
dfTextFeaturesSummary = self._summarizeTextFeatures(dfTextFeatures)
477+
478+
dfDataSummary = dfDataSummary.union(dfTextFeaturesSummary)
479+
371480
return dfDataSummary
372481

373482
def summarize(self, suppressOutput=False):

tests/test_generation_from_data.py

+48-31
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def setupLogging():
1717

1818

1919
class TestGenerationFromData:
20-
SMALL_ROW_COUNT = 10000
20+
SMALL_ROW_COUNT = 1000
2121

2222
@pytest.fixture
2323
def testLogger(self):
@@ -68,66 +68,86 @@ def generation_spec(self):
6868
.withColumn("int_value", "int", min=100, max=200, percentNulls=0.1)
6969
.withColumn("byte_value", "tinyint", max=127)
7070
.withColumn("decimal_value", "decimal(10,2)", max=1000000)
71-
.withColumn("decimal_value", "decimal(10,2)", max=1000000)
7271
.withColumn("date_value", "date", expr="current_date()", random=True)
7372
.withColumn("binary_value", "binary", expr="cast('spark' as binary)", random=True)
7473

7574
)
7675
return spec
7776

78-
def test_code_generation1(self, generation_spec, setupLogging):
77+
@pytest.fixture
78+
def source_data_df(self, generation_spec):
7979
df_source_data = generation_spec.build()
80-
df_source_data.show()
80+
return df_source_data.cache()
81+
82+
def test_code_generation1(self, source_data_df, setupLogging):
83+
source_data_df.show()
8184

82-
analyzer = dg.DataAnalyzer(sparkSession=spark, df=df_source_data)
85+
analyzer = dg.DataAnalyzer(sparkSession=spark, df=source_data_df)
8386

8487
generatedCode = analyzer.scriptDataGeneratorFromData()
8588

86-
for fld in df_source_data.schema:
89+
for fld in source_data_df.schema:
8790
assert f"withColumn('{fld.name}'" in generatedCode
8891

8992
# check generated code for syntax errors
9093
ast_tree = ast.parse(generatedCode)
9194
assert ast_tree is not None
9295

93-
def test_code_generation_from_schema(self, generation_spec, setupLogging):
94-
df_source_data = generation_spec.build()
95-
generatedCode = dg.DataAnalyzer.scriptDataGeneratorFromSchema(df_source_data.schema)
96+
def test_code_generation_from_schema(self, source_data_df, setupLogging):
97+
generatedCode = dg.DataAnalyzer.scriptDataGeneratorFromSchema(source_data_df.schema)
9698

97-
for fld in df_source_data.schema:
99+
for fld in source_data_df.schema:
98100
assert f"withColumn('{fld.name}'" in generatedCode
99101

100102
# check generated code for syntax errors
101103
ast_tree = ast.parse(generatedCode)
102104
assert ast_tree is not None
103105

104-
def test_summarize(self, testLogger, generation_spec):
105-
testLogger.info("Building test data")
106-
107-
df_source_data = generation_spec.build()
106+
def test_summarize(self, testLogger, source_data_df):
108107

109108
testLogger.info("Creating data analyzer")
110109

111-
analyzer = dg.DataAnalyzer(sparkSession=spark, df=df_source_data)
110+
analyzer = dg.DataAnalyzer(sparkSession=spark, df=source_data_df)
112111

113112
testLogger.info("Summarizing data analyzer results")
114113
analyzer.summarize()
115114

116-
def test_summarize_to_df(self, generation_spec, testLogger):
117-
testLogger.info("Building test data")
118-
119-
df_source_data = generation_spec.build()
120-
115+
def test_summarize_to_df(self, source_data_df, testLogger):
121116
testLogger.info("Creating data analyzer")
122117

123-
analyzer = dg.DataAnalyzer(sparkSession=spark, df=df_source_data)
118+
analyzer = dg.DataAnalyzer(sparkSession=spark, df=source_data_df)
124119

125120
testLogger.info("Summarizing data analyzer results")
126121
df = analyzer.summarizeToDF()
127122

128-
#df.show()
123+
df.show()
124+
125+
def test_generate_text_features(self, source_data_df, testLogger):
126+
testLogger.info("Creating data analyzer")
127+
128+
analyzer = dg.DataAnalyzer(sparkSession=spark, df=source_data_df)
129+
130+
df_text_features = analyzer.generateTextFeatures(source_data_df).limit(10)
131+
df_text_features.show()
132+
133+
#data = df_text_features.selectExpr("get_json_object(asin, '$.print_len') as asin").limit(10).collect()
134+
data = df_text_features.selectExpr("asin.print_len as asin").limit(10).collect()
135+
assert data[0]['asin'] is not None
136+
print(data[0]['asin'] )
137+
138+
def test_summarize_text_features(self, source_data_df, testLogger):
139+
testLogger.info("Creating data analyzer")
140+
141+
analyzer = dg.DataAnalyzer(sparkSession=spark, df=source_data_df)
142+
143+
df_text_features = analyzer.generateTextFeatures(source_data_df)
144+
df_summary_text_features = analyzer._summarizeTextFeatures(df_text_features)
145+
df_summary_text_features.show()
146+
147+
data = df_summary_text_features.selectExpr("get_json_object(asin, '$.print_len') as asin").limit(10).collect()
148+
assert data[0]['asin'] is not None
149+
print(data[0]['asin'])
129150

130-
df_source_data.where("title is null or length(title) = 0").show()
131151

132152
@pytest.mark.parametrize("sampleString, expectedMatch",
133153
[("0234", "digits"),
@@ -141,10 +161,8 @@ def test_summarize_to_df(self, generation_spec, testLogger):
141161
("test_function", "identifier"),
142162
("10.0.0.1", "ip_addr")
143163
])
144-
def test_match_patterns(self, sampleString, expectedMatch, generation_spec):
145-
df_source_data = generation_spec.build()
146-
147-
analyzer = dg.DataAnalyzer(sparkSession=spark, df=df_source_data)
164+
def test_match_patterns(self, sampleString, expectedMatch, source_data_df):
165+
analyzer = dg.DataAnalyzer(sparkSession=spark, df=source_data_df)
148166

149167
pattern_match_result = ""
150168
for k, v in analyzer._regex_patterns.items():
@@ -156,11 +174,10 @@ def test_match_patterns(self, sampleString, expectedMatch, generation_spec):
156174

157175
assert pattern_match_result == expectedMatch, f"expected match to be {expectedMatch}"
158176

159-
def test_source_data_property(self, generation_spec):
160-
df_source_data = generation_spec.build()
161-
analyzer = dg.DataAnalyzer(sparkSession=spark, df=df_source_data, maxRows=1000)
177+
def test_source_data_property(self, source_data_df):
178+
analyzer = dg.DataAnalyzer(sparkSession=spark, df=source_data_df, maxRows=500)
162179

163180
count_rows = analyzer.sourceSampleDf.count()
164181
print(count_rows)
165-
assert abs(count_rows - 1000) < 100, "expected count to be close to 1000"
182+
assert abs(count_rows - 500) < 50, "expected count to be close to 500"
166183

0 commit comments

Comments
 (0)