There's a section on the Databricks spark-xml Github page which talks about parsing nested xml, and it provides a solution using the Scala API, as well as a couple of Pyspark helper functions to work around the issue that there is no separate Python package for spark-xml. So using these, here's one way you could solve the problem:
# 1. Copy helper functions from https://github.com/databricks/spark-xml#pyspark-notes
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.types import _parse_datatype_json_string
import pyspark.sql.functions as F
def ext_from_xml(xml_column, schema, options={}):
java_column = _to_java_column(xml_column.cast('string'))
java_schema = spark._jsparkSession.parseDataType(schema.json())
scala_map = spark._jvm.org.apache.spark.api.python.PythonUtils.toScalaMap(options)
jc = spark._jvm.com.databricks.spark.xml.functions.from_xml(
java_column, java_schema, scala_map)
return Column(jc)
def ext_schema_of_xml_df(df, options={}):
assert len(df.columns) == 1
scala_options = spark._jvm.PythonUtils.toScalaMap(options)
java_xml_module = getattr(getattr(
spark._jvm.com.databricks.spark.xml, "package$"), "MODULE$")
java_schema = java_xml_module.schema_of_xml_df(df._jdf, scala_options)
return _parse_datatype_json_string(java_schema.json())
# 2. Set up example dataframe
xml = '<?xml version="1.0" encoding="utf-8"?> <visitors> <visitor id="9615" age="68" sex="F" /> <visitor id="1882" age="34" sex="M" /> <visitor id="5987" age="23" sex="M" /> </visitors>'
df = spark.createDataFrame([('1',xml)],['id','visitors'])
df.show()
# +---+--------------------+
# | id| visitors|
# +---+--------------------+
# | 1|<?xml version="1....|
# +---+--------------------+
# 3. Get xml schema and parse xml column
payloadSchema = ext_schema_of_xml_df(df.select("visitors"))
parsed = df.withColumn("parsed", ext_from_xml(F.col("visitors"), payloadSchema))
parsed.show()
# +---+--------------------+--------------------+
# | id| visitors| parsed|
# +---+--------------------+--------------------+
# | 1|<?xml version="1....|[[[, 68, 9615, F]...|
# +---+--------------------+--------------------+
# 4. Extract 'visitor' field from StructType
df2 = parsed.select(*parsed.columns[:-1],F.explode(F.col('parsed').getItem('visitor')))
df2.show()
# +---+--------------------+---------------+
# | id| visitors| col|
# +---+--------------------+---------------+
# | 1|<?xml version="1....|[, 68, 9615, F]|
# | 1|<?xml version="1....|[, 34, 1882, M]|
# | 1|<?xml version="1....|[, 23, 5987, M]|
# +---+--------------------+---------------+
# 5. Get field names, which will become new columns
# (there's probably a much better way of doing this :D)
new_col_names = [s.split(':')[0] for s in payloadSchema['visitor'].simpleString().split('<')[-1].strip('>>').split(',')]
new_col_names
# ['_VALUE', '_age', '_id', '_sex']
# 6. Create new columns
for c in new_col_names:
df2 = df2.withColumn(c, F.col('col').getItem(c))
df2 = df2.drop('col','_VALUE')
df2.show()
# +---+--------------------+----+----+----+
# | id| visitors|_age| _id|_sex|
# +---+--------------------+----+----+----+
# | 1|<?xml version="1....| 68|9615| F|
# | 1|<?xml version="1....| 34|1882| M|
# | 1|<?xml version="1....| 23|5987| M|
# +---+--------------------+----+----+----+
One thing to look out for is the new column names duplicating existing column names - in this case the new column names are all preceded by underscores so we don't have any duplication, but it's probably good to check that the nested xml tags don't conflict with existing column names beforehand.