2023-11-12

Recursively adding columns to pyspark dataframe nested arrays

I'm working with a pyspark DataFrame that contains multiple levels of nested arrays of structs. My goal is to add an array's hash column + record's top level hash column to each nested array. To achieve that for all nested arrays I need to use recursion since I do not know how nested the array can be.

So for this example schema

schema = StructType([
    StructField("name", StringType()),
    StructField("experience", ArrayType(StructType([
        StructField("role", StringType()),
        StructField("duration", StringType()),
        StructField("company", ArrayType(StructType([
            StructField("company_name", StringType()),
            StructField("location", StringType())
        ])))
    ])))
])

The desired output schema would look like this:

hashed_schema = StructType([
    StructField("name", StringType()),
    StructField("experience", ArrayType(StructType([
        StructField("role", StringType()),
        StructField("duration", StringType()),
        StructField("experience_hash", StringType()),  # Added hash for the experience collection
        StructField("company", ArrayType(StructType([
            StructField("company_name", StringType()),
            StructField("location", StringType()),
            StructField("company_hash", StringType())  # Added hash for the company collection
        ])))
    ]))),
    StructField("employee_hash", StringType()),  # Added hash for the entire record
])

I have tried to write a code with recursion that would iterate trough each nested array and hash its columns. While it seems to work for 1st level nested arrays, the recursion part does not work, I get an error that the recursion is too deep.


def hash_for_level(level_path):
    return md5(concat_ws("_", *[lit(elem) for elem in level_path]))

def add_hash_columns(df, level_path, current_struct, root_hash_col=None):
    # If this is the root level, create the root hash
    if not level_path and root_hash_col is None:
        root_hash_col = 'employee_hash'
        df = df.withColumn(root_hash_col, hash_for_level(['employee']))
    
    # Traverse the current structure and add hash columns
    for field in current_struct.fields:
        new_level_path = level_path + [field.name]
        # If the field is an array of structs, add a hash for each element in the array
        if isinstance(field.dataType, ArrayType):
            nested_struct = field.dataType.elementType
            hash_expr = transform(
                col('.'.join(level_path + [field.name])),
                lambda x: x.withField(new_level_path[-1] + '_hash', hash_for_level(new_level_path))
                    .withField(root_hash_col, col(root_hash_col))  # Include the root hash
            )
            # Add the hash column to the array elements
            df = df.withColumn('.'.join(level_path + [field.name]), hash_expr)
            # Recursion call to apply the same logic for nested arrays
            df = add_hash_columns(df, new_level_path, nested_struct, root_hash_col)
            
    # Add a hash column at the current level
    if level_path:
        #print("Level path:", level_path)
        hash_col_name = '_'.join(level_path) + '_hash'
        df = df.withColumn(hash_col_name, hash_for_level(level_path))
        if root_hash_col:
            # Ensure the root hash is included at each struct level
            df = df.withColumn(root_hash_col, col(root_hash_col))
            
    return df

df = spark.createDataFrame([], schema)
df = add_hash_columns(df, [], df.schema)
df


No comments:

Post a Comment