diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 5e20866f..460c8a4e 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -15,8 +15,9 @@ from collections import namedtuple from enum import Enum -from marshmallow import fields, post_load +from marshmallow import fields, post_load, pre_dump, ValidationError from marshmallow_enum import EnumField +from marshmallow_oneofschema import OneOfSchema from faculty.clients.base import BaseClient, BaseSchema, Conflict @@ -100,6 +101,57 @@ class ExperimentRunStatus(Enum): "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] ) + +class ComparisonOperator(Enum): + DEFINED = "defined" + EQUAL_TO = "eq" + NOT_EQUAL_TO = "ne" + LESS_THAN = "lt" + LESS_THAN_OR_EQUAL_TO = "le" + GREATER_THAN = "gt" + GREATER_THAN_OR_EQUAL_TO = "ge" + + +class LogicalOperator(Enum): + AND = "and" + OR = "or" + + +ProjectIdFilter = namedtuple("ProjectIdFilter", ["operator", "value"]) +ExperimentIdFilter = namedtuple("ExperimentIdFilter", ["operator", "value"]) +RunIdFilter = namedtuple("RunIdFilter", ["operator", "value"]) +DeletedAtFilter = namedtuple("DeletedAtFilter", ["operator", "value"]) +TagFilter = namedtuple("TagFilter", ["key", "operator", "value"]) +ParamFilter = namedtuple("ParamFilter", ["key", "operator", "value"]) +MetricFilter = namedtuple("MetricFilter", ["key", "operator", "value"]) + +CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) + +StartedAtSort = namedtuple("StartedAtSort", ["order"]) +RunNumberSort = namedtuple("RunNumberSort", ["order"]) +DurationSort = namedtuple("DurationSort", ["order"]) +TagSort = namedtuple("TagSort", ["key", "order"]) +ParamSort = namedtuple("ParamSort", ["key", "order"]) +MetricSort = namedtuple("MetricSort", ["key", "order"]) + + +class SortOrder(Enum): + ASC = "asc" + DESC = "desc" + + +RunQuery = namedtuple("RunQuery", ["filter", "sort", "page"]) + + +class PageSchema(BaseSchema): + start = fields.Integer(required=True) + limit = fields.Integer(required=True) + + @post_load + def make_page(self, data): + return Page(**data) + + MetricDataPoint = namedtuple("Metric", ["value", "timestamp", "step"]) MetricHistory = namedtuple( @@ -190,15 +242,6 @@ class ExperimentRunInfoSchema(BaseSchema): ended_at = fields.DateTime(data_key="endedAt", missing=None) -class PageSchema(BaseSchema): - start = fields.Integer(required=True) - limit = fields.Integer(required=True) - - @post_load - def make_page(self, data): - return Page(**data) - - class PaginationSchema(BaseSchema): start = fields.Integer(required=True) size = fields.Integer(required=True) @@ -253,6 +296,211 @@ def make_restore_runs_response(self, data): return RestoreExperimentRunsResponse(**data) +class ParamFilterValueField(fields.Field): + """Field that passes through strings or numbers.""" + + default_error_messages = { + "unsupported_type": "Param values must be of type str, int or float." + } + + def _serialize(self, value, attr, obj, **kwargs): + if isinstance(value, str): + field = fields.String() + elif isinstance(value, int) or isinstance(value, float): + field = fields.Number() + else: + self.fail("unsupported_type") + return field._serialize(value, attr, obj, **kwargs) + + +class OptionalField(fields.Field): + """Wrap another field, passing through Nones.""" + + def __init__(self, nested, *args, **kwargs): + self.nested = nested + super(OptionalField, self).__init__(*args, **kwargs) + + def _deserialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._deserialize(value, *args, **kwargs) + + def _serialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._serialize(value, *args, **kwargs) + + +class FilterValueField(fields.Field): + def __init__(self, other_field_type, *args, **kwargs): + self.other_field_type = other_field_type + super(FilterValueField, self).__init__(*args, **kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + if obj.operator == ComparisonOperator.DEFINED: + field_cls = fields.Boolean + else: + field_cls = self.other_field_type + return field_cls()._serialize(value, attr, obj, **kwargs) + + +def _validate_discrete(operator): + if operator not in { + ComparisonOperator.DEFINED, + ComparisonOperator.EQUAL_TO, + ComparisonOperator.NOT_EQUAL_TO, + }: + raise ValidationError({"operator": "Not a discrete operator."}) + + +class ProjectIdFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.UUID) + by = fields.Constant("projectId", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class ExperimentIdFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.Integer) + by = fields.Constant("experimentId", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class RunIdFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.UUID) + by = fields.Constant("runId", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class DeletedAtFilterSchema(BaseSchema): + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.DateTime) + by = fields.Constant("deletedAt", dump_only=True) + + +class TagFilterSchema(BaseSchema): + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.String) + by = fields.Constant("tag", dump_only=True) + + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + + +class ParamFilterSchema(BaseSchema): + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(ParamFilterValueField) + by = fields.Constant("param", dump_only=True) + + @pre_dump + def check_operator(self, obj): + if isinstance(obj.value, str): + _validate_discrete(obj.operator) + return obj + + +class MetricFilterSchema(BaseSchema): + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.Float) + by = fields.Constant("metric", dump_only=True) + + +class CompoundFilterSchema(BaseSchema): + operator = EnumField(LogicalOperator, by_value=True) + conditions = fields.List(fields.Nested("FilterSchema")) + + +class OneOfSchemaWithoutType(OneOfSchema): + def dump(self, *args, **kwargs): + data = super(OneOfSchemaWithoutType, self).dump(*args, **kwargs) + # Remove the type field added by marshmallow-oneofschema + return {k: v for k, v in data.items() if k != "type"} + + +class FilterSchema(OneOfSchemaWithoutType): + type_schemas = { + "ProjectIdFilter": ProjectIdFilterSchema, + "ExperimentIdFilter": ExperimentIdFilterSchema, + "RunIdFilter": RunIdFilterSchema, + "DeletedAtFilter": DeletedAtFilterSchema, + "TagFilter": TagFilterSchema, + "ParamFilter": ParamFilterSchema, + "MetricFilter": MetricFilterSchema, + "CompoundFilter": CompoundFilterSchema, + } + + +class StartedAtSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("startedAt", dump_only=True) + + +class RunNumberSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("runNumber", dump_only=True) + + +class DurationSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("duration", dump_only=True) + + +class TagSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("tag", dump_only=True) + + +class ParamSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("param", dump_only=True) + + +class MetricSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("metric", dump_only=True) + + +class SortSchema(OneOfSchemaWithoutType): + type_schemas = { + "StartedAtSort": StartedAtSortSchema, + "RunNumberSort": RunNumberSortSchema, + "DurationSort": DurationSortSchema, + "TagSort": TagSortSchema, + "ParamSort": ParamSortSchema, + "MetricSort": MetricSortSchema, + } + + +class RunQuerySchema(BaseSchema): + filter = OptionalField(fields.Nested(FilterSchema)) + sort = fields.List(fields.Nested(SortSchema)) + page = fields.Nested(PageSchema, missing=None) + + class MetricDataPointSchema(BaseSchema): """Deserialise a data point from the metric history endpoint. @@ -526,6 +774,9 @@ def list_runs( To filter runs of experiments with the given IDs only. If an empty list is passed, a result with an empty list of runs is returned. By default, runs from all experiments are returned. + lifecycle_stage: LifecycleStage, optional + To filter runs of experiments in a specific lifecycle stage only. + By default, runs in any stage are returned. start : int, optional The (zero-indexed) starting point of runs to retrieve. limit : int, optional @@ -535,10 +786,11 @@ def list_runs( ------- ListExperimentRunsResponse """ - if lifecycle_stage is not None: - raise NotImplementedError("lifecycle_stage is not supported.") - query_params = [] + experiment_ids_filter = None + lifecycle_filter = None + filter = None + if experiment_ids is not None: if len(experiment_ids) == 0: return ListExperimentRunsResponse( @@ -547,17 +799,73 @@ def list_runs( start=0, size=0, previous=None, next=None ), ) - for experiment_id in experiment_ids: - query_params.append(("experimentId", experiment_id)) + experiment_id_filters = [ + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, experiment_id) + for experiment_id in experiment_ids + ] + experiment_ids_filter = CompoundFilter( + LogicalOperator.OR, experiment_id_filters + ) + if lifecycle_stage is not None: + lifecycle_filter = DeletedAtFilter( + ComparisonOperator.DEFINED, + lifecycle_stage == LifecycleStage.DELETED, + ) + + if experiment_ids_filter is not None and lifecycle_filter is not None: + filter = CompoundFilter( + LogicalOperator.AND, [experiment_ids_filter, lifecycle_filter] + ) + elif experiment_ids_filter is not None: + filter = experiment_ids_filter + elif lifecycle_filter is not None: + filter = lifecycle_filter - if start is not None: - query_params.append(("start", start)) - if limit is not None: - query_params.append(("limit", limit)) + return self.query_runs(project_id, filter, None, start, limit) - endpoint = "/project/{}/run".format(project_id) - return self._get( - endpoint, ListExperimentRunsResponseSchema(), params=query_params + def query_runs( + self, project_id, filter=None, sort=None, start=None, limit=None + ): + """Query experiment runs. + + This method returns pages of runs. If less than the full number of runs + for the job is returned, the ``next`` page of the returned response + object will not be ``None``: + + >>> response = client.query_runs(project_id) + >>> response.pagination.next + Page(start=10, limit=10) + + Get all experiment runs by making successive calls to ``query_runs``, + passing the ``start`` and ``limit`` of the ``next`` page each time + until ``next`` is returned as ``None``. + + Parameters + ---------- + project_id : uuid.UUID + filter: SingleFilter or CompoundFilter, optional + To filter runs of experiments with the given filter. By default, + runs from all experiments are returned. + sort: List[Sort], optional + Runs are order using the conditions in sort. The relative + importance of each condition gradually decreases in order. + By default, experiment runs are sorted by their startedAt value. + start : int, optional + The (zero-indexed) starting point of runs to retrieve. + limit : int, optional + The maximum number of runs to retrieve. + + Returns + ------- + ListExperimentRunsResponse + """ + endpoint = "/project/{}/run/query".format(project_id) + page = None + if start is not None and limit is not None: + page = Page(start, limit) + payload = RunQuerySchema().dump(RunQuery(filter, sort, page)) + return self._post( + endpoint, ListExperimentRunsResponseSchema(), json=payload ) def log_run_data( @@ -673,15 +981,12 @@ def delete_runs(self, project_id, run_ids=None): deleted_run_ids=[], conflicted_run_ids=[] ) else: - payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_id)} - for run_id in run_ids - ], - } - } + run_id_filters = [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) + for run_id in run_ids + ] + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} return self._post( endpoint, DeleteExperimentRunsResponseSchema(), json=payload @@ -714,15 +1019,12 @@ def restore_runs(self, project_id, run_ids=None): restored_run_ids=[], conflicted_run_ids=[] ) else: - payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_id)} - for run_id in run_ids - ], - } - } + run_id_filters = [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) + for run_id in run_ids + ] + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} return self._post( endpoint, RestoreExperimentRunsResponseSchema(), json=payload diff --git a/setup.py b/setup.py index a8772681..8d4d473f 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ # compatible version of python-dateutil is available "marshmallow[reco]>=3.0.0rc3", "marshmallow_enum", + "marshmallow-oneofschema==2.0.0b2", "boto3", "botocore", ], diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 2441829b..3a3eb8ff 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -22,37 +22,58 @@ from faculty.clients.base import Conflict from faculty.clients.experiment import ( + ComparisonOperator, + CompoundFilter, CreateRunSchema, DeleteExperimentRunsResponse, DeleteExperimentRunsResponseSchema, + DeletedAtFilter, + DurationSort, Experiment, ExperimentClient, - ExperimentNameConflict, ExperimentDeleted, + ExperimentIdFilter, + ExperimentNameConflict, ExperimentRun, ExperimentRunDataSchema, ExperimentRunSchema, ExperimentRunStatus, ExperimentSchema, + FilterSchema, LifecycleStage, ListExperimentRunsResponse, ListExperimentRunsResponseSchema, + LogicalOperator, Metric, MetricDataPoint, - MetricSchema, + MetricFilter, MetricHistory, MetricHistorySchema, + MetricSchema, + MetricSort, Page, PageSchema, Pagination, PaginationSchema, Param, ParamConflict, + ParamFilter, ParamSchema, + ParamSort, + ProjectIdFilter, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, + RunIdFilter, + RunNumberSort, + RunQuery, + RunQuerySchema, + SortOrder, + SortSchema, + StartedAtSort, Tag, + TagFilter, TagSchema, + TagSort, ) PROJECT_ID = uuid4() @@ -67,6 +88,7 @@ LAST_UPDATED_AT_STRING = "2018-03-10T11:32:30.172Z" DELETED_AT = datetime(2018, 3, 10, 11, 37, 42, 482000, tzinfo=UTC) DELETED_AT_STRING = "2018-03-10T11:37:42.482Z" +DELETED_AT_STRING_PYTHON = "2018-03-10T11:37:42.482000+00:00" EXPERIMENT = Experiment( id=EXPERIMENT_ID, @@ -87,6 +109,7 @@ "deletedAt": DELETED_AT_STRING, } +RUN_ID = uuid4() RUN_STARTED_AT = datetime(2018, 3, 10, 11, 39, 12, 110000, tzinfo=UTC) RUN_STARTED_AT_NO_TIMEZONE = datetime(2018, 3, 10, 11, 39, 12, 110000) RUN_STARTED_AT_STRING_PYTHON = "2018-03-10T11:39:12.110000+00:00" @@ -334,6 +357,317 @@ def test_experiment_run_data_schema_multiple(): assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} +PROJECT_ID_FILTER = ProjectIdFilter(ComparisonOperator.EQUAL_TO, PROJECT_ID) +PROJECT_ID_FILTER_BODY = { + "by": "projectId", + "operator": "eq", + "value": str(PROJECT_ID), +} + +TAG_FILTER = TagFilter("tag-key", ComparisonOperator.EQUAL_TO, "tag-value") +TAG_FILTER_BODY = { + "by": "tag", + "key": "tag-key", + "operator": "eq", + "value": "tag-value", +} + + +DEFINED_TEST_CASES = [ + (ComparisonOperator.DEFINED, False, "defined", False), + (ComparisonOperator.DEFINED, True, "defined", True), + (ComparisonOperator.DEFINED, 0, "defined", False), + (ComparisonOperator.DEFINED, 1, "defined", True), +] + + +def discrete_test_cases(value, expected): + return DEFINED_TEST_CASES + [ + (ComparisonOperator.EQUAL_TO, value, "eq", expected), + (ComparisonOperator.NOT_EQUAL_TO, value, "ne", expected), + ] + + +def continuous_test_cases(value, expected): + return discrete_test_cases(value, expected) + [ + (ComparisonOperator.GREATER_THAN, value, "gt", expected), + (ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, value, "ge", expected), + (ComparisonOperator.LESS_THAN, value, "lt", expected), + (ComparisonOperator.LESS_THAN_OR_EQUAL_TO, value, "le", expected), + ] + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases(PROJECT_ID, str(PROJECT_ID)), +) +def test_filter_schema_project_id( + operator, value, expected_operator, expected_value +): + filter = ProjectIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "projectId", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases(EXPERIMENT_ID, EXPERIMENT_ID), +) +def test_filter_schema_experiment_id( + operator, value, expected_operator, expected_value +): + filter = ExperimentIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "experimentId", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases(RUN_ID, str(RUN_ID)), +) +def test_filter_schema_run_id( + operator, value, expected_operator, expected_value +): + filter = RunIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "runId", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + continuous_test_cases(DELETED_AT, DELETED_AT_STRING_PYTHON), +) +def test_filter_schema_deleted_at( + operator, value, expected_operator, expected_value +): + filter = DeletedAtFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "deletedAt", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases("tag-value", "tag-value"), +) +def test_filter_schema_tag(operator, value, expected_operator, expected_value): + filter = TagFilter("tag-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "tag", + "key": "tag-key", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + discrete_test_cases("param-value", "param-value") + + continuous_test_cases(123.2, 123.2), +) +def test_filter_schema_param( + operator, value, expected_operator, expected_value +): + filter = ParamFilter("param-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "param", + "key": "param-key", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "operator, value, expected_operator, expected_value", + continuous_test_cases(45.6, 45.6), +) +def test_filter_schema_metric( + operator, value, expected_operator, expected_value +): + filter = MetricFilter("metric-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "metric", + "key": "metric-key", + "operator": expected_operator, + "value": expected_value, + } + + +@pytest.mark.parametrize( + "filter_type", + [ProjectIdFilter, ExperimentIdFilter, RunIdFilter, DeletedAtFilter], +) +def test_filter_schema_invalid_value_no_key(filter_type): + filter = filter_type(ComparisonOperator.EQUAL_TO, "invalid") + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", [(ParamFilter, None), (MetricFilter, "invalid")] +) +def test_filter_schema_invalid_value_with_key(filter_type, value): + filter = filter_type("key", ComparisonOperator.EQUAL_TO, value) + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", + [ + (ProjectIdFilter, PROJECT_ID), + (ExperimentIdFilter, EXPERIMENT_ID), + (RunIdFilter, RUN_ID), + ], +) +@pytest.mark.parametrize( + "operator", + [ + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, + ], +) +def test_filter_schema_invalid_operator_no_key(filter_type, value, operator): + filter = filter_type(operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", + [(TagFilter, "tag-value"), (ParamFilter, "param-string-value")], +) +@pytest.mark.parametrize( + "operator", + [ + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, + ], +) +def test_filter_schema_invalid_operator_with_key(filter_type, value, operator): + filter = filter_type("key", operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "operator, expected_operator", + [(LogicalOperator.AND, "and"), (LogicalOperator.OR, "or")], +) +def test_filter_schema_compound(operator, expected_operator): + filter = CompoundFilter(operator, [PROJECT_ID_FILTER, TAG_FILTER]) + data = FilterSchema().dump(filter) + assert data == { + "operator": expected_operator, + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + } + + +def test_filter_schema_nested(): + filter = CompoundFilter( + LogicalOperator.AND, + [ + CompoundFilter( + LogicalOperator.AND, [PROJECT_ID_FILTER, TAG_FILTER] + ), + CompoundFilter( + LogicalOperator.OR, [TAG_FILTER, PROJECT_ID_FILTER] + ), + ], + ) + data = FilterSchema().dump(filter) + assert data == { + "operator": "and", + "conditions": [ + { + "operator": "and", + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + }, + { + "operator": "or", + "conditions": [TAG_FILTER_BODY, PROJECT_ID_FILTER_BODY], + }, + ], + } + + +@pytest.mark.parametrize( + "sort_type, by", + [ + (StartedAtSort, "startedAt"), + (RunNumberSort, "runNumber"), + (DurationSort, "duration"), + ], +) +@pytest.mark.parametrize( + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema_no_key(sort_type, by, order, expected_order): + sort = sort_type(order) + data = SortSchema().dump(sort) + assert data == {"by": by, "order": expected_order} + + +@pytest.mark.parametrize( + "sort_type, by", + [(TagSort, "tag"), (ParamSort, "param"), (MetricSort, "metric")], +) +@pytest.mark.parametrize( + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema_with_key(sort_type, by, order, expected_order): + sort = sort_type("sort-key", order) + data = SortSchema().dump(sort) + assert data == {"by": by, "key": "sort-key", "order": expected_order} + + +def test_run_query_schema(mocker): + mocker.patch.object(FilterSchema, "dump") + mocker.patch.object(SortSchema, "dump") + mocker.patch.object(PageSchema, "dump") + + filter = mocker.Mock() + sorts = [mocker.Mock(), mocker.Mock()] + page = mocker.Mock() + + run_query = RunQuery(filter, sorts, page) + data = RunQuerySchema().dump(run_query) + + assert data == { + "filter": FilterSchema.dump.return_value, + "sort": [SortSchema.dump.return_value, SortSchema.dump.return_value], + "page": PageSchema.dump.return_value, + } + + +def test_run_query_schema_defaults(): + run_query = RunQuery(None, None, None) + data = RunQuerySchema().dump(run_query) + assert data == {"filter": None, "sort": None, "page": None} + + @pytest.mark.parametrize("description", [None, "experiment description"]) @pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) def test_experiment_client_create(mocker, description, artifact_location): @@ -556,11 +890,16 @@ def test_list_runs_schema(mocker): assert data == LIST_EXPERIMENT_RUNS_RESPONSE -def test_page_schema(): +def test_page_schema_load(): data = PageSchema().load(PAGE_BODY) assert data == PAGE +def test_page_schema_dump(): + data = PageSchema().dump(PAGE) + assert data == PAGE_BODY + + def test_pagination_schema(): data = PaginationSchema().load(PAGINATION_BODY) assert data == PAGINATION @@ -598,72 +937,76 @@ def test_restore_experiment_runs_response_schema_invalid(mocker): RestoreExperimentRunsResponseSchema().load({}) -def test_experiment_client_list_runs_all(mocker): - mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" - ) +def test_experiment_client_list_runs(mocker): + mocker.patch.object(ExperimentClient, "query_runs") client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[], + response = client.list_runs( + PROJECT_ID, + experiment_ids=[123, 456], + lifecycle_stage=LifecycleStage.DELETED, + start=20, + limit=10, ) - -def test_experiment_client_list_runs_experiments_filter(mocker): - mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + assert response == ExperimentClient.query_runs.return_value + expected_filter = CompoundFilter( + LogicalOperator.AND, + [ + CompoundFilter( + LogicalOperator.OR, + [ + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 123), + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 456), + ], + ), + DeletedAtFilter(ComparisonOperator.DEFINED, True), + ], ) - schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" + ExperimentClient.query_runs.assert_called_once_with( + PROJECT_ID, expected_filter, None, 20, 10 ) - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[123, 456]) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[("experimentId", 123), ("experimentId", 456)], - ) +def test_experiment_client_list_runs_defaults(mocker): + mocker.patch.object(ExperimentClient, "query_runs") -def test_experiment_client_list_runs_experiments_filter_empty(mocker): client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[]) + response = client.list_runs(PROJECT_ID) - assert list_result == ListExperimentRunsResponse( - runs=[], - pagination=Pagination(start=0, size=0, previous=None, next=None), + assert response == ExperimentClient.query_runs.return_value + ExperimentClient.query_runs.assert_called_once_with( + PROJECT_ID, None, None, None, None ) -def test_experiment_client_list_runs_page(mocker): +def test_experiment_client_query_runs(mocker): mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) + request_dump_mock = mocker.patch.object(RunQuerySchema, "dump") + + filter = mocker.Mock() + sort = mocker.Mock() client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, start=20, limit=10) + list_result = client.query_runs( + PROJECT_ID, filter, sort, start=20, limit=10 + ) + assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[("start", 20), ("limit", 10)], + request_dump_mock.assert_called_once_with( + RunQuery(filter, sort, Page(20, 10)) + ) + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=request_dump_mock.return_value, ) @@ -886,31 +1229,28 @@ def test_delete_runs(mocker): mocker.patch.object( ExperimentClient, "_post", return_value=DELETE_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.DeleteExperimentRunsResponseSchema" ) + filter_dump_mock = mocker.patch.object(FilterSchema, "dump") run_ids = [uuid4(), uuid4()] client = ExperimentClient(mocker.Mock()) - assert ( - client.delete_runs(PROJECT_ID, run_ids) - == DELETE_EXPERIMENT_RUNS_RESPONSE - ) - - expected_payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_ids[0])}, - {"by": "runId", "operator": "eq", "value": str(run_ids[1])}, - ], - } - } + response = client.delete_runs(PROJECT_ID, run_ids) + assert response == DELETE_EXPERIMENT_RUNS_RESPONSE + expected_filter = CompoundFilter( + LogicalOperator.OR, + [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[0]), + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[1]), + ], + ) + filter_dump_mock.assert_called_once_with(expected_filter) ExperimentClient._post.assert_called_once_with( "/project/{}/run/delete/query".format(PROJECT_ID), - schema_mock.return_value, - json=expected_payload, + response_schema_mock.return_value, + json={"filter": filter_dump_mock.return_value}, ) @@ -931,13 +1271,15 @@ def test_delete_runs_no_run_ids(mocker): def test_delete_runs_empty_list(mocker): + mocker.patch.object(ExperimentClient, "_post") + client = ExperimentClient(mocker.Mock()) + response = client.delete_runs(PROJECT_ID, run_ids=[]) - assert client.delete_runs( - PROJECT_ID, run_ids=[] - ) == DeleteExperimentRunsResponse( + assert response == DeleteExperimentRunsResponse( deleted_run_ids=[], conflicted_run_ids=[] ) + ExperimentClient._post.assert_not_called() def test_restore_runs(mocker): @@ -946,31 +1288,28 @@ def test_restore_runs(mocker): "_post", return_value=RESTORE_EXPERIMENT_RUNS_RESPONSE, ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.RestoreExperimentRunsResponseSchema" ) + filter_dump_mock = mocker.patch.object(FilterSchema, "dump") run_ids = [uuid4(), uuid4()] client = ExperimentClient(mocker.Mock()) - assert ( - client.restore_runs(PROJECT_ID, run_ids) - == RESTORE_EXPERIMENT_RUNS_RESPONSE - ) - - expected_payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_ids[0])}, - {"by": "runId", "operator": "eq", "value": str(run_ids[1])}, - ], - } - } + response = client.restore_runs(PROJECT_ID, run_ids) + assert response == RESTORE_EXPERIMENT_RUNS_RESPONSE + expected_filter = CompoundFilter( + LogicalOperator.OR, + [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[0]), + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[1]), + ], + ) + filter_dump_mock.assert_called_once_with(expected_filter) ExperimentClient._post.assert_called_once_with( "/project/{}/run/restore/query".format(PROJECT_ID), - schema_mock.return_value, - json=expected_payload, + response_schema_mock.return_value, + json={"filter": filter_dump_mock.return_value}, ) @@ -991,10 +1330,12 @@ def test_restore_runs_no_run_ids(mocker): def test_restore_runs_empty_list(mocker): + mocker.patch.object(ExperimentClient, "_post") + client = ExperimentClient(mocker.Mock()) + response = client.restore_runs(PROJECT_ID, run_ids=[]) - assert client.restore_runs( - PROJECT_ID, run_ids=[] - ) == RestoreExperimentRunsResponse( + assert response == RestoreExperimentRunsResponse( restored_run_ids=[], conflicted_run_ids=[] ) + ExperimentClient._post.assert_not_called()