Skip to content

Commit e8eabea

Browse files
Fix pandas-dev#58421: Index[timestamp[pyarrow]].union with itself return object type
1 parent 56847c5 commit e8eabea

File tree

2 files changed

+86
-5
lines changed

2 files changed

+86
-5
lines changed

pandas/core/reshape/concat.py

+40-5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
)
4848
from pandas.core.internals import concatenate_managers
4949

50+
from pandas.core.dtypes.common import is_extension_array_dtype
51+
52+
from pandas.core.construction import array
53+
5054
if TYPE_CHECKING:
5155
from collections.abc import (
5256
Callable,
@@ -819,7 +823,20 @@ def _get_sample_object(
819823

820824

821825
def _concat_indexes(indexes) -> Index:
822-
return indexes[0].append(indexes[1:])
826+
# try to preserve extension types such as timestamp[pyarrow]
827+
values = []
828+
for idx in indexes:
829+
values.extend(idx._values if hasattr(idx, "_values") else idx)
830+
831+
# use the first index as a sample to infer the desired dtype
832+
sample = indexes[0]
833+
try:
834+
# this helps preserve extension types like timestamp[pyarrow]
835+
arr = array(values, dtype=sample.dtype)
836+
except Exception:
837+
arr = array(values) # fallback
838+
839+
return Index(arr)
823840

824841

825842
def validate_unique_levels(levels: list[Index]) -> None:
@@ -876,14 +893,32 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde
876893

877894
concat_index = _concat_indexes(indexes)
878895

879-
# these go at the end
880896
if isinstance(concat_index, MultiIndex):
881897
levels.extend(concat_index.levels)
882898
codes_list.extend(concat_index.codes)
883899
else:
884-
codes, categories = factorize_from_iterable(concat_index)
885-
levels.append(categories)
886-
codes_list.append(codes)
900+
# handle the case where the resulting index is a flat Index
901+
# but contains tuples (i.e., a collapsed MultiIndex)
902+
if isinstance(concat_index[0], tuple):
903+
# retrieve the original dtypes
904+
original_dtypes = [lvl.dtype for lvl in indexes[0].levels]
905+
906+
unzipped = list(zip(*concat_index))
907+
for i, level_values in enumerate(unzipped):
908+
# reconstruct each level using original dtype
909+
arr = array(level_values, dtype=original_dtypes[i])
910+
level_codes, _ = factorize_from_iterable(arr)
911+
levels.append(ensure_index(arr))
912+
codes_list.append(level_codes)
913+
else:
914+
# simple indexes factorize directly
915+
codes, categories = factorize_from_iterable(concat_index)
916+
values = getattr(concat_index, "_values", concat_index)
917+
if is_extension_array_dtype(values):
918+
levels.append(values)
919+
else:
920+
levels.append(categories)
921+
codes_list.append(codes)
887922

888923
if len(names) == len(levels):
889924
names = list(names)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pandas as pd
2+
import pytest
3+
4+
5+
schema = {
6+
"id": "int64[pyarrow]",
7+
"time": "timestamp[s][pyarrow]",
8+
"value": "float[pyarrow]",
9+
}
10+
11+
@pytest.mark.parametrize("dtype", ["timestamp[s][pyarrow]"])
12+
def test_concat_preserves_pyarrow_timestamp(dtype):
13+
dfA = (
14+
pd.DataFrame(
15+
[
16+
(0, "2021-01-01 00:00:00", 5.3),
17+
(1, "2021-01-01 00:01:00", 5.4),
18+
(2, "2021-01-01 00:01:00", 5.4),
19+
(3, "2021-01-01 00:02:00", 5.5),
20+
],
21+
columns=schema,
22+
)
23+
.astype(schema)
24+
.set_index(["id", "time"])
25+
)
26+
27+
dfB = (
28+
pd.DataFrame(
29+
[
30+
(1, "2022-01-01 08:00:00", 6.3),
31+
(2, "2022-01-01 08:01:00", 6.4),
32+
(3, "2022-01-01 08:02:00", 6.5),
33+
],
34+
columns=schema,
35+
)
36+
.astype(schema)
37+
.set_index(["id", "time"])
38+
)
39+
40+
df = pd.concat([dfA, dfB], keys=[0, 1], names=["run"])
41+
42+
# chech whether df.index is multiIndex
43+
assert isinstance(df.index, pd.MultiIndex), f"Expected MultiIndex, but received {type(df.index)}"
44+
45+
# Verifying special dtype timestamp[s][pyarrow] stays intact after concat
46+
assert df.index.levels[2].dtype == dtype, f"Expected {dtype}, but received {df.index.levels[2].dtype}"

0 commit comments

Comments
 (0)