Skip to content

Commit dc8daf0

Browse files
author
Alexey Smirnov
committed
Avoid code duplication in Merge DataFrame method (#5657)
1 parent f7658b2 commit dc8daf0

File tree

2 files changed

+137
-193
lines changed

2 files changed

+137
-193
lines changed

src/Microsoft.Data.Analysis/DataFrame.Join.cs

Lines changed: 113 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -142,220 +142,162 @@ public DataFrame Join(DataFrame other, string leftSuffix = "_left", string right
142142
}
143143

144144
// TODO: Merge API with an "On" parameter that merges on a column common to 2 dataframes
145-
146-
/// <summary>
147-
/// Merge DataFrames with a database style join
148-
/// </summary>
149-
/// <param name="other"></param>
150-
/// <param name="leftJoinColumn"></param>
151-
/// <param name="rightJoinColumn"></param>
152-
/// <param name="leftSuffix"></param>
153-
/// <param name="rightSuffix"></param>
154-
/// <param name="joinAlgorithm"></param>
155-
/// <returns></returns>
156-
public DataFrame Merge<TKey>(DataFrame other, string leftJoinColumn, string rightJoinColumn, string leftSuffix = "_left", string rightSuffix = "_right", JoinAlgorithm joinAlgorithm = JoinAlgorithm.Left)
145+
146+
private static Dictionary<TKey, long> Merge<TKey>(DataFrame retainedDataFrame, DataFrame supplementaryDataFame, string retainedJoinColumnName, string supplemetaryJoinColumnName, out PrimitiveDataFrameColumn<long> retainedRowIndices, out PrimitiveDataFrameColumn<long> supplementaryRowIndices, bool isInner = false, bool calculateIntersection = false)
157147
{
158-
// A simple hash join
159-
DataFrame ret = new DataFrame();
160-
DataFrame leftDataFrame = this;
161-
DataFrame rightDataFrame = other;
148+
Dictionary<TKey, long> intersection = calculateIntersection ? new Dictionary<TKey, long>(EqualityComparer<TKey>.Default) : null;
162149

163-
// The final table size is not known until runtime
164-
long rowNumber = 0;
165-
PrimitiveDataFrameColumn<long> leftRowIndices = new PrimitiveDataFrameColumn<long>("LeftIndices");
166-
PrimitiveDataFrameColumn<long> rightRowIndices = new PrimitiveDataFrameColumn<long>("RightIndices");
167-
if (joinAlgorithm == JoinAlgorithm.Left)
168-
{
169-
// First hash other dataframe on the rightJoinColumn
170-
DataFrameColumn otherColumn = other.Columns[rightJoinColumn];
171-
Dictionary<TKey, ICollection<long>> multimap = otherColumn.GroupColumnValues<TKey>(out HashSet<long> otherColumnNullIndices);
150+
retainedRowIndices = new PrimitiveDataFrameColumn<long>("RetainedIndices");
151+
supplementaryRowIndices = new PrimitiveDataFrameColumn<long>("SupplementaryIndices");
152+
153+
// First hash supplementary dataframe
154+
DataFrameColumn supplementaryColumn = supplementaryDataFame.Columns[supplemetaryJoinColumnName];
155+
Dictionary<TKey, ICollection<long>> multimap = supplementaryColumn.GroupColumnValues<TKey>(out HashSet<long> supplementaryColumnNullIndices);
172156

173-
// Go over the records in this dataframe and match with the dictionary
174-
DataFrameColumn thisColumn = Columns[leftJoinColumn];
157+
// Go over the records in this dataframe and match with the dictionary
158+
DataFrameColumn retainedColumn = retainedDataFrame.Columns[retainedJoinColumnName];
175159

176-
for (long i = 0; i < thisColumn.Length; i++)
160+
for (long i = 0; i < retainedColumn.Length; i++)
161+
{
162+
var retainedValue = retainedColumn[i];
163+
if (retainedValue != null)
177164
{
178-
var thisColumnValue = thisColumn[i];
179-
if (thisColumnValue != null)
165+
//Get all rows from supplementary dataframe that sutisfy JOIN condition
166+
if (multimap.TryGetValue((TKey)retainedValue, out ICollection<long> rowIndices))
180167
{
181-
if (multimap.TryGetValue((TKey)thisColumnValue, out ICollection<long> rowNumbers))
168+
foreach (long rowIndex in rowIndices)
182169
{
183-
foreach (long row in rowNumbers)
170+
retainedRowIndices.Append(i);
171+
supplementaryRowIndices.Append(rowIndex);
172+
173+
//store intersection if required
174+
if (calculateIntersection)
184175
{
185-
leftRowIndices.Append(i);
186-
rightRowIndices.Append(row);
176+
if (!intersection.ContainsKey((TKey)retainedValue))
177+
{
178+
intersection.Add((TKey)retainedValue, rowIndex);
179+
}
187180
}
188181
}
189-
else
190-
{
191-
leftRowIndices.Append(i);
192-
rightRowIndices.Append(null);
193-
}
194182
}
195183
else
196184
{
197-
foreach (long row in otherColumnNullIndices)
198-
{
199-
leftRowIndices.Append(i);
200-
rightRowIndices.Append(row);
201-
}
185+
if (isInner)
186+
continue;
187+
188+
retainedRowIndices.Append(i);
189+
supplementaryRowIndices.Append(null);
202190
}
203191
}
204-
}
205-
else if (joinAlgorithm == JoinAlgorithm.Right)
206-
{
207-
DataFrameColumn thisColumn = Columns[leftJoinColumn];
208-
Dictionary<TKey, ICollection<long>> multimap = thisColumn.GroupColumnValues<TKey>(out HashSet<long> thisColumnNullIndices);
209-
210-
DataFrameColumn otherColumn = other.Columns[rightJoinColumn];
211-
for (long i = 0; i < otherColumn.Length; i++)
192+
else
212193
{
213-
var otherColumnValue = otherColumn[i];
214-
if (otherColumnValue != null)
215-
{
216-
if (multimap.TryGetValue((TKey)otherColumnValue, out ICollection<long> rowNumbers))
217-
{
218-
foreach (long row in rowNumbers)
219-
{
220-
leftRowIndices.Append(row);
221-
rightRowIndices.Append(i);
222-
}
223-
}
224-
else
225-
{
226-
leftRowIndices.Append(null);
227-
rightRowIndices.Append(i);
228-
}
229-
}
230-
else
194+
foreach (long row in supplementaryColumnNullIndices)
231195
{
232-
foreach (long thisColumnNullIndex in thisColumnNullIndices)
233-
{
234-
leftRowIndices.Append(thisColumnNullIndex);
235-
rightRowIndices.Append(i);
236-
}
196+
retainedRowIndices.Append(i);
197+
supplementaryRowIndices.Append(row);
237198
}
238199
}
239200
}
201+
202+
return intersection;
203+
}
204+
205+
206+
/// <summary>
207+
/// Merge DataFrames with a database style join
208+
/// </summary>
209+
/// <param name="other"></param>
210+
/// <param name="leftJoinColumn"></param>
211+
/// <param name="rightJoinColumn"></param>
212+
/// <param name="leftSuffix"></param>
213+
/// <param name="rightSuffix"></param>
214+
/// <param name="joinAlgorithm"></param>
215+
/// <returns></returns>
216+
public DataFrame Merge<TKey>(DataFrame other, string leftJoinColumn, string rightJoinColumn, string leftSuffix = "_left", string rightSuffix = "_right", JoinAlgorithm joinAlgorithm = JoinAlgorithm.Left)
217+
{
218+
//In Outer join the joined dataframe retains each row — even if no other matching row exists in supplementary dataframe.
219+
//Outer joins subdivide further into left outer joins (left dataframe is retained), right outer joins (rightdataframe is retained), in full outer both are retained
220+
221+
PrimitiveDataFrameColumn<long> retainedRowIndices;
222+
PrimitiveDataFrameColumn<long> supplementaryRowIndices;
223+
DataFrame supplementaryDataFrame;
224+
DataFrame retainedDataFrame;
225+
bool isLeftDataFrameRetained;
226+
227+
if (joinAlgorithm == JoinAlgorithm.Left || joinAlgorithm == JoinAlgorithm.Right)
228+
{
229+
isLeftDataFrameRetained = (joinAlgorithm == JoinAlgorithm.Left);
230+
231+
supplementaryDataFrame = isLeftDataFrameRetained ? other : this;
232+
var supplementaryJoinColumn = isLeftDataFrameRetained ? rightJoinColumn : leftJoinColumn;
233+
234+
retainedDataFrame = isLeftDataFrameRetained ? this : other;
235+
var retainedJoinColumn = isLeftDataFrameRetained ? leftJoinColumn : rightJoinColumn;
236+
237+
Merge<TKey>(retainedDataFrame, supplementaryDataFrame, retainedJoinColumn, supplementaryJoinColumn, out retainedRowIndices, out supplementaryRowIndices);
238+
239+
}
240240
else if (joinAlgorithm == JoinAlgorithm.Inner)
241241
{
242-
// Hash the column with the smaller RowCount
243-
long leftRowCount = Rows.Count;
244-
long rightRowCount = other.Rows.Count;
242+
// use as supplementary (for Hashing) the dataframe with the smaller RowCount
243+
isLeftDataFrameRetained = (Rows.Count > other.Rows.Count);
245244

246-
bool leftColumnIsSmaller = leftRowCount <= rightRowCount;
247-
DataFrameColumn hashColumn = leftColumnIsSmaller ? Columns[leftJoinColumn] : other.Columns[rightJoinColumn];
248-
DataFrameColumn otherColumn = ReferenceEquals(hashColumn, Columns[leftJoinColumn]) ? other.Columns[rightJoinColumn] : Columns[leftJoinColumn];
249-
Dictionary<TKey, ICollection<long>> multimap = hashColumn.GroupColumnValues<TKey>(out HashSet<long> smallerDataFrameColumnNullIndices);
245+
supplementaryDataFrame = isLeftDataFrameRetained ? other : this;
246+
var supplementaryJoinColumn = isLeftDataFrameRetained ? rightJoinColumn : leftJoinColumn;
250247

251-
for (long i = 0; i < otherColumn.Length; i++)
252-
{
253-
var otherColumnValue = otherColumn[i];
254-
if (otherColumnValue != null)
255-
{
256-
if (multimap.TryGetValue((TKey)otherColumnValue, out ICollection<long> rowNumbers))
257-
{
258-
foreach (long row in rowNumbers)
259-
{
260-
leftRowIndices.Append(leftColumnIsSmaller ? row : i);
261-
rightRowIndices.Append(leftColumnIsSmaller ? i : row);
262-
}
263-
}
264-
}
265-
else
266-
{
267-
foreach (long nullIndex in smallerDataFrameColumnNullIndices)
268-
{
269-
leftRowIndices.Append(leftColumnIsSmaller ? nullIndex : i);
270-
rightRowIndices.Append(leftColumnIsSmaller ? i : nullIndex);
271-
}
272-
}
273-
}
248+
retainedDataFrame = isLeftDataFrameRetained ? this : other;
249+
var retainedJoinColumn = isLeftDataFrameRetained ? leftJoinColumn : rightJoinColumn;
250+
251+
Merge<TKey>(retainedDataFrame, supplementaryDataFrame, retainedJoinColumn, supplementaryJoinColumn, out retainedRowIndices, out supplementaryRowIndices, true);
274252
}
275253
else if (joinAlgorithm == JoinAlgorithm.FullOuter)
276254
{
277-
DataFrameColumn otherColumn = other.Columns[rightJoinColumn];
278-
Dictionary<TKey, ICollection<long>> multimap = otherColumn.GroupColumnValues<TKey>(out HashSet<long> otherColumnNullIndices);
279-
Dictionary<TKey, long> intersection = new Dictionary<TKey, long>(EqualityComparer<TKey>.Default);
255+
//In full outer join we would like to retain data from both side, so we do it into 2 steps: one first we do LEFT JOIN and then add lost data from the RIGHT side
256+
257+
//Step 1
258+
//Do LEFT JOIN
259+
isLeftDataFrameRetained = true;
280260

281-
// Go over the records in this dataframe and match with the dictionary
282-
DataFrameColumn thisColumn = Columns[leftJoinColumn];
283-
Int64DataFrameColumn thisColumnNullIndices = new Int64DataFrameColumn("ThisColumnNullIndices");
261+
supplementaryDataFrame = isLeftDataFrameRetained ? other : this;
262+
var supplementaryJoinColumn = isLeftDataFrameRetained ? rightJoinColumn : leftJoinColumn;
284263

285-
for (long i = 0; i < thisColumn.Length; i++)
286-
{
287-
var thisColumnValue = thisColumn[i];
288-
if (thisColumnValue != null)
289-
{
290-
if (multimap.TryGetValue((TKey)thisColumnValue, out ICollection<long> rowNumbers))
291-
{
292-
foreach (long row in rowNumbers)
293-
{
294-
leftRowIndices.Append(i);
295-
rightRowIndices.Append(row);
296-
if (!intersection.ContainsKey((TKey)thisColumnValue))
297-
{
298-
intersection.Add((TKey)thisColumnValue, rowNumber);
299-
}
300-
}
301-
}
302-
else
303-
{
304-
leftRowIndices.Append(i);
305-
rightRowIndices.Append(null);
306-
}
307-
}
308-
else
309-
{
310-
thisColumnNullIndices.Append(i);
311-
}
312-
}
313-
for (long i = 0; i < otherColumn.Length; i++)
264+
retainedDataFrame = isLeftDataFrameRetained ? this : other;
265+
var retainedJoinColumn = isLeftDataFrameRetained ? leftJoinColumn : rightJoinColumn;
266+
267+
var intersection = Merge<TKey>(retainedDataFrame, supplementaryDataFrame, retainedJoinColumn, supplementaryJoinColumn, out retainedRowIndices, out supplementaryRowIndices, calculateIntersection: true);
268+
269+
//Step 2
270+
//Do RIGHT JOIN to retain all data from supplementary DataFrame too (take into account data intersection from the first step to avoid duplicates)
271+
DataFrameColumn supplementaryColumn = supplementaryDataFrame.Columns[supplementaryJoinColumn];
272+
273+
for (long i = 0; i < supplementaryColumn.Length; i++)
314274
{
315-
var value = otherColumn[i];
275+
var value = supplementaryColumn[i];
316276
if (value != null)
317277
{
318278
if (!intersection.ContainsKey((TKey)value))
319279
{
320-
leftRowIndices.Append(null);
321-
rightRowIndices.Append(i);
280+
retainedRowIndices.Append(null);
281+
supplementaryRowIndices.Append(i);
322282
}
323283
}
324284
}
325-
326-
// Now handle the null rows
327-
foreach (long? thisColumnNullIndex in thisColumnNullIndices)
328-
{
329-
foreach (long otherColumnNullIndex in otherColumnNullIndices)
330-
{
331-
leftRowIndices.Append(thisColumnNullIndex.Value);
332-
rightRowIndices.Append(otherColumnNullIndex);
333-
}
334-
if (otherColumnNullIndices.Count == 0)
335-
{
336-
leftRowIndices.Append(thisColumnNullIndex.Value);
337-
rightRowIndices.Append(null);
338-
}
339-
}
340-
if (thisColumnNullIndices.Length == 0)
341-
{
342-
foreach (long otherColumnNullIndex in otherColumnNullIndices)
343-
{
344-
leftRowIndices.Append(null);
345-
rightRowIndices.Append(otherColumnNullIndex);
346-
}
347-
}
348285
}
349286
else
350287
throw new NotImplementedException(nameof(joinAlgorithm));
351-
352-
for (int i = 0; i < leftDataFrame.Columns.Count; i++)
288+
289+
DataFrame ret = new DataFrame();
290+
291+
//insert columns from left dataframe (this)
292+
for (int i = 0; i < this.Columns.Count; i++)
353293
{
354-
ret.Columns.Insert(i, leftDataFrame.Columns[i].Clone(leftRowIndices));
294+
ret.Columns.Insert(i, this.Columns[i].Clone(isLeftDataFrameRetained ? retainedRowIndices : supplementaryRowIndices));
355295
}
356-
for (int i = 0; i < rightDataFrame.Columns.Count; i++)
296+
297+
//insert columns from right dataframe (other)
298+
for (int i = 0; i < other.Columns.Count; i++)
357299
{
358-
DataFrameColumn column = rightDataFrame.Columns[i].Clone(rightRowIndices);
300+
DataFrameColumn column = other.Columns[i].Clone(isLeftDataFrameRetained ? supplementaryRowIndices : retainedRowIndices);
359301
SetSuffixForDuplicatedColumnNames(ret, column, leftSuffix, rightSuffix);
360302
ret.Columns.Insert(ret.Columns.Count, column);
361303
}

0 commit comments

Comments
 (0)