# Copyright 2022 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Groups splitting methods.
Module contains `Splitter` core class and `split` method which are
intended to solve group splitting problems, primarily for A/B/.. tests.
Group splitting tasks usually include following parameters: number of groups,
group sizes, and splitting algorithm.
Currently, group splitting problems could be solved using data provided
in form of both pandas and Spark(with some restrictions) dataframes.
"""
from __future__ import annotations
from typing import Optional
import yaml
from ambrosia import types
from ambrosia.tools import log, type_checks
from ambrosia.tools.ab_abstract_component import ABMetaClass, ABToolAbstract
from .handlers import data_shape, handle_full_split, split_data_handler
SPLITTING_BOUND_CONST: float = 0.5
[docs]
class Splitter(yaml.YAMLObject, ABToolAbstract, metaclass=ABMetaClass):
"""
Unit for creating experimental groups from batch data.
Split your data into groups of selected size with respect to:
- Stratification columns
- Metric distance of objects in feature space
- Set of passed ids
Parameters
----------
dataframe : PassedDataType, optional
Dataframe or string name of .csv table which
contains data used for groups split.
id_column : IdColumnNameType, optional
Name of id column which is used in hash split.
groups_size : int, optional
Size of the splitted groups.
test_group_ids : PeriodColumnNamesType, optional
Ids of objects which are in B(test) group.
Used in tasks of post experiment A(control) group pick up.
fit_columns : PeriodColumnNamesType, optional
List of columns names which values will be interpreted as
coordinates of points in multidimensional space during metric split.
strat_columns : PeriodColumnNamesType, optional
Columns for stratification.
https://en.wikipedia.org/wiki/Stratified_sampling
Attributes
----------
dataframe : PassedDataType
Pandas or Spark dataframe with split data.
id_column : IdColumnNameType
Name of id column which is used in hash split.
groups_size : int
Split size of groups.
test_group_ids : PeriodColumnNamesType
Ids of objects which are in B(test) group.
fit_columns : PeriodColumnNamesType
List of columns names used for metric split.
strat_columns : PeriodColumnNamesType
Stratification columns names.
Examples
--------
Our development team decided to add onboarding to the mobile app.
Already knowing the required group size, we would like to select
users for groups A and B respectively. Using the splitter class,
this task could be done in the following way:
>>> splitter = Splitter(dataframe=dataframe)
>>> splitter.run(group_size=1000, method='hash', salt='onboarding')
Suppose now, we know that people of different ages and from several
countries use our application, so we would like to take this into
account during split. To do this, you might use stratification,
which can be easily applied by passing only one additional parameter:
>>> splitter = Splitter(data=dataframe, strat_columns=['age', 'country'])
>>> splitter.run(group_size=1000, method='hash', salt='onboarding')
If we have fixed users for the testing group,
this can be specified as a parameter:
>>> splitter = Splitter(data=dataframe, strat_columns=['age', 'country'])
>>> splitter.run(method='hash',
>>> salt='onboarding',
>>> test_group_ids=B_group_id
>>> )
Notes
-----
Main methods for split:
Simple:
- Randomly chosen groups (via ``np.random.choice``).
Hash:
- Using hashing of identifiers and distribution
by buckets, selects the desired buckets for groups formation.
Metric:
- For a fixed reference group or a randomly selected one,
other groups are selected using the nearest neighbor method
(for desired list of columns passed in ``fit_columns`` parameter).
Constructors:
>>> # Empty constructor
>>> splitter = Splitter()
>>> # Some data
>>> splitter = Splitter(dataframe=df,
>>> id_column='my_id_column',
>>> strat_columns=['gender', 'age'],
>>> test_group_ids=ids_for_B_group
>>> )
Setters:
>>> splitter.set_dataframe(dataframe)
>>> # You can pass string for pd.read_csv
>>> splitter.set_dataframe('name_of_table.csv')
>>> # Other setters
>>> splitter.set_group_size(1000)
>>> splitter.set_strat_columns(['age', 'region'])
Run:
>>> splitter.run(method='hash', groups_size=10000)
>>> splitter.run(method='metric'
>>> test_group_ids=b_group,
>>> id_column='id',
>>> strat_columns=['age', 'city']
>>> fit_columns=['metric_history_column', 'other_metric']
>>> method_meric='fast', # It is used as kwarg
>>> norm='l2' # It is used as kwarg
>>> )
Load from yaml config:
>>> config = '''
!splitter # <--- this is yaml tag (important!)
groups_size:
1000
id_column:
id
strat_columns:
- age
- country
'''
>>> splitter = yaml.load(config)
>>> # Or use the implmented function
>>> splitter = load_from_config(config)
"""
yaml_tag = "!splitter"
@type_checks.check_type_decorator(type_checks.check_type_dataframe)
def set_dataframe(self, dataframe: Optional[types.PassedDataType]) -> None:
self.__df = dataframe
@type_checks.check_type_decorator(type_checks.check_type_id_column)
def set_id_column(self, id_column: Optional[str]) -> None:
self.__id_column = id_column
@type_checks.check_type_decorator(type_checks.check_type_group_size)
def set_group_size(self, groups_size: Optional[int]) -> None:
self.__groups_size = groups_size
@type_checks.check_type_decorator(type_checks.check_type_test_group_ids)
def set_test_group_ids(self, test_group_ids: types.IndicesType) -> None:
self.__test_group_ids = test_group_ids
@type_checks.check_type_decorator(type_checks.check_type_fit_columns)
def set_fit_columns(self, fit_columns: types.ColumnNamesType) -> None:
self.__fit_columns = fit_columns
@type_checks.check_type_decorator(type_checks.check_type_strat_columns)
def set_strat_columns(self, strat_columns: types.ColumnNamesType) -> None:
self.__strat_columns = strat_columns
def __init__(
self,
dataframe: Optional[types.PassedDataType] = None,
id_column: Optional[types.ColumnNameType] = None,
groups_size: Optional[int] = None,
test_group_ids: Optional[types.IndicesType] = None,
fit_columns: Optional[types.ColumnNamesType] = None,
strat_columns: Optional[types.ColumnNamesType] = None,
):
"""
Splitter class constructor to initialize the object.
"""
self.set_dataframe(dataframe)
self.set_id_column(id_column)
self.set_group_size(groups_size)
self.set_test_group_ids(test_group_ids)
self.set_fit_columns(fit_columns)
self.set_strat_columns(strat_columns)
def __getstate__(self):
"""
Get the state of the object to serialize.
"""
return dict(
id_column=self.__id_column,
groups_size=self.__groups_size,
fit_columns=self.__fit_columns,
strat_columns=self.__strat_columns,
)
@classmethod
def from_yaml(cls, loader: yaml.Loader, node: yaml.Node):
kwargs = loader.construct_mapping(node)
return cls(**kwargs)
[docs]
def run(
self,
method: str,
dataframe: Optional[types.PassedDataType] = None,
id_column: Optional[types.ColumnNameType] = None,
groups_size: Optional[int] = None,
part_of_table: Optional[float] = None,
groups_number: int = 2,
test_group_ids: Optional[types.IndicesType] = None,
strat_columns: Optional[types.ColumnNamesType] = None,
salt: Optional[str] = None,
fit_columns: Optional[types.ColumnNamesType] = None,
**kwargs,
) -> types.SplitterResult:
"""
Perform a split into groups with selected or saved parameters.
Parameters
----------
method : str
Split method, for example ``"hash"``.
dataframe : PassedDataType, optional
Dataframe or string name of .csv table which
contains data used for groups split.
id_column : IdColumnNameType, optional
Name of id column which is used in hash split.
groups_size : int, optional
Size of the splitted groups.
part_of_table: float, optional
Split factor(for group A) for tasks of dataframe full split.
If is not ``None``, then overrides ``groups_size`` parameter
during the split.
groups_number : int, default: ``2``
Number of groups to be splitted.
test_group_ids : PeriodColumnNamesType, optional
Ids of objects which are in B(test) group.
Used in tasks of post experiment A(control) group pick up.
strat_columns : PeriodColumnNamesType, optional
Columns for stratification.
https://en.wikipedia.org/wiki/Stratified_sampling
salt : str, optional
Salt for hashing in hash-split.
fit_columns : PeriodColumnNamesType, optional
List of columns names which values will be interpreted as
coordinates of points in multidimensional space during metric split.
**kwargs : Dict
Other keyword arguments.
Returns
-------
groups : pd.DataFrame
Returns a dataframe with groups and label column.
Dataframe will contain all columns of the original dataframe.
Other Parameters
----------------
threads : int, default : ``1``
Number of threads used for calculations.
"""
method: str = type_checks.check_split_method_value(method)
dataframe: types.PassedDataType = type_checks.check_type_dataframe(dataframe)
id_column: types.ColumnNameType = type_checks.check_type_id_column(id_column)
groups_size: int = type_checks.check_type_group_size(groups_size)
test_group_ids: types.IndicesType = type_checks.check_type_test_group_ids(test_group_ids)
fit_columns: types.ColumnNamesType = type_checks.check_type_fit_columns(fit_columns)
strat_columns: types.ColumnNamesType = type_checks.check_type_strat_columns(strat_columns)
arguments_choice: types._PrepareArgumentsType = {
"dataframe": (self.__df, dataframe),
}
strat_columns: str = strat_columns if strat_columns is not None else self.__strat_columns
test_group_ids = test_group_ids if test_group_ids is not None else self.__test_group_ids
id_column = id_column if id_column is not None else self.__id_column
if test_group_ids is not None:
arguments_choice["group_b_indices"] = (None, test_group_ids)
else:
arguments_choice["groups_size"] = (self.__groups_size, groups_size)
if part_of_table is not None:
# Group size will be set later
arguments_choice["groups_size"] = (self.__groups_size, 0)
if groups_size is not None:
log.info_log("Groups size variable ignored because part splitting variable set")
if groups_number > 2:
groups_number = 2
log.info_log("Groups number was set to 2 because part splitting variable set")
if method in ("metric", "dim_decrease"):
# For methods use metric/cluster/unsupervised approach
arguments_choice["fit_columns"] = (self.__fit_columns, fit_columns)
chosen_args: types._UsageArgumentsType = Splitter._prepare_arguments(arguments_choice)
if part_of_table is not None:
split_part: float = part_of_table if (part_of_table <= SPLITTING_BOUND_CONST) else 1 - part_of_table
chosen_args["groups_size"] = round(split_part * data_shape(chosen_args["dataframe"]))
chosen_args["split_method"] = method
chosen_args["id_column"] = id_column
chosen_args["strat_columns"] = strat_columns
chosen_args["salt"] = salt
chosen_args["groups_number"] = groups_number
groups: types.SplitterResult = split_data_handler(**chosen_args, **kwargs)
if part_of_table is not None:
return handle_full_split(chosen_args["dataframe"], groups, part_of_table, id_column)
return groups
[docs]
def load_from_config(yaml_config: str, loader: type = yaml.Loader) -> Splitter:
"""
Restore a ``Splitter`` class instance from a yaml config.
For yaml_config parameter you can pass file name with
config, which must ends with .yaml, for example: "config.yaml".
For loader you can choose SafeLoader.
"""
if isinstance(yaml_config, str) and yaml_config.endswith(".yaml"):
with open(yaml_config, "r", encoding="utf-8") as file:
return yaml.load(file, Loader=loader)
return yaml.load(yaml_config, Loader=loader)
[docs]
def split(
method: str,
dataframe: Optional[types.PassedDataType] = None,
id_column: Optional[types.ColumnNameType] = None,
groups_size: Optional[int] = None,
part_of_table: Optional[float] = None,
groups_number: int = 2,
test_group_ids: Optional[types.IndicesType] = None,
strat_columns: Optional[types.ColumnNamesType] = None,
salt: Optional[str] = None,
fit_columns: Optional[types.ColumnNamesType] = None,
threads: int = 1,
**kwargs,
) -> types.SplitterResult:
"""
Function wrapper around the ``Splitter`` class.
Used to create splitted groups from the dataframe.
Creates an instance of the ``Splitter`` class internally and execute
run method with corresponding arguments.
Parameters
----------
method : str
Split method, for example ``"hash"``.
dataframe : PassedDataType, optional
Dataframe or string name of .csv table which
contains data used for groups split.
id_column : IdColumnNameType, optional
Name of id column which is used in hash split.
groups_size : int, optional
Size of the splitted groups.
part_of_table: float, optional
Split factor(for group A) for tasks of dataframe full split.
If is not ``None``, then overrides ``groups_size`` parameter
during the split.
groups_number : int, default : ``2``
Number of groups to be splitted.
test_group_ids : PeriodColumnNamesType, optional
Ids of objects which are in B(test) group.
Used in tasks of post experiment A(control) group pick up.
strat_columns : PeriodColumnNamesType, optional
Columns for stratification.
https://en.wikipedia.org/wiki/Stratified_sampling
salt : str, optional
Salt for hashing in hash-split.
fit_columns : PeriodColumnNamesType, optional
List of columns names which values will be interpreted as
coordinates of points in multidimensional space during metric split.
threads : int, default : ``1``
Number of threads used for calculations.
**kwargs : Dict
Other keyword arguments.
Returns
-------
groups : pd.DataFrame
Returns a dataframe with groups and label column.
Dataframe will contain all columns of the original dataframe.
"""
return Splitter(
dataframe=dataframe,
id_column=id_column,
groups_size=groups_size,
fit_columns=fit_columns,
test_group_ids=test_group_ids,
strat_columns=strat_columns,
).run(method, salt=salt, threads=threads, part_of_table=part_of_table, groups_number=groups_number, **kwargs)