Skip to content

Sqlite string selection #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/graphnet/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __init__(
loss_weight_default_value: Optional[float] = None,
seed: Optional[int] = None,
labels: Optional[Dict[str, Any]] = None,
use_super_selection: bool = False,
):
"""Construct Dataset.

Expand Down Expand Up @@ -311,6 +312,10 @@ def __init__(
NOTE: DEPRECATED Use `data_representation` instead.
# DEPRECATION: REMOVE AT 2.0 LAUNCH
# See https://github.com/graphnet-team/graphnet/issues/647
use_super_selection: If True, the string selection is handled by
the query function of the dataset class, rather than
pd.DataFrame.query. Defaults to False and should
only be used with sqlite.
"""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
Expand Down Expand Up @@ -354,6 +359,7 @@ def __init__(
self._data_representation = deepcopy(data_representation)
self._labels = labels
self._string_column = data_representation._detector.string_index_name
self._use_super_selection = use_super_selection

if node_truth is not None:
assert isinstance(node_truth_table, str)
Expand Down Expand Up @@ -404,6 +410,7 @@ def __init__(
self,
index_column=index_column,
seed=seed,
use_super_selection=self._use_super_selection,
)

if self._labels is not None:
Expand Down Expand Up @@ -677,10 +684,13 @@ def _create_graph(
"""
# Convert truth to dict
if len(truth.shape) == 1:
truth = truth.reshape(1, -1)
truth_dict = {
key: truth[:, index] for index, key in enumerate(self._truth)
}
truth_dict = {
key: truth[0][index] for index, key in enumerate(self._truth)
}
else:
truth_dict = {
key: truth[:, index] for index, key in enumerate(self._truth)
}

# Define custom labels
labels_dict = self._get_labels(truth_dict)
Expand Down
36 changes: 26 additions & 10 deletions src/graphnet/data/utilities/string_selection_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ def __init__(
index_column: str,
seed: Optional[int] = None,
use_cache: bool = True,
use_super_selection: bool = False,
):
"""Construct `StringSelectionResolver`."""
self._dataset = dataset
self._index_column = index_column
self._seed = seed
self._use_cache = use_cache

self._use_super_selection = use_super_selection
# Base class constructor
if self._use_super_selection:
self._use_cache = False
super().__init__(name=__name__, class_name=self.__class__.__name__)

# Public method(s)
Expand Down Expand Up @@ -214,19 +217,32 @@ def _query_selection_from_dataset(self, selection: str) -> pd.DataFrame:
df_values = self._load_values_cache(values_cache_path)

else:
df_values = pd.DataFrame(
data=self._dataset.query_table(
self._dataset.truth_table,
list(variables),
),
columns=list(variables),
)
if self._use_super_selection:
df_values = pd.DataFrame(
data=self._dataset.query_table(
self._dataset.truth_table,
list(variables),
selection=selection,
).tolist(),
columns=list(variables),
)

else:
df_values = pd.DataFrame(
data=self._dataset.query_table(
self._dataset.truth_table,
list(variables),
).tolist(),
columns=list(variables),
)

# (Opt.) Cache indices.
if self._use_cache and not os.path.exists(values_cache_path):
self._save_values_cache(df_values, values_cache_path)

df_selection = df_values.query(selection)
if not self._use_super_selection:
df_selection = df_values.query(selection)
else:
df_selection = df_values
return df_selection

def _get_random_state(self, selection: str) -> Optional[int]:
Expand Down
1 change: 1 addition & 0 deletions src/graphnet/utilities/config/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class DatasetConfig(BaseConfig):

data_representation: Any = None
labels: Optional[Dict[str, Any]] = None
use_super_selection: bool = False

def __init__(self, **data: Any) -> None:
"""Construct `DataConfig`.
Expand Down