114

I've seen various people suggesting that Dataframe.explode is a useful way to do this, but it results in more rows than the original dataframe, which isn't what I want at all. I simply want to do the Dataframe equivalent of the very simple:

rdd.map(lambda row: row + [row.my_str_col.split('-')])

which takes something looking like:

col1 | my_str_col
-----+-----------
  18 |  856-yygrm
 201 |  777-psgdg

and converts it to this:

col1 | my_str_col | _col3 | _col4
-----+------------+-------+------
  18 |  856-yygrm |   856 | yygrm
 201 |  777-psgdg |   777 | psgdg

I am aware of pyspark.sql.functions.split(), but it results in a nested array column instead of two top-level columns like I want.

Ideally, I want these new columns to be named as well.

5 Answers 5

172

pyspark.sql.functions.split() is the right approach here - you simply need to flatten the nested ArrayType column into multiple top-level columns. In this case, where each array only contains 2 items, it's very easy. You simply use Column.getItem() to retrieve each part of the array as a column itself:

split_col = pyspark.sql.functions.split(df['my_str_col'], '-')
df = df.withColumn('NAME1', split_col.getItem(0))
df = df.withColumn('NAME2', split_col.getItem(1))

The result will be:

col1 | my_str_col | NAME1 | NAME2
-----+------------+-------+------
  18 |  856-yygrm |   856 | yygrm
 201 |  777-psgdg |   777 | psgdg

I am not sure how I would solve this in a general case where the nested arrays were not the same size from Row to Row.

4
  • 3
    Is there a way to put the remaining items in a single column? i.e. split_col.getItem(2 - n) in a third column. I guess something like the above loop to make columns for all items then concatenating them might work, but I don't know if that's very efficient or not.
    – Chris
    Commented Oct 18, 2017 at 19:28
  • Use df.withColumn('NAME_remaining', pyspark.sql.functions.split(df[my_str_col'],'-',3).getItem(2) to get the remaining items. spark.apache.org/docs/latest/api/sql/index.html Commented Oct 26, 2020 at 18:01
  • I found that if you are trying to assign one of the split items back to the original column, you have to rename the original column with withColumnRenamed() before the split in order to avoid an error apparently related to issues.apache.org/jira/browse/SPARK-14948.
    – Steve
    Commented Oct 27, 2020 at 19:59
  • How do you perform a split such that first part of the split is columnname and the second part is the column value? Commented Jul 12, 2021 at 16:15
58

Here's a solution to the general case that doesn't involve needing to know the length of the array ahead of time, using collect, or using udfs. Unfortunately this only works for spark version 2.1 and above, because it requires the posexplode function.

Suppose you had the following DataFrame:

df = spark.createDataFrame(
    [
        [1, 'A, B, C, D'], 
        [2, 'E, F, G'], 
        [3, 'H, I'], 
        [4, 'J']
    ]
    , ["num", "letters"]
)
df.show()
#+---+----------+
#|num|   letters|
#+---+----------+
#|  1|A, B, C, D|
#|  2|   E, F, G|
#|  3|      H, I|
#|  4|         J|
#+---+----------+

Split the letters column and then use posexplode to explode the resultant array along with the position in the array. Next use pyspark.sql.functions.expr to grab the element at index pos in this array.

import pyspark.sql.functions as f

df.select(
        "num",
        f.split("letters", ", ").alias("letters"),
        f.posexplode(f.split("letters", ", ")).alias("pos", "val")
    )\
    .show()
#+---+------------+---+---+
#|num|     letters|pos|val|
#+---+------------+---+---+
#|  1|[A, B, C, D]|  0|  A|
#|  1|[A, B, C, D]|  1|  B|
#|  1|[A, B, C, D]|  2|  C|
#|  1|[A, B, C, D]|  3|  D|
#|  2|   [E, F, G]|  0|  E|
#|  2|   [E, F, G]|  1|  F|
#|  2|   [E, F, G]|  2|  G|
#|  3|      [H, I]|  0|  H|
#|  3|      [H, I]|  1|  I|
#|  4|         [J]|  0|  J|
#+---+------------+---+---+

Now we create two new columns from this result. First one is the name of our new column, which will be a concatenation of letter and the index in the array. The second column will be the value at the corresponding index in the array. We get the latter by exploiting the functionality of pyspark.sql.functions.expr which allows us use column values as parameters.

df.select(
        "num",
        f.split("letters", ", ").alias("letters"),
        f.posexplode(f.split("letters", ", ")).alias("pos", "val")
    )\
    .drop("val")\
    .select(
        "num",
        f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
        f.expr("letters[pos]").alias("val")
    )\
    .show()
#+---+-------+---+
#|num|   name|val|
#+---+-------+---+
#|  1|letter0|  A|
#|  1|letter1|  B|
#|  1|letter2|  C|
#|  1|letter3|  D|
#|  2|letter0|  E|
#|  2|letter1|  F|
#|  2|letter2|  G|
#|  3|letter0|  H|
#|  3|letter1|  I|
#|  4|letter0|  J|
#+---+-------+---+

Now we can just groupBy the num and pivot the DataFrame. Putting that all together, we get:

df.select(
        "num",
        f.split("letters", ", ").alias("letters"),
        f.posexplode(f.split("letters", ", ")).alias("pos", "val")
    )\
    .drop("val")\
    .select(
        "num",
        f.concat(f.lit("letter"),f.col("pos").cast("string")).alias("name"),
        f.expr("letters[pos]").alias("val")
    )\
    .groupBy("num").pivot("name").agg(f.first("val"))\
    .show()
#+---+-------+-------+-------+-------+
#|num|letter0|letter1|letter2|letter3|
#+---+-------+-------+-------+-------+
#|  1|      A|      B|      C|      D|
#|  3|      H|      I|   null|   null|
#|  2|      E|      F|      G|   null|
#|  4|      J|   null|   null|   null|
#+---+-------+-------+-------+-------+
1
  • 2
    FYI I tried this with 3909 elements to split on ~1.7M original rows and it was too slow / not completing after an hour Commented Jan 25, 2022 at 23:02
23

Here's another approach, in case you want split a string with a delimiter.

import pyspark.sql.functions as f

df = spark.createDataFrame([("1:a:2001",),("2:b:2002",),("3:c:2003",)],["value"])
df.show()
+--------+
|   value|
+--------+
|1:a:2001|
|2:b:2002|
|3:c:2003|
+--------+

df_split = df.select(f.split(df.value,":")).rdd.flatMap(
              lambda x: x).toDF(schema=["col1","col2","col3"])

df_split.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
|   1|   a|2001|
|   2|   b|2002|
|   3|   c|2003|
+----+----+----+

I don't think this transition back and forth to RDDs is going to slow you down... Also don't worry about last schema specification: it's optional, you can avoid it generalizing the solution to data with unknown column size.

6
  • 1
    how can I do this in scala? I get stuck with the flatMap lambda function Commented Mar 28, 2021 at 4:31
  • 1
    pay attention the pattern is given as a regular expression, hence you need to use \ for special characters Commented Jun 4, 2021 at 13:21
  • 1
    If you don't want to refer back to df inside your expression, you can pass the name of the column to split, i.e. df.select(f.split("value",":"))... Commented Jun 28, 2021 at 9:00
  • @moshebeeri You saved me!
    – diman82
    Commented Jul 11, 2021 at 19:26
  • What if there were more than one column ("value")? How would flatMap behave in this case? Commented Aug 18, 2022 at 12:07
4

Instead of Column.getItem(i) we can use Column[i].
Also, enumerate is useful in big dataframes.

from pyspark.sql import functions as F
  • Keep parent column:

    for i, c in enumerate(['new_1', 'new_2']):
        df = df.withColumn(c, F.split('my_str_col', '-')[i])
    

    or

    new_cols = ['new_1', 'new_2']
    df = df.select('*', *[F.split('my_str_col', '-')[i].alias(c) for i, c in enumerate(new_cols)])
    
  • Replace parent column:

    for i, c in enumerate(['new_1', 'new_2']):
        df = df.withColumn(c, F.split('my_str_col', '-')[i])
    df = df.drop('my_str_col')
    

    or

    new_cols = ['new_1', 'new_2']
    df = df.select(
        *[c for c in df.columns if c != 'my_str_col'],
        *[F.split('my_str_col', '-')[i].alias(c) for i, c in enumerate(new_cols)]
    )
    
2

I understand your pain. Using split() can work, but can also lead to breaks.

Let's take your df and make a slight change to it:

df = spark.createDataFrame([('1:"a:3":2001',),('2:"b":2002',),('3:"c":2003',)],["value"]) 

df.show()

+------------+
|       value|
+------------+
|1:"a:3":2001|
|  2:"b":2002|
|  3:"c":2003|
+------------+

If you try to apply split() to this as outlined above:

df_split = df.select(split(df.value,":")).rdd.flatMap(
              lambda x: x).toDF(schema=["col1","col2","col3"]).show()

you will get

IllegalStateException: Input row doesn't have expected number of values required by the schema. 4 fields are required while 3 values are provided.

So, is there a more elegant way of addressing this? I was so happy to have it pointed out to me. pyspark.sql.functions.from_csv() is your friend.

Taking my above example df:

from pyspark.sql.functions import from_csv

# Define a column schema to apply with from_csv()
col_schema = ["col1 INTEGER","col2 STRING","col3 INTEGER"]
schema_str = ",".join(col_schema)

# define the separator because it isn't a ','
options = {'sep': ":"}

# create a df from the value column using schema and options
df_csv = df.select(from_csv(df.value, schema_str, options).alias("value_parsed"))
df_csv.show()

+--------------+
|  value_parsed|
+--------------+
|[1, a:3, 2001]|
|  [2, b, 2002]|
|  [3, c, 2003]|
+--------------+

Then we can easily flatten the df to put the values in columns:

df2 = df_csv.select("value_parsed.*").toDF("col1","col2","col3")
df2.show()

+----+----+----+
|col1|col2|col3|
+----+----+----+
|   1| a:3|2001|
|   2|   b|2002|
|   3|   c|2003|
+----+----+----+

No breaks. Data correctly parsed. Life is good. Have a beer.

1
  • Using this regex in split() method should also do the trick- [:](?=(?:[^"]*"[^"]*")*[^"]*$)
    – Mohana B C
    Commented Jun 17, 2022 at 3:20

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.