I wrote a plugin that does this https://github.com/tylerriccio33/pl-horizontal/tree/v0.1.4
import polars as pl
from pl_horizontal import arg_max_horizontal
## Return the column name of the maximum value per row
df = pl.DataFrame({
"a": [1, 2, None],
"b": [3, None, 1],
"c": [2, 1, 4]
})
res = df.select(arg_max_horizontal(pl.all(), return_colname=True))
print(res)
assert res.to_series().to_list() == ['b', 'a', 'c']
Benchmarking it against the list concat method I get the attached:
enter image description here
This also has the added benefit of working nicely with lazy dataframes where the columns aren't known beforehand.
EDIT: I just added your example to my test suite :)
def test_arg_max_so() -> None:
# https://stackoverflow.com/questions/77967334/getting-min-max-column-name-in-polars/
df = pl.DataFrame(
{
"a": [1, 8, 3],
"b": [4, 5, None],
}
)
res = df.with_columns(max = arg_max_horizontal(pl.all(), return_colname=True))
assert res["max"].to_list() == ["b", "a", "a"]