From 70a6a7e0e163771e87dcef00e57c4ef8eb16ec97 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Fri, 10 May 2019 15:23:25 +0100 Subject: [PATCH 01/52] Implement schema for run query filters --- faculty/clients/experiment.py | 122 ++++++++++++++++++++++++++++++- tests/clients/test_experiment.py | 42 +++++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1c76b2f1..b98d59e0 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -14,8 +14,9 @@ from collections import namedtuple from enum import Enum +import uuid -from marshmallow import fields, post_load +from marshmallow import fields, post_load, utils as marshmallow_utils from marshmallow_enum import EnumField from faculty.clients.base import BaseClient, BaseSchema, Conflict @@ -43,6 +44,12 @@ def __init__(self, message, experiment_id): self.experiment_id = experiment_id +class RunQueryFilterValidation(Exception): + def __init__(self, message, value): + super(RunQueryFilterValidation, self).__init__(message) + self.value = value + + class ExperimentRunStatus(Enum): RUNNING = "running" FINISHED = "finished" @@ -99,6 +106,119 @@ class ExperimentRunStatus(Enum): "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] ) +SingleFilter = namedtuple("SingleFilter", ["by", "key", "operator", "value"]) + +CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) + + +class SingleFilterOperator(Enum): + DEFINED = "defined" + EQUAL_TO = "eq" + NOT_EQUAL_TO = "ne" + LESS_THAN = "lt" + LESS_THAN_OR_EQUAL_TO = "le" + GREATER_THAN = "gt" + GREATER_THAN_OR_EQUAL_TO = "ge" + + +class SingleFilterBy(Enum): + PROJECT_ID = "projectId" + EXPERIMENT_ID = "experimentId" + RUN_ID = "runId" + DELETED_AT = "deletedAt" + TAG = "tag" + PARAM = "param" + METRIC = "metric" + + +class CompoundFilterOperator(Enum): + AND = "and" + OR = "or" + + +class SingleFilterValueField(fields.Field): + """ + Field that serialises/deserialises a run filter. + """ + + def _is_valid_uuid(self, value, obj): + return ( + isinstance(value, uuid.UUID) + and ( + obj.by == SingleFilterBy.PROJECT_ID + or obj.by == SingleFilterBy.RUN_ID + ) + ) + + def _is_valid_experiment_id(self, value, obj): + return ( + isinstance(value, int) + and obj.by == SingleFilterBy.EXPERIMENT_ID + ) + + def _is_directly_stringifiable(self, value, obj): + return ( + self._is_valid_uuid(value, obj) + or self._is_valid_experiment_id(value, obj) + or obj.by == SingleFilterBy.TAG + or obj.by == SingleFilterBy.PARAM + or obj.by == SingleFilterBy.METRIC + ) + + def _deserialize(self, value, attr, obj, **kwargs): + pass + + def _serialize(self, value, attr, obj, **kwargs): + if self._is_directly_stringifiable(value, obj): + return str(value) + elif obj.by == SingleFilterBy.DELETED_AT: + return marshmallow_utils.from_iso_datetime(str(value)) + else: + raise RunQueryFilterValidation( + "Validation error serialising run query filter", + value + ) + + +class FilterField(fields.Field): + """ + Field that serialises/deserialises a run filter. + """ + + def _deserialize(self, value, attr, obj, **kwargs): + if value is None: + return None + elif isinstance(value, SingleFilter): + return SingleFilterSchema().load(value) + else: + return CompoundFilterSchema().load(value) + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + if isinstance(value, SingleFilter): + return SingleFilterSchema().dump(value) + else: + return CompoundFilterSchema().dump(value) + + +class SingleFilterSchema(BaseSchema): + by = EnumField(SingleFilterBy, by_value=True, required=True) + operator = EnumField(SingleFilterOperator, by_value=True, required=True) + value = SingleFilterValueField(required=True) + + +class CompoundFilterSchema(BaseSchema): + operator = EnumField( + CompoundFilterOperator, by_value=True, required=True) + conditions = fields.List(FilterField()) + + +class QueryRunsSchema(BaseSchema): + filter = FilterField(required=True) + sort = fields.String() + page = fields.String() + class MetricSchema(BaseSchema): key = fields.String(required=True) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 7b5455f0..d2aad3d7 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -23,6 +23,8 @@ from faculty.clients.base import Conflict from faculty.clients.experiment import ( CreateRunSchema, + CompoundFilter, + CompoundFilterOperator, DeleteExperimentRunsResponse, DeleteExperimentRunsResponseSchema, Experiment, @@ -47,8 +49,12 @@ Param, ParamConflict, ParamSchema, + QueryRunsSchema, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, + SingleFilter, + SingleFilterBy, + SingleFilterOperator, Tag, TagSchema, ) @@ -227,6 +233,42 @@ ], } +# TODO: testing query stuff + + +def test_query_runs_schema(): + queryRunsObj = { + "filter": CompoundFilter( + operator=CompoundFilterOperator.AND, + conditions=[ + SingleFilter( + SingleFilterBy.EXPERIMENT_ID, + None, SingleFilterOperator.EQUAL_TO, 7), + None + ] + ), + "sort": "sort", + "page": "page" + } + expected_json = { + "filter": { + "operator": "and", + "conditions": [ + { + "by": "experimentId", + "operator": "eq", + "value": '7' + }, + None + ] + }, + "sort": "sort", + "page": "page" + } + data = QueryRunsSchema().dump(queryRunsObj) + assert data == expected_json +# TODO: testing query stuff + def test_experiment_schema(): data = ExperimentSchema().load(EXPERIMENT_BODY) From 27930c18eddcdd0eed111e76ad51a5add87b7a6d Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 11:24:36 +0100 Subject: [PATCH 02/52] Add sort and page fields serialisation and corresponding test --- faculty/clients/experiment.py | 46 ++++++++++++++++++++++++-------- tests/clients/test_experiment.py | 24 +++++++++-------- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index b98d59e0..700c55d9 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -204,6 +204,7 @@ def _serialize(self, value, attr, obj, **kwargs): class SingleFilterSchema(BaseSchema): by = EnumField(SingleFilterBy, by_value=True, required=True) + key = fields.String() operator = EnumField(SingleFilterOperator, by_value=True, required=True) value = SingleFilterValueField(required=True) @@ -214,10 +215,42 @@ class CompoundFilterSchema(BaseSchema): conditions = fields.List(FilterField()) +Sort = namedtuple("Sort",["by", "key", "order"]) + + +class SortBy(Enum): + STARTED_AT = "startedAt" + RUN_NUMBER = "runNumber" + DURATION = "duration" + TAG = "tag" + PARAM = "param" + METRIC = "metric" + + +class SortOrder(Enum): + ASC = "asc" + DESC = "desc" + + +class SortSchema(BaseSchema): + by = EnumField(SortBy, by_value=True, required=True) + key = fields.String() + order = EnumField(SortOrder, by_value=True, required=True) + + +class PageSchema(BaseSchema): + start = fields.Integer(required=True) + limit = fields.Integer(required=True) + + @post_load + def make_page(self, data): + return Page(**data) + + class QueryRunsSchema(BaseSchema): filter = FilterField(required=True) - sort = fields.String() - page = fields.String() + sort = fields.List(fields.Nested(SortSchema)) + page = fields.Nested(PageSchema, missing=None) class MetricSchema(BaseSchema): @@ -302,15 +335,6 @@ class ExperimentRunInfoSchema(BaseSchema): ended_at = fields.DateTime(data_key="endedAt", missing=None) -class PageSchema(BaseSchema): - start = fields.Integer(required=True) - limit = fields.Integer(required=True) - - @post_load - def make_page(self, data): - return Page(**data) - - class PaginationSchema(BaseSchema): start = fields.Integer(required=True) size = fields.Integer(required=True) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index d2aad3d7..98b0f4ce 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -55,6 +55,9 @@ SingleFilter, SingleFilterBy, SingleFilterOperator, + Sort, + SortBy, + SortOrder, Tag, TagSchema, ) @@ -234,42 +237,41 @@ } # TODO: testing query stuff - - def test_query_runs_schema(): queryRunsObj = { "filter": CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[ SingleFilter( - SingleFilterBy.EXPERIMENT_ID, - None, SingleFilterOperator.EQUAL_TO, 7), + SingleFilterBy.TAG, + "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), None ] ), - "sort": "sort", - "page": "page" + "sort": [Sort(SortBy.PARAM, "param_key", SortOrder.ASC), \ + Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC)], + "page": PAGE } expected_json = { "filter": { "operator": "and", "conditions": [ { - "by": "experimentId", + "by": "tag", + "key": "tag_key", "operator": "eq", - "value": '7' + "value": "tag_value" }, None ] }, - "sort": "sort", - "page": "page" + "sort": [{"by": "param", "key": "param_key", "order": "asc"}, {"by": "runNumber", "key": None, "order": "desc"}], + "page": PAGE_BODY } data = QueryRunsSchema().dump(queryRunsObj) assert data == expected_json # TODO: testing query stuff - def test_experiment_schema(): data = ExperimentSchema().load(EXPERIMENT_BODY) assert data == EXPERIMENT From 0b2033b7c4ce8f4dc93a49db2147cf2747aab701 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 12:41:33 +0100 Subject: [PATCH 03/52] Add QueryRun named tuple --- faculty/clients/experiment.py | 1 + tests/clients/test_experiment.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 700c55d9..2ace5683 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -246,6 +246,7 @@ class PageSchema(BaseSchema): def make_page(self, data): return Page(**data) +QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) class QueryRunsSchema(BaseSchema): filter = FilterField(required=True) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 98b0f4ce..552ce2a5 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -49,6 +49,7 @@ Param, ParamConflict, ParamSchema, + QueryRuns, QueryRunsSchema, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, @@ -236,10 +237,10 @@ ], } -# TODO: testing query stuff + def test_query_runs_schema(): - queryRunsObj = { - "filter": CompoundFilter( + queryRunsObj = QueryRuns( + CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[ SingleFilter( @@ -248,10 +249,10 @@ def test_query_runs_schema(): None ] ), - "sort": [Sort(SortBy.PARAM, "param_key", SortOrder.ASC), \ + [Sort(SortBy.PARAM, "param_key", SortOrder.ASC), \ Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC)], - "page": PAGE - } + PAGE + ) expected_json = { "filter": { "operator": "and", @@ -270,7 +271,7 @@ def test_query_runs_schema(): } data = QueryRunsSchema().dump(queryRunsObj) assert data == expected_json -# TODO: testing query stuff + def test_experiment_schema(): data = ExperimentSchema().load(EXPERIMENT_BODY) From 52ad53e6ee0a0ca385ea3667faeadae3b798d67b Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 12:02:38 +0100 Subject: [PATCH 04/52] Reorder schemas --- faculty/clients/experiment.py | 184 +++++++++++++++++----------------- 1 file changed, 92 insertions(+), 92 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 2ace5683..abe66357 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -136,86 +136,7 @@ class CompoundFilterOperator(Enum): OR = "or" -class SingleFilterValueField(fields.Field): - """ - Field that serialises/deserialises a run filter. - """ - - def _is_valid_uuid(self, value, obj): - return ( - isinstance(value, uuid.UUID) - and ( - obj.by == SingleFilterBy.PROJECT_ID - or obj.by == SingleFilterBy.RUN_ID - ) - ) - - def _is_valid_experiment_id(self, value, obj): - return ( - isinstance(value, int) - and obj.by == SingleFilterBy.EXPERIMENT_ID - ) - - def _is_directly_stringifiable(self, value, obj): - return ( - self._is_valid_uuid(value, obj) - or self._is_valid_experiment_id(value, obj) - or obj.by == SingleFilterBy.TAG - or obj.by == SingleFilterBy.PARAM - or obj.by == SingleFilterBy.METRIC - ) - - def _deserialize(self, value, attr, obj, **kwargs): - pass - - def _serialize(self, value, attr, obj, **kwargs): - if self._is_directly_stringifiable(value, obj): - return str(value) - elif obj.by == SingleFilterBy.DELETED_AT: - return marshmallow_utils.from_iso_datetime(str(value)) - else: - raise RunQueryFilterValidation( - "Validation error serialising run query filter", - value - ) - - -class FilterField(fields.Field): - """ - Field that serialises/deserialises a run filter. - """ - - def _deserialize(self, value, attr, obj, **kwargs): - if value is None: - return None - elif isinstance(value, SingleFilter): - return SingleFilterSchema().load(value) - else: - return CompoundFilterSchema().load(value) - - def _serialize(self, value, attr, obj, **kwargs): - if value is None: - return None - if isinstance(value, SingleFilter): - return SingleFilterSchema().dump(value) - else: - return CompoundFilterSchema().dump(value) - - -class SingleFilterSchema(BaseSchema): - by = EnumField(SingleFilterBy, by_value=True, required=True) - key = fields.String() - operator = EnumField(SingleFilterOperator, by_value=True, required=True) - value = SingleFilterValueField(required=True) - - -class CompoundFilterSchema(BaseSchema): - operator = EnumField( - CompoundFilterOperator, by_value=True, required=True) - conditions = fields.List(FilterField()) - - -Sort = namedtuple("Sort",["by", "key", "order"]) +Sort = namedtuple("Sort", ["by", "key", "order"]) class SortBy(Enum): @@ -232,12 +153,6 @@ class SortOrder(Enum): DESC = "desc" -class SortSchema(BaseSchema): - by = EnumField(SortBy, by_value=True, required=True) - key = fields.String() - order = EnumField(SortOrder, by_value=True, required=True) - - class PageSchema(BaseSchema): start = fields.Integer(required=True) limit = fields.Integer(required=True) @@ -248,12 +163,6 @@ def make_page(self, data): QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) -class QueryRunsSchema(BaseSchema): - filter = FilterField(required=True) - sort = fields.List(fields.Nested(SortSchema)) - page = fields.Nested(PageSchema, missing=None) - - class MetricSchema(BaseSchema): key = fields.String(required=True) value = fields.Float(required=True) @@ -390,6 +299,97 @@ def make_restore_runs_response(self, data): return RestoreExperimentRunsResponse(**data) +class SingleFilterValueField(fields.Field): + """ + Field that serialises/deserialises a run filter. + """ + + def _is_valid_uuid(self, value, obj): + return ( + isinstance(value, uuid.UUID) + and ( + obj.by == SingleFilterBy.PROJECT_ID + or obj.by == SingleFilterBy.RUN_ID + ) + ) + + def _is_valid_experiment_id(self, value, obj): + return ( + isinstance(value, int) + and obj.by == SingleFilterBy.EXPERIMENT_ID + ) + + def _is_directly_stringifiable(self, value, obj): + return ( + self._is_valid_uuid(value, obj) + or self._is_valid_experiment_id(value, obj) + or obj.by == SingleFilterBy.TAG + or obj.by == SingleFilterBy.PARAM + or obj.by == SingleFilterBy.METRIC + ) + + def _deserialize(self, value, attr, obj, **kwargs): + pass + + def _serialize(self, value, attr, obj, **kwargs): + if self._is_directly_stringifiable(value, obj): + return str(value) + elif obj.by == SingleFilterBy.DELETED_AT: + return marshmallow_utils.from_iso_datetime(str(value)) + else: + raise RunQueryFilterValidation( + "Validation error serialising run query filter", + value + ) + + +class FilterField(fields.Field): + """ + Field that serialises/deserialises a run filter. + """ + + def _deserialize(self, value, attr, obj, **kwargs): + if value is None: + return None + elif isinstance(value, SingleFilter): + return SingleFilterSchema().load(value) + else: + return CompoundFilterSchema().load(value) + + def _serialize(self, value, attr, obj, **kwargs): + if value is None: + return None + if isinstance(value, SingleFilter): + return SingleFilterSchema().dump(value) + else: + return CompoundFilterSchema().dump(value) + + +class SingleFilterSchema(BaseSchema): + by = EnumField(SingleFilterBy, by_value=True, required=True) + key = fields.String() + operator = EnumField(SingleFilterOperator, by_value=True, required=True) + value = SingleFilterValueField(required=True) + + +class CompoundFilterSchema(BaseSchema): + operator = EnumField( + CompoundFilterOperator, by_value=True, required=True) + conditions = fields.List(FilterField()) + + +class SortSchema(BaseSchema): + by = EnumField(SortBy, by_value=True, required=True) + key = fields.String() + order = EnumField(SortOrder, by_value=True, required=True) + + +class QueryRunsSchema(BaseSchema): + filter = FilterField(required=True) + sort = fields.List(fields.Nested(SortSchema)) + page = fields.Nested(PageSchema, missing=None) + + class MetricHistorySchema(BaseSchema): history = fields.Nested(MetricSchema, many=True, required=True) From 7e9c221e7e07fadbdff673475b9e4274b46d7afd Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 12:44:05 +0100 Subject: [PATCH 05/52] Rearrange query runs --- faculty/clients/experiment.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index abe66357..08241644 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -153,6 +153,9 @@ class SortOrder(Enum): DESC = "desc" +QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) + + class PageSchema(BaseSchema): start = fields.Integer(required=True) limit = fields.Integer(required=True) @@ -161,7 +164,6 @@ class PageSchema(BaseSchema): def make_page(self, data): return Page(**data) -QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) class MetricSchema(BaseSchema): key = fields.String(required=True) From c740eee5b29238d0777504cd5dd252fdc1e9a7b6 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 14:34:54 +0100 Subject: [PATCH 06/52] Parametrise tests and test project id filter --- faculty/clients/experiment.py | 1 + tests/clients/test_experiment.py | 119 ++++++++++++++++++++++--------- 2 files changed, 85 insertions(+), 35 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 08241644..b5e3de70 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -339,6 +339,7 @@ def _serialize(self, value, attr, obj, **kwargs): elif obj.by == SingleFilterBy.DELETED_AT: return marshmallow_utils.from_iso_datetime(str(value)) else: + print(value) raise RunQueryFilterValidation( "Validation error serialising run query filter", value diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 552ce2a5..bef258b2 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -238,41 +238,6 @@ } -def test_query_runs_schema(): - queryRunsObj = QueryRuns( - CompoundFilter( - operator=CompoundFilterOperator.AND, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), - None - ] - ), - [Sort(SortBy.PARAM, "param_key", SortOrder.ASC), \ - Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC)], - PAGE - ) - expected_json = { - "filter": { - "operator": "and", - "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value" - }, - None - ] - }, - "sort": [{"by": "param", "key": "param_key", "order": "asc"}, {"by": "runNumber", "key": None, "order": "desc"}], - "page": PAGE_BODY - } - data = QueryRunsSchema().dump(queryRunsObj) - assert data == expected_json - - def test_experiment_schema(): data = ExperimentSchema().load(EXPERIMENT_BODY) assert data == EXPERIMENT @@ -374,6 +339,90 @@ def test_experiment_run_data_schema_multiple(): assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} +# check that all three fields are nullable +# check all types of filter +# check all types of sort + +PROJECT_ID_FILTER = SingleFilter( + by=SingleFilterBy.PROJECT_ID, key="k", operator=SingleFilterOperator.EQUAL_TO, value=PROJECT_ID) +PROJECT_ID_FILTER_BODY = { + "by": "projectId", + "key": "k", + "operator": "eq", + "value": str(PROJECT_ID) +} + +TAG_FILTER = SingleFilter( + by=SingleFilterBy.TAG, key="tag_key", operator=SingleFilterOperator.EQUAL_TO, value="tag_value") +TAG_FILTER_BODY = { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value" +} + +AND_FILTER = CompoundFilter( + operator=CompoundFilterOperator.AND, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), + None + ] +) +AND_FILTER_BODY = { + "operator": "and", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value" + }, + None + ] +} + +MULTI_SORT = [ + Sort(SortBy.PARAM, "param_key", SortOrder.ASC), + Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC) +] +MULTI_SORT_BODY = [ + {"by": "param", "key": "param_key", "order": "asc"}, + {"by": "runNumber", "key": None, "order": "desc"}] + + +@pytest.mark.parametrize( + "pfilter,pfilter_body", + [ + [None, None], + [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], + [TAG_FILTER, TAG_FILTER_BODY], + [AND_FILTER, AND_FILTER_BODY] + ] +) +@pytest.mark.parametrize( + "psort,psort_body", + [ + [None, None], + [MULTI_SORT, MULTI_SORT_BODY] + ] +) +def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): + queryRunsObj = { + "filter": pfilter, + "sort": psort, + "page": PAGE + } + expected_json = { + "filter": pfilter_body, + "sort": psort_body, + "page": PAGE_BODY + } + data = QueryRunsSchema().dump(queryRunsObj) + assert data == expected_json + + @pytest.mark.parametrize("description", [None, "experiment description"]) @pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) def test_experiment_client_create(mocker, description, artifact_location): From b04bc42bee6d090fe8fd314d94e8f82b07253cbd Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 15:07:53 +0100 Subject: [PATCH 07/52] Complete parameterising RunQuery tests for all different filters --- faculty/clients/experiment.py | 2 +- tests/clients/test_experiment.py | 55 ++++++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index b5e3de70..81bb02c8 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -337,7 +337,7 @@ def _serialize(self, value, attr, obj, **kwargs): if self._is_directly_stringifiable(value, obj): return str(value) elif obj.by == SingleFilterBy.DELETED_AT: - return marshmallow_utils.from_iso_datetime(str(value)) + return marshmallow_utils.from_iso_datetime(str(value)).isoformat() else: print(value) raise RunQueryFilterValidation( diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index bef258b2..c107f774 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -75,6 +75,7 @@ LAST_UPDATED_AT_STRING = "2018-03-10T11:32:30.172Z" DELETED_AT = datetime(2018, 3, 10, 11, 37, 42, 482000, tzinfo=UTC) DELETED_AT_STRING = "2018-03-10T11:37:42.482Z" +DELETED_AT_STRING_PYTHON = "2018-03-10T11:37:42.482000+00:00" EXPERIMENT = Experiment( id=EXPERIMENT_ID, @@ -344,14 +345,41 @@ def test_experiment_run_data_schema_multiple(): # check all types of sort PROJECT_ID_FILTER = SingleFilter( - by=SingleFilterBy.PROJECT_ID, key="k", operator=SingleFilterOperator.EQUAL_TO, value=PROJECT_ID) + by=SingleFilterBy.PROJECT_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=PROJECT_ID) PROJECT_ID_FILTER_BODY = { "by": "projectId", - "key": "k", + "key": None, "operator": "eq", "value": str(PROJECT_ID) } +EXPERIMENT_ID_FILTER = SingleFilter( + by=SingleFilterBy.EXPERIMENT_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=1) +EXPERIMENT_ID_BODY = { + "by": "experimentId", + "key": None, + "operator": "eq", + "value": "1" +} + +RUN_ID_FILTER = SingleFilter( + by=SingleFilterBy.RUN_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=EXPERIMENT_RUN_ID) +RUN_ID_BODY = { + "by": "runId", + "key": None, + "operator": "eq", + "value": str(EXPERIMENT_RUN_ID) +} + +DELETED_AT_FILTER = SingleFilter( + by=SingleFilterBy.DELETED_AT, key=None, operator=SingleFilterOperator.EQUAL_TO, value=DELETED_AT) +DELETED_AT_BODY = { + "by": "deletedAt", + "key": None, + "operator": "eq", + "value": DELETED_AT_STRING_PYTHON +} + TAG_FILTER = SingleFilter( by=SingleFilterBy.TAG, key="tag_key", operator=SingleFilterOperator.EQUAL_TO, value="tag_value") TAG_FILTER_BODY = { @@ -361,6 +389,24 @@ def test_experiment_run_data_schema_multiple(): "value": "tag_value" } +PARAM_FILTER = SingleFilter( + by=SingleFilterBy.PARAM, key="param_key", operator=SingleFilterOperator.EQUAL_TO, value="param_value") +PARAM_FILTER_BODY = { + "by": "param", + "key": "param_key", + "operator": "eq", + "value": "param_value" +} + +METRIC_FILTER = SingleFilter( + by=SingleFilterBy.METRIC, key="metric_key", operator=SingleFilterOperator.EQUAL_TO, value="metric_value") +METRIC_FILTER_BODY = { + "by": "metric", + "key": "metric_key", + "operator": "eq", + "value": "metric_value" +} + AND_FILTER = CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[ @@ -397,7 +443,12 @@ def test_experiment_run_data_schema_multiple(): [ [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], + [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], + [RUN_ID_FILTER, RUN_ID_BODY], + [DELETED_AT_FILTER, DELETED_AT_BODY], [TAG_FILTER, TAG_FILTER_BODY], + [PARAM_FILTER, PARAM_FILTER_BODY], + [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY] ] ) From e98fdf7db179ae84e8c96ea178e5f516f4de4721 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 15:25:18 +0100 Subject: [PATCH 08/52] Add sort test cases --- tests/clients/test_experiment.py | 48 +++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c107f774..17cf94e6 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -429,6 +429,46 @@ def test_experiment_run_data_schema_multiple(): ] } +OR_FILTER = CompoundFilter( + operator=CompoundFilterOperator.OR, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), + None + ] +) +OR_FILTER_BODY = { + "operator": "or", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value" + }, + None + ] +} + +RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] +RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] + +STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] +STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] + +DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] +DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] + +PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] +PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] + +TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] +TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] + +METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] +METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] + MULTI_SORT = [ Sort(SortBy.PARAM, "param_key", SortOrder.ASC), Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC) @@ -448,7 +488,7 @@ def test_experiment_run_data_schema_multiple(): [DELETED_AT_FILTER, DELETED_AT_BODY], [TAG_FILTER, TAG_FILTER_BODY], [PARAM_FILTER, PARAM_FILTER_BODY], - [METRIC_FILTER, METRIC_FILTER_BODY], + [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY] ] ) @@ -456,6 +496,12 @@ def test_experiment_run_data_schema_multiple(): "psort,psort_body", [ [None, None], + [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], + [STARTED_AT_SORT, STARTED_AT_SORT_BODY], + [DURATION_SORT, DURATION_SORT_BODY], + [PARAM_SORT, PARAM_SORT_BODY], + [TAG_SORT, TAG_SORT_BODY], + [METRIC_SORT, METRIC_SORT_BODY], [MULTI_SORT, MULTI_SORT_BODY] ] ) From 0625d1bda535bdd7c2de14de2ce04fb4dbf87800 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 15:37:30 +0100 Subject: [PATCH 09/52] Fix formatting --- faculty/clients/experiment.py | 18 ++---- tests/clients/test_experiment.py | 95 +++++++++++++++++++------------- 2 files changed, 63 insertions(+), 50 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 81bb02c8..5c2a3c99 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -307,18 +307,14 @@ class SingleFilterValueField(fields.Field): """ def _is_valid_uuid(self, value, obj): - return ( - isinstance(value, uuid.UUID) - and ( - obj.by == SingleFilterBy.PROJECT_ID - or obj.by == SingleFilterBy.RUN_ID - ) + return isinstance(value, uuid.UUID) and ( + obj.by == SingleFilterBy.PROJECT_ID + or obj.by == SingleFilterBy.RUN_ID ) def _is_valid_experiment_id(self, value, obj): return ( - isinstance(value, int) - and obj.by == SingleFilterBy.EXPERIMENT_ID + isinstance(value, int) and obj.by == SingleFilterBy.EXPERIMENT_ID ) def _is_directly_stringifiable(self, value, obj): @@ -341,8 +337,7 @@ def _serialize(self, value, attr, obj, **kwargs): else: print(value) raise RunQueryFilterValidation( - "Validation error serialising run query filter", - value + "Validation error serialising run query filter", value ) @@ -376,8 +371,7 @@ class SingleFilterSchema(BaseSchema): class CompoundFilterSchema(BaseSchema): - operator = EnumField( - CompoundFilterOperator, by_value=True, required=True) + operator = EnumField(CompoundFilterOperator, by_value=True, required=True) conditions = fields.List(FilterField()) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 17cf94e6..fecccf6b 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -345,66 +345,82 @@ def test_experiment_run_data_schema_multiple(): # check all types of sort PROJECT_ID_FILTER = SingleFilter( - by=SingleFilterBy.PROJECT_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=PROJECT_ID) + SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID +) PROJECT_ID_FILTER_BODY = { "by": "projectId", "key": None, "operator": "eq", - "value": str(PROJECT_ID) + "value": str(PROJECT_ID), } EXPERIMENT_ID_FILTER = SingleFilter( - by=SingleFilterBy.EXPERIMENT_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=1) + SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 +) EXPERIMENT_ID_BODY = { "by": "experimentId", "key": None, "operator": "eq", - "value": "1" + "value": "1", } RUN_ID_FILTER = SingleFilter( - by=SingleFilterBy.RUN_ID, key=None, operator=SingleFilterOperator.EQUAL_TO, value=EXPERIMENT_RUN_ID) + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + EXPERIMENT_RUN_ID, +) RUN_ID_BODY = { "by": "runId", "key": None, "operator": "eq", - "value": str(EXPERIMENT_RUN_ID) + "value": str(EXPERIMENT_RUN_ID), } DELETED_AT_FILTER = SingleFilter( - by=SingleFilterBy.DELETED_AT, key=None, operator=SingleFilterOperator.EQUAL_TO, value=DELETED_AT) + SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT +) DELETED_AT_BODY = { "by": "deletedAt", "key": None, "operator": "eq", - "value": DELETED_AT_STRING_PYTHON + "value": DELETED_AT_STRING_PYTHON, } TAG_FILTER = SingleFilter( - by=SingleFilterBy.TAG, key="tag_key", operator=SingleFilterOperator.EQUAL_TO, value="tag_value") + SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" +) TAG_FILTER_BODY = { "by": "tag", "key": "tag_key", "operator": "eq", - "value": "tag_value" + "value": "tag_value", } PARAM_FILTER = SingleFilter( - by=SingleFilterBy.PARAM, key="param_key", operator=SingleFilterOperator.EQUAL_TO, value="param_value") + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.EQUAL_TO, + "param_value", +) PARAM_FILTER_BODY = { "by": "param", "key": "param_key", "operator": "eq", - "value": "param_value" + "value": "param_value", } METRIC_FILTER = SingleFilter( - by=SingleFilterBy.METRIC, key="metric_key", operator=SingleFilterOperator.EQUAL_TO, value="metric_value") + SingleFilterBy.METRIC, + "metric_key", + SingleFilterOperator.EQUAL_TO, + "metric_value", +) METRIC_FILTER_BODY = { "by": "metric", "key": "metric_key", "operator": "eq", - "value": "metric_value" + "value": "metric_value", } AND_FILTER = CompoundFilter( @@ -412,9 +428,12 @@ def test_experiment_run_data_schema_multiple(): conditions=[ SingleFilter( SingleFilterBy.TAG, - "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), - None - ] + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], ) AND_FILTER_BODY = { "operator": "and", @@ -423,10 +442,10 @@ def test_experiment_run_data_schema_multiple(): "by": "tag", "key": "tag_key", "operator": "eq", - "value": "tag_value" + "value": "tag_value", }, - None - ] + None, + ], } OR_FILTER = CompoundFilter( @@ -434,9 +453,12 @@ def test_experiment_run_data_schema_multiple(): conditions=[ SingleFilter( SingleFilterBy.TAG, - "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value"), - None - ] + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], ) OR_FILTER_BODY = { "operator": "or", @@ -445,10 +467,10 @@ def test_experiment_run_data_schema_multiple(): "by": "tag", "key": "tag_key", "operator": "eq", - "value": "tag_value" + "value": "tag_value", }, - None - ] + None, + ], } RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] @@ -471,11 +493,12 @@ def test_experiment_run_data_schema_multiple(): MULTI_SORT = [ Sort(SortBy.PARAM, "param_key", SortOrder.ASC), - Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC) + Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), ] MULTI_SORT_BODY = [ {"by": "param", "key": "param_key", "order": "asc"}, - {"by": "runNumber", "key": None, "order": "desc"}] + {"by": "runNumber", "key": None, "order": "desc"}, +] @pytest.mark.parametrize( @@ -489,8 +512,8 @@ def test_experiment_run_data_schema_multiple(): [TAG_FILTER, TAG_FILTER_BODY], [PARAM_FILTER, PARAM_FILTER_BODY], [METRIC_FILTER, METRIC_FILTER_BODY], - [AND_FILTER, AND_FILTER_BODY] - ] + [AND_FILTER, AND_FILTER_BODY], + ], ) @pytest.mark.parametrize( "psort,psort_body", @@ -502,19 +525,15 @@ def test_experiment_run_data_schema_multiple(): [PARAM_SORT, PARAM_SORT_BODY], [TAG_SORT, TAG_SORT_BODY], [METRIC_SORT, METRIC_SORT_BODY], - [MULTI_SORT, MULTI_SORT_BODY] - ] + [MULTI_SORT, MULTI_SORT_BODY], + ], ) def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): - queryRunsObj = { - "filter": pfilter, - "sort": psort, - "page": PAGE - } + queryRunsObj = QueryRuns(pfilter, psort, PAGE) expected_json = { "filter": pfilter_body, "sort": psort_body, - "page": PAGE_BODY + "page": PAGE_BODY, } data = QueryRunsSchema().dump(queryRunsObj) assert data == expected_json From b1988b7b21e9457648156cef1780b9c9d20c1d21 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 17:14:05 +0100 Subject: [PATCH 10/52] Implement validation of RunQuery --- faculty/clients/experiment.py | 12 +- tests/clients/test_experiment.py | 393 ++++++++++++++++--------------- 2 files changed, 209 insertions(+), 196 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 5c2a3c99..43b0c83a 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -136,8 +136,16 @@ class CompoundFilterOperator(Enum): OR = "or" -Sort = namedtuple("Sort", ["by", "key", "order"]) - +_Sort = namedtuple("_Sort", ["by", "key", "order"]) + +class Sort(_Sort): + def __new__(self, by, key, order): + if by in [SortBy.STARTED_AT, SortBy.RUN_NUMBER, SortBy.DURATION] and key is not None: + raise ValueError("key must be none for type {}".format(by)) + elif by in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] and key is None: + raise ValueError("key must not be none for type {}".format(by)) + self = super(Sort, self).__init__(by, key, order) + return self class SortBy(Enum): STARTED_AT = "startedAt" diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index fecccf6b..b37b486e 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -344,200 +344,205 @@ def test_experiment_run_data_schema_multiple(): # check all types of filter # check all types of sort -PROJECT_ID_FILTER = SingleFilter( - SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID -) -PROJECT_ID_FILTER_BODY = { - "by": "projectId", - "key": None, - "operator": "eq", - "value": str(PROJECT_ID), -} - -EXPERIMENT_ID_FILTER = SingleFilter( - SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 -) -EXPERIMENT_ID_BODY = { - "by": "experimentId", - "key": None, - "operator": "eq", - "value": "1", -} - -RUN_ID_FILTER = SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - EXPERIMENT_RUN_ID, -) -RUN_ID_BODY = { - "by": "runId", - "key": None, - "operator": "eq", - "value": str(EXPERIMENT_RUN_ID), -} - -DELETED_AT_FILTER = SingleFilter( - SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT -) -DELETED_AT_BODY = { - "by": "deletedAt", - "key": None, - "operator": "eq", - "value": DELETED_AT_STRING_PYTHON, -} - -TAG_FILTER = SingleFilter( - SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" -) -TAG_FILTER_BODY = { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", -} - -PARAM_FILTER = SingleFilter( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.EQUAL_TO, - "param_value", -) -PARAM_FILTER_BODY = { - "by": "param", - "key": "param_key", - "operator": "eq", - "value": "param_value", -} - -METRIC_FILTER = SingleFilter( - SingleFilterBy.METRIC, - "metric_key", - SingleFilterOperator.EQUAL_TO, - "metric_value", -) -METRIC_FILTER_BODY = { - "by": "metric", - "key": "metric_key", - "operator": "eq", - "value": "metric_value", -} - -AND_FILTER = CompoundFilter( - operator=CompoundFilterOperator.AND, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), - None, - ], -) -AND_FILTER_BODY = { - "operator": "and", - "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - }, - None, - ], -} - -OR_FILTER = CompoundFilter( - operator=CompoundFilterOperator.OR, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), - None, - ], -) -OR_FILTER_BODY = { - "operator": "or", - "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - }, - None, - ], -} - -RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] -RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] - -STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] -STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] - -DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] -DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] - -PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] -PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] - -TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] -TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] - -METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] -METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] - -MULTI_SORT = [ - Sort(SortBy.PARAM, "param_key", SortOrder.ASC), - Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), -] -MULTI_SORT_BODY = [ - {"by": "param", "key": "param_key", "order": "asc"}, - {"by": "runNumber", "key": None, "order": "desc"}, -] - - -@pytest.mark.parametrize( - "pfilter,pfilter_body", - [ - [None, None], - [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], - [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], - [RUN_ID_FILTER, RUN_ID_BODY], - [DELETED_AT_FILTER, DELETED_AT_BODY], - [TAG_FILTER, TAG_FILTER_BODY], - [PARAM_FILTER, PARAM_FILTER_BODY], - [METRIC_FILTER, METRIC_FILTER_BODY], - [AND_FILTER, AND_FILTER_BODY], - ], -) -@pytest.mark.parametrize( - "psort,psort_body", - [ - [None, None], - [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], - [STARTED_AT_SORT, STARTED_AT_SORT_BODY], - [DURATION_SORT, DURATION_SORT_BODY], - [PARAM_SORT, PARAM_SORT_BODY], - [TAG_SORT, TAG_SORT_BODY], - [METRIC_SORT, METRIC_SORT_BODY], - [MULTI_SORT, MULTI_SORT_BODY], - ], -) -def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): - queryRunsObj = QueryRuns(pfilter, psort, PAGE) - expected_json = { - "filter": pfilter_body, - "sort": psort_body, - "page": PAGE_BODY, - } - data = QueryRunsSchema().dump(queryRunsObj) - assert data == expected_json - +# PROJECT_ID_FILTER = SingleFilter( +# SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID +# ) +# PROJECT_ID_FILTER_BODY = { +# "by": "projectId", +# "key": None, +# "operator": "eq", +# "value": str(PROJECT_ID), +# } + +# EXPERIMENT_ID_FILTER = SingleFilter( +# SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 +# ) +# EXPERIMENT_ID_BODY = { +# "by": "experimentId", +# "key": None, +# "operator": "eq", +# "value": "1", +# } + +# RUN_ID_FILTER = SingleFilter( +# SingleFilterBy.RUN_ID, +# None, +# SingleFilterOperator.EQUAL_TO, +# EXPERIMENT_RUN_ID, +# ) +# RUN_ID_BODY = { +# "by": "runId", +# "key": None, +# "operator": "eq", +# "value": str(EXPERIMENT_RUN_ID), +# } + +# DELETED_AT_FILTER = SingleFilter( +# SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT +# ) +# DELETED_AT_BODY = { +# "by": "deletedAt", +# "key": None, +# "operator": "eq", +# "value": DELETED_AT_STRING_PYTHON, +# } + +# TAG_FILTER = SingleFilter( +# SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" +# ) +# TAG_FILTER_BODY = { +# "by": "tag", +# "key": "tag_key", +# "operator": "eq", +# "value": "tag_value", +# } + +# PARAM_FILTER = SingleFilter( +# SingleFilterBy.PARAM, +# "param_key", +# SingleFilterOperator.EQUAL_TO, +# "param_value", +# ) +# PARAM_FILTER_BODY = { +# "by": "param", +# "key": "param_key", +# "operator": "eq", +# "value": "param_value", +# } + +# METRIC_FILTER = SingleFilter( +# SingleFilterBy.METRIC, +# "metric_key", +# SingleFilterOperator.EQUAL_TO, +# "metric_value", +# ) +# METRIC_FILTER_BODY = { +# "by": "metric", +# "key": "metric_key", +# "operator": "eq", +# "value": "metric_value", +# } + +# AND_FILTER = CompoundFilter( +# operator=CompoundFilterOperator.AND, +# conditions=[ +# SingleFilter( +# SingleFilterBy.TAG, +# "tag_key", +# SingleFilterOperator.EQUAL_TO, +# "tag_value", +# ), +# None, +# ], +# ) +# AND_FILTER_BODY = { +# "operator": "and", +# "conditions": [ +# { +# "by": "tag", +# "key": "tag_key", +# "operator": "eq", +# "value": "tag_value", +# }, +# None, +# ], +# } + +# OR_FILTER = CompoundFilter( +# operator=CompoundFilterOperator.OR, +# conditions=[ +# SingleFilter( +# SingleFilterBy.TAG, +# "tag_key", +# SingleFilterOperator.EQUAL_TO, +# "tag_value", +# ), +# None, +# ], +# ) +# OR_FILTER_BODY = { +# "operator": "or", +# "conditions": [ +# { +# "by": "tag", +# "key": "tag_key", +# "operator": "eq", +# "value": "tag_value", +# }, +# None, +# ], +# } + +# RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] +# RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] + +# STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] +# STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] + +# DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] +# DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] + +# PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] +# PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] + +# TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] +# TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] + +# METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] +# METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] + +# MULTI_SORT = [ +# Sort(SortBy.PARAM, "param_key", SortOrder.ASC), +# Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), +# ] +# MULTI_SORT_BODY = [ +# {"by": "param", "key": "param_key", "order": "asc"}, +# {"by": "runNumber", "key": None, "order": "desc"}, +# ] + + +# @pytest.mark.parametrize( +# "pfilter,pfilter_body", +# [ +# [None, None], +# [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], +# [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], +# [RUN_ID_FILTER, RUN_ID_BODY], +# [DELETED_AT_FILTER, DELETED_AT_BODY], +# [TAG_FILTER, TAG_FILTER_BODY], +# [PARAM_FILTER, PARAM_FILTER_BODY], +# [METRIC_FILTER, METRIC_FILTER_BODY], +# [AND_FILTER, AND_FILTER_BODY], +# ], +# ) +# @pytest.mark.parametrize( +# "psort,psort_body", +# [ +# [None, None], +# [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], +# [STARTED_AT_SORT, STARTED_AT_SORT_BODY], +# [DURATION_SORT, DURATION_SORT_BODY], +# [PARAM_SORT, PARAM_SORT_BODY], +# [TAG_SORT, TAG_SORT_BODY], +# [METRIC_SORT, METRIC_SORT_BODY], +# [MULTI_SORT, MULTI_SORT_BODY], +# ], +# ) +# def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): +# queryRunsObj = QueryRuns(pfilter, psort, PAGE) +# expected_json = { +# "filter": pfilter_body, +# "sort": psort_body, +# "page": PAGE_BODY, +# } +# data = QueryRunsSchema().dump(queryRunsObj) +# assert data == expected_json + +def test_sort_validation(mocker): + with pytest.raises( + ValueError, match="key must be none for type {}".format(SortBy.RUN_NUMBER) + ): + Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) @pytest.mark.parametrize("description", [None, "experiment description"]) @pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) From 761e970d9acd32b982c9c3bafa875d11315a2629 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 13 May 2019 17:35:50 +0100 Subject: [PATCH 11/52] Fix validation in Sort constructor --- faculty/clients/experiment.py | 7 +- tests/clients/test_experiment.py | 396 +++++++++++++++---------------- 2 files changed, 202 insertions(+), 201 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 43b0c83a..816b3925 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -138,15 +138,17 @@ class CompoundFilterOperator(Enum): _Sort = namedtuple("_Sort", ["by", "key", "order"]) + class Sort(_Sort): - def __new__(self, by, key, order): + def __new__(cls, by, key, order): if by in [SortBy.STARTED_AT, SortBy.RUN_NUMBER, SortBy.DURATION] and key is not None: raise ValueError("key must be none for type {}".format(by)) elif by in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] and key is None: raise ValueError("key must not be none for type {}".format(by)) - self = super(Sort, self).__init__(by, key, order) + self = super(Sort, cls).__new__(cls, by, key, order) return self + class SortBy(Enum): STARTED_AT = "startedAt" RUN_NUMBER = "runNumber" @@ -343,7 +345,6 @@ def _serialize(self, value, attr, obj, **kwargs): elif obj.by == SingleFilterBy.DELETED_AT: return marshmallow_utils.from_iso_datetime(str(value)).isoformat() else: - print(value) raise RunQueryFilterValidation( "Validation error serialising run query filter", value ) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index b37b486e..4192c6f4 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -340,210 +340,210 @@ def test_experiment_run_data_schema_multiple(): assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} -# check that all three fields are nullable -# check all types of filter -# check all types of sort - -# PROJECT_ID_FILTER = SingleFilter( -# SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID -# ) -# PROJECT_ID_FILTER_BODY = { -# "by": "projectId", -# "key": None, -# "operator": "eq", -# "value": str(PROJECT_ID), -# } - -# EXPERIMENT_ID_FILTER = SingleFilter( -# SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 -# ) -# EXPERIMENT_ID_BODY = { -# "by": "experimentId", -# "key": None, -# "operator": "eq", -# "value": "1", -# } - -# RUN_ID_FILTER = SingleFilter( -# SingleFilterBy.RUN_ID, -# None, -# SingleFilterOperator.EQUAL_TO, -# EXPERIMENT_RUN_ID, -# ) -# RUN_ID_BODY = { -# "by": "runId", -# "key": None, -# "operator": "eq", -# "value": str(EXPERIMENT_RUN_ID), -# } - -# DELETED_AT_FILTER = SingleFilter( -# SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT -# ) -# DELETED_AT_BODY = { -# "by": "deletedAt", -# "key": None, -# "operator": "eq", -# "value": DELETED_AT_STRING_PYTHON, -# } - -# TAG_FILTER = SingleFilter( -# SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" -# ) -# TAG_FILTER_BODY = { -# "by": "tag", -# "key": "tag_key", -# "operator": "eq", -# "value": "tag_value", -# } - -# PARAM_FILTER = SingleFilter( -# SingleFilterBy.PARAM, -# "param_key", -# SingleFilterOperator.EQUAL_TO, -# "param_value", -# ) -# PARAM_FILTER_BODY = { -# "by": "param", -# "key": "param_key", -# "operator": "eq", -# "value": "param_value", -# } - -# METRIC_FILTER = SingleFilter( -# SingleFilterBy.METRIC, -# "metric_key", -# SingleFilterOperator.EQUAL_TO, -# "metric_value", -# ) -# METRIC_FILTER_BODY = { -# "by": "metric", -# "key": "metric_key", -# "operator": "eq", -# "value": "metric_value", -# } - -# AND_FILTER = CompoundFilter( -# operator=CompoundFilterOperator.AND, -# conditions=[ -# SingleFilter( -# SingleFilterBy.TAG, -# "tag_key", -# SingleFilterOperator.EQUAL_TO, -# "tag_value", -# ), -# None, -# ], -# ) -# AND_FILTER_BODY = { -# "operator": "and", -# "conditions": [ -# { -# "by": "tag", -# "key": "tag_key", -# "operator": "eq", -# "value": "tag_value", -# }, -# None, -# ], -# } - -# OR_FILTER = CompoundFilter( -# operator=CompoundFilterOperator.OR, -# conditions=[ -# SingleFilter( -# SingleFilterBy.TAG, -# "tag_key", -# SingleFilterOperator.EQUAL_TO, -# "tag_value", -# ), -# None, -# ], -# ) -# OR_FILTER_BODY = { -# "operator": "or", -# "conditions": [ -# { -# "by": "tag", -# "key": "tag_key", -# "operator": "eq", -# "value": "tag_value", -# }, -# None, -# ], -# } - -# RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] -# RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] - -# STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] -# STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] - -# DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] -# DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] - -# PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] -# PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] - -# TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] -# TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] - -# METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] -# METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] - -# MULTI_SORT = [ -# Sort(SortBy.PARAM, "param_key", SortOrder.ASC), -# Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), -# ] -# MULTI_SORT_BODY = [ -# {"by": "param", "key": "param_key", "order": "asc"}, -# {"by": "runNumber", "key": None, "order": "desc"}, -# ] - - -# @pytest.mark.parametrize( -# "pfilter,pfilter_body", -# [ -# [None, None], -# [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], -# [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], -# [RUN_ID_FILTER, RUN_ID_BODY], -# [DELETED_AT_FILTER, DELETED_AT_BODY], -# [TAG_FILTER, TAG_FILTER_BODY], -# [PARAM_FILTER, PARAM_FILTER_BODY], -# [METRIC_FILTER, METRIC_FILTER_BODY], -# [AND_FILTER, AND_FILTER_BODY], -# ], -# ) -# @pytest.mark.parametrize( -# "psort,psort_body", -# [ -# [None, None], -# [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], -# [STARTED_AT_SORT, STARTED_AT_SORT_BODY], -# [DURATION_SORT, DURATION_SORT_BODY], -# [PARAM_SORT, PARAM_SORT_BODY], -# [TAG_SORT, TAG_SORT_BODY], -# [METRIC_SORT, METRIC_SORT_BODY], -# [MULTI_SORT, MULTI_SORT_BODY], -# ], -# ) -# def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): -# queryRunsObj = QueryRuns(pfilter, psort, PAGE) -# expected_json = { -# "filter": pfilter_body, -# "sort": psort_body, -# "page": PAGE_BODY, -# } -# data = QueryRunsSchema().dump(queryRunsObj) -# assert data == expected_json +PROJECT_ID_FILTER = SingleFilter( + SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID +) +PROJECT_ID_FILTER_BODY = { + "by": "projectId", + "key": None, + "operator": "eq", + "value": str(PROJECT_ID), +} + +EXPERIMENT_ID_FILTER = SingleFilter( + SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 +) +EXPERIMENT_ID_BODY = { + "by": "experimentId", + "key": None, + "operator": "eq", + "value": "1", +} + +RUN_ID_FILTER = SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + EXPERIMENT_RUN_ID, +) +RUN_ID_BODY = { + "by": "runId", + "key": None, + "operator": "eq", + "value": str(EXPERIMENT_RUN_ID), +} + +DELETED_AT_FILTER = SingleFilter( + SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT +) +DELETED_AT_BODY = { + "by": "deletedAt", + "key": None, + "operator": "eq", + "value": DELETED_AT_STRING_PYTHON, +} + +TAG_FILTER = SingleFilter( + SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" +) +TAG_FILTER_BODY = { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value", +} + +PARAM_FILTER = SingleFilter( + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.EQUAL_TO, + "param_value", +) +PARAM_FILTER_BODY = { + "by": "param", + "key": "param_key", + "operator": "eq", + "value": "param_value", +} + +METRIC_FILTER = SingleFilter( + SingleFilterBy.METRIC, + "metric_key", + SingleFilterOperator.EQUAL_TO, + "metric_value", +) +METRIC_FILTER_BODY = { + "by": "metric", + "key": "metric_key", + "operator": "eq", + "value": "metric_value", +} + +AND_FILTER = CompoundFilter( + operator=CompoundFilterOperator.AND, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], +) +AND_FILTER_BODY = { + "operator": "and", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value", + }, + None, + ], +} + +OR_FILTER = CompoundFilter( + operator=CompoundFilterOperator.OR, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], +) +OR_FILTER_BODY = { + "operator": "or", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value", + }, + None, + ], +} + +RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] +RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] + +STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] +STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] + +DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] +DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] + +PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] +PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] + +TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] +TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] + +METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] +METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] + +MULTI_SORT = [ + Sort(SortBy.PARAM, "param_key", SortOrder.ASC), + Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), +] +MULTI_SORT_BODY = [ + {"by": "param", "key": "param_key", "order": "asc"}, + {"by": "runNumber", "key": None, "order": "desc"}, +] + + +@pytest.mark.parametrize( + "pfilter,pfilter_body", + [ + [None, None], + [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], + [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], + [RUN_ID_FILTER, RUN_ID_BODY], + [DELETED_AT_FILTER, DELETED_AT_BODY], + [TAG_FILTER, TAG_FILTER_BODY], + [PARAM_FILTER, PARAM_FILTER_BODY], + [METRIC_FILTER, METRIC_FILTER_BODY], + [AND_FILTER, AND_FILTER_BODY], + ], +) +@pytest.mark.parametrize( + "psort,psort_body", + [ + [None, None], + [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], + [STARTED_AT_SORT, STARTED_AT_SORT_BODY], + [DURATION_SORT, DURATION_SORT_BODY], + [PARAM_SORT, PARAM_SORT_BODY], + [TAG_SORT, TAG_SORT_BODY], + [METRIC_SORT, METRIC_SORT_BODY], + [MULTI_SORT, MULTI_SORT_BODY], + ], +) +def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): + queryRunsObj = QueryRuns(pfilter, psort, PAGE) + s = Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC) + expected_json = { + "filter": pfilter_body, + "sort": psort_body, + "page": PAGE_BODY, + } + data = QueryRunsSchema().dump(queryRunsObj) + assert data == expected_json + def test_sort_validation(mocker): with pytest.raises( - ValueError, match="key must be none for type {}".format(SortBy.RUN_NUMBER) + ValueError, + match="key must be none for type {}".format(SortBy.RUN_NUMBER) ): Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) + @pytest.mark.parametrize("description", [None, "experiment description"]) @pytest.mark.parametrize("artifact_location", [None, "s3://mybucket"]) def test_experiment_client_create(mocker, description, artifact_location): From 5ba48da697f9a1981e0f9615c53200c9267f5920 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 13 May 2019 18:11:53 +0100 Subject: [PATCH 12/52] Add validation for filters --- faculty/clients/experiment.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 816b3925..9c77e8a7 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -106,7 +106,16 @@ class ExperimentRunStatus(Enum): "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] ) -SingleFilter = namedtuple("SingleFilter", ["by", "key", "operator", "value"]) +_SingleFilter = namedtuple("_SingleFilter", ["by", "key", "operator", "value"]) + +class SingleFilter(_SingleFilter): + def __new__(cls, by, key, operator, value): + if by.has_key() and key is None: + raise ValueError("key must not be none for a {} filter".format(by)) + elif not by.has_key() and key is not None: + raise ValueError("key must be none for a {} filter".format(by)) + return super(SingleFilter, cls).__new__(cls, by, key, operator, value) + CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) @@ -130,6 +139,9 @@ class SingleFilterBy(Enum): PARAM = "param" METRIC = "metric" + def has_key(self): + return self in [SingleFilterBy.TAG, SingleFilterBy.PARAM, SingleFilterBy.METRIC] + class CompoundFilterOperator(Enum): AND = "and" @@ -141,12 +153,11 @@ class CompoundFilterOperator(Enum): class Sort(_Sort): def __new__(cls, by, key, order): - if by in [SortBy.STARTED_AT, SortBy.RUN_NUMBER, SortBy.DURATION] and key is not None: + if by.has_key() and key is None: raise ValueError("key must be none for type {}".format(by)) - elif by in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] and key is None: - raise ValueError("key must not be none for type {}".format(by)) - self = super(Sort, cls).__new__(cls, by, key, order) - return self + elif not by.has_key() and key is not None: + raise ValueError("key must be none for type {}".format(by)) + return super(Sort, cls).__new__(cls, by, key, order) class SortBy(Enum): @@ -157,6 +168,9 @@ class SortBy(Enum): PARAM = "param" METRIC = "metric" + def has_key(self): + return self in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] + class SortOrder(Enum): ASC = "asc" From 38993eebf37f47790c432288e36e4e1157b2d507 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 09:06:02 +0100 Subject: [PATCH 13/52] Reformat code and address warning --- faculty/clients/experiment.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 9c77e8a7..b09c0f27 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -108,11 +108,12 @@ class ExperimentRunStatus(Enum): _SingleFilter = namedtuple("_SingleFilter", ["by", "key", "operator", "value"]) + class SingleFilter(_SingleFilter): def __new__(cls, by, key, operator, value): - if by.has_key() and key is None: + if by.filter_has_key() and key is None: raise ValueError("key must not be none for a {} filter".format(by)) - elif not by.has_key() and key is not None: + elif not by.filter_has_key() and key is not None: raise ValueError("key must be none for a {} filter".format(by)) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -139,8 +140,12 @@ class SingleFilterBy(Enum): PARAM = "param" METRIC = "metric" - def has_key(self): - return self in [SingleFilterBy.TAG, SingleFilterBy.PARAM, SingleFilterBy.METRIC] + def filter_has_key(self): + return self in { + SingleFilterBy.TAG, + SingleFilterBy.PARAM, + SingleFilterBy.METRIC + } class CompoundFilterOperator(Enum): @@ -153,9 +158,9 @@ class CompoundFilterOperator(Enum): class Sort(_Sort): def __new__(cls, by, key, order): - if by.has_key() and key is None: + if by.filter_has_key() and key is None: raise ValueError("key must be none for type {}".format(by)) - elif not by.has_key() and key is not None: + elif not by.filter_has_key() and key is not None: raise ValueError("key must be none for type {}".format(by)) return super(Sort, cls).__new__(cls, by, key, order) @@ -168,8 +173,8 @@ class SortBy(Enum): PARAM = "param" METRIC = "metric" - def has_key(self): - return self in [SortBy.TAG, SortBy.PARAM, SortBy.METRIC] + def filter_has_key(self): + return self in {SortBy.TAG, SortBy.PARAM, SortBy.METRIC} class SortOrder(Enum): From 282dd2605028311ffb79b454001afe542e0fdad3 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 09:08:03 +0100 Subject: [PATCH 14/52] Remove unused variable --- faculty/clients/experiment.py | 2 +- tests/clients/test_experiment.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index b09c0f27..eef581c8 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -144,7 +144,7 @@ def filter_has_key(self): return self in { SingleFilterBy.TAG, SingleFilterBy.PARAM, - SingleFilterBy.METRIC + SingleFilterBy.METRIC, } diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 4192c6f4..5e6275f2 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -526,7 +526,6 @@ def test_experiment_run_data_schema_multiple(): ) def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): queryRunsObj = QueryRuns(pfilter, psort, PAGE) - s = Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC) expected_json = { "filter": pfilter_body, "sort": psort_body, @@ -539,7 +538,7 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): def test_sort_validation(mocker): with pytest.raises( ValueError, - match="key must be none for type {}".format(SortBy.RUN_NUMBER) + match="key must be none for type {}".format(SortBy.RUN_NUMBER), ): Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) From 7c414f91498e25fa1d398d47905605b932f59bda Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 09:10:32 +0100 Subject: [PATCH 15/52] Rename helper methods --- faculty/clients/experiment.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index eef581c8..a5f51bc3 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -111,9 +111,9 @@ class ExperimentRunStatus(Enum): class SingleFilter(_SingleFilter): def __new__(cls, by, key, operator, value): - if by.filter_has_key() and key is None: + if by.needs_key() and key is None: raise ValueError("key must not be none for a {} filter".format(by)) - elif not by.filter_has_key() and key is not None: + elif not by.needs_key() and key is not None: raise ValueError("key must be none for a {} filter".format(by)) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -140,7 +140,7 @@ class SingleFilterBy(Enum): PARAM = "param" METRIC = "metric" - def filter_has_key(self): + def needs_key(self): return self in { SingleFilterBy.TAG, SingleFilterBy.PARAM, @@ -158,9 +158,9 @@ class CompoundFilterOperator(Enum): class Sort(_Sort): def __new__(cls, by, key, order): - if by.filter_has_key() and key is None: + if by.needs_key() and key is None: raise ValueError("key must be none for type {}".format(by)) - elif not by.filter_has_key() and key is not None: + elif not by.needs_key() and key is not None: raise ValueError("key must be none for type {}".format(by)) return super(Sort, cls).__new__(cls, by, key, order) @@ -173,7 +173,7 @@ class SortBy(Enum): PARAM = "param" METRIC = "metric" - def filter_has_key(self): + def needs_key(self): return self in {SortBy.TAG, SortBy.PARAM, SortBy.METRIC} From 491cd1743477cd9b1604e54b748d2aed40c00e58 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 09:16:48 +0100 Subject: [PATCH 16/52] Use in instead of single element equality --- faculty/clients/experiment.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index a5f51bc3..1c3c42be 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -337,8 +337,7 @@ class SingleFilterValueField(fields.Field): def _is_valid_uuid(self, value, obj): return isinstance(value, uuid.UUID) and ( - obj.by == SingleFilterBy.PROJECT_ID - or obj.by == SingleFilterBy.RUN_ID + obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID} ) def _is_valid_experiment_id(self, value, obj): @@ -350,9 +349,12 @@ def _is_directly_stringifiable(self, value, obj): return ( self._is_valid_uuid(value, obj) or self._is_valid_experiment_id(value, obj) - or obj.by == SingleFilterBy.TAG - or obj.by == SingleFilterBy.PARAM - or obj.by == SingleFilterBy.METRIC + or obj.by + in { + SingleFilterBy.TAG, + SingleFilterBy.PARAM, + SingleFilterBy.METRIC, + } ) def _deserialize(self, value, attr, obj, **kwargs): From 8c371186b845f8dbabcf4c859e5c9aeffd046f34 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 10:08:44 +0100 Subject: [PATCH 17/52] Add tests to filter and sort validation --- faculty/clients/experiment.py | 12 +++++++---- tests/clients/test_experiment.py | 34 +++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1c3c42be..61199285 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -112,9 +112,11 @@ class ExperimentRunStatus(Enum): class SingleFilter(_SingleFilter): def __new__(cls, by, key, operator, value): if by.needs_key() and key is None: - raise ValueError("key must not be none for a {} filter".format(by)) + raise ValueError( + "key must not be none for filter type {}".format(by) + ) elif not by.needs_key() and key is not None: - raise ValueError("key must be none for a {} filter".format(by)) + raise ValueError("key must be none for filter type {}".format(by)) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -159,9 +161,11 @@ class CompoundFilterOperator(Enum): class Sort(_Sort): def __new__(cls, by, key, order): if by.needs_key() and key is None: - raise ValueError("key must be none for type {}".format(by)) + raise ValueError( + "key must not be none for sort type {}".format(by) + ) elif not by.needs_key() and key is not None: - raise ValueError("key must be none for type {}".format(by)) + raise ValueError("key must be none for sort type {}".format(by)) return super(Sort, cls).__new__(cls, by, key, order) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 5e6275f2..c300cd62 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -535,12 +535,44 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): assert data == expected_json +def test_single_filter_validation(mocker): + with pytest.raises( + ValueError, + match="key must be none for filter type {}".format( + SingleFilterBy.PROJECT_ID + ), + ): + SingleFilter( + SingleFilterBy.PROJECT_ID, + "invalid_key", + SingleFilterOperator.EQUAL_TO, + PROJECT_ID, + ) + with pytest.raises( + ValueError, + match="key must not be none for filter type {}".format( + SingleFilterBy.TAG + ), + ): + SingleFilter( + SingleFilterBy.TAG, + None, + SingleFilterOperator.EQUAL_TO, + "tag_value", + ) + + def test_sort_validation(mocker): with pytest.raises( ValueError, - match="key must be none for type {}".format(SortBy.RUN_NUMBER), + match="key must be none for sort type {}".format(SortBy.RUN_NUMBER), ): Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) + with pytest.raises( + ValueError, + match="key must not be none for sort type {}".format(SortBy.TAG), + ): + Sort(SortBy.TAG, None, SortOrder.ASC) @pytest.mark.parametrize("description", [None, "experiment description"]) From 9f794a9a33c9de89ff980db6fc1cb843b9c43d5c Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Tue, 14 May 2019 11:54:27 +0100 Subject: [PATCH 18/52] Add query_runs and test --- faculty/clients/experiment.py | 46 ++++++++++++++++++++++++++++++++ tests/clients/test_experiment.py | 46 +++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 61199285..5ccb83c0 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -708,6 +708,52 @@ def list_runs( endpoint, ListExperimentRunsResponseSchema(), params=query_params ) + def query_runs( + self, + project_id, + filter=None, + sort=None, + start=None, + limit=None, + ): + """Query experiment runs. + + This method returns pages of runs. If less than the full number of runs + for the job is returned, the ``next`` page of the returned response + object will not be ``None``: + + >>> response = client.query_runs(project_id) + >>> response.pagination.next + Page(start=10, limit=10) + + Get all experiment runs by making successive calls to ``query_runs``, + passing the ``start`` and ``limit`` of the ``next`` page each time + until ``next`` is returned as ``None``. + + Parameters + ---------- + project_id : uuid.UUID + filter: either a SingleFilter object or a CompoundFilter object, optional + To filter runs of experiments with the given filter. By default, runs + from all experiments are returned. + sort: List[Sort], optional + Runs are ordered using sorting elements lexicographically. By default, + experiment runs are sorted by their startedAt value. + start : int, optional + The (zero-indexed) starting point of runs to retrieve. + limit : int, optional + The maximum number of runs to retrieve. + + Returns + ------- + ListExperimentRunsResponse + """ + endpoint = "/project/{}/run/query".format(project_id) + # runs_query = QueryRuns(filter, sort, Page(start, limit)) + payload = QueryRunsSchema().dump(QueryRuns(filter, sort, Page(start, limit))) + return self._post(endpoint, ListExperimentRunsResponseSchema(), json=payload) + + def log_run_data( self, project_id, run_id, metrics=None, params=None, tags=None ): diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c300cd62..da6a2c5b 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -536,7 +536,7 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): def test_single_filter_validation(mocker): - with pytest.raises( + with pytest.raises( ValueError, match="key must be none for filter type {}".format( SingleFilterBy.PROJECT_ID @@ -908,6 +908,50 @@ def test_experiment_client_list_runs_page(mocker): ) +def test_experiment_client_query_runs(mocker): + mocker.patch.object( + ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + ) + response_schema_mock = mocker.patch( + "faculty.clients.experiment.ListExperimentRunsResponseSchema" + ) + request_schema_mock = mocker.patch( + "faculty.clients.experiment.QueryRunsSchema" + ) + dump_mock = request_schema_mock.return_value.dump + + test_filter = SingleFilter( + SingleFilterBy.EXPERIMENT_ID, + None, + SingleFilterOperator.EQUAL_TO, + "2" + ) + test_sort = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)] + + client = ExperimentClient(mocker.Mock()) + list_result = client.query_runs( + PROJECT_ID, + filter=test_filter, + sort=test_sort, + start=20, + limit=10 + ) + + assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE + + request_schema_mock.assert_called_once_with() + dump_mock.assert_called_once_with( + QueryRuns(test_filter, test_sort, Page(20, 10)) + ) + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=dump_mock.return_value, + ) + + + def test_log_run_data(mocker): mocker.patch.object(ExperimentClient, "_patch_raw") run_data_schema_mock = mocker.patch( From dae0209ce1aef6f118d0184c334a439f21092b71 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 14 May 2019 11:59:36 +0100 Subject: [PATCH 19/52] Reformat code correctly --- faculty/clients/experiment.py | 26 ++++++++++++-------------- tests/clients/test_experiment.py | 16 ++++------------ 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 5ccb83c0..2aa64d21 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -709,12 +709,7 @@ def list_runs( ) def query_runs( - self, - project_id, - filter=None, - sort=None, - start=None, - limit=None, + self, project_id, filter=None, sort=None, start=None, limit=None ): """Query experiment runs. @@ -733,12 +728,12 @@ def query_runs( Parameters ---------- project_id : uuid.UUID - filter: either a SingleFilter object or a CompoundFilter object, optional - To filter runs of experiments with the given filter. By default, runs - from all experiments are returned. + filter: either SingleFilter object or CompoundFilter object, optional + To filter runs of experiments with the given filter. By default, + runs from all experiments are returned. sort: List[Sort], optional - Runs are ordered using sorting elements lexicographically. By default, - experiment runs are sorted by their startedAt value. + Runs are ordered using sorting elements lexicographically. By + default, experiment runs are sorted by their startedAt value. start : int, optional The (zero-indexed) starting point of runs to retrieve. limit : int, optional @@ -750,9 +745,12 @@ def query_runs( """ endpoint = "/project/{}/run/query".format(project_id) # runs_query = QueryRuns(filter, sort, Page(start, limit)) - payload = QueryRunsSchema().dump(QueryRuns(filter, sort, Page(start, limit))) - return self._post(endpoint, ListExperimentRunsResponseSchema(), json=payload) - + payload = QueryRunsSchema().dump( + QueryRuns(filter, sort, Page(start, limit)) + ) + return self._post( + endpoint, ListExperimentRunsResponseSchema(), json=payload + ) def log_run_data( self, project_id, run_id, metrics=None, params=None, tags=None diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index da6a2c5b..569f2cec 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -536,7 +536,7 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): def test_single_filter_validation(mocker): - with pytest.raises( + with pytest.raises( ValueError, match="key must be none for filter type {}".format( SingleFilterBy.PROJECT_ID @@ -921,27 +921,20 @@ def test_experiment_client_query_runs(mocker): dump_mock = request_schema_mock.return_value.dump test_filter = SingleFilter( - SingleFilterBy.EXPERIMENT_ID, - None, - SingleFilterOperator.EQUAL_TO, - "2" + SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, "2" ) test_sort = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)] client = ExperimentClient(mocker.Mock()) list_result = client.query_runs( - PROJECT_ID, - filter=test_filter, - sort=test_sort, - start=20, - limit=10 + PROJECT_ID, filter=test_filter, sort=test_sort, start=20, limit=10 ) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE request_schema_mock.assert_called_once_with() dump_mock.assert_called_once_with( - QueryRuns(test_filter, test_sort, Page(20, 10)) + QueryRuns(test_filter, test_sort, Page(20, 10)) ) response_schema_mock.assert_called_once_with() ExperimentClient._post.assert_called_once_with( @@ -951,7 +944,6 @@ def test_experiment_client_query_runs(mocker): ) - def test_log_run_data(mocker): mocker.patch.object(ExperimentClient, "_patch_raw") run_data_schema_mock = mocker.patch( From 5767a87775cd3755c6007f7883d116117cda736d Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Tue, 14 May 2019 15:04:32 +0100 Subject: [PATCH 20/52] Modify deserialise function in SingleFilterValueField --- faculty/clients/experiment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 2aa64d21..df745d5e 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -361,8 +361,8 @@ def _is_directly_stringifiable(self, value, obj): } ) - def _deserialize(self, value, attr, obj, **kwargs): - pass + def _deserialize(self, value, attr, data, **kwargs): + return value def _serialize(self, value, attr, obj, **kwargs): if self._is_directly_stringifiable(value, obj): From faabf2a635f98ed26eb3b5a404cb8eeb5ac3acb8 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Wed, 15 May 2019 12:03:23 +0100 Subject: [PATCH 21/52] Only create page if both start and limit are not None --- faculty/clients/experiment.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index df745d5e..1c436d94 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -744,9 +744,11 @@ def query_runs( ListExperimentRunsResponse """ endpoint = "/project/{}/run/query".format(project_id) - # runs_query = QueryRuns(filter, sort, Page(start, limit)) + page = None + if start is not None and limit is not None: + page = Page(start, limit) payload = QueryRunsSchema().dump( - QueryRuns(filter, sort, Page(start, limit)) + QueryRuns(filter, sort, page) ) return self._post( endpoint, ListExperimentRunsResponseSchema(), json=payload From cbfd76163153079ce248d489e1cba7a09d041c33 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Wed, 15 May 2019 14:47:06 +0100 Subject: [PATCH 22/52] Throw a RunQueryFilterValidation error during serialisation if None is passed to conditions --- faculty/clients/experiment.py | 11 +++++++--- tests/clients/test_experiment.py | 35 ++++++++++++++++++++++++++------ 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1c436d94..741396bd 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -380,7 +380,7 @@ class FilterField(fields.Field): Field that serialises/deserialises a run filter. """ - def _deserialize(self, value, attr, obj, **kwargs): + def _deserialize(self, value, attr, data, **kwargs): if value is None: return None elif isinstance(value, SingleFilter): @@ -389,8 +389,13 @@ def _deserialize(self, value, attr, obj, **kwargs): return CompoundFilterSchema().load(value) def _serialize(self, value, attr, obj, **kwargs): - if value is None: + print(type(obj)) + if value is None and isinstance(obj, QueryRuns): return None + elif value is None: + raise RunQueryFilterValidation( + "Validation error serialising a None filter", value + ) if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) else: @@ -406,7 +411,7 @@ class SingleFilterSchema(BaseSchema): class CompoundFilterSchema(BaseSchema): operator = EnumField(CompoundFilterOperator, by_value=True, required=True) - conditions = fields.List(FilterField()) + conditions = fields.List(FilterField(skip_if=None)) class SortSchema(BaseSchema): diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 569f2cec..0bd43920 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -53,6 +53,7 @@ QueryRunsSchema, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, + RunQueryFilterValidation, SingleFilter, SingleFilterBy, SingleFilterOperator, @@ -353,7 +354,7 @@ def test_experiment_run_data_schema_multiple(): EXPERIMENT_ID_FILTER = SingleFilter( SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 ) -EXPERIMENT_ID_BODY = { +EXPERIMENT_ID_FILTER_BODY = { "by": "experimentId", "key": None, "operator": "eq", @@ -428,7 +429,6 @@ def test_experiment_run_data_schema_multiple(): SingleFilterOperator.EQUAL_TO, "tag_value", ), - None, ], ) AND_FILTER_BODY = { @@ -440,7 +440,6 @@ def test_experiment_run_data_schema_multiple(): "operator": "eq", "value": "tag_value", }, - None, ], } @@ -453,7 +452,6 @@ def test_experiment_run_data_schema_multiple(): SingleFilterOperator.EQUAL_TO, "tag_value", ), - None, ], ) OR_FILTER_BODY = { @@ -465,7 +463,6 @@ def test_experiment_run_data_schema_multiple(): "operator": "eq", "value": "tag_value", }, - None, ], } @@ -502,13 +499,14 @@ def test_experiment_run_data_schema_multiple(): [ [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], - [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_BODY], + [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_FILTER_BODY], [RUN_ID_FILTER, RUN_ID_BODY], [DELETED_AT_FILTER, DELETED_AT_BODY], [TAG_FILTER, TAG_FILTER_BODY], [PARAM_FILTER, PARAM_FILTER_BODY], [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY], + [OR_FILTER, OR_FILTER_BODY], ], ) @pytest.mark.parametrize( @@ -562,6 +560,31 @@ def test_single_filter_validation(mocker): ) +def test_compound_filter_validation(mocker): + with pytest.raises( + RunQueryFilterValidation, + match="Validation error serialising a None filter" + ): + filter = CompoundFilter( + operator=CompoundFilterOperator.OR, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None + ], + ) + + queryRunsObj = QueryRuns(filter, None, None) + QueryRunsSchema().dump(queryRunsObj) + + + + + def test_sort_validation(mocker): with pytest.raises( ValueError, From 752f2605576c267b02e13fcdeece2f989c86d9c1 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Wed, 15 May 2019 15:07:42 +0100 Subject: [PATCH 23/52] Combine AND and OR filter in test_experiment --- tests/clients/test_experiment.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 0bd43920..00b92b75 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -451,7 +451,18 @@ def test_experiment_run_data_schema_multiple(): "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value", - ), + ), + CompoundFilter( + operator=CompoundFilterOperator.AND, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + ], + ) ], ) OR_FILTER_BODY = { @@ -463,6 +474,17 @@ def test_experiment_run_data_schema_multiple(): "operator": "eq", "value": "tag_value", }, + { + "operator": "and", + "conditions": [ + { + "by": "tag", + "key": "tag_key", + "operator": "eq", + "value": "tag_value", + }, + ], + } ], } From 09deb7ab1a052b28730ee3c9078c0c47c9ec5564 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Wed, 15 May 2019 15:12:53 +0100 Subject: [PATCH 24/52] ename RUN_ID_BODY AND DELETED_AT_BODY to RUN_ID_FILTER_BODY and DELETED_AT_FILTER_BODY --- tests/clients/test_experiment.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 00b92b75..f1d09822 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -367,7 +367,7 @@ def test_experiment_run_data_schema_multiple(): SingleFilterOperator.EQUAL_TO, EXPERIMENT_RUN_ID, ) -RUN_ID_BODY = { +RUN_ID_FILTER_BODY = { "by": "runId", "key": None, "operator": "eq", @@ -377,7 +377,7 @@ def test_experiment_run_data_schema_multiple(): DELETED_AT_FILTER = SingleFilter( SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT ) -DELETED_AT_BODY = { +DELETED_AT_FILTER_BODY = { "by": "deletedAt", "key": None, "operator": "eq", @@ -522,8 +522,8 @@ def test_experiment_run_data_schema_multiple(): [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_FILTER_BODY], - [RUN_ID_FILTER, RUN_ID_BODY], - [DELETED_AT_FILTER, DELETED_AT_BODY], + [RUN_ID_FILTER, RUN_ID_FILTER_BODY], + [DELETED_AT_FILTER, DELETED_AT_FILTER_BODY], [TAG_FILTER, TAG_FILTER_BODY], [PARAM_FILTER, PARAM_FILTER_BODY], [METRIC_FILTER, METRIC_FILTER_BODY], From dde11e7cdeec533ebd923d45d255aa5f05de210c Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Wed, 15 May 2019 16:12:10 +0100 Subject: [PATCH 25/52] Reformat code using black --- faculty/clients/experiment.py | 4 +--- tests/clients/test_experiment.py | 36 +++++++++++++------------------- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 741396bd..3847e98d 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -752,9 +752,7 @@ def query_runs( page = None if start is not None and limit is not None: page = Page(start, limit) - payload = QueryRunsSchema().dump( - QueryRuns(filter, sort, page) - ) + payload = QueryRunsSchema().dump(QueryRuns(filter, sort, page)) return self._post( endpoint, ListExperimentRunsResponseSchema(), json=payload ) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index f1d09822..c39cb511 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -428,18 +428,13 @@ def test_experiment_run_data_schema_multiple(): "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value", - ), + ) ], ) AND_FILTER_BODY = { "operator": "and", "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - }, + {"by": "tag", "key": "tag_key", "operator": "eq", "value": "tag_value"} ], } @@ -451,18 +446,18 @@ def test_experiment_run_data_schema_multiple(): "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value", - ), + ), CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ) ], - ) + ), ], ) OR_FILTER_BODY = { @@ -482,9 +477,9 @@ def test_experiment_run_data_schema_multiple(): "key": "tag_key", "operator": "eq", "value": "tag_value", - }, + } ], - } + }, ], } @@ -585,7 +580,7 @@ def test_single_filter_validation(mocker): def test_compound_filter_validation(mocker): with pytest.raises( RunQueryFilterValidation, - match="Validation error serialising a None filter" + match="Validation error serialising a None filter", ): filter = CompoundFilter( operator=CompoundFilterOperator.OR, @@ -596,7 +591,7 @@ def test_compound_filter_validation(mocker): SingleFilterOperator.EQUAL_TO, "tag_value", ), - None + None, ], ) @@ -604,9 +599,6 @@ def test_compound_filter_validation(mocker): QueryRunsSchema().dump(queryRunsObj) - - - def test_sort_validation(mocker): with pytest.raises( ValueError, From f076fe5139c66e42cff0ff3cdfccb3bc81036ce2 Mon Sep 17 00:00:00 2001 From: Elias Benussi <4412300+eliasbenussi@users.noreply.github.com> Date: Wed, 22 May 2019 12:44:09 +0200 Subject: [PATCH 26/52] Update faculty/clients/experiment.py Co-Authored-By: Andrew Crozier --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 3847e98d..a3b8c77c 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -733,7 +733,7 @@ def query_runs( Parameters ---------- project_id : uuid.UUID - filter: either SingleFilter object or CompoundFilter object, optional + filter: SingleFilter or CompoundFilter, optional To filter runs of experiments with the given filter. By default, runs from all experiments are returned. sort: List[Sort], optional From 11bbdfafbb7367e87d0819e31b5ca8cc26c7eb4a Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Wed, 22 May 2019 11:54:42 +0100 Subject: [PATCH 27/52] Reduce duplication --- tests/clients/test_experiment.py | 53 ++++---------------------------- 1 file changed, 6 insertions(+), 47 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c39cb511..c88deaee 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -421,65 +421,24 @@ def test_experiment_run_data_schema_multiple(): } AND_FILTER = CompoundFilter( - operator=CompoundFilterOperator.AND, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ) - ], + operator=CompoundFilterOperator.AND, conditions=[TAG_FILTER] ) -AND_FILTER_BODY = { - "operator": "and", - "conditions": [ - {"by": "tag", "key": "tag_key", "operator": "eq", "value": "tag_value"} - ], -} +AND_FILTER_BODY = {"operator": "and", "conditions": [TAG_FILTER_BODY]} OR_FILTER = CompoundFilter( operator=CompoundFilterOperator.OR, conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), + TAG_FILTER, CompoundFilter( - operator=CompoundFilterOperator.AND, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ) - ], + operator=CompoundFilterOperator.AND, conditions=[PARAM_FILTER] ), ], ) OR_FILTER_BODY = { "operator": "or", "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - }, - { - "operator": "and", - "conditions": [ - { - "by": "tag", - "key": "tag_key", - "operator": "eq", - "value": "tag_value", - } - ], - }, + TAG_FILTER_BODY, + {"operator": "and", "conditions": [PARAM_FILTER_BODY]}, ], } From 5a094740eb374d19f671ffb63afff47d55e5f9fb Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Wed, 22 May 2019 11:54:54 +0100 Subject: [PATCH 28/52] Remove print statement --- faculty/clients/experiment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index a3b8c77c..0a0228e9 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -389,7 +389,6 @@ def _deserialize(self, value, attr, data, **kwargs): return CompoundFilterSchema().load(value) def _serialize(self, value, attr, obj, **kwargs): - print(type(obj)) if value is None and isinstance(obj, QueryRuns): return None elif value is None: From a9eea182f8efaba2b1ff832cc209dc6721abe26d Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Thu, 23 May 2019 14:22:31 +0100 Subject: [PATCH 29/52] Replace RunQueryFilterValidation with marshmallow ValidationError, utilise marshmallow serialize methods for SingleFilterValueField serialisation --- faculty/clients/experiment.py | 89 +++++++++++++++++--------------- tests/clients/test_experiment.py | 64 +++++++++++++++++------ 2 files changed, 96 insertions(+), 57 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 0a0228e9..06331820 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -14,13 +14,18 @@ from collections import namedtuple from enum import Enum -import uuid -from marshmallow import fields, post_load, utils as marshmallow_utils +from marshmallow import fields, post_load, ValidationError from marshmallow_enum import EnumField from faculty.clients.base import BaseClient, BaseSchema, Conflict +error_messages = { + "invalid_param_value": "Invalid param value", + "invalid_filter_type": "Invalid filter type", + "invalid_none_filter": "A none filter", +} + class ExperimentNameConflict(Exception): def __init__(self, name): @@ -44,12 +49,6 @@ def __init__(self, message, experiment_id): self.experiment_id = experiment_id -class RunQueryFilterValidation(Exception): - def __init__(self, message, value): - super(RunQueryFilterValidation, self).__init__(message) - self.value = value - - class ExperimentRunStatus(Enum): RUNNING = "running" FINISHED = "finished" @@ -111,12 +110,24 @@ class ExperimentRunStatus(Enum): class SingleFilter(_SingleFilter): def __new__(cls, by, key, operator, value): - if by.needs_key() and key is None: + if isinstance(by, SingleFilterBy) and by.needs_key() and key is None: raise ValueError( "key must not be none for filter type {}".format(by) ) - elif not by.needs_key() and key is not None: + elif ( + isinstance(by, SingleFilterBy) + and not by.needs_key() + and key is not None + ): raise ValueError("key must be none for filter type {}".format(by)) + elif by == SingleFilterBy.PARAM and operator.is_numeric_operator(): + if not (isinstance(value, float) or isinstance(value, int)): + raise ValueError( + ( + "value can not be type {}. It has to be either an int " + + "or a float" + ).format(type(value)) + ) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -132,6 +143,14 @@ class SingleFilterOperator(Enum): GREATER_THAN = "gt" GREATER_THAN_OR_EQUAL_TO = "ge" + def is_numeric_operator(self): + return self in { + SingleFilterOperator.LESS_THAN, + SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, + SingleFilterOperator.GREATER_THAN, + SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, + } + class SingleFilterBy(Enum): PROJECT_ID = "projectId" @@ -339,40 +358,30 @@ class SingleFilterValueField(fields.Field): Field that serialises/deserialises a run filter. """ - def _is_valid_uuid(self, value, obj): - return isinstance(value, uuid.UUID) and ( - obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID} - ) - - def _is_valid_experiment_id(self, value, obj): - return ( - isinstance(value, int) and obj.by == SingleFilterBy.EXPERIMENT_ID - ) - - def _is_directly_stringifiable(self, value, obj): - return ( - self._is_valid_uuid(value, obj) - or self._is_valid_experiment_id(value, obj) - or obj.by - in { - SingleFilterBy.TAG, - SingleFilterBy.PARAM, - SingleFilterBy.METRIC, - } - ) - def _deserialize(self, value, attr, data, **kwargs): return value def _serialize(self, value, attr, obj, **kwargs): - if self._is_directly_stringifiable(value, obj): - return str(value) + if obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID}: + field = fields.UUID() + elif obj.by == SingleFilterBy.EXPERIMENT_ID: + field = fields.Integer() elif obj.by == SingleFilterBy.DELETED_AT: - return marshmallow_utils.from_iso_datetime(str(value)).isoformat() + field = fields.DateTime() + elif obj.by == SingleFilterBy.TAG: + field = fields.Str() + elif obj.by == SingleFilterBy.PARAM: + if isinstance(obj.value, int) or isinstance(obj.value, float): + field = fields.Float() + elif isinstance(obj.value, str): + field = fields.Str() + else: + raise ValidationError(error_messages["invalid_param_value"]) + elif obj.by == SingleFilterBy.METRIC: + field = fields.Float() else: - raise RunQueryFilterValidation( - "Validation error serialising run query filter", value - ) + raise ValidationError(error_messages["invalid_filter_type"]) + return field._serialize(value, attr, obj, **kwargs) class FilterField(fields.Field): @@ -392,9 +401,7 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None and isinstance(obj, QueryRuns): return None elif value is None: - raise RunQueryFilterValidation( - "Validation error serialising a None filter", value - ) + raise ValidationError(error_messages["invalid_none_filter"]) if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) else: diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c88deaee..acee5b57 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -53,10 +53,10 @@ QueryRunsSchema, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, - RunQueryFilterValidation, SingleFilter, SingleFilterBy, SingleFilterOperator, + SingleFilterSchema, Sort, SortBy, SortOrder, @@ -358,7 +358,7 @@ def test_experiment_run_data_schema_multiple(): "by": "experimentId", "key": None, "operator": "eq", - "value": "1", + "value": 1, } RUN_ID_FILTER = SingleFilter( @@ -394,13 +394,26 @@ def test_experiment_run_data_schema_multiple(): "value": "tag_value", } -PARAM_FILTER = SingleFilter( +PARAM_NUM_FILTER = SingleFilter( + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, + 1.0, +) +PARAM_NUM_FILTER_BODY = { + "by": "param", + "key": "param_key", + "operator": "ge", + "value": 1.0, +} + +PARAM_TEXT_FILTER = SingleFilter( SingleFilterBy.PARAM, "param_key", SingleFilterOperator.EQUAL_TO, "param_value", ) -PARAM_FILTER_BODY = { +PARAM_TEXT_FILTER_BODY = { "by": "param", "key": "param_key", "operator": "eq", @@ -408,16 +421,13 @@ def test_experiment_run_data_schema_multiple(): } METRIC_FILTER = SingleFilter( - SingleFilterBy.METRIC, - "metric_key", - SingleFilterOperator.EQUAL_TO, - "metric_value", + SingleFilterBy.METRIC, "metric_key", SingleFilterOperator.EQUAL_TO, 1.0 ) METRIC_FILTER_BODY = { "by": "metric", "key": "metric_key", "operator": "eq", - "value": "metric_value", + "value": 1.0, } AND_FILTER = CompoundFilter( @@ -430,7 +440,7 @@ def test_experiment_run_data_schema_multiple(): conditions=[ TAG_FILTER, CompoundFilter( - operator=CompoundFilterOperator.AND, conditions=[PARAM_FILTER] + operator=CompoundFilterOperator.AND, conditions=[PARAM_TEXT_FILTER] ), ], ) @@ -438,7 +448,7 @@ def test_experiment_run_data_schema_multiple(): "operator": "or", "conditions": [ TAG_FILTER_BODY, - {"operator": "and", "conditions": [PARAM_FILTER_BODY]}, + {"operator": "and", "conditions": [PARAM_TEXT_FILTER_BODY]}, ], } @@ -479,7 +489,8 @@ def test_experiment_run_data_schema_multiple(): [RUN_ID_FILTER, RUN_ID_FILTER_BODY], [DELETED_AT_FILTER, DELETED_AT_FILTER_BODY], [TAG_FILTER, TAG_FILTER_BODY], - [PARAM_FILTER, PARAM_FILTER_BODY], + [PARAM_NUM_FILTER, PARAM_NUM_FILTER_BODY], + [PARAM_TEXT_FILTER, PARAM_TEXT_FILTER_BODY], [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY], [OR_FILTER, OR_FILTER_BODY], @@ -509,6 +520,17 @@ def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): assert data == expected_json +def test_single_filter_value_field_validation(mocker): + with pytest.raises(ValidationError): + singleFilterObj = SingleFilter( + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.GREATER_THAN, + True, + ) + SingleFilterSchema().dump(singleFilterObj) + + def test_single_filter_validation(mocker): with pytest.raises( ValueError, @@ -534,13 +556,23 @@ def test_single_filter_validation(mocker): SingleFilterOperator.EQUAL_TO, "tag_value", ) + with pytest.raises( + ValueError, + match=( + "value can not be type {}. It has to be either an int " + + "or a float" + ).format(type("param_value")), + ): + SingleFilter( + SingleFilterBy.PARAM, + "param_key", + SingleFilterOperator.GREATER_THAN, + "param_value", + ) def test_compound_filter_validation(mocker): - with pytest.raises( - RunQueryFilterValidation, - match="Validation error serialising a None filter", - ): + with pytest.raises(ValidationError): filter = CompoundFilter( operator=CompoundFilterOperator.OR, conditions=[ From 6b74cffae182965ec1905753eceee6b2a8a8b5ef Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Thu, 23 May 2019 14:28:35 +0100 Subject: [PATCH 30/52] Serialize single filter value field as number when SingleFilterBy is PARAM --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 06331820..1f34ace3 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -372,7 +372,7 @@ def _serialize(self, value, attr, obj, **kwargs): field = fields.Str() elif obj.by == SingleFilterBy.PARAM: if isinstance(obj.value, int) or isinstance(obj.value, float): - field = fields.Float() + field = fields.Number() elif isinstance(obj.value, str): field = fields.Str() else: From 9de8cab152249f2f7dc348cc251c952adade67ea Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Thu, 23 May 2019 17:10:26 +0100 Subject: [PATCH 31/52] Rename variable --- tests/clients/test_experiment.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index acee5b57..28374c83 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -481,7 +481,7 @@ def test_experiment_run_data_schema_multiple(): @pytest.mark.parametrize( - "pfilter,pfilter_body", + "filter,filter_body", [ [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], @@ -509,10 +509,10 @@ def test_experiment_run_data_schema_multiple(): [MULTI_SORT, MULTI_SORT_BODY], ], ) -def test_query_runs_schema(mocker, pfilter, psort, pfilter_body, psort_body): - queryRunsObj = QueryRuns(pfilter, psort, PAGE) +def test_query_runs_schema(mocker, filter, psort, filter_body, psort_body): + queryRunsObj = QueryRuns(filter, psort, PAGE) expected_json = { - "filter": pfilter_body, + "filter": filter_body, "sort": psort_body, "page": PAGE_BODY, } From 1fb2461dca3bc5c62f86a8eb68281ee27da9c69d Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Thu, 23 May 2019 17:13:28 +0100 Subject: [PATCH 32/52] Rephrase docs for better clarity --- faculty/clients/experiment.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1f34ace3..1aaabe65 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -743,8 +743,9 @@ def query_runs( To filter runs of experiments with the given filter. By default, runs from all experiments are returned. sort: List[Sort], optional - Runs are ordered using sorting elements lexicographically. By - default, experiment runs are sorted by their startedAt value. + Runs are order using the conditions in sort. The relative + importance of each condition gradually decreases in order. + By default, experiment runs are sorted by their startedAt value. start : int, optional The (zero-indexed) starting point of runs to retrieve. limit : int, optional From 463ebadf4b8e6cd0b46212f3ed50526828c0c964 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 11:14:43 +0100 Subject: [PATCH 33/52] Implement list_runs in terms of query_runs --- faculty/clients/experiment.py | 46 ++++++++++++++------ tests/clients/test_experiment.py | 73 +++++++++++++++++++------------- 2 files changed, 76 insertions(+), 43 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1aaabe65..6b38c255 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -694,10 +694,11 @@ def list_runs( ------- ListExperimentRunsResponse """ - if lifecycle_stage is not None: - raise NotImplementedError("lifecycle_stage is not supported.") - query_params = [] + experiment_ids_filter = None + lifecycle_filter = None + filter = None + if experiment_ids is not None: if len(experiment_ids) == 0: return ListExperimentRunsResponse( @@ -706,18 +707,37 @@ def list_runs( start=0, size=0, previous=None, next=None ), ) - for experiment_id in experiment_ids: - query_params.append(("experimentId", experiment_id)) + experiment_id_filters = [ + SingleFilter( + SingleFilterBy.EXPERIMENT_ID, + None, + SingleFilterOperator.EQUAL_TO, + value, + ) + for value in experiment_ids + ] + experiment_ids_filter = CompoundFilter( + CompoundFilterOperator.OR, experiment_id_filters + ) + if lifecycle_stage is not None: + lifecycle_filter = SingleFilter( + SingleFilterBy.DELETED_AT, + None, + SingleFilterOperator.DEFINED, + lifecycle_stage == LifecycleStage.DELETED, + ) - if start is not None: - query_params.append(("start", start)) - if limit is not None: - query_params.append(("limit", limit)) + if experiment_ids_filter is not None and lifecycle_filter is not None: + filter = CompoundFilter( + CompoundFilterOperator.AND, + [experiment_ids_filter, lifecycle_filter], + ) + elif experiment_ids_filter is not None: + filter = experiment_ids_filter + elif lifecycle_filter is not None: + filter = lifecycle_filter - endpoint = "/project/{}/run".format(project_id) - return self._get( - endpoint, ListExperimentRunsResponseSchema(), params=query_params - ) + return self.query_runs(project_id, filter, None, start, limit) def query_runs( self, project_id, filter=None, sort=None, start=None, limit=None diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 28374c83..c8822a08 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -869,70 +869,83 @@ def test_restore_experiment_runs_response_schema_invalid(mocker): def test_experiment_client_list_runs_all(mocker): mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) + request_schema_mock = mocker.patch( + "faculty.clients.experiment.QueryRunsSchema" + ) + dump_mock = request_schema_mock.return_value.dump client = ExperimentClient(mocker.Mock()) list_result = client.list_runs(PROJECT_ID) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[], + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=dump_mock.return_value, ) def test_experiment_client_list_runs_experiments_filter(mocker): mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) + request_schema_mock = mocker.patch( + "faculty.clients.experiment.QueryRunsSchema" + ) + dump_mock = request_schema_mock.return_value.dump client = ExperimentClient(mocker.Mock()) list_result = client.list_runs(PROJECT_ID, experiment_ids=[123, 456]) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[("experimentId", 123), ("experimentId", 456)], - ) - - -def test_experiment_client_list_runs_experiments_filter_empty(mocker): - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[]) - assert list_result == ListExperimentRunsResponse( - runs=[], - pagination=Pagination(start=0, size=0, previous=None, next=None), + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=dump_mock.return_value, ) def test_experiment_client_list_runs_page(mocker): mocker.patch.object( - ExperimentClient, "_get", return_value=LIST_EXPERIMENT_RUNS_RESPONSE + ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) + request_schema_mock = mocker.patch( + "faculty.clients.experiment.QueryRunsSchema" + ) + dump_mock = request_schema_mock.return_value.dump client = ExperimentClient(mocker.Mock()) list_result = client.list_runs(PROJECT_ID, start=20, limit=10) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - schema_mock.assert_called_once_with() - ExperimentClient._get.assert_called_once_with( - "/project/{}/run".format(PROJECT_ID), - schema_mock.return_value, - params=[("start", 20), ("limit", 10)], + response_schema_mock.assert_called_once_with() + ExperimentClient._post.assert_called_once_with( + "/project/{}/run/query".format(PROJECT_ID), + response_schema_mock.return_value, + json=dump_mock.return_value, + ) + + +def test_experiment_client_list_runs_experiments_filter_empty(mocker): + client = ExperimentClient(mocker.Mock()) + list_result = client.list_runs(PROJECT_ID, experiment_ids=[]) + + assert list_result == ListExperimentRunsResponse( + runs=[], + pagination=Pagination(start=0, size=0, previous=None, next=None), ) From 16c6cd5aa45e3dce047db69e3a8e7348f87528ab Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 11:20:35 +0100 Subject: [PATCH 34/52] Update docs --- faculty/clients/experiment.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 6b38c255..6b1c3223 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -685,6 +685,9 @@ def list_runs( To filter runs of experiments with the given IDs only. If an empty list is passed, a result with an empty list of runs is returned. By default, runs from all experiments are returned. + lifecycle_stage: LifecycleStage, optional + To filter runs of experiments in a specific lifecycle stage only. + By default, runs in any stage are returned. start : int, optional The (zero-indexed) starting point of runs to retrieve. limit : int, optional From d00dbeeceb46fd57b2926f6e8ef7fda105097fc7 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 13:02:47 +0100 Subject: [PATCH 35/52] Update delete and restore runs to use filter objects --- faculty/clients/experiment.py | 44 +++++++++++++++++++------------- tests/clients/test_experiment.py | 44 +++++++++++++++++++++++--------- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 6b1c3223..08a4d472 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -891,15 +891,19 @@ def delete_runs(self, project_id, run_ids=None): deleted_run_ids=[], conflicted_run_ids=[] ) else: - payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_id)} - for run_id in run_ids - ], - } - } + run_id_filters = [ + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_id, + ) + for run_id in run_ids + ] + run_ids_filter = CompoundFilter( + CompoundFilterOperator.OR, run_id_filters + ) + payload = {"filter": run_ids_filter} return self._post( endpoint, DeleteExperimentRunsResponseSchema(), json=payload @@ -932,15 +936,19 @@ def restore_runs(self, project_id, run_ids=None): restored_run_ids=[], conflicted_run_ids=[] ) else: - payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_id)} - for run_id in run_ids - ], - } - } + run_id_filters = [ + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_id, + ) + for run_id in run_ids + ] + run_ids_filter = CompoundFilter( + CompoundFilterOperator.OR, run_id_filters + ) + payload = {"filter": run_ids_filter} return self._post( endpoint, RestoreExperimentRunsResponseSchema(), json=payload diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index c8822a08..929df4f2 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -1216,13 +1216,23 @@ def test_delete_runs(mocker): ) expected_payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_ids[0])}, - {"by": "runId", "operator": "eq", "value": str(run_ids[1])}, + "filter": CompoundFilter( + CompoundFilterOperator.OR, + [ + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_ids[0], + ), + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_ids[1], + ), ], - } + ) } ExperimentClient._post.assert_called_once_with( @@ -1276,13 +1286,23 @@ def test_restore_runs(mocker): ) expected_payload = { - "filter": { - "operator": "or", - "conditions": [ - {"by": "runId", "operator": "eq", "value": str(run_ids[0])}, - {"by": "runId", "operator": "eq", "value": str(run_ids[1])}, + "filter": CompoundFilter( + CompoundFilterOperator.OR, + [ + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_ids[0], + ), + SingleFilter( + SingleFilterBy.RUN_ID, + None, + SingleFilterOperator.EQUAL_TO, + run_ids[1], + ), ], - } + ) } ExperimentClient._post.assert_called_once_with( From 4110c91bd9e9438c312f275f8d97c71ef122444f Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 14:30:26 +0100 Subject: [PATCH 36/52] Unnest condition --- faculty/clients/experiment.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 08a4d472..905a4967 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -20,7 +20,7 @@ from faculty.clients.base import BaseClient, BaseSchema, Conflict -error_messages = { +_ERROR_MESSAGES = { "invalid_param_value": "Invalid param value", "invalid_filter_type": "Invalid filter type", "invalid_none_filter": "A none filter", @@ -120,14 +120,17 @@ def __new__(cls, by, key, operator, value): and key is not None ): raise ValueError("key must be none for filter type {}".format(by)) - elif by == SingleFilterBy.PARAM and operator.is_numeric_operator(): - if not (isinstance(value, float) or isinstance(value, int)): - raise ValueError( - ( - "value can not be type {}. It has to be either an int " - + "or a float" - ).format(type(value)) - ) + elif ( + by == SingleFilterBy.PARAM + and operator.is_numeric_operator() + and not (isinstance(value, float) or isinstance(value, int)) + ): + raise ValueError( + ( + "value can not be type {}. It has to be either an int " + + "or a float" + ).format(type(value)) + ) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -376,11 +379,11 @@ def _serialize(self, value, attr, obj, **kwargs): elif isinstance(obj.value, str): field = fields.Str() else: - raise ValidationError(error_messages["invalid_param_value"]) + raise ValidationError(_ERROR_MESSAGES["invalid_param_value"]) elif obj.by == SingleFilterBy.METRIC: field = fields.Float() else: - raise ValidationError(error_messages["invalid_filter_type"]) + raise ValidationError(_ERROR_MESSAGES["invalid_filter_type"]) return field._serialize(value, attr, obj, **kwargs) @@ -401,7 +404,7 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None and isinstance(obj, QueryRuns): return None elif value is None: - raise ValidationError(error_messages["invalid_none_filter"]) + raise ValidationError(_ERROR_MESSAGES["invalid_none_filter"]) if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) else: @@ -715,9 +718,9 @@ def list_runs( SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, - value, + experiment_id, ) - for value in experiment_ids + for experiment_id in experiment_ids ] experiment_ids_filter = CompoundFilter( CompoundFilterOperator.OR, experiment_id_filters From 7a83f8cbc1ffb3874da63536031abe82901df8a8 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Tue, 28 May 2019 16:20:38 +0100 Subject: [PATCH 37/52] Parametrise single filter construction tests --- faculty/clients/experiment.py | 2 +- tests/clients/test_experiment.py | 49 +++++++++++++++----------------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 905a4967..ad83779a 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -127,7 +127,7 @@ def __new__(cls, by, key, operator, value): ): raise ValueError( ( - "value can not be type {}. It has to be either an int " + "invalid type {}. Value has to be either an int " + "or a float" ).format(type(value)) ) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 929df4f2..e7a8e1e5 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -531,44 +531,41 @@ def test_single_filter_value_field_validation(mocker): SingleFilterSchema().dump(singleFilterObj) -def test_single_filter_validation(mocker): - with pytest.raises( - ValueError, - match="key must be none for filter type {}".format( - SingleFilterBy.PROJECT_ID - ), - ): - SingleFilter( +@pytest.mark.parametrize( + "by,key,op,value,err_msg", + [ + [ SingleFilterBy.PROJECT_ID, "invalid_key", SingleFilterOperator.EQUAL_TO, PROJECT_ID, - ) - with pytest.raises( - ValueError, - match="key must not be none for filter type {}".format( - SingleFilterBy.TAG - ), - ): - SingleFilter( + "key must be none for filter type {}".format( + SingleFilterBy.PROJECT_ID + ), + ], + [ SingleFilterBy.TAG, None, SingleFilterOperator.EQUAL_TO, "tag_value", - ) - with pytest.raises( - ValueError, - match=( - "value can not be type {}. It has to be either an int " - + "or a float" - ).format(type("param_value")), - ): - SingleFilter( + "key must not be none for filter type {}".format( + SingleFilterBy.TAG + ), + ], + [ SingleFilterBy.PARAM, "param_key", SingleFilterOperator.GREATER_THAN, "param_value", - ) + "invalid type {}. Value has to be either an int or a float".format( + type("param_value") + ), + ], + ], +) +def test_single_filter_validation(mocker, by, key, op, value, err_msg): + with pytest.raises(ValueError, match=err_msg): + SingleFilter(by, key, op, value) def test_compound_filter_validation(mocker): From 093d9775effcab68ca1a9630cba8869c6c466b71 Mon Sep 17 00:00:00 2001 From: Elias Benussi Date: Mon, 3 Jun 2019 11:17:24 +0100 Subject: [PATCH 38/52] Serialise as Boolean if operator is defined --- faculty/clients/experiment.py | 4 +++- tests/clients/test_experiment.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index ad83779a..a1d6cae5 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -365,7 +365,9 @@ def _deserialize(self, value, attr, data, **kwargs): return value def _serialize(self, value, attr, obj, **kwargs): - if obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID}: + if obj.operator == SingleFilterOperator.DEFINED: + field = fields.Boolean() + elif obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID}: field = fields.UUID() elif obj.by == SingleFilterBy.EXPERIMENT_ID: field = fields.Integer() diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index e7a8e1e5..bd6b169e 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -384,6 +384,16 @@ def test_experiment_run_data_schema_multiple(): "value": DELETED_AT_STRING_PYTHON, } +DELETED_FILTER = SingleFilter( + SingleFilterBy.DELETED_AT, None, SingleFilterOperator.DEFINED, True +) +DELETED_FILTER_BODY = { + "by": "deletedAt", + "key": None, + "operator": "defined", + "value": True, +} + TAG_FILTER = SingleFilter( SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" ) @@ -488,6 +498,7 @@ def test_experiment_run_data_schema_multiple(): [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_FILTER_BODY], [RUN_ID_FILTER, RUN_ID_FILTER_BODY], [DELETED_AT_FILTER, DELETED_AT_FILTER_BODY], + [DELETED_FILTER, DELETED_FILTER_BODY], [TAG_FILTER, TAG_FILTER_BODY], [PARAM_NUM_FILTER, PARAM_NUM_FILTER_BODY], [PARAM_TEXT_FILTER, PARAM_TEXT_FILTER_BODY], From 499f1dc17d506dfaa6990bb670e79bb6cc0142a4 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Mon, 3 Jun 2019 14:48:01 +0100 Subject: [PATCH 39/52] Add parameterised tests for single filter schema and sort schema --- tests/clients/test_experiment.py | 237 +++++++++++++++++++------------ 1 file changed, 144 insertions(+), 93 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index bd6b169e..15ed62ef 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -60,6 +60,7 @@ Sort, SortBy, SortOrder, + SortSchema, Tag, TagSchema, ) @@ -97,6 +98,7 @@ "deletedAt": DELETED_AT_STRING, } +RUN_ID = uuid4() RUN_STARTED_AT = datetime(2018, 3, 10, 11, 39, 12, 110000, tzinfo=UTC) RUN_STARTED_AT_NO_TIMEZONE = datetime(2018, 3, 10, 11, 39, 12, 110000) RUN_STARTED_AT_STRING_PYTHON = "2018-03-10T11:39:12.110000+00:00" @@ -351,49 +353,6 @@ def test_experiment_run_data_schema_multiple(): "value": str(PROJECT_ID), } -EXPERIMENT_ID_FILTER = SingleFilter( - SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, 1 -) -EXPERIMENT_ID_FILTER_BODY = { - "by": "experimentId", - "key": None, - "operator": "eq", - "value": 1, -} - -RUN_ID_FILTER = SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - EXPERIMENT_RUN_ID, -) -RUN_ID_FILTER_BODY = { - "by": "runId", - "key": None, - "operator": "eq", - "value": str(EXPERIMENT_RUN_ID), -} - -DELETED_AT_FILTER = SingleFilter( - SingleFilterBy.DELETED_AT, None, SingleFilterOperator.EQUAL_TO, DELETED_AT -) -DELETED_AT_FILTER_BODY = { - "by": "deletedAt", - "key": None, - "operator": "eq", - "value": DELETED_AT_STRING_PYTHON, -} - -DELETED_FILTER = SingleFilter( - SingleFilterBy.DELETED_AT, None, SingleFilterOperator.DEFINED, True -) -DELETED_FILTER_BODY = { - "by": "deletedAt", - "key": None, - "operator": "defined", - "value": True, -} - TAG_FILTER = SingleFilter( SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" ) @@ -404,19 +363,6 @@ def test_experiment_run_data_schema_multiple(): "value": "tag_value", } -PARAM_NUM_FILTER = SingleFilter( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, - 1.0, -) -PARAM_NUM_FILTER_BODY = { - "by": "param", - "key": "param_key", - "operator": "ge", - "value": 1.0, -} - PARAM_TEXT_FILTER = SingleFilter( SingleFilterBy.PARAM, "param_key", @@ -430,16 +376,6 @@ def test_experiment_run_data_schema_multiple(): "value": "param_value", } -METRIC_FILTER = SingleFilter( - SingleFilterBy.METRIC, "metric_key", SingleFilterOperator.EQUAL_TO, 1.0 -) -METRIC_FILTER_BODY = { - "by": "metric", - "key": "metric_key", - "operator": "eq", - "value": 1.0, -} - AND_FILTER = CompoundFilter( operator=CompoundFilterOperator.AND, conditions=[TAG_FILTER] ) @@ -465,21 +401,9 @@ def test_experiment_run_data_schema_multiple(): RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] -STARTED_AT_SORT = [Sort(SortBy.STARTED_AT, None, SortOrder.DESC)] -STARTED_AT_SORT_BODY = [{"by": "startedAt", "key": None, "order": "desc"}] - DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] -PARAM_SORT = [Sort(SortBy.PARAM, "param_key", SortOrder.DESC)] -PARAM_SORT_BODY = [{"by": "param", "key": "param_key", "order": "desc"}] - -TAG_SORT = [Sort(SortBy.TAG, "tag_key", SortOrder.DESC)] -TAG_SORT_BODY = [{"by": "tag", "key": "tag_key", "order": "desc"}] - -METRIC_SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.DESC)] -METRIC_SORT_BODY = [{"by": "metric", "key": "metric_key", "order": "desc"}] - MULTI_SORT = [ Sort(SortBy.PARAM, "param_key", SortOrder.ASC), Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), @@ -491,32 +415,19 @@ def test_experiment_run_data_schema_multiple(): @pytest.mark.parametrize( - "filter,filter_body", + "filter, filter_body", [ [None, None], [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], - [EXPERIMENT_ID_FILTER, EXPERIMENT_ID_FILTER_BODY], - [RUN_ID_FILTER, RUN_ID_FILTER_BODY], - [DELETED_AT_FILTER, DELETED_AT_FILTER_BODY], - [DELETED_FILTER, DELETED_FILTER_BODY], - [TAG_FILTER, TAG_FILTER_BODY], - [PARAM_NUM_FILTER, PARAM_NUM_FILTER_BODY], - [PARAM_TEXT_FILTER, PARAM_TEXT_FILTER_BODY], - [METRIC_FILTER, METRIC_FILTER_BODY], [AND_FILTER, AND_FILTER_BODY], [OR_FILTER, OR_FILTER_BODY], ], ) @pytest.mark.parametrize( - "psort,psort_body", + "psort, psort_body", [ [None, None], - [RUN_NUMBER_SORT, RUN_NUMBER_SORT_BODY], - [STARTED_AT_SORT, STARTED_AT_SORT_BODY], [DURATION_SORT, DURATION_SORT_BODY], - [PARAM_SORT, PARAM_SORT_BODY], - [TAG_SORT, TAG_SORT_BODY], - [METRIC_SORT, METRIC_SORT_BODY], [MULTI_SORT, MULTI_SORT_BODY], ], ) @@ -531,6 +442,146 @@ def test_query_runs_schema(mocker, filter, psort, filter_body, psort_body): assert data == expected_json +@pytest.mark.parametrize( + "by,key,value,by_body,value_body", + [ + [ + SingleFilterBy.PROJECT_ID, + None, + PROJECT_ID, + "projectId", + str(PROJECT_ID), + ], + [ + SingleFilterBy.EXPERIMENT_ID, + None, + EXPERIMENT_ID, + "experimentId", + EXPERIMENT_ID, + ], + [SingleFilterBy.RUN_ID, None, RUN_ID, "runId", str(RUN_ID)], + [ + SingleFilterBy.DELETED_AT, + None, + DELETED_AT, + "deletedAt", + DELETED_AT_STRING_PYTHON, + ], + [SingleFilterBy.TAG, "tag-key", "tag-value", "tag", "tag-value"], + [ + SingleFilterBy.PARAM, + "param-key", + "param-text-value", + "param", + "param-text-value", + ], + [SingleFilterBy.PARAM, "param-key", 1, "param", 1], + [SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0], + ], +) +@pytest.mark.parametrize( + "operator,operator_body", + [ + [SingleFilterOperator.EQUAL_TO, "eq"], + [SingleFilterOperator.NOT_EQUAL_TO, "ne"], + ], +) +def test_single_filter_schema_equality_operators( + mocker, by, key, value, by_body, value_body, operator, operator_body +): + singleFilterObj = SingleFilter(by, key, operator, value) + expected_json = { + "by": by_body, + "key": key, + "operator": operator_body, + "value": value_body, + } + data = SingleFilterSchema().dump(singleFilterObj) + assert data == expected_json + + +@pytest.mark.parametrize( + "by,key,value,by_body,value_body", + [ + [ + SingleFilterBy.DELETED_AT, + None, + DELETED_AT, + "deletedAt", + DELETED_AT_STRING_PYTHON, + ], + [SingleFilterBy.PARAM, "param-key", 1, "param", 1], + [SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0], + ], +) +@pytest.mark.parametrize( + "operator,operator_body", + [ + [SingleFilterOperator.LESS_THAN, "lt"], + [SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, "le"], + [SingleFilterOperator.GREATER_THAN, "gt"], + [SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, "ge"], + ], +) +def test_single_filter_schema_relational_operators( + mocker, by, key, value, by_body, value_body, operator, operator_body +): + singleFilterObj = SingleFilter(by, key, operator, value) + expected_json = { + "by": by_body, + "key": key, + "operator": operator_body, + "value": value_body, + } + data = SingleFilterSchema().dump(singleFilterObj) + assert data == expected_json + + +@pytest.mark.parametrize( + "by,key,by_body", + [ + [SingleFilterBy.PROJECT_ID, None, "projectId"], + [SingleFilterBy.EXPERIMENT_ID, None, "experimentId"], + [SingleFilterBy.RUN_ID, None, "runId"], + [SingleFilterBy.DELETED_AT, None, "deletedAt"], + [SingleFilterBy.TAG, "tag-key", "tag"], + [SingleFilterBy.PARAM, "param-key", "param"], + [SingleFilterBy.METRIC, "metric-key", "metric"], + ], +) +def test_single_filter_schema_defined_operator(mocker, by, key, by_body): + singleFilterObj = SingleFilter(by, key, SingleFilterOperator.DEFINED, True) + expected_json = { + "by": by_body, + "key": key, + "operator": "defined", + "value": True, + } + data = SingleFilterSchema().dump(singleFilterObj) + assert data == expected_json + + +@pytest.mark.parametrize( + "by,key,by_body", + [ + [SortBy.STARTED_AT, None, "startedAt"], + [SortBy.RUN_NUMBER, None, "runNumber"], + [SortBy.DURATION, None, "duration"], + [SortBy.TAG, "tag-key", "tag"], + [SortBy.PARAM, "param-key", "param"], + [SortBy.METRIC, "metric-key", "metric"], + ], +) +@pytest.mark.parametrize( + "order, order_body", [[SortOrder.ASC, "asc"], [SortOrder.DESC, "desc"]] +) +def test_sort_schema(mocker, by, key, by_body, order, order_body): + sortObj = Sort(by, key, order) + expected_json = {"by": by_body, "key": key, "order": order_body} + data = SortSchema().dump(sortObj) + assert data == expected_json + + def test_single_filter_value_field_validation(mocker): with pytest.raises(ValidationError): singleFilterObj = SingleFilter( From 2192deaa6146f5e3f104d144f3e351378c436872 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Fri, 7 Jun 2019 13:20:19 +0100 Subject: [PATCH 40/52] A bit of refactoring of the filter serialisation code * Use marshmallow's `field.fail()` mechanism * Map fitler types to field types with dict * Factor out param filter type polymorphism --- faculty/clients/experiment.py | 74 ++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index a1d6cae5..71d20c32 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -20,12 +20,6 @@ from faculty.clients.base import BaseClient, BaseSchema, Conflict -_ERROR_MESSAGES = { - "invalid_param_value": "Invalid param value", - "invalid_filter_type": "Invalid filter type", - "invalid_none_filter": "A none filter", -} - class ExperimentNameConflict(Exception): def __init__(self, name): @@ -356,37 +350,61 @@ def make_restore_runs_response(self, data): return RestoreExperimentRunsResponse(**data) +class ParamValueField(fields.Field): + """Field that passes through strings or numbers.""" + + default_error_messages = { + "unsupported_type": "Param values must be of type str, int or float." + } + + def _determine_field(self, value): + if isinstance(value, str): + return fields.String() + elif isinstance(value, int) or isinstance(value, float): + return fields.Number() + else: + self.fail("unsupported_type") + + def _deserialize(self, value, attr, data, **kwargs): + field = self._determine_field(value) + return field._deserialize(value, attr, data, **kwargs) + + def _serialize(self, value, attr, obj, **kwargs): + field = self._determine_field(value) + return field._serialize(value, attr, obj, **kwargs) + + class SingleFilterValueField(fields.Field): """ Field that serialises/deserialises a run filter. """ + default_error_messages = { + "invalid_filter_operator": "Invalid filter operator." + } + + FILTER_BY_FIELD_MAPPING = { + SingleFilterBy.PROJECT_ID: fields.UUID, + SingleFilterBy.RUN_ID: fields.UUID, + SingleFilterBy.EXPERIMENT_ID: fields.Integer, + SingleFilterBy.DELETED_AT: fields.DateTime, + SingleFilterBy.TAG: fields.String, + SingleFilterBy.PARAM: ParamValueField, + SingleFilterBy.METRIC: fields.Number, + } + def _deserialize(self, value, attr, data, **kwargs): return value def _serialize(self, value, attr, obj, **kwargs): if obj.operator == SingleFilterOperator.DEFINED: - field = fields.Boolean() - elif obj.by in {SingleFilterBy.PROJECT_ID, SingleFilterBy.RUN_ID}: - field = fields.UUID() - elif obj.by == SingleFilterBy.EXPERIMENT_ID: - field = fields.Integer() - elif obj.by == SingleFilterBy.DELETED_AT: - field = fields.DateTime() - elif obj.by == SingleFilterBy.TAG: - field = fields.Str() - elif obj.by == SingleFilterBy.PARAM: - if isinstance(obj.value, int) or isinstance(obj.value, float): - field = fields.Number() - elif isinstance(obj.value, str): - field = fields.Str() - else: - raise ValidationError(_ERROR_MESSAGES["invalid_param_value"]) - elif obj.by == SingleFilterBy.METRIC: - field = fields.Float() + field_cls = fields.Boolean else: - raise ValidationError(_ERROR_MESSAGES["invalid_filter_type"]) - return field._serialize(value, attr, obj, **kwargs) + try: + field_cls = self.FILTER_BY_FIELD_MAPPING[obj.by] + except KeyError: + self.fail("invalid_filter_operator") + return field_cls()._serialize(value, attr, obj, **kwargs) class FilterField(fields.Field): @@ -394,6 +412,8 @@ class FilterField(fields.Field): Field that serialises/deserialises a run filter. """ + default_error_messages = {"invalid_none_filter": "A none filter"} + def _deserialize(self, value, attr, data, **kwargs): if value is None: return None @@ -406,7 +426,7 @@ def _serialize(self, value, attr, obj, **kwargs): if value is None and isinstance(obj, QueryRuns): return None elif value is None: - raise ValidationError(_ERROR_MESSAGES["invalid_none_filter"]) + self.fail("invalid_none_filter") if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) else: From b838919e4b02fe890f1a59ecc0587eaf0c725a87 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Fri, 7 Jun 2019 13:45:48 +0100 Subject: [PATCH 41/52] Use new OptionalField wrapper instead of coupling FilterField implementation to context in which it is used --- faculty/clients/experiment.py | 40 ++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 71d20c32..07c4839f 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -407,30 +407,50 @@ def _serialize(self, value, attr, obj, **kwargs): return field_cls()._serialize(value, attr, obj, **kwargs) +class OptionalField(fields.Field): + """Wrap another field, passing through Nones.""" + + def __init__(self, nested, *args, **kwargs): + self.nested = nested + super().__init__(*args, **kwargs) + + def _deserialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._deserialize(value, *args, **kwargs) + + def _serialize(self, value, *args, **kwargs): + if value is None: + return None + else: + return self.nested._serialize(value, *args, **kwargs) + + class FilterField(fields.Field): """ Field that serialises/deserialises a run filter. """ - default_error_messages = {"invalid_none_filter": "A none filter"} + default_error_messages = { + "invalid_filter_type": "Unsupported filter type." + } def _deserialize(self, value, attr, data, **kwargs): - if value is None: - return None - elif isinstance(value, SingleFilter): + # TODO: fix this - the isinstance check won't work as the filter hasn't + # yet been deserialised + if isinstance(value, SingleFilter): return SingleFilterSchema().load(value) else: return CompoundFilterSchema().load(value) def _serialize(self, value, attr, obj, **kwargs): - if value is None and isinstance(obj, QueryRuns): - return None - elif value is None: - self.fail("invalid_none_filter") if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) - else: + elif isinstance(value, CompoundFilter): return CompoundFilterSchema().dump(value) + else: + self.fail("invalid_filter_type") class SingleFilterSchema(BaseSchema): @@ -452,7 +472,7 @@ class SortSchema(BaseSchema): class QueryRunsSchema(BaseSchema): - filter = FilterField(required=True) + filter = OptionalField(FilterField()) sort = fields.List(fields.Nested(SortSchema)) page = fields.Nested(PageSchema, missing=None) From d547ba2df344a58404587ab7f3864642786e8cdf Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Fri, 7 Jun 2019 13:49:01 +0100 Subject: [PATCH 42/52] Remove unneeded _serlalize methods These methods were unused, untested, and in some cases broken. --- faculty/clients/experiment.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 07c4839f..a8f19f59 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -350,27 +350,20 @@ def make_restore_runs_response(self, data): return RestoreExperimentRunsResponse(**data) -class ParamValueField(fields.Field): +class ParamFilterValueField(fields.Field): """Field that passes through strings or numbers.""" default_error_messages = { "unsupported_type": "Param values must be of type str, int or float." } - def _determine_field(self, value): + def _serialize(self, value, attr, obj, **kwargs): if isinstance(value, str): - return fields.String() + field = fields.String() elif isinstance(value, int) or isinstance(value, float): - return fields.Number() + field = fields.Number() else: self.fail("unsupported_type") - - def _deserialize(self, value, attr, data, **kwargs): - field = self._determine_field(value) - return field._deserialize(value, attr, data, **kwargs) - - def _serialize(self, value, attr, obj, **kwargs): - field = self._determine_field(value) return field._serialize(value, attr, obj, **kwargs) @@ -389,13 +382,10 @@ class SingleFilterValueField(fields.Field): SingleFilterBy.EXPERIMENT_ID: fields.Integer, SingleFilterBy.DELETED_AT: fields.DateTime, SingleFilterBy.TAG: fields.String, - SingleFilterBy.PARAM: ParamValueField, + SingleFilterBy.PARAM: ParamFilterValueField, SingleFilterBy.METRIC: fields.Number, } - def _deserialize(self, value, attr, data, **kwargs): - return value - def _serialize(self, value, attr, obj, **kwargs): if obj.operator == SingleFilterOperator.DEFINED: field_cls = fields.Boolean @@ -436,14 +426,6 @@ class FilterField(fields.Field): "invalid_filter_type": "Unsupported filter type." } - def _deserialize(self, value, attr, data, **kwargs): - # TODO: fix this - the isinstance check won't work as the filter hasn't - # yet been deserialised - if isinstance(value, SingleFilter): - return SingleFilterSchema().load(value) - else: - return CompoundFilterSchema().load(value) - def _serialize(self, value, attr, obj, **kwargs): if isinstance(value, SingleFilter): return SingleFilterSchema().dump(value) From cdb81e2184ac074d0dde50f147a1b32b6d0089bb Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Fri, 7 Jun 2019 14:18:17 +0100 Subject: [PATCH 43/52] Remove parameter skip_if in CompoundFilterSchema --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index a8f19f59..e9605fef 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -444,7 +444,7 @@ class SingleFilterSchema(BaseSchema): class CompoundFilterSchema(BaseSchema): operator = EnumField(CompoundFilterOperator, by_value=True, required=True) - conditions = fields.List(FilterField(skip_if=None)) + conditions = fields.List(FilterField()) class SortSchema(BaseSchema): From f2cd19a864c61c6b6b0345f039d10c3219932068 Mon Sep 17 00:00:00 2001 From: Hailey Fong Date: Fri, 7 Jun 2019 14:40:00 +0100 Subject: [PATCH 44/52] Remove unused import --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index e9605fef..6acface8 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -15,7 +15,7 @@ from collections import namedtuple from enum import Enum -from marshmallow import fields, post_load, ValidationError +from marshmallow import fields, post_load from marshmallow_enum import EnumField from faculty.clients.base import BaseClient, BaseSchema, Conflict From d0d4dc41f885e3e7335ea04d9895a102be7973aa Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Fri, 7 Jun 2019 14:26:09 +0100 Subject: [PATCH 45/52] Do some test refactoring --- tests/clients/test_experiment.py | 243 ++++++++++++++++--------------- 1 file changed, 127 insertions(+), 116 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index 15ed62ef..0c6beb1c 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -417,25 +417,25 @@ def test_experiment_run_data_schema_multiple(): @pytest.mark.parametrize( "filter, filter_body", [ - [None, None], - [PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY], - [AND_FILTER, AND_FILTER_BODY], - [OR_FILTER, OR_FILTER_BODY], + (None, None), + (PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY), + (AND_FILTER, AND_FILTER_BODY), + (OR_FILTER, OR_FILTER_BODY), ], ) @pytest.mark.parametrize( - "psort, psort_body", + "sort, sort_body", [ - [None, None], - [DURATION_SORT, DURATION_SORT_BODY], - [MULTI_SORT, MULTI_SORT_BODY], + (None, None), + (DURATION_SORT, DURATION_SORT_BODY), + (MULTI_SORT, MULTI_SORT_BODY), ], ) -def test_query_runs_schema(mocker, filter, psort, filter_body, psort_body): - queryRunsObj = QueryRuns(filter, psort, PAGE) +def test_query_runs_schema(mocker, filter, sort, filter_body, sort_body): + queryRunsObj = QueryRuns(filter, sort, PAGE) expected_json = { "filter": filter_body, - "sort": psort_body, + "sort": sort_body, "page": PAGE_BODY, } data = QueryRunsSchema().dump(queryRunsObj) @@ -443,160 +443,154 @@ def test_query_runs_schema(mocker, filter, psort, filter_body, psort_body): @pytest.mark.parametrize( - "by,key,value,by_body,value_body", + "by, key, value, by_body, value_body", [ - [ + ( SingleFilterBy.PROJECT_ID, None, PROJECT_ID, "projectId", str(PROJECT_ID), - ], - [ + ), + ( SingleFilterBy.EXPERIMENT_ID, None, EXPERIMENT_ID, "experimentId", EXPERIMENT_ID, - ], - [SingleFilterBy.RUN_ID, None, RUN_ID, "runId", str(RUN_ID)], - [ + ), + (SingleFilterBy.RUN_ID, None, RUN_ID, "runId", str(RUN_ID)), + ( SingleFilterBy.DELETED_AT, None, DELETED_AT, "deletedAt", DELETED_AT_STRING_PYTHON, - ], - [SingleFilterBy.TAG, "tag-key", "tag-value", "tag", "tag-value"], - [ + ), + (SingleFilterBy.TAG, "tag-key", "tag-value", "tag", "tag-value"), + ( SingleFilterBy.PARAM, "param-key", "param-text-value", "param", "param-text-value", - ], - [SingleFilterBy.PARAM, "param-key", 1, "param", 1], - [SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0], + ), + (SingleFilterBy.PARAM, "param-key", 1, "param", 1), + (SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0), ], ) @pytest.mark.parametrize( - "operator,operator_body", + "operator, operator_body", [ - [SingleFilterOperator.EQUAL_TO, "eq"], - [SingleFilterOperator.NOT_EQUAL_TO, "ne"], + (SingleFilterOperator.EQUAL_TO, "eq"), + (SingleFilterOperator.NOT_EQUAL_TO, "ne"), ], ) def test_single_filter_schema_equality_operators( - mocker, by, key, value, by_body, value_body, operator, operator_body + by, key, value, by_body, value_body, operator, operator_body ): - singleFilterObj = SingleFilter(by, key, operator, value) + filter = SingleFilter(by, key, operator, value) expected_json = { "by": by_body, "key": key, "operator": operator_body, "value": value_body, } - data = SingleFilterSchema().dump(singleFilterObj) + data = SingleFilterSchema().dump(filter) assert data == expected_json @pytest.mark.parametrize( - "by,key,value,by_body,value_body", + "by, key, value, by_body, value_body", [ - [ + ( SingleFilterBy.DELETED_AT, None, DELETED_AT, "deletedAt", DELETED_AT_STRING_PYTHON, - ], - [SingleFilterBy.PARAM, "param-key", 1, "param", 1], - [SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0], + ), + (SingleFilterBy.PARAM, "param-key", 1, "param", 1), + (SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0), ], ) @pytest.mark.parametrize( - "operator,operator_body", + "operator, operator_body", [ - [SingleFilterOperator.LESS_THAN, "lt"], - [SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, "le"], - [SingleFilterOperator.GREATER_THAN, "gt"], - [SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, "ge"], + (SingleFilterOperator.LESS_THAN, "lt"), + (SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, "le"), + (SingleFilterOperator.GREATER_THAN, "gt"), + (SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, "ge"), ], ) def test_single_filter_schema_relational_operators( - mocker, by, key, value, by_body, value_body, operator, operator_body + by, key, value, by_body, value_body, operator, operator_body ): - singleFilterObj = SingleFilter(by, key, operator, value) + filter = SingleFilter(by, key, operator, value) expected_json = { "by": by_body, "key": key, "operator": operator_body, "value": value_body, } - data = SingleFilterSchema().dump(singleFilterObj) + data = SingleFilterSchema().dump(filter) assert data == expected_json @pytest.mark.parametrize( - "by,key,by_body", + "by, key, by_body", [ - [SingleFilterBy.PROJECT_ID, None, "projectId"], - [SingleFilterBy.EXPERIMENT_ID, None, "experimentId"], - [SingleFilterBy.RUN_ID, None, "runId"], - [SingleFilterBy.DELETED_AT, None, "deletedAt"], - [SingleFilterBy.TAG, "tag-key", "tag"], - [SingleFilterBy.PARAM, "param-key", "param"], - [SingleFilterBy.METRIC, "metric-key", "metric"], + (SingleFilterBy.PROJECT_ID, None, "projectId"), + (SingleFilterBy.EXPERIMENT_ID, None, "experimentId"), + (SingleFilterBy.RUN_ID, None, "runId"), + (SingleFilterBy.DELETED_AT, None, "deletedAt"), + (SingleFilterBy.TAG, "tag-key", "tag"), + (SingleFilterBy.PARAM, "param-key", "param"), + (SingleFilterBy.METRIC, "metric-key", "metric"), ], ) -def test_single_filter_schema_defined_operator(mocker, by, key, by_body): - singleFilterObj = SingleFilter(by, key, SingleFilterOperator.DEFINED, True) +def test_single_filter_schema_defined_operator(by, key, by_body): + filter = SingleFilter(by, key, SingleFilterOperator.DEFINED, True) expected_json = { "by": by_body, "key": key, "operator": "defined", "value": True, } - data = SingleFilterSchema().dump(singleFilterObj) + data = SingleFilterSchema().dump(filter) assert data == expected_json @pytest.mark.parametrize( - "by,key,by_body", + "by, value, message", [ - [SortBy.STARTED_AT, None, "startedAt"], - [SortBy.RUN_NUMBER, None, "runNumber"], - [SortBy.DURATION, None, "duration"], - [SortBy.TAG, "tag-key", "tag"], - [SortBy.PARAM, "param-key", "param"], - [SortBy.METRIC, "metric-key", "metric"], + (SingleFilterBy.PROJECT_ID, "invalid", "Not a valid UUID."), + (SingleFilterBy.EXPERIMENT_ID, "string", "Not a valid integer."), + (SingleFilterBy.RUN_ID, "invalid", "Not a valid UUID."), + ( + SingleFilterBy.DELETED_AT, + "invalid", + "cannot be formatted as a datetime", + ), + (SingleFilterBy.METRIC, "invalid", "Not a valid number."), + (SingleFilterBy.PARAM, None, "must be of type str, int or float"), ], ) -@pytest.mark.parametrize( - "order, order_body", [[SortOrder.ASC, "asc"], [SortOrder.DESC, "desc"]] -) -def test_sort_schema(mocker, by, key, by_body, order, order_body): - sortObj = Sort(by, key, order) - expected_json = {"by": by_body, "key": key, "order": order_body} - data = SortSchema().dump(sortObj) - assert data == expected_json - - -def test_single_filter_value_field_validation(mocker): - with pytest.raises(ValidationError): - singleFilterObj = SingleFilter( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.GREATER_THAN, - True, - ) - SingleFilterSchema().dump(singleFilterObj) +def test_single_filter_invalid_value(by, value, message): + filter = SingleFilter( + by, + "key" if by.needs_key() else None, + SingleFilterOperator.EQUAL_TO, + value, + ) + with pytest.raises(ValidationError, match=message): + SingleFilterSchema().dump(filter) @pytest.mark.parametrize( - "by,key,op,value,err_msg", + "by, key, operator, value, message", [ - [ + ( SingleFilterBy.PROJECT_ID, "invalid_key", SingleFilterOperator.EQUAL_TO, @@ -604,8 +598,8 @@ def test_single_filter_value_field_validation(mocker): "key must be none for filter type {}".format( SingleFilterBy.PROJECT_ID ), - ], - [ + ), + ( SingleFilterBy.TAG, None, SingleFilterOperator.EQUAL_TO, @@ -613,8 +607,8 @@ def test_single_filter_value_field_validation(mocker): "key must not be none for filter type {}".format( SingleFilterBy.TAG ), - ], - [ + ), + ( SingleFilterBy.PARAM, "param_key", SingleFilterOperator.GREATER_THAN, @@ -622,43 +616,60 @@ def test_single_filter_value_field_validation(mocker): "invalid type {}. Value has to be either an int or a float".format( type("param_value") ), - ], + ), ], ) -def test_single_filter_validation(mocker, by, key, op, value, err_msg): - with pytest.raises(ValueError, match=err_msg): - SingleFilter(by, key, op, value) - - -def test_compound_filter_validation(mocker): +def test_single_filter_validation(by, key, operator, value, message): + with pytest.raises(ValueError, match=message): + SingleFilter(by, key, operator, value) + + +def test_compound_filter_validation(): + filter = CompoundFilter( + operator=CompoundFilterOperator.OR, + conditions=[ + SingleFilter( + SingleFilterBy.TAG, + "tag_key", + SingleFilterOperator.EQUAL_TO, + "tag_value", + ), + None, + ], + ) + run_query = QueryRuns(filter, None, None) with pytest.raises(ValidationError): - filter = CompoundFilter( - operator=CompoundFilterOperator.OR, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", - ), - None, - ], - ) + QueryRunsSchema().dump(run_query) + - queryRunsObj = QueryRuns(filter, None, None) - QueryRunsSchema().dump(queryRunsObj) +@pytest.mark.parametrize( + "by, key, by_body", + [ + (SortBy.STARTED_AT, None, "startedAt"), + (SortBy.RUN_NUMBER, None, "runNumber"), + (SortBy.DURATION, None, "duration"), + (SortBy.TAG, "tag-key", "tag"), + (SortBy.PARAM, "param-key", "param"), + (SortBy.METRIC, "metric-key", "metric"), + ], +) +@pytest.mark.parametrize( + "order, order_body", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema(by, key, by_body, order, order_body): + sort = Sort(by, key, order) + expected_json = {"by": by_body, "key": key, "order": order_body} + data = SortSchema().dump(sort) + assert data == expected_json -def test_sort_validation(mocker): - with pytest.raises( - ValueError, - match="key must be none for sort type {}".format(SortBy.RUN_NUMBER), - ): +def test_sort_validate_no_key(): + with pytest.raises(ValueError, match="key must be none"): Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) - with pytest.raises( - ValueError, - match="key must not be none for sort type {}".format(SortBy.TAG), - ): + + +def test_sort_validate_has_key(): + with pytest.raises(ValueError, match="key must not be none"): Sort(SortBy.TAG, None, SortOrder.ASC) From 0e7277e93d53fb36c80894d76c03dee65fe99422 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Tue, 11 Jun 2019 12:06:01 +0100 Subject: [PATCH 46/52] Minor cleanups --- faculty/clients/experiment.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 80859b2f..1e177e40 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -117,13 +117,12 @@ def __new__(cls, by, key, operator, value): raise ValueError("key must be none for filter type {}".format(by)) elif ( by == SingleFilterBy.PARAM - and operator.is_numeric_operator() + and operator.is_numeric() and not (isinstance(value, float) or isinstance(value, int)) ): raise ValueError( ( - "invalid type {}. Value has to be either an int " - + "or a float" + "invalid type {}. Value has to be either an int or a float" ).format(type(value)) ) return super(SingleFilter, cls).__new__(cls, by, key, operator, value) @@ -141,7 +140,7 @@ class SingleFilterOperator(Enum): GREATER_THAN = "gt" GREATER_THAN_OR_EQUAL_TO = "ge" - def is_numeric_operator(self): + def is_numeric(self): return self in { SingleFilterOperator.LESS_THAN, SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, From 2ab75fbe242ec5ca7004efe7d1190f2c8c72f415 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Tue, 11 Jun 2019 13:04:52 +0100 Subject: [PATCH 47/52] WIP: Different filter implementations --- faculty/clients/experiment.py | 229 ++++++++++++++++++---------------- setup.py | 1 + 2 files changed, 125 insertions(+), 105 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 1e177e40..84850361 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -15,8 +15,9 @@ from collections import namedtuple from enum import Enum -from marshmallow import fields, post_load +from marshmallow import fields, post_load, post_dump from marshmallow_enum import EnumField +from marshmallow_oneofschema import OneOfSchema from faculty.clients.base import BaseClient, BaseSchema, Conflict @@ -100,38 +101,14 @@ class ExperimentRunStatus(Enum): "RestoreExperimentRunsResponse", ["restored_run_ids", "conflicted_run_ids"] ) -_SingleFilter = namedtuple("_SingleFilter", ["by", "key", "operator", "value"]) - -class SingleFilter(_SingleFilter): - def __new__(cls, by, key, operator, value): - if isinstance(by, SingleFilterBy) and by.needs_key() and key is None: - raise ValueError( - "key must not be none for filter type {}".format(by) - ) - elif ( - isinstance(by, SingleFilterBy) - and not by.needs_key() - and key is not None - ): - raise ValueError("key must be none for filter type {}".format(by)) - elif ( - by == SingleFilterBy.PARAM - and operator.is_numeric() - and not (isinstance(value, float) or isinstance(value, int)) - ): - raise ValueError( - ( - "invalid type {}. Value has to be either an int or a float" - ).format(type(value)) - ) - return super(SingleFilter, cls).__new__(cls, by, key, operator, value) - - -CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) +class DiscreteOperator(Enum): + DEFINED = "defined" + EQUAL_TO = "eq" + NOT_EQUAL_TO = "ne" -class SingleFilterOperator(Enum): +class ContinuousOperator(Enum): DEFINED = "defined" EQUAL_TO = "eq" NOT_EQUAL_TO = "ne" @@ -140,35 +117,21 @@ class SingleFilterOperator(Enum): GREATER_THAN = "gt" GREATER_THAN_OR_EQUAL_TO = "ge" - def is_numeric(self): - return self in { - SingleFilterOperator.LESS_THAN, - SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, - SingleFilterOperator.GREATER_THAN, - SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, - } +class LogicalOperator(Enum): + AND = "and" + OR = "or" -class SingleFilterBy(Enum): - PROJECT_ID = "projectId" - EXPERIMENT_ID = "experimentId" - RUN_ID = "runId" - DELETED_AT = "deletedAt" - TAG = "tag" - PARAM = "param" - METRIC = "metric" - - def needs_key(self): - return self in { - SingleFilterBy.TAG, - SingleFilterBy.PARAM, - SingleFilterBy.METRIC, - } +ProjectIdFilter = namedtuple("ProjectIdFilter", ["operator", "value"]) +ExperimentIdFilter = namedtuple("ExperimentIdFilter", ["operator", "value"]) +RunIdFilter = namedtuple("RunIdFilter", ["operator", "value"]) +DeletedAtFilter = namedtuple("DeletedAtFilter", ["operator", "value"]) +TagFilter = namedtuple("TagFilter", ["key", "operator", "value"]) +ParamFilter = namedtuple("ParamFilter", ["key", "operator", "value"]) +MetricFilter = namedtuple("MetricFilter", ["key", "operator", "value"]) -class CompoundFilterOperator(Enum): - AND = "and" - OR = "or" +CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) _Sort = namedtuple("_Sort", ["by", "key", "order"]) @@ -375,34 +338,34 @@ def _serialize(self, value, attr, obj, **kwargs): return field._serialize(value, attr, obj, **kwargs) -class SingleFilterValueField(fields.Field): - """ - Field that serialises/deserialises a run filter. - """ - - default_error_messages = { - "invalid_filter_operator": "Invalid filter operator." - } - - FILTER_BY_FIELD_MAPPING = { - SingleFilterBy.PROJECT_ID: fields.UUID, - SingleFilterBy.RUN_ID: fields.UUID, - SingleFilterBy.EXPERIMENT_ID: fields.Integer, - SingleFilterBy.DELETED_AT: fields.DateTime, - SingleFilterBy.TAG: fields.String, - SingleFilterBy.PARAM: ParamFilterValueField, - SingleFilterBy.METRIC: fields.Number, - } - - def _serialize(self, value, attr, obj, **kwargs): - if obj.operator == SingleFilterOperator.DEFINED: - field_cls = fields.Boolean - else: - try: - field_cls = self.FILTER_BY_FIELD_MAPPING[obj.by] - except KeyError: - self.fail("invalid_filter_operator") - return field_cls()._serialize(value, attr, obj, **kwargs) +# class SingleFilterValueField(fields.Field): +# """ +# Field that serialises/deserialises a run filter. +# """ + +# default_error_messages = { +# "invalid_filter_operator": "Invalid filter operator." +# } + +# FILTER_BY_FIELD_MAPPING = { +# SingleFilterBy.PROJECT_ID: fields.UUID, +# SingleFilterBy.RUN_ID: fields.UUID, +# SingleFilterBy.EXPERIMENT_ID: fields.Integer, +# SingleFilterBy.DELETED_AT: fields.DateTime, +# SingleFilterBy.TAG: fields.String, +# SingleFilterBy.PARAM: ParamFilterValueField, +# SingleFilterBy.METRIC: fields.Number, +# } + +# def _serialize(self, value, attr, obj, **kwargs): +# if obj.operator == SingleFilterOperator.DEFINED: +# field_cls = fields.Boolean +# else: +# try: +# field_cls = self.FILTER_BY_FIELD_MAPPING[obj.by] +# except KeyError: +# self.fail("invalid_filter_operator") +# return field_cls()._serialize(value, attr, obj, **kwargs) class OptionalField(fields.Field): @@ -425,34 +388,90 @@ def _serialize(self, value, *args, **kwargs): return self.nested._serialize(value, *args, **kwargs) -class FilterField(fields.Field): - """ - Field that serialises/deserialises a run filter. - """ +# class FilterField(fields.Field): +# """ +# Field that serialises/deserialises a run filter. +# """ - default_error_messages = { - "invalid_filter_type": "Unsupported filter type." - } +# default_error_messages = { +# "invalid_filter_type": "Unsupported filter type." +# } - def _serialize(self, value, attr, obj, **kwargs): - if isinstance(value, SingleFilter): - return SingleFilterSchema().dump(value) - elif isinstance(value, CompoundFilter): - return CompoundFilterSchema().dump(value) - else: - self.fail("invalid_filter_type") +# def _serialize(self, value, attr, obj, **kwargs): +# if isinstance(value, SingleFilter): +# return SingleFilterSchema().dump(value) +# elif isinstance(value, CompoundFilter): +# return CompoundFilterSchema().dump(value) +# else: +# self.fail("invalid_filter_type") -class SingleFilterSchema(BaseSchema): - by = EnumField(SingleFilterBy, by_value=True, required=True) - key = fields.String() - operator = EnumField(SingleFilterOperator, by_value=True, required=True) - value = SingleFilterValueField(required=True) +class ProjectIdFilterSchema(BaseSchema): + operator = EnumField(DiscreteOperator, by_value=True, required=True) + value = fields.UUID(required=True) + by = fields.Constant("projectId", dump_only=True) + + +class ExperimentIdFilterSchema(BaseSchema): + operator = EnumField(DiscreteOperator, by_value=True, required=True) + value = fields.Integer(required=True) + by = fields.Constant("experimentId", dump_only=True) + + +class RunIdFilterSchema(BaseSchema): + operator = EnumField(DiscreteOperator, by_value=True, required=True) + value = fields.UUID(required=True) + by = fields.Constant("runId", dump_only=True) + + +class DeletedAtFilterSchema(BaseSchema): + operator = EnumField(ContinuousOperator, by_value=True, required=True) + value = fields.DateTime(required=True) + by = fields.Constant("deletedAt", dump_only=True) + + +class TagFilterSchema(BaseSchema): + key = fields.String(required=True) + operator = EnumField(DiscreteOperator, by_value=True, required=True) + value = fields.String(required=True) + by = fields.Constant("tag", dump_only=True) + + +class ParamFilterSchema(BaseSchema): + key = fields.String(required=True) + operator = EnumField(ContinuousOperator, by_value=True, required=True) + value = ParamFilterValueField(required=True) + by = fields.Constant("param", dump_only=True) + + +class MetricFilterSchema(BaseSchema): + key = fields.String(required=True) + operator = EnumField(ContinuousOperator, by_value=True, required=True) + value = fields.Float(required=True) + by = fields.Constant("metric", dump_only=True) class CompoundFilterSchema(BaseSchema): - operator = EnumField(CompoundFilterOperator, by_value=True, required=True) - conditions = fields.List(FilterField()) + operator = EnumField(LogicalOperator, by_value=True, required=True) + conditions = fields.List(fields.Nested("FilterSchema")) + + +class FilterSchema(OneOfSchema): + type_schemas = { + "ProjectIdFilter": ProjectIdFilterSchema, + "ExperimentIdFilter": ExperimentIdFilterSchema, + "RunIdFilter": RunIdFilterSchema, + "DeletedAtFilter": DeletedAtFilterSchema, + "TagFilter": TagFilterSchema, + "ParamFilter": ParamFilterSchema, + "MetricFilter": MetricFilterSchema, + "CompoundFilter": CompoundFilterSchema, + } + + def dump(self, *args, **kwargs): + data = super(FilterSchema, self).dump(*args, **kwargs) + # Remove the type field added by marshmallow-oneofschema + return {k: v for k, v in data.items() if k != "type"} class SortSchema(BaseSchema): @@ -462,7 +481,7 @@ class SortSchema(BaseSchema): class QueryRunsSchema(BaseSchema): - filter = OptionalField(FilterField()) + filter = OptionalField(fields.Nested(FilterSchema)) sort = fields.List(fields.Nested(SortSchema)) page = fields.Nested(PageSchema, missing=None) diff --git a/setup.py b/setup.py index a8772681..1e427358 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ # compatible version of python-dateutil is available "marshmallow[reco]>=3.0.0rc3", "marshmallow_enum", + "marshmallow-oneofschema", "boto3", "botocore", ], From 51b448cb03db127a20627dfcadc3bb6c13e1e71c Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Tue, 11 Jun 2019 16:53:08 +0100 Subject: [PATCH 48/52] Complete filter and sort object refactor and tidy tests --- faculty/clients/experiment.py | 271 ++++++----- tests/clients/test_experiment.py | 752 ++++++++++++++----------------- 2 files changed, 472 insertions(+), 551 deletions(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index 84850361..cd243803 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -15,7 +15,7 @@ from collections import namedtuple from enum import Enum -from marshmallow import fields, post_load, post_dump +from marshmallow import fields, post_load, pre_dump, ValidationError from marshmallow_enum import EnumField from marshmallow_oneofschema import OneOfSchema @@ -102,13 +102,7 @@ class ExperimentRunStatus(Enum): ) -class DiscreteOperator(Enum): - DEFINED = "defined" - EQUAL_TO = "eq" - NOT_EQUAL_TO = "ne" - - -class ContinuousOperator(Enum): +class ComparisonOperator(Enum): DEFINED = "defined" EQUAL_TO = "eq" NOT_EQUAL_TO = "ne" @@ -133,31 +127,12 @@ class LogicalOperator(Enum): CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) - -_Sort = namedtuple("_Sort", ["by", "key", "order"]) - - -class Sort(_Sort): - def __new__(cls, by, key, order): - if by.needs_key() and key is None: - raise ValueError( - "key must not be none for sort type {}".format(by) - ) - elif not by.needs_key() and key is not None: - raise ValueError("key must be none for sort type {}".format(by)) - return super(Sort, cls).__new__(cls, by, key, order) - - -class SortBy(Enum): - STARTED_AT = "startedAt" - RUN_NUMBER = "runNumber" - DURATION = "duration" - TAG = "tag" - PARAM = "param" - METRIC = "metric" - - def needs_key(self): - return self in {SortBy.TAG, SortBy.PARAM, SortBy.METRIC} +StartedAtSort = namedtuple("StartedAtSort", ["order"]) +RunNumberSort = namedtuple("RunNumberSort", ["order"]) +DurationSort = namedtuple("DurationSort", ["order"]) +TagSort = namedtuple("TagSort", ["key", "order"]) +ParamSort = namedtuple("ParamSort", ["key", "order"]) +MetricSort = namedtuple("MetricSort", ["key", "order"]) class SortOrder(Enum): @@ -165,7 +140,7 @@ class SortOrder(Enum): DESC = "desc" -QueryRuns = namedtuple("QueryRuns", ["filter", "sort", "page"]) +RunQuery = namedtuple("RunQuery", ["filter", "sort", "page"]) class PageSchema(BaseSchema): @@ -338,36 +313,6 @@ def _serialize(self, value, attr, obj, **kwargs): return field._serialize(value, attr, obj, **kwargs) -# class SingleFilterValueField(fields.Field): -# """ -# Field that serialises/deserialises a run filter. -# """ - -# default_error_messages = { -# "invalid_filter_operator": "Invalid filter operator." -# } - -# FILTER_BY_FIELD_MAPPING = { -# SingleFilterBy.PROJECT_ID: fields.UUID, -# SingleFilterBy.RUN_ID: fields.UUID, -# SingleFilterBy.EXPERIMENT_ID: fields.Integer, -# SingleFilterBy.DELETED_AT: fields.DateTime, -# SingleFilterBy.TAG: fields.String, -# SingleFilterBy.PARAM: ParamFilterValueField, -# SingleFilterBy.METRIC: fields.Number, -# } - -# def _serialize(self, value, attr, obj, **kwargs): -# if obj.operator == SingleFilterOperator.DEFINED: -# field_cls = fields.Boolean -# else: -# try: -# field_cls = self.FILTER_BY_FIELD_MAPPING[obj.by] -# except KeyError: -# self.fail("invalid_filter_operator") -# return field_cls()._serialize(value, attr, obj, **kwargs) - - class OptionalField(fields.Field): """Wrap another field, passing through Nones.""" @@ -388,75 +333,112 @@ def _serialize(self, value, *args, **kwargs): return self.nested._serialize(value, *args, **kwargs) -# class FilterField(fields.Field): -# """ -# Field that serialises/deserialises a run filter. -# """ +class FilterValueField(fields.Field): + def __init__(self, other_field_type, *args, **kwargs): + self.other_field_type = other_field_type + super(FilterValueField, self).__init__(*args, **kwargs) -# default_error_messages = { -# "invalid_filter_type": "Unsupported filter type." -# } + def _serialize(self, value, attr, obj, **kwargs): + if obj.operator == ComparisonOperator.DEFINED: + field_cls = fields.Boolean + else: + field_cls = self.other_field_type + return field_cls()._serialize(value, attr, obj, **kwargs) -# def _serialize(self, value, attr, obj, **kwargs): -# if isinstance(value, SingleFilter): -# return SingleFilterSchema().dump(value) -# elif isinstance(value, CompoundFilter): -# return CompoundFilterSchema().dump(value) -# else: -# self.fail("invalid_filter_type") + +def _validate_discrete(operator): + if operator not in { + ComparisonOperator.DEFINED, + ComparisonOperator.EQUAL_TO, + ComparisonOperator.NOT_EQUAL_TO, + }: + raise ValidationError({"operator": "Not a discrete operator."}) class ProjectIdFilterSchema(BaseSchema): - operator = EnumField(DiscreteOperator, by_value=True, required=True) - value = fields.UUID(required=True) + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.UUID) by = fields.Constant("projectId", dump_only=True) + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + class ExperimentIdFilterSchema(BaseSchema): - operator = EnumField(DiscreteOperator, by_value=True, required=True) - value = fields.Integer(required=True) + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.Integer) by = fields.Constant("experimentId", dump_only=True) + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + class RunIdFilterSchema(BaseSchema): - operator = EnumField(DiscreteOperator, by_value=True, required=True) - value = fields.UUID(required=True) + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.UUID) by = fields.Constant("runId", dump_only=True) + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + class DeletedAtFilterSchema(BaseSchema): - operator = EnumField(ContinuousOperator, by_value=True, required=True) - value = fields.DateTime(required=True) + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.DateTime) by = fields.Constant("deletedAt", dump_only=True) class TagFilterSchema(BaseSchema): - key = fields.String(required=True) - operator = EnumField(DiscreteOperator, by_value=True, required=True) - value = fields.String(required=True) + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.String) by = fields.Constant("tag", dump_only=True) + @pre_dump + def check_operator(self, obj): + _validate_discrete(obj.operator) + return obj + class ParamFilterSchema(BaseSchema): - key = fields.String(required=True) - operator = EnumField(ContinuousOperator, by_value=True, required=True) - value = ParamFilterValueField(required=True) + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(ParamFilterValueField) by = fields.Constant("param", dump_only=True) + @pre_dump + def check_operator(self, obj): + if isinstance(obj.value, str): + _validate_discrete(obj.operator) + return obj + class MetricFilterSchema(BaseSchema): - key = fields.String(required=True) - operator = EnumField(ContinuousOperator, by_value=True, required=True) - value = fields.Float(required=True) + key = fields.String() + operator = EnumField(ComparisonOperator, by_value=True) + value = FilterValueField(fields.Float) by = fields.Constant("metric", dump_only=True) class CompoundFilterSchema(BaseSchema): - operator = EnumField(LogicalOperator, by_value=True, required=True) + operator = EnumField(LogicalOperator, by_value=True) conditions = fields.List(fields.Nested("FilterSchema")) -class FilterSchema(OneOfSchema): +class OneOfSchemaWithoutType(OneOfSchema): + def dump(self, *args, **kwargs): + data = super(OneOfSchemaWithoutType, self).dump(*args, **kwargs) + # Remove the type field added by marshmallow-oneofschema + return {k: v for k, v in data.items() if k != "type"} + + +class FilterSchema(OneOfSchemaWithoutType): type_schemas = { "ProjectIdFilter": ProjectIdFilterSchema, "ExperimentIdFilter": ExperimentIdFilterSchema, @@ -468,19 +450,52 @@ class FilterSchema(OneOfSchema): "CompoundFilter": CompoundFilterSchema, } - def dump(self, *args, **kwargs): - data = super(FilterSchema, self).dump(*args, **kwargs) - # Remove the type field added by marshmallow-oneofschema - return {k: v for k, v in data.items() if k != "type"} + +class StartedAtSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("startedAt", dump_only=True) + + +class RunNumberSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("runNumber", dump_only=True) + + +class DurationSortSchema(BaseSchema): + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("duration", dump_only=True) + + +class TagSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("tag", dump_only=True) + + +class ParamSortSchema(BaseSchema): + key = fields.String() + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("param", dump_only=True) -class SortSchema(BaseSchema): - by = EnumField(SortBy, by_value=True, required=True) +class MetricSortSchema(BaseSchema): key = fields.String() - order = EnumField(SortOrder, by_value=True, required=True) + order = EnumField(SortOrder, by_value=True) + by = fields.Constant("metric", dump_only=True) -class QueryRunsSchema(BaseSchema): +class SortSchema(OneOfSchemaWithoutType): + type_schemas = { + "StartedAtSort": StartedAtSortSchema, + "RunNumberSort": RunNumberSortSchema, + "DurationSort": DurationSortSchema, + "TagSort": TagSortSchema, + "ParamSort": ParamSortSchema, + "MetricSort": MetricSortSchema, + } + + +class RunQuerySchema(BaseSchema): filter = OptionalField(fields.Nested(FilterSchema)) sort = fields.List(fields.Nested(SortSchema)) page = fields.Nested(PageSchema, missing=None) @@ -785,29 +800,21 @@ def list_runs( ), ) experiment_id_filters = [ - SingleFilter( - SingleFilterBy.EXPERIMENT_ID, - None, - SingleFilterOperator.EQUAL_TO, - experiment_id, - ) + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, experiment_id) for experiment_id in experiment_ids ] experiment_ids_filter = CompoundFilter( - CompoundFilterOperator.OR, experiment_id_filters + LogicalOperator.OR, experiment_id_filters ) if lifecycle_stage is not None: - lifecycle_filter = SingleFilter( - SingleFilterBy.DELETED_AT, - None, - SingleFilterOperator.DEFINED, + lifecycle_filter = DeletedAtFilter( + ComparisonOperator.DEFINED, lifecycle_stage == LifecycleStage.DELETED, ) if experiment_ids_filter is not None and lifecycle_filter is not None: filter = CompoundFilter( - CompoundFilterOperator.AND, - [experiment_ids_filter, lifecycle_filter], + LogicalOperator.AND, [experiment_ids_filter, lifecycle_filter] ) elif experiment_ids_filter is not None: filter = experiment_ids_filter @@ -856,7 +863,7 @@ def query_runs( page = None if start is not None and limit is not None: page = Page(start, limit) - payload = QueryRunsSchema().dump(QueryRuns(filter, sort, page)) + payload = RunQuerySchema().dump(RunQuery(filter, sort, page)) return self._post( endpoint, ListExperimentRunsResponseSchema(), json=payload ) @@ -975,18 +982,11 @@ def delete_runs(self, project_id, run_ids=None): ) else: run_id_filters = [ - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_id, - ) + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) for run_id in run_ids ] - run_ids_filter = CompoundFilter( - CompoundFilterOperator.OR, run_id_filters - ) - payload = {"filter": run_ids_filter} + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} return self._post( endpoint, DeleteExperimentRunsResponseSchema(), json=payload @@ -1020,18 +1020,11 @@ def restore_runs(self, project_id, run_ids=None): ) else: run_id_filters = [ - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_id, - ) + RunIdFilter(ComparisonOperator.EQUAL_TO, run_id) for run_id in run_ids ] - run_ids_filter = CompoundFilter( - CompoundFilterOperator.OR, run_id_filters - ) - payload = {"filter": run_ids_filter} + filter = CompoundFilter(LogicalOperator.OR, run_id_filters) + payload = {"filter": FilterSchema().dump(filter)} return self._post( endpoint, RestoreExperimentRunsResponseSchema(), json=payload diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index d37578f4..d6db90b0 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -22,49 +22,58 @@ from faculty.clients.base import Conflict from faculty.clients.experiment import ( - CreateRunSchema, + ComparisonOperator, CompoundFilter, - CompoundFilterOperator, + CreateRunSchema, DeleteExperimentRunsResponse, DeleteExperimentRunsResponseSchema, + DeletedAtFilter, + DurationSort, Experiment, ExperimentClient, - ExperimentNameConflict, ExperimentDeleted, + ExperimentIdFilter, + ExperimentNameConflict, ExperimentRun, ExperimentRunDataSchema, ExperimentRunSchema, ExperimentRunStatus, ExperimentSchema, + FilterSchema, LifecycleStage, ListExperimentRunsResponse, ListExperimentRunsResponseSchema, + LogicalOperator, Metric, MetricDataPoint, - MetricSchema, + MetricFilter, MetricHistory, MetricHistorySchema, + MetricSchema, + MetricSort, Page, PageSchema, Pagination, PaginationSchema, Param, ParamConflict, + ParamFilter, ParamSchema, - QueryRuns, - QueryRunsSchema, + ParamSort, + ProjectIdFilter, RestoreExperimentRunsResponse, RestoreExperimentRunsResponseSchema, - SingleFilter, - SingleFilterBy, - SingleFilterOperator, - SingleFilterSchema, - Sort, - SortBy, + RunIdFilter, + RunNumberSort, + RunQuery, + RunQuerySchema, SortOrder, SortSchema, + StartedAtSort, Tag, + TagFilter, TagSchema, + TagSort, ) PROJECT_ID = uuid4() @@ -348,334 +357,315 @@ def test_experiment_run_data_schema_multiple(): assert data == {"tags": [TAG_BODY, OTHER_TAG_BODY]} -PROJECT_ID_FILTER = SingleFilter( - SingleFilterBy.PROJECT_ID, None, SingleFilterOperator.EQUAL_TO, PROJECT_ID -) +PROJECT_ID_FILTER = ProjectIdFilter(ComparisonOperator.EQUAL_TO, PROJECT_ID) PROJECT_ID_FILTER_BODY = { "by": "projectId", - "key": None, "operator": "eq", "value": str(PROJECT_ID), } -TAG_FILTER = SingleFilter( - SingleFilterBy.TAG, "tag_key", SingleFilterOperator.EQUAL_TO, "tag_value" -) +TAG_FILTER = TagFilter("tag-key", ComparisonOperator.EQUAL_TO, "tag-value") TAG_FILTER_BODY = { "by": "tag", - "key": "tag_key", + "key": "tag-key", "operator": "eq", - "value": "tag_value", + "value": "tag-value", } -PARAM_TEXT_FILTER = SingleFilter( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.EQUAL_TO, - "param_value", -) -PARAM_TEXT_FILTER_BODY = { - "by": "param", - "key": "param_key", - "operator": "eq", - "value": "param_value", -} -AND_FILTER = CompoundFilter( - operator=CompoundFilterOperator.AND, conditions=[TAG_FILTER] -) -AND_FILTER_BODY = {"operator": "and", "conditions": [TAG_FILTER_BODY]} - -OR_FILTER = CompoundFilter( - operator=CompoundFilterOperator.OR, - conditions=[ - TAG_FILTER, - CompoundFilter( - operator=CompoundFilterOperator.AND, conditions=[PARAM_TEXT_FILTER] - ), - ], -) -OR_FILTER_BODY = { - "operator": "or", - "conditions": [ - TAG_FILTER_BODY, - {"operator": "and", "conditions": [PARAM_TEXT_FILTER_BODY]}, - ], -} +DEFINED_TEST_CASES = [ + (ComparisonOperator.DEFINED, False, "defined", False), + (ComparisonOperator.DEFINED, True, "defined", True), + (ComparisonOperator.DEFINED, 0, "defined", False), + (ComparisonOperator.DEFINED, 1, "defined", True), +] -RUN_NUMBER_SORT = [Sort(SortBy.RUN_NUMBER, None, SortOrder.ASC)] -RUN_NUMBER_SORT_BODY = [{"by": "runNumber", "key": None, "order": "asc"}] -DURATION_SORT = [Sort(SortBy.DURATION, None, SortOrder.DESC)] -DURATION_SORT_BODY = [{"by": "duration", "key": None, "order": "desc"}] +def discrete_test_cases(value, expected): + return DEFINED_TEST_CASES + [ + (ComparisonOperator.EQUAL_TO, value, "eq", expected), + (ComparisonOperator.NOT_EQUAL_TO, value, "ne", expected), + ] -MULTI_SORT = [ - Sort(SortBy.PARAM, "param_key", SortOrder.ASC), - Sort(SortBy.RUN_NUMBER, None, SortOrder.DESC), -] -MULTI_SORT_BODY = [ - {"by": "param", "key": "param_key", "order": "asc"}, - {"by": "runNumber", "key": None, "order": "desc"}, -] + +def continuous_test_cases(value, expected): + return discrete_test_cases(value, expected) + [ + (ComparisonOperator.GREATER_THAN, value, "gt", expected), + (ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, value, "ge", expected), + (ComparisonOperator.LESS_THAN, value, "lt", expected), + (ComparisonOperator.LESS_THAN_OR_EQUAL_TO, value, "le", expected), + ] @pytest.mark.parametrize( - "filter, filter_body", - [ - (None, None), - (PROJECT_ID_FILTER, PROJECT_ID_FILTER_BODY), - (AND_FILTER, AND_FILTER_BODY), - (OR_FILTER, OR_FILTER_BODY), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases(PROJECT_ID, str(PROJECT_ID)), ) +def test_filter_schema_project_id( + operator, value, expected_operator, expected_value +): + filter = ProjectIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "projectId", + "operator": expected_operator, + "value": expected_value, + } + + @pytest.mark.parametrize( - "sort, sort_body", - [ - (None, None), - (DURATION_SORT, DURATION_SORT_BODY), - (MULTI_SORT, MULTI_SORT_BODY), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases(EXPERIMENT_ID, EXPERIMENT_ID), ) -def test_query_runs_schema(mocker, filter, sort, filter_body, sort_body): - queryRunsObj = QueryRuns(filter, sort, PAGE) - expected_json = { - "filter": filter_body, - "sort": sort_body, - "page": PAGE_BODY, +def test_filter_schema_experiment_id( + operator, value, expected_operator, expected_value +): + filter = ExperimentIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "experimentId", + "operator": expected_operator, + "value": expected_value, } - data = QueryRunsSchema().dump(queryRunsObj) - assert data == expected_json @pytest.mark.parametrize( - "by, key, value, by_body, value_body", - [ - ( - SingleFilterBy.PROJECT_ID, - None, - PROJECT_ID, - "projectId", - str(PROJECT_ID), - ), - ( - SingleFilterBy.EXPERIMENT_ID, - None, - EXPERIMENT_ID, - "experimentId", - EXPERIMENT_ID, - ), - (SingleFilterBy.RUN_ID, None, RUN_ID, "runId", str(RUN_ID)), - ( - SingleFilterBy.DELETED_AT, - None, - DELETED_AT, - "deletedAt", - DELETED_AT_STRING_PYTHON, - ), - (SingleFilterBy.TAG, "tag-key", "tag-value", "tag", "tag-value"), - ( - SingleFilterBy.PARAM, - "param-key", - "param-text-value", - "param", - "param-text-value", - ), - (SingleFilterBy.PARAM, "param-key", 1, "param", 1), - (SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases(RUN_ID, str(RUN_ID)), ) +def test_filter_schema_run_id( + operator, value, expected_operator, expected_value +): + filter = RunIdFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "runId", + "operator": expected_operator, + "value": expected_value, + } + + @pytest.mark.parametrize( - "operator, operator_body", - [ - (SingleFilterOperator.EQUAL_TO, "eq"), - (SingleFilterOperator.NOT_EQUAL_TO, "ne"), - ], + "operator, value, expected_operator, expected_value", + continuous_test_cases(DELETED_AT, DELETED_AT_STRING_PYTHON), ) -def test_single_filter_schema_equality_operators( - by, key, value, by_body, value_body, operator, operator_body +def test_filter_schema_deleted_at( + operator, value, expected_operator, expected_value ): - filter = SingleFilter(by, key, operator, value) - expected_json = { - "by": by_body, - "key": key, - "operator": operator_body, - "value": value_body, + filter = DeletedAtFilter(operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "deletedAt", + "operator": expected_operator, + "value": expected_value, } - data = SingleFilterSchema().dump(filter) - assert data == expected_json @pytest.mark.parametrize( - "by, key, value, by_body, value_body", - [ - ( - SingleFilterBy.DELETED_AT, - None, - DELETED_AT, - "deletedAt", - DELETED_AT_STRING_PYTHON, - ), - (SingleFilterBy.PARAM, "param-key", 1, "param", 1), - (SingleFilterBy.METRIC, "metric-key", 2.0, "metric", 2.0), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases("tag-value", "tag-value"), ) +def test_filter_schema_tag(operator, value, expected_operator, expected_value): + filter = TagFilter("tag-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "tag", + "key": "tag-key", + "operator": expected_operator, + "value": expected_value, + } + + @pytest.mark.parametrize( - "operator, operator_body", - [ - (SingleFilterOperator.LESS_THAN, "lt"), - (SingleFilterOperator.LESS_THAN_OR_EQUAL_TO, "le"), - (SingleFilterOperator.GREATER_THAN, "gt"), - (SingleFilterOperator.GREATER_THAN_OR_EQUAL_TO, "ge"), - ], + "operator, value, expected_operator, expected_value", + discrete_test_cases("param-value", "param-value") + + continuous_test_cases(123.2, 123.2), ) -def test_single_filter_schema_relational_operators( - by, key, value, by_body, value_body, operator, operator_body +def test_filter_schema_param( + operator, value, expected_operator, expected_value ): - filter = SingleFilter(by, key, operator, value) - expected_json = { - "by": by_body, - "key": key, - "operator": operator_body, - "value": value_body, + filter = ParamFilter("param-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "param", + "key": "param-key", + "operator": expected_operator, + "value": expected_value, } - data = SingleFilterSchema().dump(filter) - assert data == expected_json @pytest.mark.parametrize( - "by, key, by_body", - [ - (SingleFilterBy.PROJECT_ID, None, "projectId"), - (SingleFilterBy.EXPERIMENT_ID, None, "experimentId"), - (SingleFilterBy.RUN_ID, None, "runId"), - (SingleFilterBy.DELETED_AT, None, "deletedAt"), - (SingleFilterBy.TAG, "tag-key", "tag"), - (SingleFilterBy.PARAM, "param-key", "param"), - (SingleFilterBy.METRIC, "metric-key", "metric"), - ], + "operator, value, expected_operator, expected_value", + continuous_test_cases(45.6, 45.6), ) -def test_single_filter_schema_defined_operator(by, key, by_body): - filter = SingleFilter(by, key, SingleFilterOperator.DEFINED, True) - expected_json = { - "by": by_body, - "key": key, - "operator": "defined", - "value": True, +def test_filter_schema_metric( + operator, value, expected_operator, expected_value +): + filter = MetricFilter("metric-key", operator, value) + data = FilterSchema().dump(filter) + assert data == { + "by": "metric", + "key": "metric-key", + "operator": expected_operator, + "value": expected_value, } - data = SingleFilterSchema().dump(filter) - assert data == expected_json @pytest.mark.parametrize( - "by, value, message", + "filter_type", + [ProjectIdFilter, ExperimentIdFilter, RunIdFilter, DeletedAtFilter], +) +def test_filter_schema_invalid_value_no_key(filter_type): + filter = filter_type(ComparisonOperator.EQUAL_TO, "invalid") + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", [(ParamFilter, None), (MetricFilter, "invalid")] +) +def test_filter_schema_invalid_value_with_key(filter_type, value): + filter = filter_type("key", ComparisonOperator.EQUAL_TO, value) + with pytest.raises(ValidationError): + FilterSchema().dump(filter) + + +@pytest.mark.parametrize( + "filter_type, value", [ - (SingleFilterBy.PROJECT_ID, "invalid", "Not a valid UUID."), - (SingleFilterBy.EXPERIMENT_ID, "string", "Not a valid integer."), - (SingleFilterBy.RUN_ID, "invalid", "Not a valid UUID."), - ( - SingleFilterBy.DELETED_AT, - "invalid", - "cannot be formatted as a datetime", - ), - (SingleFilterBy.METRIC, "invalid", "Not a valid number."), - (SingleFilterBy.PARAM, None, "must be of type str, int or float"), + (ProjectIdFilter, PROJECT_ID), + (ExperimentIdFilter, EXPERIMENT_ID), + (RunIdFilter, RUN_ID), ], ) -def test_single_filter_invalid_value(by, value, message): - filter = SingleFilter( - by, - "key" if by.needs_key() else None, - SingleFilterOperator.EQUAL_TO, - value, - ) - with pytest.raises(ValidationError, match=message): - SingleFilterSchema().dump(filter) +@pytest.mark.parametrize( + "operator", + [ + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, + ], +) +def test_filter_schema_invalid_operator_no_key(filter_type, value, operator): + filter = filter_type(operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) @pytest.mark.parametrize( - "by, key, operator, value, message", + "filter_type, value", + [(TagFilter, "tag-value"), (ParamFilter, "param-string-value")], +) +@pytest.mark.parametrize( + "operator", [ - ( - SingleFilterBy.PROJECT_ID, - "invalid_key", - SingleFilterOperator.EQUAL_TO, - PROJECT_ID, - "key must be none for filter type {}".format( - SingleFilterBy.PROJECT_ID - ), - ), - ( - SingleFilterBy.TAG, - None, - SingleFilterOperator.EQUAL_TO, - "tag_value", - "key must not be none for filter type {}".format( - SingleFilterBy.TAG - ), - ), - ( - SingleFilterBy.PARAM, - "param_key", - SingleFilterOperator.GREATER_THAN, - "param_value", - "invalid type {}. Value has to be either an int or a float".format( - type("param_value") - ), - ), + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL_TO, ], ) -def test_single_filter_validation(by, key, operator, value, message): - with pytest.raises(ValueError, match=message): - SingleFilter(by, key, operator, value) +def test_filter_schema_invalid_operator_with_key(filter_type, value, operator): + filter = filter_type("key", operator, value) + with pytest.raises(ValidationError, match="Not a discrete operator"): + FilterSchema().dump(filter) -def test_compound_filter_validation(): +@pytest.mark.parametrize( + "operator, expected_operator", + [(LogicalOperator.AND, "and"), (LogicalOperator.OR, "or")], +) +def test_filter_schema_compound(operator, expected_operator): + filter = CompoundFilter(operator, [PROJECT_ID_FILTER, TAG_FILTER]) + data = FilterSchema().dump(filter) + assert data == { + "operator": expected_operator, + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + } + + +def test_filter_schema_nested(): filter = CompoundFilter( - operator=CompoundFilterOperator.OR, - conditions=[ - SingleFilter( - SingleFilterBy.TAG, - "tag_key", - SingleFilterOperator.EQUAL_TO, - "tag_value", + LogicalOperator.AND, + [ + CompoundFilter( + LogicalOperator.AND, [PROJECT_ID_FILTER, TAG_FILTER] + ), + CompoundFilter( + LogicalOperator.OR, [TAG_FILTER, PROJECT_ID_FILTER] ), - None, ], ) - run_query = QueryRuns(filter, None, None) - with pytest.raises(ValidationError): - QueryRunsSchema().dump(run_query) + data = FilterSchema().dump(filter) + assert data == { + "operator": "and", + "conditions": [ + { + "operator": "and", + "conditions": [PROJECT_ID_FILTER_BODY, TAG_FILTER_BODY], + }, + { + "operator": "or", + "conditions": [TAG_FILTER_BODY, PROJECT_ID_FILTER_BODY], + }, + ], + } @pytest.mark.parametrize( - "by, key, by_body", + "sort_type, by", [ - (SortBy.STARTED_AT, None, "startedAt"), - (SortBy.RUN_NUMBER, None, "runNumber"), - (SortBy.DURATION, None, "duration"), - (SortBy.TAG, "tag-key", "tag"), - (SortBy.PARAM, "param-key", "param"), - (SortBy.METRIC, "metric-key", "metric"), + (StartedAtSort, "startedAt"), + (RunNumberSort, "runNumber"), + (DurationSort, "duration"), ], ) @pytest.mark.parametrize( - "order, order_body", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] ) -def test_sort_schema(by, key, by_body, order, order_body): - sort = Sort(by, key, order) - expected_json = {"by": by_body, "key": key, "order": order_body} +def test_sort_schema_no_tag(sort_type, by, order, expected_order): + sort = sort_type(order) data = SortSchema().dump(sort) - assert data == expected_json + assert data == {"by": by, "order": expected_order} -def test_sort_validate_no_key(): - with pytest.raises(ValueError, match="key must be none"): - Sort(SortBy.RUN_NUMBER, "invalid_number", SortOrder.ASC) +@pytest.mark.parametrize( + "sort_type, by", + [(TagSort, "tag"), (ParamSort, "param"), (MetricSort, "metric")], +) +@pytest.mark.parametrize( + "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] +) +def test_sort_schema_with_tag(sort_type, by, order, expected_order): + sort = sort_type("sort-key", order) + data = SortSchema().dump(sort) + assert data == {"by": by, "key": "sort-key", "order": expected_order} + +def test_run_query_schema(mocker): + mocker.patch.object(FilterSchema, "dump") + mocker.patch.object(SortSchema, "dump") + mocker.patch.object(PageSchema, "dump") -def test_sort_validate_has_key(): - with pytest.raises(ValueError, match="key must not be none"): - Sort(SortBy.TAG, None, SortOrder.ASC) + filter = mocker.Mock() + sorts = [mocker.Mock(), mocker.Mock()] + page = mocker.Mock() + + run_query = RunQuery(filter, sorts, page) + data = RunQuerySchema().dump(run_query) + + assert data == { + "filter": FilterSchema.dump.return_value, + "sort": [SortSchema.dump.return_value, SortSchema.dump.return_value], + "page": PageSchema.dump.return_value, + } + + +def test_run_query_schema_defaults(): + run_query = RunQuery(None, None, None) + data = RunQuerySchema().dump(run_query) + assert data == {"filter": None, "sort": None, "page": None} @pytest.mark.parametrize("description", [None, "experiment description"]) @@ -900,11 +890,16 @@ def test_list_runs_schema(mocker): assert data == LIST_EXPERIMENT_RUNS_RESPONSE -def test_page_schema(): +def test_page_schema_load(): data = PageSchema().load(PAGE_BODY) assert data == PAGE +def test_page_schema_dump(): + data = PageSchema().dump(PAGE) + assert data == PAGE_BODY + + def test_pagination_schema(): data = PaginationSchema().load(PAGINATION_BODY) assert data == PAGINATION @@ -942,85 +937,46 @@ def test_restore_experiment_runs_response_schema_invalid(mocker): RestoreExperimentRunsResponseSchema().load({}) -def test_experiment_client_list_runs_all(mocker): - mocker.patch.object( - ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - response_schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" - ) - request_schema_mock = mocker.patch( - "faculty.clients.experiment.QueryRunsSchema" - ) - dump_mock = request_schema_mock.return_value.dump - - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - - response_schema_mock.assert_called_once_with() - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/query".format(PROJECT_ID), - response_schema_mock.return_value, - json=dump_mock.return_value, - ) - - -def test_experiment_client_list_runs_experiments_filter(mocker): - mocker.patch.object( - ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - response_schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" - ) - request_schema_mock = mocker.patch( - "faculty.clients.experiment.QueryRunsSchema" - ) - dump_mock = request_schema_mock.return_value.dump +def test_experiment_client_list_runs(mocker): + mocker.patch.object(ExperimentClient, "query_runs") client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[123, 456]) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - - response_schema_mock.assert_called_once_with() - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/query".format(PROJECT_ID), - response_schema_mock.return_value, - json=dump_mock.return_value, - ) - - -def test_experiment_client_list_runs_page(mocker): - mocker.patch.object( - ExperimentClient, "_post", return_value=LIST_EXPERIMENT_RUNS_RESPONSE - ) - response_schema_mock = mocker.patch( - "faculty.clients.experiment.ListExperimentRunsResponseSchema" + response = client.list_runs( + PROJECT_ID, + experiment_ids=[123, 456], + lifecycle_stage=LifecycleStage.DELETED, + start=20, + limit=10, + ) + + assert response == ExperimentClient.query_runs.return_value + expected_filter = CompoundFilter( + LogicalOperator.AND, + [ + CompoundFilter( + LogicalOperator.OR, + [ + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 123), + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 456), + ], + ), + DeletedAtFilter(ComparisonOperator.DEFINED, True), + ], ) - request_schema_mock = mocker.patch( - "faculty.clients.experiment.QueryRunsSchema" + ExperimentClient.query_runs.assert_called_once_with( + PROJECT_ID, expected_filter, None, 20, 10 ) - dump_mock = request_schema_mock.return_value.dump - client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, start=20, limit=10) - assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - - response_schema_mock.assert_called_once_with() - ExperimentClient._post.assert_called_once_with( - "/project/{}/run/query".format(PROJECT_ID), - response_schema_mock.return_value, - json=dump_mock.return_value, - ) +def test_experiment_client_list_runs_defaults(mocker): + mocker.patch.object(ExperimentClient, "query_runs") -def test_experiment_client_list_runs_experiments_filter_empty(mocker): client = ExperimentClient(mocker.Mock()) - list_result = client.list_runs(PROJECT_ID, experiment_ids=[]) + response = client.list_runs(PROJECT_ID) - assert list_result == ListExperimentRunsResponse( - runs=[], - pagination=Pagination(start=0, size=0, previous=None, next=None), + assert response == ExperimentClient.query_runs.return_value + ExperimentClient.query_runs.assert_called_once_with( + PROJECT_ID, None, None, None, None ) @@ -1031,32 +987,26 @@ def test_experiment_client_query_runs(mocker): response_schema_mock = mocker.patch( "faculty.clients.experiment.ListExperimentRunsResponseSchema" ) - request_schema_mock = mocker.patch( - "faculty.clients.experiment.QueryRunsSchema" - ) - dump_mock = request_schema_mock.return_value.dump + request_dump_mock = mocker.patch.object(RunQuerySchema, "dump") - test_filter = SingleFilter( - SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, "2" - ) - test_sort = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)] + filter = mocker.Mock() + sort = mocker.Mock() client = ExperimentClient(mocker.Mock()) list_result = client.query_runs( - PROJECT_ID, filter=test_filter, sort=test_sort, start=20, limit=10 + PROJECT_ID, filter, sort, start=20, limit=10 ) assert list_result == LIST_EXPERIMENT_RUNS_RESPONSE - request_schema_mock.assert_called_once_with() - dump_mock.assert_called_once_with( - QueryRuns(test_filter, test_sort, Page(20, 10)) + request_dump_mock.assert_called_once_with( + RunQuery(filter, sort, Page(20, 10)) ) response_schema_mock.assert_called_once_with() ExperimentClient._post.assert_called_once_with( "/project/{}/run/query".format(PROJECT_ID), response_schema_mock.return_value, - json=dump_mock.return_value, + json=request_dump_mock.return_value, ) @@ -1279,41 +1229,28 @@ def test_delete_runs(mocker): mocker.patch.object( ExperimentClient, "_post", return_value=DELETE_EXPERIMENT_RUNS_RESPONSE ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.DeleteExperimentRunsResponseSchema" ) + filter_dump_mock = mocker.patch.object(FilterSchema, "dump") run_ids = [uuid4(), uuid4()] client = ExperimentClient(mocker.Mock()) - assert ( - client.delete_runs(PROJECT_ID, run_ids) - == DELETE_EXPERIMENT_RUNS_RESPONSE - ) - - expected_payload = { - "filter": CompoundFilter( - CompoundFilterOperator.OR, - [ - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_ids[0], - ), - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_ids[1], - ), - ], - ) - } - + response = client.delete_runs(PROJECT_ID, run_ids) + + assert response == DELETE_EXPERIMENT_RUNS_RESPONSE + expected_filter = CompoundFilter( + LogicalOperator.OR, + [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[0]), + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[1]), + ], + ) + filter_dump_mock.assert_called_once_with(expected_filter) ExperimentClient._post.assert_called_once_with( "/project/{}/run/delete/query".format(PROJECT_ID), - schema_mock.return_value, - json=expected_payload, + response_schema_mock.return_value, + json={"filter": filter_dump_mock.return_value}, ) @@ -1334,13 +1271,15 @@ def test_delete_runs_no_run_ids(mocker): def test_delete_runs_empty_list(mocker): + mocker.patch.object(ExperimentClient, "_post") + client = ExperimentClient(mocker.Mock()) + response = client.delete_runs(PROJECT_ID, run_ids=[]) - assert client.delete_runs( - PROJECT_ID, run_ids=[] - ) == DeleteExperimentRunsResponse( + assert response == DeleteExperimentRunsResponse( deleted_run_ids=[], conflicted_run_ids=[] ) + ExperimentClient._post.assert_not_called() def test_restore_runs(mocker): @@ -1349,41 +1288,28 @@ def test_restore_runs(mocker): "_post", return_value=RESTORE_EXPERIMENT_RUNS_RESPONSE, ) - schema_mock = mocker.patch( + response_schema_mock = mocker.patch( "faculty.clients.experiment.RestoreExperimentRunsResponseSchema" ) + filter_dump_mock = mocker.patch.object(FilterSchema, "dump") run_ids = [uuid4(), uuid4()] client = ExperimentClient(mocker.Mock()) - assert ( - client.restore_runs(PROJECT_ID, run_ids) - == RESTORE_EXPERIMENT_RUNS_RESPONSE - ) - - expected_payload = { - "filter": CompoundFilter( - CompoundFilterOperator.OR, - [ - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_ids[0], - ), - SingleFilter( - SingleFilterBy.RUN_ID, - None, - SingleFilterOperator.EQUAL_TO, - run_ids[1], - ), - ], - ) - } - + response = client.restore_runs(PROJECT_ID, run_ids) + + assert response == RESTORE_EXPERIMENT_RUNS_RESPONSE + expected_filter = CompoundFilter( + LogicalOperator.OR, + [ + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[0]), + RunIdFilter(ComparisonOperator.EQUAL_TO, run_ids[1]), + ], + ) + filter_dump_mock.assert_called_once_with(expected_filter) ExperimentClient._post.assert_called_once_with( "/project/{}/run/restore/query".format(PROJECT_ID), - schema_mock.return_value, - json=expected_payload, + response_schema_mock.return_value, + json={"filter": filter_dump_mock.return_value}, ) @@ -1404,10 +1330,12 @@ def test_restore_runs_no_run_ids(mocker): def test_restore_runs_empty_list(mocker): + mocker.patch.object(ExperimentClient, "_post") + client = ExperimentClient(mocker.Mock()) + response = client.restore_runs(PROJECT_ID, run_ids=[]) - assert client.restore_runs( - PROJECT_ID, run_ids=[] - ) == RestoreExperimentRunsResponse( + assert response == RestoreExperimentRunsResponse( restored_run_ids=[], conflicted_run_ids=[] ) + ExperimentClient._post.assert_not_called() From 305efd2e7f2e0a09188f3ed35f9775323e900651 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 14:10:54 +0100 Subject: [PATCH 49/52] Fix test names --- tests/clients/test_experiment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/clients/test_experiment.py b/tests/clients/test_experiment.py index d6db90b0..3a3eb8ff 100644 --- a/tests/clients/test_experiment.py +++ b/tests/clients/test_experiment.py @@ -624,7 +624,7 @@ def test_filter_schema_nested(): @pytest.mark.parametrize( "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] ) -def test_sort_schema_no_tag(sort_type, by, order, expected_order): +def test_sort_schema_no_key(sort_type, by, order, expected_order): sort = sort_type(order) data = SortSchema().dump(sort) assert data == {"by": by, "order": expected_order} @@ -637,7 +637,7 @@ def test_sort_schema_no_tag(sort_type, by, order, expected_order): @pytest.mark.parametrize( "order, expected_order", [(SortOrder.ASC, "asc"), (SortOrder.DESC, "desc")] ) -def test_sort_schema_with_tag(sort_type, by, order, expected_order): +def test_sort_schema_with_key(sort_type, by, order, expected_order): sort = sort_type("sort-key", order) data = SortSchema().dump(sort) assert data == {"by": by, "key": "sort-key", "order": expected_order} From 9654c5c1d4d06c024443fee5be4fd14d23a922c4 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 16:29:14 +0100 Subject: [PATCH 50/52] Use beta version on Python 2.7 / 3.4 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1e427358..1e04d352 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ # compatible version of python-dateutil is available "marshmallow[reco]>=3.0.0rc3", "marshmallow_enum", - "marshmallow-oneofschema", + "marshmallow-oneofschema>=2.0.0b2", "boto3", "botocore", ], From 9b45e8c26cd426d6aff884125f82f11f08d228b6 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 16:29:38 +0100 Subject: [PATCH 51/52] Fix Python 2 compatability --- faculty/clients/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faculty/clients/experiment.py b/faculty/clients/experiment.py index cd243803..460c8a4e 100644 --- a/faculty/clients/experiment.py +++ b/faculty/clients/experiment.py @@ -318,7 +318,7 @@ class OptionalField(fields.Field): def __init__(self, nested, *args, **kwargs): self.nested = nested - super().__init__(*args, **kwargs) + super(OptionalField, self).__init__(*args, **kwargs) def _deserialize(self, value, *args, **kwargs): if value is None: From e420fb4c77517f8d516fc4b6edb57b59d731e876 Mon Sep 17 00:00:00 2001 From: Andrew Crozier Date: Wed, 12 Jun 2019 17:00:12 +0100 Subject: [PATCH 52/52] Pin oneofschema version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1e04d352..8d4d473f 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ # compatible version of python-dateutil is available "marshmallow[reco]>=3.0.0rc3", "marshmallow_enum", - "marshmallow-oneofschema>=2.0.0b2", + "marshmallow-oneofschema==2.0.0b2", "boto3", "botocore", ],