3

I would like to add 2 columns (cat_a, cat_b) to DataFrame df using the .assign() method. But I don't get the code working...

import pandas as pd
np.random.seed(999)
num = 10
df = pd.DataFrame({'id': np.random.choice(range(1000, 10000), num, replace=False),
                   'sex': np.random.choice(list('MF'), num, replace=True),
                   'year': np.random.randint(1980, 1990, num)})
print(df)

     id sex  year
0  3461   F  1983
1  8663   M  1988
2  6615   M  1986
3  5336   M  1982
4  3756   F  1984
5  8653   F  1989
6  9362   M  1985
7  3944   M  1981
8  3334   F  1986
9  6135   F  1988

This should be the values of de new columns cat_a and cat_b

# cat_a
list(map(lambda y: 'A' if y <= 1985 else 'B', df.year))
['A', 'B', 'B', 'A', 'A', 'B', 'A', 'A', 'B', 'B']

# cat_b
list(map(lambda s, y: 1 if s == 'M' and y <= 1985 else (2 if s == 'M' else (3 if y < 1985 else 4)), df.sex, df.year))
[3, 2, 2, 1, 3, 4, 1, 1, 4, 4]

Trying the syntax of the .assign() method:

df.assign(cat_a = 'AB', cat_b = 1234)
print(df)

     id sex  year cat_a  cat_b
0  3461   F  1983    AB   1234
1  8663   M  1988    AB   1234
2  6615   M  1986    AB   1234
3  5336   M  1982    AB   1234
4  3756   F  1984    AB   1234
5  8653   F  1989    AB   1234
6  9362   M  1985    AB   1234
7  3944   M  1981    AB   1234
8  3334   F  1986    AB   1234
9  6135   F  1988    AB   1234

Replacing dummie values gives an error:

df.assign(cat_a = lambda x: 'A' if x.year <= 1985 else 'B',
          cat_b = lambda x: 1 if x.sex == 'M' and x.year <= 1985 
                              else (2 if x.sex == 'M'
                                      else (3 if x.year < 1985
                                              else 4
                                           )
                                   )
         )

Any suggestions how to get the code working would be very welcome!
I have workarounds but I would like to get my results with the .assign() method.

1 Answer 1

7

Use vectorized solution with numpy.where and numpy.select:

m1 = df.year <= 1985
m2 = df.sex == 'M'

a = np.where(m1, 'A', 'B')
b = np.select([m1 & m2, ~m1 & m2, m1 & ~m2], [1,2,3], default=4)

df = df.assign(cat_a = a, cat_b = b)
print (df)
     id sex  year cat_a  cat_b
0  3461   F  1983     A      3
1  8663   M  1988     B      2
2  6615   M  1986     B      2
3  5336   M  1982     A      1
4  3756   F  1984     A      3
5  8653   F  1989     B      4
6  9362   M  1985     A      1
7  3944   M  1981     A      1
8  3334   F  1986     B      4
9  6135   F  1988     B      4

Verify:

a = list(map(lambda y: 'A' if y <= 1985 else 'B', df.year))
b = list(map(lambda s, y: 1 if s == 'M' and y <= 1985 else (2 if s == 'M' else (3 if y < 1985 else 4)), df.sex, df.year))

df = df.assign(cat_a = a, cat_b = b)
print (df)
     id sex  year cat_a  cat_b
0  3461   F  1983     A      3
1  8663   M  1988     B      2
2  6615   M  1986     B      2
3  5336   M  1982     A      1
4  3756   F  1984     A      3
5  8653   F  1989     B      4
6  9362   M  1985     A      1
7  3944   M  1981     A      1
8  3334   F  1986     B      4
9  6135   F  1988     B      4

Performance is really interesting, in small DataFrames to 1k is faster mapping, for bigger DataFrames is better use numpy solution:

pic

np.random.seed(999)

def mapping(df):
    a = list(map(lambda y: 'A' if y <= 1985 else 'B', df.year))
    b = list(map(lambda s, y: 1 if s == 'M' and y <= 1985 else (2 if s == 'M' else (3 if y < 1985 else 4)), df.sex, df.year))

    return df.assign(cat_a = a, cat_b = b)

def vec(df):
    m1 = df.year <= 1985
    m2 = df.sex == 'M'
    a = np.where(m1, 'A', 'B')
    b = np.select([m1 & m2, ~m1 & m2, m1 & ~m2], [1,2,3], default=4)
    return df.assign(cat_a = a, cat_b = b)

def make_df(n):
    df = pd.DataFrame({'id': np.random.choice(range(10, 1000000), n, replace=False),
                   'sex': np.random.choice(list('MF'), n, replace=True),
                   'year': np.random.randint(1980, 1990, n)})
    return df

perfplot.show(
    setup=make_df,
    kernels=[mapping, vec],
    n_range=[2**k for k in range(2, 18)],
    logx=True,
    logy=True,
    equality_check=False,  # rows may appear in different order
    xlabel='len(df)')
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.