|
22 | 22 |
|
23 | 23 | from pandas.core.dtypes.common import (
|
24 | 24 | is_bool,
|
| 25 | + is_extension_array_dtype, |
25 | 26 | is_scalar,
|
26 | 27 | )
|
27 | 28 | from pandas.core.dtypes.concat import concat_compat
|
|
36 | 37 | factorize_from_iterables,
|
37 | 38 | )
|
38 | 39 | import pandas.core.common as com
|
| 40 | +from pandas.core.construction import array |
39 | 41 | from pandas.core.indexes.api import (
|
40 | 42 | Index,
|
41 | 43 | MultiIndex,
|
@@ -819,7 +821,20 @@ def _get_sample_object(
|
819 | 821 |
|
820 | 822 |
|
821 | 823 | def _concat_indexes(indexes) -> Index:
|
822 |
| - return indexes[0].append(indexes[1:]) |
| 824 | + # try to preserve extension types such as timestamp[pyarrow] |
| 825 | + values = [] |
| 826 | + for idx in indexes: |
| 827 | + values.extend(idx._values if hasattr(idx, "_values") else idx) |
| 828 | + |
| 829 | + # use the first index as a sample to infer the desired dtype |
| 830 | + sample = indexes[0] |
| 831 | + try: |
| 832 | + # this helps preserve extension types like timestamp[pyarrow] |
| 833 | + arr = array(values, dtype=sample.dtype) |
| 834 | + except Exception: |
| 835 | + arr = array(values) # fallback |
| 836 | + |
| 837 | + return Index(arr) |
823 | 838 |
|
824 | 839 |
|
825 | 840 | def validate_unique_levels(levels: list[Index]) -> None:
|
@@ -876,14 +891,32 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde
|
876 | 891 |
|
877 | 892 | concat_index = _concat_indexes(indexes)
|
878 | 893 |
|
879 |
| - # these go at the end |
880 | 894 | if isinstance(concat_index, MultiIndex):
|
881 | 895 | levels.extend(concat_index.levels)
|
882 | 896 | codes_list.extend(concat_index.codes)
|
883 | 897 | else:
|
884 |
| - codes, categories = factorize_from_iterable(concat_index) |
885 |
| - levels.append(categories) |
886 |
| - codes_list.append(codes) |
| 898 | + # handle the case where the resulting index is a flat Index |
| 899 | + # but contains tuples (i.e., a collapsed MultiIndex) |
| 900 | + if isinstance(concat_index[0], tuple): |
| 901 | + # retrieve the original dtypes |
| 902 | + original_dtypes = [lvl.dtype for lvl in indexes[0].levels] |
| 903 | + |
| 904 | + unzipped = list(zip(*concat_index)) |
| 905 | + for i, level_values in enumerate(unzipped): |
| 906 | + # reconstruct each level using original dtype |
| 907 | + arr = array(level_values, dtype=original_dtypes[i]) |
| 908 | + level_codes, _ = factorize_from_iterable(arr) |
| 909 | + levels.append(ensure_index(arr)) |
| 910 | + codes_list.append(level_codes) |
| 911 | + else: |
| 912 | + # simple indexes factorize directly |
| 913 | + codes, categories = factorize_from_iterable(concat_index) |
| 914 | + values = getattr(concat_index, "_values", concat_index) |
| 915 | + if is_extension_array_dtype(values): |
| 916 | + levels.append(values) |
| 917 | + else: |
| 918 | + levels.append(categories) |
| 919 | + codes_list.append(codes) |
887 | 920 |
|
888 | 921 | if len(names) == len(levels):
|
889 | 922 | names = list(names)
|
|
0 commit comments