Skip to content

Commit 4bb422d

Browse files
BUG: Preserve extension dtypes in MultiIndex during concat (pandas-dev#58421)
1 parent 56847c5 commit 4bb422d

File tree

3 files changed

+90
-5
lines changed

3 files changed

+90
-5
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ MultiIndex
712712
- :func:`MultiIndex.get_level_values` accessing a :class:`DatetimeIndex` does not carry the frequency attribute along (:issue:`58327`, :issue:`57949`)
713713
- Bug in :class:`DataFrame` arithmetic operations in case of unaligned MultiIndex columns (:issue:`60498`)
714714
- Bug in :class:`DataFrame` arithmetic operations with :class:`Series` in case of unaligned MultiIndex (:issue:`61009`)
715+
- Fixed a bug where extension dtypes like ``timestamp[pyarrow]`` were not preserved when building ``MultiIndex`` levels during ``pd.concat`` operations. (:issue:`58421`)
715716
-
716717

717718
I/O

pandas/core/reshape/concat.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from pandas.core.dtypes.common import (
2424
is_bool,
25+
is_extension_array_dtype,
2526
is_scalar,
2627
)
2728
from pandas.core.dtypes.concat import concat_compat
@@ -36,6 +37,7 @@
3637
factorize_from_iterables,
3738
)
3839
import pandas.core.common as com
40+
from pandas.core.construction import array
3941
from pandas.core.indexes.api import (
4042
Index,
4143
MultiIndex,
@@ -819,7 +821,20 @@ def _get_sample_object(
819821

820822

821823
def _concat_indexes(indexes) -> Index:
822-
return indexes[0].append(indexes[1:])
824+
# try to preserve extension types such as timestamp[pyarrow]
825+
values = []
826+
for idx in indexes:
827+
values.extend(idx._values if hasattr(idx, "_values") else idx)
828+
829+
# use the first index as a sample to infer the desired dtype
830+
sample = indexes[0]
831+
try:
832+
# this helps preserve extension types like timestamp[pyarrow]
833+
arr = array(values, dtype=sample.dtype)
834+
except Exception:
835+
arr = array(values) # fallback
836+
837+
return Index(arr)
823838

824839

825840
def validate_unique_levels(levels: list[Index]) -> None:
@@ -876,14 +891,32 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde
876891

877892
concat_index = _concat_indexes(indexes)
878893

879-
# these go at the end
880894
if isinstance(concat_index, MultiIndex):
881895
levels.extend(concat_index.levels)
882896
codes_list.extend(concat_index.codes)
883897
else:
884-
codes, categories = factorize_from_iterable(concat_index)
885-
levels.append(categories)
886-
codes_list.append(codes)
898+
# handle the case where the resulting index is a flat Index
899+
# but contains tuples (i.e., a collapsed MultiIndex)
900+
if isinstance(concat_index[0], tuple):
901+
# retrieve the original dtypes
902+
original_dtypes = [lvl.dtype for lvl in indexes[0].levels]
903+
904+
unzipped = list(zip(*concat_index))
905+
for i, level_values in enumerate(unzipped):
906+
# reconstruct each level using original dtype
907+
arr = array(level_values, dtype=original_dtypes[i])
908+
level_codes, _ = factorize_from_iterable(arr)
909+
levels.append(ensure_index(arr))
910+
codes_list.append(level_codes)
911+
else:
912+
# simple indexes factorize directly
913+
codes, categories = factorize_from_iterable(concat_index)
914+
values = getattr(concat_index, "_values", concat_index)
915+
if is_extension_array_dtype(values):
916+
levels.append(values)
917+
else:
918+
levels.append(categories)
919+
codes_list.append(codes)
887920

888921
if len(names) == len(levels):
889922
names = list(names)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
3+
import pandas as pd
4+
5+
schema = {
6+
"id": "int64[pyarrow]",
7+
"time": "timestamp[s][pyarrow]",
8+
"value": "float[pyarrow]",
9+
}
10+
11+
12+
@pytest.mark.parametrize("dtype", ["timestamp[s][pyarrow]"])
13+
def test_concat_preserves_pyarrow_timestamp(dtype):
14+
dfA = (
15+
pd.DataFrame(
16+
[
17+
(0, "2021-01-01 00:00:00", 5.3),
18+
(1, "2021-01-01 00:01:00", 5.4),
19+
(2, "2021-01-01 00:01:00", 5.4),
20+
(3, "2021-01-01 00:02:00", 5.5),
21+
],
22+
columns=schema,
23+
)
24+
.astype(schema)
25+
.set_index(["id", "time"])
26+
)
27+
28+
dfB = (
29+
pd.DataFrame(
30+
[
31+
(1, "2022-01-01 08:00:00", 6.3),
32+
(2, "2022-01-01 08:01:00", 6.4),
33+
(3, "2022-01-01 08:02:00", 6.5),
34+
],
35+
columns=schema,
36+
)
37+
.astype(schema)
38+
.set_index(["id", "time"])
39+
)
40+
41+
df = pd.concat([dfA, dfB], keys=[0, 1], names=["run"])
42+
43+
# check whether df.index is multiIndex
44+
assert isinstance(df.index, pd.MultiIndex), (
45+
f"Expected MultiIndex, but received {type(df.index)}"
46+
)
47+
48+
# Verifying special dtype timestamp[s][pyarrow] stays intact after concat
49+
assert df.index.levels[2].dtype == dtype, (
50+
f"Expected {dtype}, but received {df.index.levels[2].dtype}"
51+
)

0 commit comments

Comments
 (0)