Is there a better way to only return each pl.element() in a polars array if it matches an item contained within a list?

While it works, I get the error The predicate 'col("").is_in([Series])' in 'when->then->otherwise' is not a valid aggregation and might produce a different number of rows than the group_by operation would. This behavior is experimental and may be subject to change warning which leads me to believe there's probably a more concise/better way:

import polars as pl

terms = ['a', 'z']

(pl.LazyFrame({'a':['x y z']})
   .select(pl.col('a')
             .str.split(' ')
             .list.eval(pl.when(pl.element().is_in(terms))
                          .then(pl.element())
                          .otherwise(None))
             .list.drop_nulls()
             .list.join(' ')
           )
   .fetch()
 )

For posterity's sake, it replaces my previous attempt using .map_elements():

import polars as pl
import re

terms = ['a', 'z']

(pl.LazyFrame({'a':['x y z']})
   .select(pl.col('a')
             .str.split(' ')
             .map_elements(lambda x: ' '.join(list(set(re.findall('|'.join(terms), x)))),
                           return_dtype = pl.Utf8)
           )
   .fetch()
 )
2

There are 2 answers

0
Thomas On BEST ANSWER

@jqurious and @Dean MacGregor were exactly right, I just wanted to post an solution that explained the differences succinctly:

terms = ['a', 'z']

(pl.LazyFrame({'a':['x a y zebra']})
   .with_columns(only_whole_terms = pl.col('a')
                                      .str.split(' ')
                                      .list.set_intersection(terms),
                 each_term = pl.col('a').str.extract_all('|'.join(terms)),
                )
   .fetch()
)

shape: (1, 3)
┌─────────────┬──────────────────┬─────────────────┐
│ a           ┆ only_whole_terms ┆ each_term       │
│ ---         ┆ ---              ┆ ---             │
│ str         ┆ list[str]        ┆ list[str]       │
╞═════════════╪══════════════════╪═════════════════╡
│ x a y zebra ┆ ["a"]            ┆ ["a", "z", "a"] │
└─────────────┴──────────────────┴─────────────────┘

Also, this closely related question adds a bit more.

5
Dean MacGregor On

In addition to the tricks that @jqurious listed in comments you could also do a regex extract. This started simple but got a little bit clunky as I tried different things. The good thing about the rust regex engine is that it is very performant. The bad thing is that it doesn't have look-arounds so working around that makes it look clunky.

Without look arounds, to ensure we didn't take the z from zebra I had to extract the space before and after a term. Of course there's no space before the first letter and no space after the last letter so that's why I concat a space before and after the initial column. Additionally, to ensure it could capture two letters in a row, I had to replace all single spaces with double spaces which get replaced back to single spaces after the extract step.

terms = ['a', 'z', 'x']
termsre = "(" + "|".join([f" {x} " for x in terms]) + ")"
(pl.LazyFrame({'a':['x y z z zebra a', 'x y z', 'a b c']})
 .with_columns(
     b = (pl.lit(" ") + pl.col('a')
       .str.replace_all(" ", "  ") + pl.lit(" "))
       .str.extract_all(termsre)
       .list.join('')
       .str.replace_all("  "," ")
       .str.strip_chars()
 )
 .collect()
)
shape: (3, 2)
┌─────────────────┬─────────┐
│ a               ┆ b       │
│ ---             ┆ ---     │
│ str             ┆ str     │
╞═════════════════╪═════════╡
│ x y z z zebra a ┆ x z z a │
│ x y z           ┆ x z     │
│ a b c           ┆ a       │
└─────────────────┴─────────┘

Sidenote, fetch is for debugging with a limited number of rows. You generally want to be using collect